From e85893ac0ffe5fe0ef02155ef1b12943709a5b04 Mon Sep 17 00:00:00 2001
From: Alex Yang <himself65@outlook.com>
Date: Sat, 6 Apr 2024 18:59:12 -0500
Subject: [PATCH] fix: message content type (#696)

---
 examples/jsonExtract.ts                       |  4 +---
 examples/recipes/cost-analysis.ts             |  5 ++--
 packages/core/src/ChatHistory.ts              |  4 +++-
 packages/core/src/agent/openai/worker.ts      | 11 +++++++--
 packages/core/src/agent/react/types.ts        |  7 ++++--
 packages/core/src/agent/react/worker.ts       |  9 +++----
 .../src/engines/chat/ContextChatEngine.ts     |  5 +++-
 .../core/src/engines/chat/SimpleChatEngine.ts | 10 +++++---
 packages/core/src/evaluation/Correctness.ts   |  3 ++-
 packages/core/src/llm/LLM.ts                  | 19 +++++++++------
 packages/core/src/llm/anthropic.ts            |  4 ++--
 packages/core/src/llm/base.ts                 |  7 ++++--
 packages/core/src/llm/open_ai.ts              |  2 +-
 packages/core/src/llm/types.ts                | 23 +++++++++++-------
 packages/core/src/llm/utils.ts                | 24 ++++++++++++++-----
 packages/core/src/selectors/llmSelectors.ts   |  6 ++---
 tsconfig.json                                 | 10 +-------
 17 files changed, 96 insertions(+), 57 deletions(-)

diff --git a/examples/jsonExtract.ts b/examples/jsonExtract.ts
index 496147228..68af23c1d 100644
--- a/examples/jsonExtract.ts
+++ b/examples/jsonExtract.ts
@@ -36,9 +36,7 @@ async function main() {
     ],
   });
 
-  const json = JSON.parse(response.message.content);
-
-  console.log(json);
+  console.log(response.message.content);
 }
 
 main().catch(console.error);
diff --git a/examples/recipes/cost-analysis.ts b/examples/recipes/cost-analysis.ts
index 1070118b0..cf8d102b4 100644
--- a/examples/recipes/cost-analysis.ts
+++ b/examples/recipes/cost-analysis.ts
@@ -1,6 +1,7 @@
 import { encodingForModel } from "js-tiktoken";
 import { OpenAI } from "llamaindex";
 import { Settings } from "llamaindex/Settings";
+import { extractText } from "llamaindex/llm/utils";
 
 const encoding = encodingForModel("gpt-4-0125-preview");
 
@@ -13,7 +14,7 @@ let tokenCount = 0;
 Settings.callbackManager.on("llm-start", (event) => {
   const { messages } = event.detail.payload;
   tokenCount += messages.reduce((count, message) => {
-    return count + encoding.encode(message.content).length;
+    return count + encoding.encode(extractText(message.content)).length;
   }, 0);
   console.log("Token count:", tokenCount);
   // https://openai.com/pricing
@@ -22,7 +23,7 @@ Settings.callbackManager.on("llm-start", (event) => {
 });
 Settings.callbackManager.on("llm-end", (event) => {
   const { response } = event.detail.payload;
-  tokenCount += encoding.encode(response.message.content).length;
+  tokenCount += encoding.encode(extractText(response.message.content)).length;
   console.log("Token count:", tokenCount);
   // https://openai.com/pricing
   // $30.00 / 1M tokens
diff --git a/packages/core/src/ChatHistory.ts b/packages/core/src/ChatHistory.ts
index dae76cd5d..1b2a04a95 100644
--- a/packages/core/src/ChatHistory.ts
+++ b/packages/core/src/ChatHistory.ts
@@ -3,6 +3,7 @@ import type { SummaryPrompt } from "./Prompt.js";
 import { defaultSummaryPrompt, messagesToHistoryStr } from "./Prompt.js";
 import { OpenAI } from "./llm/open_ai.js";
 import type { ChatMessage, LLM, MessageType } from "./llm/types.js";
+import { extractText } from "./llm/utils.js";
 
 /**
  * A ChatHistory is used to keep the state of back and forth chat messages
@@ -188,7 +189,8 @@ export class SummaryChatHistory extends ChatHistory {
 
     // get tokens of current request messages and the transient messages
     const tokens = requestMessages.reduce(
-      (count, message) => count + this.tokenizer(message.content).length,
+      (count, message) =>
+        count + this.tokenizer(extractText(message.content)).length,
       0,
     );
     if (tokens > this.tokensToSummarize) {
diff --git a/packages/core/src/agent/openai/worker.ts b/packages/core/src/agent/openai/worker.ts
index 2c48cc67c..d727b2482 100644
--- a/packages/core/src/agent/openai/worker.ts
+++ b/packages/core/src/agent/openai/worker.ts
@@ -15,7 +15,11 @@ import {
   type LLMChatParamsBase,
   type OpenAIAdditionalChatOptions,
 } from "../../llm/index.js";
-import { streamConverter, streamReducer } from "../../llm/utils.js";
+import {
+  extractText,
+  streamConverter,
+  streamReducer,
+} from "../../llm/utils.js";
 import { ChatMemoryBuffer } from "../../memory/ChatMemoryBuffer.js";
 import type { ObjectRetriever } from "../../objects/base.js";
 import type { ToolOutput } from "../../tools/types.js";
@@ -162,7 +166,10 @@ export class OpenAIAgentWorker
   ): AgentChatResponse {
     task.extraState.newMemory.put(aiMessage);
 
-    return new AgentChatResponse(aiMessage.content, task.extraState.sources);
+    return new AgentChatResponse(
+      extractText(aiMessage.content),
+      task.extraState.sources,
+    );
   }
 
   private async _getStreamAiResponse(
diff --git a/packages/core/src/agent/react/types.ts b/packages/core/src/agent/react/types.ts
index 185ec378d..e8849275f 100644
--- a/packages/core/src/agent/react/types.ts
+++ b/packages/core/src/agent/react/types.ts
@@ -1,4 +1,5 @@
 import type { ChatMessage } from "../../llm/index.js";
+import { extractText } from "../../llm/utils.js";
 
 export interface BaseReasoningStep {
   getContent(): string;
@@ -51,10 +52,12 @@ export abstract class BaseOutputParser {
   formatMessages(messages: ChatMessage[]): ChatMessage[] {
     if (messages) {
       if (messages[0].role === "system") {
-        messages[0].content = this.format(messages[0].content || "");
+        messages[0].content = this.format(
+          extractText(messages[0].content) || "",
+        );
       } else {
         messages[messages.length - 1].content = this.format(
-          messages[messages.length - 1].content || "",
+          extractText(messages[messages.length - 1].content) || "",
         );
       }
     }
diff --git a/packages/core/src/agent/react/worker.ts b/packages/core/src/agent/react/worker.ts
index fcd1252d3..96af95787 100644
--- a/packages/core/src/agent/react/worker.ts
+++ b/packages/core/src/agent/react/worker.ts
@@ -3,6 +3,7 @@ import type { ChatMessage } from "cohere-ai/api";
 import { Settings } from "../../Settings.js";
 import { AgentChatResponse } from "../../engines/chat/index.js";
 import { type ChatResponse, type LLM } from "../../llm/index.js";
+import { extractText } from "../../llm/utils.js";
 import { ChatMemoryBuffer } from "../../memory/ChatMemoryBuffer.js";
 import type { ObjectRetriever } from "../../objects/base.js";
 import { ToolOutput } from "../../tools/index.js";
@@ -34,7 +35,7 @@ function addUserStepToReasoning(
 ): void {
   if (step.stepState.isFirst) {
     memory.put({
-      content: step.input,
+      content: step.input ?? "",
       role: "user",
     });
     step.stepState.isFirst = false;
@@ -130,7 +131,7 @@ export class ReActAgentWorker implements AgentWorker<ChatParams> {
 
     try {
       reasoningStep = this.outputParser.parse(
-        messageContent,
+        extractText(messageContent),
         isStreaming,
       ) as ActionReasoningStep;
     } catch (e) {
@@ -144,7 +145,7 @@ export class ReActAgentWorker implements AgentWorker<ChatParams> {
     currentReasoning.push(reasoningStep);
 
     if (reasoningStep.isDone()) {
-      return [messageContent, currentReasoning, true];
+      return [extractText(messageContent), currentReasoning, true];
     }
 
     const actionReasoningStep = new ActionReasoningStep({
@@ -157,7 +158,7 @@ export class ReActAgentWorker implements AgentWorker<ChatParams> {
       throw new Error(`Expected ActionReasoningStep, got ${reasoningStep}`);
     }
 
-    return [messageContent, currentReasoning, false];
+    return [extractText(messageContent), currentReasoning, false];
   }
 
   async _processActions(
diff --git a/packages/core/src/engines/chat/ContextChatEngine.ts b/packages/core/src/engines/chat/ContextChatEngine.ts
index 9dc140017..8586836df 100644
--- a/packages/core/src/engines/chat/ContextChatEngine.ts
+++ b/packages/core/src/engines/chat/ContextChatEngine.ts
@@ -93,7 +93,10 @@ export class ContextChatEngine extends PromptMixin implements ChatEngine {
       messages: requestMessages.messages,
     });
     chatHistory.addMessage(response.message);
-    return new Response(response.message.content, requestMessages.nodes);
+    return new Response(
+      extractText(response.message.content),
+      requestMessages.nodes,
+    );
   }
 
   reset() {
diff --git a/packages/core/src/engines/chat/SimpleChatEngine.ts b/packages/core/src/engines/chat/SimpleChatEngine.ts
index 3494186c5..e57ce7fa9 100644
--- a/packages/core/src/engines/chat/SimpleChatEngine.ts
+++ b/packages/core/src/engines/chat/SimpleChatEngine.ts
@@ -4,7 +4,11 @@ import { Response } from "../../Response.js";
 import { wrapEventCaller } from "../../internal/context/EventCaller.js";
 import type { ChatResponseChunk, LLM } from "../../llm/index.js";
 import { OpenAI } from "../../llm/index.js";
-import { streamConverter, streamReducer } from "../../llm/utils.js";
+import {
+  extractText,
+  streamConverter,
+  streamReducer,
+} from "../../llm/utils.js";
 import type {
   ChatEngine,
   ChatEngineParamsNonStreaming,
@@ -46,7 +50,7 @@ export class SimpleChatEngine implements ChatEngine {
         streamReducer({
           stream,
           initialValue: "",
-          reducer: (accumulator, part) => (accumulator += part.delta),
+          reducer: (accumulator, part) => accumulator + part.delta,
           finished: (accumulator) => {
             chatHistory.addMessage({ content: accumulator, role: "assistant" });
           },
@@ -59,7 +63,7 @@ export class SimpleChatEngine implements ChatEngine {
       messages: await chatHistory.requestMessages(),
     });
     chatHistory.addMessage(response.message);
-    return new Response(response.message.content);
+    return new Response(extractText(response.message.content));
   }
 
   reset() {
diff --git a/packages/core/src/evaluation/Correctness.ts b/packages/core/src/evaluation/Correctness.ts
index 1354e83f9..5f4269327 100644
--- a/packages/core/src/evaluation/Correctness.ts
+++ b/packages/core/src/evaluation/Correctness.ts
@@ -2,6 +2,7 @@ import { MetadataMode } from "../Node.js";
 import type { ServiceContext } from "../ServiceContext.js";
 import { llmFromSettingsOrContext } from "../Settings.js";
 import type { ChatMessage, LLM } from "../llm/types.js";
+import { extractText } from "../llm/utils.js";
 import { PromptMixin } from "../prompts/Mixin.js";
 import type { CorrectnessSystemPrompt } from "./prompts.js";
 import {
@@ -85,7 +86,7 @@ export class CorrectnessEvaluator extends PromptMixin implements BaseEvaluator {
     });
 
     const [score, reasoning] = this.parserFunction(
-      evalResponse.message.content,
+      extractText(evalResponse.message.content),
     );
 
     return {
diff --git a/packages/core/src/llm/LLM.ts b/packages/core/src/llm/LLM.ts
index 56f8bc489..0af646d2e 100644
--- a/packages/core/src/llm/LLM.ts
+++ b/packages/core/src/llm/LLM.ts
@@ -15,7 +15,7 @@ import type {
   LLMMetadata,
   MessageType,
 } from "./types.js";
-import { wrapLLMEvent } from "./utils.js";
+import { extractText, wrapLLMEvent } from "./utils.js";
 
 export const ALL_AVAILABLE_LLAMADEUCE_MODELS = {
   "Llama-2-70b-chat-old": {
@@ -215,16 +215,15 @@ If a question does not make any sense, or is not factually coherent, explain why
 
     return {
       prompt: messages.reduce((acc, message, index) => {
+        const content = extractText(message.content);
         if (index % 2 === 0) {
           return (
-            `${acc}${
-              withBos ? BOS : ""
-            }${B_INST} ${message.content.trim()} ${E_INST}` +
+            `${acc}${withBos ? BOS : ""}${B_INST} ${content.trim()} ${E_INST}` +
             (withNewlines ? "\n" : "")
           );
         } else {
           return (
-            `${acc} ${message.content.trim()}` +
+            `${acc} ${content.trim()}` +
             (withNewlines ? "\n" : " ") +
             (withBos ? EOS : "")
           ); // Yes, the EOS comes after the space. This is not a mistake.
@@ -322,7 +321,10 @@ export class Portkey extends BaseLLM {
     } else {
       const bodyParams = additionalChatOptions || {};
       const response = await this.session.portkey.chatCompletions.create({
-        messages,
+        messages: messages.map((message) => ({
+          content: extractText(message.content),
+          role: message.role,
+        })),
         ...bodyParams,
       });
 
@@ -337,7 +339,10 @@ export class Portkey extends BaseLLM {
     params?: Record<string, any>,
   ): AsyncIterable<ChatResponseChunk> {
     const chunkStream = await this.session.portkey.chatCompletions.create({
-      messages,
+      messages: messages.map((message) => ({
+        content: extractText(message.content),
+        role: message.role,
+      })),
       ...params,
       stream: true,
     });
diff --git a/packages/core/src/llm/anthropic.ts b/packages/core/src/llm/anthropic.ts
index 40ae46e13..3fffecee1 100644
--- a/packages/core/src/llm/anthropic.ts
+++ b/packages/core/src/llm/anthropic.ts
@@ -10,7 +10,7 @@ import type {
 } from "llamaindex";
 import _ from "lodash";
 import { BaseLLM } from "./base.js";
-import { wrapLLMEvent } from "./utils.js";
+import { extractText, wrapLLMEvent } from "./utils.js";
 
 export class AnthropicSession {
   anthropic: SDKAnthropic;
@@ -138,7 +138,7 @@ export class Anthropic extends BaseLLM {
       }
 
       return {
-        content: message.content,
+        content: extractText(message.content),
         role: message.role,
       };
     });
diff --git a/packages/core/src/llm/base.ts b/packages/core/src/llm/base.ts
index d67cdbb5b..8e1a4ec38 100644
--- a/packages/core/src/llm/base.ts
+++ b/packages/core/src/llm/base.ts
@@ -9,7 +9,7 @@ import type {
   LLMCompletionParamsStreaming,
   LLMMetadata,
 } from "./types.js";
-import { streamConverter } from "./utils.js";
+import { extractText, streamConverter } from "./utils.js";
 
 export abstract class BaseLLM<
   AdditionalChatOptions extends Record<string, unknown> = Record<
@@ -44,7 +44,10 @@ export abstract class BaseLLM<
     const chatResponse = await this.chat({
       messages: [{ content: prompt, role: "user" }],
     });
-    return { text: chatResponse.message.content as string };
+    return {
+      text: extractText(chatResponse.message.content),
+      raw: chatResponse.raw,
+    };
   }
 
   abstract chat(
diff --git a/packages/core/src/llm/open_ai.ts b/packages/core/src/llm/open_ai.ts
index 7805f4951..ffc5a176b 100644
--- a/packages/core/src/llm/open_ai.ts
+++ b/packages/core/src/llm/open_ai.ts
@@ -308,7 +308,7 @@ export class OpenAI extends BaseLLM<OpenAIAdditionalChatOptions> {
       stream: false,
     });
 
-    const content = response.choices[0].message?.content ?? null;
+    const content = response.choices[0].message?.content ?? "";
 
     const kwargsOutput: Record<string, any> = {};
 
diff --git a/packages/core/src/llm/types.ts b/packages/core/src/llm/types.ts
index 626183b11..8abf65480 100644
--- a/packages/core/src/llm/types.ts
+++ b/packages/core/src/llm/types.ts
@@ -75,8 +75,7 @@ export type MessageType =
   | "tool";
 
 export interface ChatMessage {
-  // TODO: use MessageContent
-  content: any;
+  content: MessageContent;
   role: MessageType;
   additionalKwargs?: Record<string, any>;
 }
@@ -137,7 +136,7 @@ export interface LLMChatParamsNonStreaming<
 }
 
 export interface LLMCompletionParamsBase {
-  prompt: any;
+  prompt: MessageContent;
 }
 
 export interface LLMCompletionParamsStreaming extends LLMCompletionParamsBase {
@@ -149,11 +148,19 @@ export interface LLMCompletionParamsNonStreaming
   stream?: false | null;
 }
 
-export interface MessageContentDetail {
-  type: "text" | "image_url";
-  text?: string;
-  image_url?: { url: string };
-}
+export type MessageContentTextDetail = {
+  type: "text";
+  text: string;
+};
+
+export type MessageContentImageDetail = {
+  type: "image_url";
+  image_url: { url: string };
+};
+
+export type MessageContentDetail =
+  | MessageContentTextDetail
+  | MessageContentImageDetail;
 
 /**
  * Extended type for the content of a message that allows for multi-modal messages.
diff --git a/packages/core/src/llm/utils.ts b/packages/core/src/llm/utils.ts
index 03725ad5e..2fb626708 100644
--- a/packages/core/src/llm/utils.ts
+++ b/packages/core/src/llm/utils.ts
@@ -1,6 +1,12 @@
 import { AsyncLocalStorage } from "@llamaindex/env";
 import { getCallbackManager } from "../internal/settings/CallbackManager.js";
-import type { ChatResponse, LLM, LLMChat, MessageContent } from "./types.js";
+import type {
+  ChatResponse,
+  LLM,
+  LLMChat,
+  MessageContent,
+  MessageContentTextDetail,
+} from "./types.js";
 
 export async function* streamConverter<S, D>(
   stream: AsyncIterable<S>,
@@ -15,7 +21,7 @@ export async function* streamReducer<S, D>(params: {
   stream: AsyncIterable<S>;
   reducer: (previousValue: D, currentValue: S) => D;
   initialValue: D;
-  finished?: (value: D | undefined) => void;
+  finished?: (value: D) => void;
 }): AsyncIterable<S> {
   let value = params.initialValue;
   for await (const data of params.stream) {
@@ -26,23 +32,29 @@ export async function* streamReducer<S, D>(params: {
     params.finished(value);
   }
 }
+
 /**
  * Extracts just the text from a multi-modal message or the message itself if it's just text.
  *
  * @param message The message to extract text from.
  * @returns The extracted text
  */
-
 export function extractText(message: MessageContent): string {
-  if (Array.isArray(message)) {
+  if (typeof message !== "string" && !Array.isArray(message)) {
+    console.warn(
+      "extractText called with non-string message, this is likely a bug.",
+    );
+    return `${message}`;
+  } else if (typeof message !== "string" && Array.isArray(message)) {
     // message is of type MessageContentDetail[] - retrieve just the text parts and concatenate them
     // so we can pass them to the context generator
     return message
-      .filter((c) => c.type === "text")
+      .filter((c): c is MessageContentTextDetail => c.type === "text")
       .map((c) => c.text)
       .join("\n\n");
+  } else {
+    return message;
   }
-  return message;
 }
 
 /**
diff --git a/packages/core/src/selectors/llmSelectors.ts b/packages/core/src/selectors/llmSelectors.ts
index 654740fb0..5f7349f26 100644
--- a/packages/core/src/selectors/llmSelectors.ts
+++ b/packages/core/src/selectors/llmSelectors.ts
@@ -48,7 +48,7 @@ export class LLMMultiSelector extends BaseSelector {
   llm: LLMPredictorType;
   prompt: MultiSelectPrompt;
   maxOutputs: number;
-  outputParser: BaseOutputParser<StructuredOutput<Answer[]>> | null;
+  outputParser: BaseOutputParser<StructuredOutput<Answer[]>>;
 
   constructor(init: {
     llm: LLMPredictorType;
@@ -118,7 +118,7 @@ export class LLMMultiSelector extends BaseSelector {
 export class LLMSingleSelector extends BaseSelector {
   llm: LLMPredictorType;
   prompt: SingleSelectPrompt;
-  outputParser: BaseOutputParser<StructuredOutput<Answer[]>> | null;
+  outputParser: BaseOutputParser<StructuredOutput<Answer[]>>;
 
   constructor(init: {
     llm: LLMPredictorType;
@@ -154,7 +154,7 @@ export class LLMSingleSelector extends BaseSelector {
 
     const prompt = this.prompt(choicesText.length, choicesText, query.queryStr);
 
-    const formattedPrompt = this.outputParser?.format(prompt);
+    const formattedPrompt = this.outputParser.format(prompt);
 
     const prediction = await this.llm.complete({
       prompt: formattedPrompt,
diff --git a/tsconfig.json b/tsconfig.json
index 84821fc77..9027b38e4 100644
--- a/tsconfig.json
+++ b/tsconfig.json
@@ -11,15 +11,7 @@
     "outDir": "./lib",
     "tsBuildInfoFile": "./lib/.tsbuildinfo",
     "incremental": true,
-    "composite": true,
-    "paths": {
-      "llamaindex": ["./packages/core/src/index.ts"],
-      "llamaindex/*": ["./packages/core/src/*.ts"],
-      "@llamaindex/env": ["./packages/env/src/index.ts"],
-      "@llamaindex/env/*": ["./packages/env/src/*.ts"],
-      "@llamaindex/experimental": ["./packages/experimental/src/index.ts"],
-      "@llamaindex/experimental/*": ["./packages/experimental/src/*.ts"]
-    }
+    "composite": true
   },
   "files": [],
   "references": [
-- 
GitLab