diff --git a/examples/jsonExtract.ts b/examples/jsonExtract.ts index 496147228ed02294628950cd4bc1824f94587386..68af23c1dfaed3b38274712382ffd19b50297f29 100644 --- a/examples/jsonExtract.ts +++ b/examples/jsonExtract.ts @@ -36,9 +36,7 @@ async function main() { ], }); - const json = JSON.parse(response.message.content); - - console.log(json); + console.log(response.message.content); } main().catch(console.error); diff --git a/examples/recipes/cost-analysis.ts b/examples/recipes/cost-analysis.ts index 1070118b059a1a00ed5f0b501ac8efac198527d5..cf8d102b42aa02e1cbb34da30028860a6ed6cd64 100644 --- a/examples/recipes/cost-analysis.ts +++ b/examples/recipes/cost-analysis.ts @@ -1,6 +1,7 @@ import { encodingForModel } from "js-tiktoken"; import { OpenAI } from "llamaindex"; import { Settings } from "llamaindex/Settings"; +import { extractText } from "llamaindex/llm/utils"; const encoding = encodingForModel("gpt-4-0125-preview"); @@ -13,7 +14,7 @@ let tokenCount = 0; Settings.callbackManager.on("llm-start", (event) => { const { messages } = event.detail.payload; tokenCount += messages.reduce((count, message) => { - return count + encoding.encode(message.content).length; + return count + encoding.encode(extractText(message.content)).length; }, 0); console.log("Token count:", tokenCount); // https://openai.com/pricing @@ -22,7 +23,7 @@ Settings.callbackManager.on("llm-start", (event) => { }); Settings.callbackManager.on("llm-end", (event) => { const { response } = event.detail.payload; - tokenCount += encoding.encode(response.message.content).length; + tokenCount += encoding.encode(extractText(response.message.content)).length; console.log("Token count:", tokenCount); // https://openai.com/pricing // $30.00 / 1M tokens diff --git a/packages/core/src/ChatHistory.ts b/packages/core/src/ChatHistory.ts index dae76cd5dd733f5268c7098edc732f2357e75983..1b2a04a957aa01e87d2d549a875980514b4d650c 100644 --- a/packages/core/src/ChatHistory.ts +++ b/packages/core/src/ChatHistory.ts @@ -3,6 +3,7 @@ import type { SummaryPrompt } from "./Prompt.js"; import { defaultSummaryPrompt, messagesToHistoryStr } from "./Prompt.js"; import { OpenAI } from "./llm/open_ai.js"; import type { ChatMessage, LLM, MessageType } from "./llm/types.js"; +import { extractText } from "./llm/utils.js"; /** * A ChatHistory is used to keep the state of back and forth chat messages @@ -188,7 +189,8 @@ export class SummaryChatHistory extends ChatHistory { // get tokens of current request messages and the transient messages const tokens = requestMessages.reduce( - (count, message) => count + this.tokenizer(message.content).length, + (count, message) => + count + this.tokenizer(extractText(message.content)).length, 0, ); if (tokens > this.tokensToSummarize) { diff --git a/packages/core/src/agent/openai/worker.ts b/packages/core/src/agent/openai/worker.ts index 2c48cc67c75be2023833a8bdbfde5804b81cb78e..d727b248291c0181f253a1098a1d1ef5f09e7492 100644 --- a/packages/core/src/agent/openai/worker.ts +++ b/packages/core/src/agent/openai/worker.ts @@ -15,7 +15,11 @@ import { type LLMChatParamsBase, type OpenAIAdditionalChatOptions, } from "../../llm/index.js"; -import { streamConverter, streamReducer } from "../../llm/utils.js"; +import { + extractText, + streamConverter, + streamReducer, +} from "../../llm/utils.js"; import { ChatMemoryBuffer } from "../../memory/ChatMemoryBuffer.js"; import type { ObjectRetriever } from "../../objects/base.js"; import type { ToolOutput } from "../../tools/types.js"; @@ -162,7 +166,10 @@ export class OpenAIAgentWorker ): AgentChatResponse { task.extraState.newMemory.put(aiMessage); - return new AgentChatResponse(aiMessage.content, task.extraState.sources); + return new AgentChatResponse( + extractText(aiMessage.content), + task.extraState.sources, + ); } private async _getStreamAiResponse( diff --git a/packages/core/src/agent/react/types.ts b/packages/core/src/agent/react/types.ts index 185ec378d439f9df29381d3430da7b5b4b238d2c..e8849275f3c5e3e9e889c98fe8668e0e232dc69c 100644 --- a/packages/core/src/agent/react/types.ts +++ b/packages/core/src/agent/react/types.ts @@ -1,4 +1,5 @@ import type { ChatMessage } from "../../llm/index.js"; +import { extractText } from "../../llm/utils.js"; export interface BaseReasoningStep { getContent(): string; @@ -51,10 +52,12 @@ export abstract class BaseOutputParser { formatMessages(messages: ChatMessage[]): ChatMessage[] { if (messages) { if (messages[0].role === "system") { - messages[0].content = this.format(messages[0].content || ""); + messages[0].content = this.format( + extractText(messages[0].content) || "", + ); } else { messages[messages.length - 1].content = this.format( - messages[messages.length - 1].content || "", + extractText(messages[messages.length - 1].content) || "", ); } } diff --git a/packages/core/src/agent/react/worker.ts b/packages/core/src/agent/react/worker.ts index fcd1252d37c2903f9cbc0991ae051cf2a01343c8..96af95787e470838a71dad7fc43928cc42ac5baa 100644 --- a/packages/core/src/agent/react/worker.ts +++ b/packages/core/src/agent/react/worker.ts @@ -3,6 +3,7 @@ import type { ChatMessage } from "cohere-ai/api"; import { Settings } from "../../Settings.js"; import { AgentChatResponse } from "../../engines/chat/index.js"; import { type ChatResponse, type LLM } from "../../llm/index.js"; +import { extractText } from "../../llm/utils.js"; import { ChatMemoryBuffer } from "../../memory/ChatMemoryBuffer.js"; import type { ObjectRetriever } from "../../objects/base.js"; import { ToolOutput } from "../../tools/index.js"; @@ -34,7 +35,7 @@ function addUserStepToReasoning( ): void { if (step.stepState.isFirst) { memory.put({ - content: step.input, + content: step.input ?? "", role: "user", }); step.stepState.isFirst = false; @@ -130,7 +131,7 @@ export class ReActAgentWorker implements AgentWorker<ChatParams> { try { reasoningStep = this.outputParser.parse( - messageContent, + extractText(messageContent), isStreaming, ) as ActionReasoningStep; } catch (e) { @@ -144,7 +145,7 @@ export class ReActAgentWorker implements AgentWorker<ChatParams> { currentReasoning.push(reasoningStep); if (reasoningStep.isDone()) { - return [messageContent, currentReasoning, true]; + return [extractText(messageContent), currentReasoning, true]; } const actionReasoningStep = new ActionReasoningStep({ @@ -157,7 +158,7 @@ export class ReActAgentWorker implements AgentWorker<ChatParams> { throw new Error(`Expected ActionReasoningStep, got ${reasoningStep}`); } - return [messageContent, currentReasoning, false]; + return [extractText(messageContent), currentReasoning, false]; } async _processActions( diff --git a/packages/core/src/engines/chat/ContextChatEngine.ts b/packages/core/src/engines/chat/ContextChatEngine.ts index 9dc1400179a4dcfc0f3a081a6d2d2b1fcea8f062..8586836df591470fbb2dcd8ee69420d73cf880d3 100644 --- a/packages/core/src/engines/chat/ContextChatEngine.ts +++ b/packages/core/src/engines/chat/ContextChatEngine.ts @@ -93,7 +93,10 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine { messages: requestMessages.messages, }); chatHistory.addMessage(response.message); - return new Response(response.message.content, requestMessages.nodes); + return new Response( + extractText(response.message.content), + requestMessages.nodes, + ); } reset() { diff --git a/packages/core/src/engines/chat/SimpleChatEngine.ts b/packages/core/src/engines/chat/SimpleChatEngine.ts index 3494186c5d1efe7ee2bf0d0f50718d2b047671b3..e57ce7fa97a2b97d279aefaf1d991fe8a057d67f 100644 --- a/packages/core/src/engines/chat/SimpleChatEngine.ts +++ b/packages/core/src/engines/chat/SimpleChatEngine.ts @@ -4,7 +4,11 @@ import { Response } from "../../Response.js"; import { wrapEventCaller } from "../../internal/context/EventCaller.js"; import type { ChatResponseChunk, LLM } from "../../llm/index.js"; import { OpenAI } from "../../llm/index.js"; -import { streamConverter, streamReducer } from "../../llm/utils.js"; +import { + extractText, + streamConverter, + streamReducer, +} from "../../llm/utils.js"; import type { ChatEngine, ChatEngineParamsNonStreaming, @@ -46,7 +50,7 @@ export class SimpleChatEngine implements ChatEngine { streamReducer({ stream, initialValue: "", - reducer: (accumulator, part) => (accumulator += part.delta), + reducer: (accumulator, part) => accumulator + part.delta, finished: (accumulator) => { chatHistory.addMessage({ content: accumulator, role: "assistant" }); }, @@ -59,7 +63,7 @@ export class SimpleChatEngine implements ChatEngine { messages: await chatHistory.requestMessages(), }); chatHistory.addMessage(response.message); - return new Response(response.message.content); + return new Response(extractText(response.message.content)); } reset() { diff --git a/packages/core/src/evaluation/Correctness.ts b/packages/core/src/evaluation/Correctness.ts index 1354e83f9f8008020a3bbaa4654120289a8d1f1a..5f4269327458d4f6e24306b4c8110c8410f147a9 100644 --- a/packages/core/src/evaluation/Correctness.ts +++ b/packages/core/src/evaluation/Correctness.ts @@ -2,6 +2,7 @@ import { MetadataMode } from "../Node.js"; import type { ServiceContext } from "../ServiceContext.js"; import { llmFromSettingsOrContext } from "../Settings.js"; import type { ChatMessage, LLM } from "../llm/types.js"; +import { extractText } from "../llm/utils.js"; import { PromptMixin } from "../prompts/Mixin.js"; import type { CorrectnessSystemPrompt } from "./prompts.js"; import { @@ -85,7 +86,7 @@ export class CorrectnessEvaluator extends PromptMixin implements BaseEvaluator { }); const [score, reasoning] = this.parserFunction( - evalResponse.message.content, + extractText(evalResponse.message.content), ); return { diff --git a/packages/core/src/llm/LLM.ts b/packages/core/src/llm/LLM.ts index 56f8bc489a00faa5d50e14b26f0e7762aad53065..0af646d2e71bdf707d3c095d5dd0629ef5da7479 100644 --- a/packages/core/src/llm/LLM.ts +++ b/packages/core/src/llm/LLM.ts @@ -15,7 +15,7 @@ import type { LLMMetadata, MessageType, } from "./types.js"; -import { wrapLLMEvent } from "./utils.js"; +import { extractText, wrapLLMEvent } from "./utils.js"; export const ALL_AVAILABLE_LLAMADEUCE_MODELS = { "Llama-2-70b-chat-old": { @@ -215,16 +215,15 @@ If a question does not make any sense, or is not factually coherent, explain why return { prompt: messages.reduce((acc, message, index) => { + const content = extractText(message.content); if (index % 2 === 0) { return ( - `${acc}${ - withBos ? BOS : "" - }${B_INST} ${message.content.trim()} ${E_INST}` + + `${acc}${withBos ? BOS : ""}${B_INST} ${content.trim()} ${E_INST}` + (withNewlines ? "\n" : "") ); } else { return ( - `${acc} ${message.content.trim()}` + + `${acc} ${content.trim()}` + (withNewlines ? "\n" : " ") + (withBos ? EOS : "") ); // Yes, the EOS comes after the space. This is not a mistake. @@ -322,7 +321,10 @@ export class Portkey extends BaseLLM { } else { const bodyParams = additionalChatOptions || {}; const response = await this.session.portkey.chatCompletions.create({ - messages, + messages: messages.map((message) => ({ + content: extractText(message.content), + role: message.role, + })), ...bodyParams, }); @@ -337,7 +339,10 @@ export class Portkey extends BaseLLM { params?: Record<string, any>, ): AsyncIterable<ChatResponseChunk> { const chunkStream = await this.session.portkey.chatCompletions.create({ - messages, + messages: messages.map((message) => ({ + content: extractText(message.content), + role: message.role, + })), ...params, stream: true, }); diff --git a/packages/core/src/llm/anthropic.ts b/packages/core/src/llm/anthropic.ts index 40ae46e133ca0efccf0aeffa2662f25f45bf48c1..3fffecee131427798d84001997870c0a62913e2e 100644 --- a/packages/core/src/llm/anthropic.ts +++ b/packages/core/src/llm/anthropic.ts @@ -10,7 +10,7 @@ import type { } from "llamaindex"; import _ from "lodash"; import { BaseLLM } from "./base.js"; -import { wrapLLMEvent } from "./utils.js"; +import { extractText, wrapLLMEvent } from "./utils.js"; export class AnthropicSession { anthropic: SDKAnthropic; @@ -138,7 +138,7 @@ export class Anthropic extends BaseLLM { } return { - content: message.content, + content: extractText(message.content), role: message.role, }; }); diff --git a/packages/core/src/llm/base.ts b/packages/core/src/llm/base.ts index d67cdbb5b73c82c8adb75ab5a101a1f70402bec2..8e1a4ec382457ff294f08003a2b89bdff4010771 100644 --- a/packages/core/src/llm/base.ts +++ b/packages/core/src/llm/base.ts @@ -9,7 +9,7 @@ import type { LLMCompletionParamsStreaming, LLMMetadata, } from "./types.js"; -import { streamConverter } from "./utils.js"; +import { extractText, streamConverter } from "./utils.js"; export abstract class BaseLLM< AdditionalChatOptions extends Record<string, unknown> = Record< @@ -44,7 +44,10 @@ export abstract class BaseLLM< const chatResponse = await this.chat({ messages: [{ content: prompt, role: "user" }], }); - return { text: chatResponse.message.content as string }; + return { + text: extractText(chatResponse.message.content), + raw: chatResponse.raw, + }; } abstract chat( diff --git a/packages/core/src/llm/open_ai.ts b/packages/core/src/llm/open_ai.ts index 7805f4951207c3716dcabf003792b29911f6fc1c..ffc5a176b55f32a9c98ed18f4aeaff6a283b01e9 100644 --- a/packages/core/src/llm/open_ai.ts +++ b/packages/core/src/llm/open_ai.ts @@ -308,7 +308,7 @@ export class OpenAI extends BaseLLM<OpenAIAdditionalChatOptions> { stream: false, }); - const content = response.choices[0].message?.content ?? null; + const content = response.choices[0].message?.content ?? ""; const kwargsOutput: Record<string, any> = {}; diff --git a/packages/core/src/llm/types.ts b/packages/core/src/llm/types.ts index 626183b1122d3b40b70e480d56a0c7efdc1318cf..8abf6548041f5d98e687371fa83d7e8f2ddc3de0 100644 --- a/packages/core/src/llm/types.ts +++ b/packages/core/src/llm/types.ts @@ -75,8 +75,7 @@ export type MessageType = | "tool"; export interface ChatMessage { - // TODO: use MessageContent - content: any; + content: MessageContent; role: MessageType; additionalKwargs?: Record<string, any>; } @@ -137,7 +136,7 @@ export interface LLMChatParamsNonStreaming< } export interface LLMCompletionParamsBase { - prompt: any; + prompt: MessageContent; } export interface LLMCompletionParamsStreaming extends LLMCompletionParamsBase { @@ -149,11 +148,19 @@ export interface LLMCompletionParamsNonStreaming stream?: false | null; } -export interface MessageContentDetail { - type: "text" | "image_url"; - text?: string; - image_url?: { url: string }; -} +export type MessageContentTextDetail = { + type: "text"; + text: string; +}; + +export type MessageContentImageDetail = { + type: "image_url"; + image_url: { url: string }; +}; + +export type MessageContentDetail = + | MessageContentTextDetail + | MessageContentImageDetail; /** * Extended type for the content of a message that allows for multi-modal messages. diff --git a/packages/core/src/llm/utils.ts b/packages/core/src/llm/utils.ts index 03725ad5e845592e69ef7656ca496048645d0fc5..2fb6267087149bdf1f813aec0a92ebfa905de80f 100644 --- a/packages/core/src/llm/utils.ts +++ b/packages/core/src/llm/utils.ts @@ -1,6 +1,12 @@ import { AsyncLocalStorage } from "@llamaindex/env"; import { getCallbackManager } from "../internal/settings/CallbackManager.js"; -import type { ChatResponse, LLM, LLMChat, MessageContent } from "./types.js"; +import type { + ChatResponse, + LLM, + LLMChat, + MessageContent, + MessageContentTextDetail, +} from "./types.js"; export async function* streamConverter<S, D>( stream: AsyncIterable<S>, @@ -15,7 +21,7 @@ export async function* streamReducer<S, D>(params: { stream: AsyncIterable<S>; reducer: (previousValue: D, currentValue: S) => D; initialValue: D; - finished?: (value: D | undefined) => void; + finished?: (value: D) => void; }): AsyncIterable<S> { let value = params.initialValue; for await (const data of params.stream) { @@ -26,23 +32,29 @@ export async function* streamReducer<S, D>(params: { params.finished(value); } } + /** * Extracts just the text from a multi-modal message or the message itself if it's just text. * * @param message The message to extract text from. * @returns The extracted text */ - export function extractText(message: MessageContent): string { - if (Array.isArray(message)) { + if (typeof message !== "string" && !Array.isArray(message)) { + console.warn( + "extractText called with non-string message, this is likely a bug.", + ); + return `${message}`; + } else if (typeof message !== "string" && Array.isArray(message)) { // message is of type MessageContentDetail[] - retrieve just the text parts and concatenate them // so we can pass them to the context generator return message - .filter((c) => c.type === "text") + .filter((c): c is MessageContentTextDetail => c.type === "text") .map((c) => c.text) .join("\n\n"); + } else { + return message; } - return message; } /** diff --git a/packages/core/src/selectors/llmSelectors.ts b/packages/core/src/selectors/llmSelectors.ts index 654740fb0d763032683d30c46f781480f42f8d9a..5f7349f26a553fbe83b906c116fd94141aa0b779 100644 --- a/packages/core/src/selectors/llmSelectors.ts +++ b/packages/core/src/selectors/llmSelectors.ts @@ -48,7 +48,7 @@ export class LLMMultiSelector extends BaseSelector { llm: LLMPredictorType; prompt: MultiSelectPrompt; maxOutputs: number; - outputParser: BaseOutputParser<StructuredOutput<Answer[]>> | null; + outputParser: BaseOutputParser<StructuredOutput<Answer[]>>; constructor(init: { llm: LLMPredictorType; @@ -118,7 +118,7 @@ export class LLMMultiSelector extends BaseSelector { export class LLMSingleSelector extends BaseSelector { llm: LLMPredictorType; prompt: SingleSelectPrompt; - outputParser: BaseOutputParser<StructuredOutput<Answer[]>> | null; + outputParser: BaseOutputParser<StructuredOutput<Answer[]>>; constructor(init: { llm: LLMPredictorType; @@ -154,7 +154,7 @@ export class LLMSingleSelector extends BaseSelector { const prompt = this.prompt(choicesText.length, choicesText, query.queryStr); - const formattedPrompt = this.outputParser?.format(prompt); + const formattedPrompt = this.outputParser.format(prompt); const prediction = await this.llm.complete({ prompt: formattedPrompt, diff --git a/tsconfig.json b/tsconfig.json index 84821fc77d8197deb533571f4aca75d8562290bb..9027b38e4351876bf661b81ef54348a9d93fe8ee 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -11,15 +11,7 @@ "outDir": "./lib", "tsBuildInfoFile": "./lib/.tsbuildinfo", "incremental": true, - "composite": true, - "paths": { - "llamaindex": ["./packages/core/src/index.ts"], - "llamaindex/*": ["./packages/core/src/*.ts"], - "@llamaindex/env": ["./packages/env/src/index.ts"], - "@llamaindex/env/*": ["./packages/env/src/*.ts"], - "@llamaindex/experimental": ["./packages/experimental/src/index.ts"], - "@llamaindex/experimental/*": ["./packages/experimental/src/*.ts"] - } + "composite": true }, "files": [], "references": [