From e85893ac0ffe5fe0ef02155ef1b12943709a5b04 Mon Sep 17 00:00:00 2001 From: Alex Yang <himself65@outlook.com> Date: Sat, 6 Apr 2024 18:59:12 -0500 Subject: [PATCH] fix: message content type (#696) --- examples/jsonExtract.ts | 4 +--- examples/recipes/cost-analysis.ts | 5 ++-- packages/core/src/ChatHistory.ts | 4 +++- packages/core/src/agent/openai/worker.ts | 11 +++++++-- packages/core/src/agent/react/types.ts | 7 ++++-- packages/core/src/agent/react/worker.ts | 9 +++---- .../src/engines/chat/ContextChatEngine.ts | 5 +++- .../core/src/engines/chat/SimpleChatEngine.ts | 10 +++++--- packages/core/src/evaluation/Correctness.ts | 3 ++- packages/core/src/llm/LLM.ts | 19 +++++++++------ packages/core/src/llm/anthropic.ts | 4 ++-- packages/core/src/llm/base.ts | 7 ++++-- packages/core/src/llm/open_ai.ts | 2 +- packages/core/src/llm/types.ts | 23 +++++++++++------- packages/core/src/llm/utils.ts | 24 ++++++++++++++----- packages/core/src/selectors/llmSelectors.ts | 6 ++--- tsconfig.json | 10 +------- 17 files changed, 96 insertions(+), 57 deletions(-) diff --git a/examples/jsonExtract.ts b/examples/jsonExtract.ts index 496147228..68af23c1d 100644 --- a/examples/jsonExtract.ts +++ b/examples/jsonExtract.ts @@ -36,9 +36,7 @@ async function main() { ], }); - const json = JSON.parse(response.message.content); - - console.log(json); + console.log(response.message.content); } main().catch(console.error); diff --git a/examples/recipes/cost-analysis.ts b/examples/recipes/cost-analysis.ts index 1070118b0..cf8d102b4 100644 --- a/examples/recipes/cost-analysis.ts +++ b/examples/recipes/cost-analysis.ts @@ -1,6 +1,7 @@ import { encodingForModel } from "js-tiktoken"; import { OpenAI } from "llamaindex"; import { Settings } from "llamaindex/Settings"; +import { extractText } from "llamaindex/llm/utils"; const encoding = encodingForModel("gpt-4-0125-preview"); @@ -13,7 +14,7 @@ let tokenCount = 0; Settings.callbackManager.on("llm-start", (event) => { const { messages } = event.detail.payload; tokenCount += messages.reduce((count, message) => { - return count + encoding.encode(message.content).length; + return count + encoding.encode(extractText(message.content)).length; }, 0); console.log("Token count:", tokenCount); // https://openai.com/pricing @@ -22,7 +23,7 @@ Settings.callbackManager.on("llm-start", (event) => { }); Settings.callbackManager.on("llm-end", (event) => { const { response } = event.detail.payload; - tokenCount += encoding.encode(response.message.content).length; + tokenCount += encoding.encode(extractText(response.message.content)).length; console.log("Token count:", tokenCount); // https://openai.com/pricing // $30.00 / 1M tokens diff --git a/packages/core/src/ChatHistory.ts b/packages/core/src/ChatHistory.ts index dae76cd5d..1b2a04a95 100644 --- a/packages/core/src/ChatHistory.ts +++ b/packages/core/src/ChatHistory.ts @@ -3,6 +3,7 @@ import type { SummaryPrompt } from "./Prompt.js"; import { defaultSummaryPrompt, messagesToHistoryStr } from "./Prompt.js"; import { OpenAI } from "./llm/open_ai.js"; import type { ChatMessage, LLM, MessageType } from "./llm/types.js"; +import { extractText } from "./llm/utils.js"; /** * A ChatHistory is used to keep the state of back and forth chat messages @@ -188,7 +189,8 @@ export class SummaryChatHistory extends ChatHistory { // get tokens of current request messages and the transient messages const tokens = requestMessages.reduce( - (count, message) => count + this.tokenizer(message.content).length, + (count, message) => + count + this.tokenizer(extractText(message.content)).length, 0, ); if (tokens > this.tokensToSummarize) { diff --git a/packages/core/src/agent/openai/worker.ts b/packages/core/src/agent/openai/worker.ts index 2c48cc67c..d727b2482 100644 --- a/packages/core/src/agent/openai/worker.ts +++ b/packages/core/src/agent/openai/worker.ts @@ -15,7 +15,11 @@ import { type LLMChatParamsBase, type OpenAIAdditionalChatOptions, } from "../../llm/index.js"; -import { streamConverter, streamReducer } from "../../llm/utils.js"; +import { + extractText, + streamConverter, + streamReducer, +} from "../../llm/utils.js"; import { ChatMemoryBuffer } from "../../memory/ChatMemoryBuffer.js"; import type { ObjectRetriever } from "../../objects/base.js"; import type { ToolOutput } from "../../tools/types.js"; @@ -162,7 +166,10 @@ export class OpenAIAgentWorker ): AgentChatResponse { task.extraState.newMemory.put(aiMessage); - return new AgentChatResponse(aiMessage.content, task.extraState.sources); + return new AgentChatResponse( + extractText(aiMessage.content), + task.extraState.sources, + ); } private async _getStreamAiResponse( diff --git a/packages/core/src/agent/react/types.ts b/packages/core/src/agent/react/types.ts index 185ec378d..e8849275f 100644 --- a/packages/core/src/agent/react/types.ts +++ b/packages/core/src/agent/react/types.ts @@ -1,4 +1,5 @@ import type { ChatMessage } from "../../llm/index.js"; +import { extractText } from "../../llm/utils.js"; export interface BaseReasoningStep { getContent(): string; @@ -51,10 +52,12 @@ export abstract class BaseOutputParser { formatMessages(messages: ChatMessage[]): ChatMessage[] { if (messages) { if (messages[0].role === "system") { - messages[0].content = this.format(messages[0].content || ""); + messages[0].content = this.format( + extractText(messages[0].content) || "", + ); } else { messages[messages.length - 1].content = this.format( - messages[messages.length - 1].content || "", + extractText(messages[messages.length - 1].content) || "", ); } } diff --git a/packages/core/src/agent/react/worker.ts b/packages/core/src/agent/react/worker.ts index fcd1252d3..96af95787 100644 --- a/packages/core/src/agent/react/worker.ts +++ b/packages/core/src/agent/react/worker.ts @@ -3,6 +3,7 @@ import type { ChatMessage } from "cohere-ai/api"; import { Settings } from "../../Settings.js"; import { AgentChatResponse } from "../../engines/chat/index.js"; import { type ChatResponse, type LLM } from "../../llm/index.js"; +import { extractText } from "../../llm/utils.js"; import { ChatMemoryBuffer } from "../../memory/ChatMemoryBuffer.js"; import type { ObjectRetriever } from "../../objects/base.js"; import { ToolOutput } from "../../tools/index.js"; @@ -34,7 +35,7 @@ function addUserStepToReasoning( ): void { if (step.stepState.isFirst) { memory.put({ - content: step.input, + content: step.input ?? "", role: "user", }); step.stepState.isFirst = false; @@ -130,7 +131,7 @@ export class ReActAgentWorker implements AgentWorker<ChatParams> { try { reasoningStep = this.outputParser.parse( - messageContent, + extractText(messageContent), isStreaming, ) as ActionReasoningStep; } catch (e) { @@ -144,7 +145,7 @@ export class ReActAgentWorker implements AgentWorker<ChatParams> { currentReasoning.push(reasoningStep); if (reasoningStep.isDone()) { - return [messageContent, currentReasoning, true]; + return [extractText(messageContent), currentReasoning, true]; } const actionReasoningStep = new ActionReasoningStep({ @@ -157,7 +158,7 @@ export class ReActAgentWorker implements AgentWorker<ChatParams> { throw new Error(`Expected ActionReasoningStep, got ${reasoningStep}`); } - return [messageContent, currentReasoning, false]; + return [extractText(messageContent), currentReasoning, false]; } async _processActions( diff --git a/packages/core/src/engines/chat/ContextChatEngine.ts b/packages/core/src/engines/chat/ContextChatEngine.ts index 9dc140017..8586836df 100644 --- a/packages/core/src/engines/chat/ContextChatEngine.ts +++ b/packages/core/src/engines/chat/ContextChatEngine.ts @@ -93,7 +93,10 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine { messages: requestMessages.messages, }); chatHistory.addMessage(response.message); - return new Response(response.message.content, requestMessages.nodes); + return new Response( + extractText(response.message.content), + requestMessages.nodes, + ); } reset() { diff --git a/packages/core/src/engines/chat/SimpleChatEngine.ts b/packages/core/src/engines/chat/SimpleChatEngine.ts index 3494186c5..e57ce7fa9 100644 --- a/packages/core/src/engines/chat/SimpleChatEngine.ts +++ b/packages/core/src/engines/chat/SimpleChatEngine.ts @@ -4,7 +4,11 @@ import { Response } from "../../Response.js"; import { wrapEventCaller } from "../../internal/context/EventCaller.js"; import type { ChatResponseChunk, LLM } from "../../llm/index.js"; import { OpenAI } from "../../llm/index.js"; -import { streamConverter, streamReducer } from "../../llm/utils.js"; +import { + extractText, + streamConverter, + streamReducer, +} from "../../llm/utils.js"; import type { ChatEngine, ChatEngineParamsNonStreaming, @@ -46,7 +50,7 @@ export class SimpleChatEngine implements ChatEngine { streamReducer({ stream, initialValue: "", - reducer: (accumulator, part) => (accumulator += part.delta), + reducer: (accumulator, part) => accumulator + part.delta, finished: (accumulator) => { chatHistory.addMessage({ content: accumulator, role: "assistant" }); }, @@ -59,7 +63,7 @@ export class SimpleChatEngine implements ChatEngine { messages: await chatHistory.requestMessages(), }); chatHistory.addMessage(response.message); - return new Response(response.message.content); + return new Response(extractText(response.message.content)); } reset() { diff --git a/packages/core/src/evaluation/Correctness.ts b/packages/core/src/evaluation/Correctness.ts index 1354e83f9..5f4269327 100644 --- a/packages/core/src/evaluation/Correctness.ts +++ b/packages/core/src/evaluation/Correctness.ts @@ -2,6 +2,7 @@ import { MetadataMode } from "../Node.js"; import type { ServiceContext } from "../ServiceContext.js"; import { llmFromSettingsOrContext } from "../Settings.js"; import type { ChatMessage, LLM } from "../llm/types.js"; +import { extractText } from "../llm/utils.js"; import { PromptMixin } from "../prompts/Mixin.js"; import type { CorrectnessSystemPrompt } from "./prompts.js"; import { @@ -85,7 +86,7 @@ export class CorrectnessEvaluator extends PromptMixin implements BaseEvaluator { }); const [score, reasoning] = this.parserFunction( - evalResponse.message.content, + extractText(evalResponse.message.content), ); return { diff --git a/packages/core/src/llm/LLM.ts b/packages/core/src/llm/LLM.ts index 56f8bc489..0af646d2e 100644 --- a/packages/core/src/llm/LLM.ts +++ b/packages/core/src/llm/LLM.ts @@ -15,7 +15,7 @@ import type { LLMMetadata, MessageType, } from "./types.js"; -import { wrapLLMEvent } from "./utils.js"; +import { extractText, wrapLLMEvent } from "./utils.js"; export const ALL_AVAILABLE_LLAMADEUCE_MODELS = { "Llama-2-70b-chat-old": { @@ -215,16 +215,15 @@ If a question does not make any sense, or is not factually coherent, explain why return { prompt: messages.reduce((acc, message, index) => { + const content = extractText(message.content); if (index % 2 === 0) { return ( - `${acc}${ - withBos ? BOS : "" - }${B_INST} ${message.content.trim()} ${E_INST}` + + `${acc}${withBos ? BOS : ""}${B_INST} ${content.trim()} ${E_INST}` + (withNewlines ? "\n" : "") ); } else { return ( - `${acc} ${message.content.trim()}` + + `${acc} ${content.trim()}` + (withNewlines ? "\n" : " ") + (withBos ? EOS : "") ); // Yes, the EOS comes after the space. This is not a mistake. @@ -322,7 +321,10 @@ export class Portkey extends BaseLLM { } else { const bodyParams = additionalChatOptions || {}; const response = await this.session.portkey.chatCompletions.create({ - messages, + messages: messages.map((message) => ({ + content: extractText(message.content), + role: message.role, + })), ...bodyParams, }); @@ -337,7 +339,10 @@ export class Portkey extends BaseLLM { params?: Record<string, any>, ): AsyncIterable<ChatResponseChunk> { const chunkStream = await this.session.portkey.chatCompletions.create({ - messages, + messages: messages.map((message) => ({ + content: extractText(message.content), + role: message.role, + })), ...params, stream: true, }); diff --git a/packages/core/src/llm/anthropic.ts b/packages/core/src/llm/anthropic.ts index 40ae46e13..3fffecee1 100644 --- a/packages/core/src/llm/anthropic.ts +++ b/packages/core/src/llm/anthropic.ts @@ -10,7 +10,7 @@ import type { } from "llamaindex"; import _ from "lodash"; import { BaseLLM } from "./base.js"; -import { wrapLLMEvent } from "./utils.js"; +import { extractText, wrapLLMEvent } from "./utils.js"; export class AnthropicSession { anthropic: SDKAnthropic; @@ -138,7 +138,7 @@ export class Anthropic extends BaseLLM { } return { - content: message.content, + content: extractText(message.content), role: message.role, }; }); diff --git a/packages/core/src/llm/base.ts b/packages/core/src/llm/base.ts index d67cdbb5b..8e1a4ec38 100644 --- a/packages/core/src/llm/base.ts +++ b/packages/core/src/llm/base.ts @@ -9,7 +9,7 @@ import type { LLMCompletionParamsStreaming, LLMMetadata, } from "./types.js"; -import { streamConverter } from "./utils.js"; +import { extractText, streamConverter } from "./utils.js"; export abstract class BaseLLM< AdditionalChatOptions extends Record<string, unknown> = Record< @@ -44,7 +44,10 @@ export abstract class BaseLLM< const chatResponse = await this.chat({ messages: [{ content: prompt, role: "user" }], }); - return { text: chatResponse.message.content as string }; + return { + text: extractText(chatResponse.message.content), + raw: chatResponse.raw, + }; } abstract chat( diff --git a/packages/core/src/llm/open_ai.ts b/packages/core/src/llm/open_ai.ts index 7805f4951..ffc5a176b 100644 --- a/packages/core/src/llm/open_ai.ts +++ b/packages/core/src/llm/open_ai.ts @@ -308,7 +308,7 @@ export class OpenAI extends BaseLLM<OpenAIAdditionalChatOptions> { stream: false, }); - const content = response.choices[0].message?.content ?? null; + const content = response.choices[0].message?.content ?? ""; const kwargsOutput: Record<string, any> = {}; diff --git a/packages/core/src/llm/types.ts b/packages/core/src/llm/types.ts index 626183b11..8abf65480 100644 --- a/packages/core/src/llm/types.ts +++ b/packages/core/src/llm/types.ts @@ -75,8 +75,7 @@ export type MessageType = | "tool"; export interface ChatMessage { - // TODO: use MessageContent - content: any; + content: MessageContent; role: MessageType; additionalKwargs?: Record<string, any>; } @@ -137,7 +136,7 @@ export interface LLMChatParamsNonStreaming< } export interface LLMCompletionParamsBase { - prompt: any; + prompt: MessageContent; } export interface LLMCompletionParamsStreaming extends LLMCompletionParamsBase { @@ -149,11 +148,19 @@ export interface LLMCompletionParamsNonStreaming stream?: false | null; } -export interface MessageContentDetail { - type: "text" | "image_url"; - text?: string; - image_url?: { url: string }; -} +export type MessageContentTextDetail = { + type: "text"; + text: string; +}; + +export type MessageContentImageDetail = { + type: "image_url"; + image_url: { url: string }; +}; + +export type MessageContentDetail = + | MessageContentTextDetail + | MessageContentImageDetail; /** * Extended type for the content of a message that allows for multi-modal messages. diff --git a/packages/core/src/llm/utils.ts b/packages/core/src/llm/utils.ts index 03725ad5e..2fb626708 100644 --- a/packages/core/src/llm/utils.ts +++ b/packages/core/src/llm/utils.ts @@ -1,6 +1,12 @@ import { AsyncLocalStorage } from "@llamaindex/env"; import { getCallbackManager } from "../internal/settings/CallbackManager.js"; -import type { ChatResponse, LLM, LLMChat, MessageContent } from "./types.js"; +import type { + ChatResponse, + LLM, + LLMChat, + MessageContent, + MessageContentTextDetail, +} from "./types.js"; export async function* streamConverter<S, D>( stream: AsyncIterable<S>, @@ -15,7 +21,7 @@ export async function* streamReducer<S, D>(params: { stream: AsyncIterable<S>; reducer: (previousValue: D, currentValue: S) => D; initialValue: D; - finished?: (value: D | undefined) => void; + finished?: (value: D) => void; }): AsyncIterable<S> { let value = params.initialValue; for await (const data of params.stream) { @@ -26,23 +32,29 @@ export async function* streamReducer<S, D>(params: { params.finished(value); } } + /** * Extracts just the text from a multi-modal message or the message itself if it's just text. * * @param message The message to extract text from. * @returns The extracted text */ - export function extractText(message: MessageContent): string { - if (Array.isArray(message)) { + if (typeof message !== "string" && !Array.isArray(message)) { + console.warn( + "extractText called with non-string message, this is likely a bug.", + ); + return `${message}`; + } else if (typeof message !== "string" && Array.isArray(message)) { // message is of type MessageContentDetail[] - retrieve just the text parts and concatenate them // so we can pass them to the context generator return message - .filter((c) => c.type === "text") + .filter((c): c is MessageContentTextDetail => c.type === "text") .map((c) => c.text) .join("\n\n"); + } else { + return message; } - return message; } /** diff --git a/packages/core/src/selectors/llmSelectors.ts b/packages/core/src/selectors/llmSelectors.ts index 654740fb0..5f7349f26 100644 --- a/packages/core/src/selectors/llmSelectors.ts +++ b/packages/core/src/selectors/llmSelectors.ts @@ -48,7 +48,7 @@ export class LLMMultiSelector extends BaseSelector { llm: LLMPredictorType; prompt: MultiSelectPrompt; maxOutputs: number; - outputParser: BaseOutputParser<StructuredOutput<Answer[]>> | null; + outputParser: BaseOutputParser<StructuredOutput<Answer[]>>; constructor(init: { llm: LLMPredictorType; @@ -118,7 +118,7 @@ export class LLMMultiSelector extends BaseSelector { export class LLMSingleSelector extends BaseSelector { llm: LLMPredictorType; prompt: SingleSelectPrompt; - outputParser: BaseOutputParser<StructuredOutput<Answer[]>> | null; + outputParser: BaseOutputParser<StructuredOutput<Answer[]>>; constructor(init: { llm: LLMPredictorType; @@ -154,7 +154,7 @@ export class LLMSingleSelector extends BaseSelector { const prompt = this.prompt(choicesText.length, choicesText, query.queryStr); - const formattedPrompt = this.outputParser?.format(prompt); + const formattedPrompt = this.outputParser.format(prompt); const prediction = await this.llm.complete({ prompt: formattedPrompt, diff --git a/tsconfig.json b/tsconfig.json index 84821fc77..9027b38e4 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -11,15 +11,7 @@ "outDir": "./lib", "tsBuildInfoFile": "./lib/.tsbuildinfo", "incremental": true, - "composite": true, - "paths": { - "llamaindex": ["./packages/core/src/index.ts"], - "llamaindex/*": ["./packages/core/src/*.ts"], - "@llamaindex/env": ["./packages/env/src/index.ts"], - "@llamaindex/env/*": ["./packages/env/src/*.ts"], - "@llamaindex/experimental": ["./packages/experimental/src/index.ts"], - "@llamaindex/experimental/*": ["./packages/experimental/src/*.ts"] - } + "composite": true }, "files": [], "references": [ -- GitLab