From a87a4d1222e28719731d5fffa0112d39bd4bdf07 Mon Sep 17 00:00:00 2001
From: Parham Saidi <parham@parha.me>
Date: Tue, 25 Jun 2024 19:51:40 +0200
Subject: [PATCH] feat: tool calling for Bedrock's Claude and General LLM Agent
 (#955)

---
 .changeset/brave-cherries-juggle.md         |   6 +
 packages/community/src/llm/bedrock/base.ts  | 179 ++++++++++++++++---
 packages/community/src/llm/bedrock/types.ts |  54 +++++-
 packages/community/src/llm/bedrock/utils.ts |  56 +++++-
 packages/llamaindex/src/agent/anthropic.ts  | 112 ++----------
 packages/llamaindex/src/agent/base.ts       |  28 +++
 packages/llamaindex/src/agent/index.ts      |   5 +
 packages/llamaindex/src/agent/llm.ts        |  45 +++++
 packages/llamaindex/src/agent/openai.ts     | 184 ++------------------
 packages/llamaindex/src/agent/utils.ts      | 147 +++++++++++++++-
 packages/llamaindex/src/llm/index.ts        |   2 +-
 11 files changed, 515 insertions(+), 303 deletions(-)
 create mode 100644 .changeset/brave-cherries-juggle.md
 create mode 100644 packages/llamaindex/src/agent/llm.ts

diff --git a/.changeset/brave-cherries-juggle.md b/.changeset/brave-cherries-juggle.md
new file mode 100644
index 000000000..138a6162b
--- /dev/null
+++ b/.changeset/brave-cherries-juggle.md
@@ -0,0 +1,6 @@
+---
+"llamaindex": patch
+"@llamaindex/community": patch
+---
+
+feat: added tool support calling for Bedrock's Calude and general llm support for agents
diff --git a/packages/community/src/llm/bedrock/base.ts b/packages/community/src/llm/bedrock/base.ts
index 853f74d33..a3d5b3405 100644
--- a/packages/community/src/llm/bedrock/base.ts
+++ b/packages/community/src/llm/bedrock/base.ts
@@ -2,12 +2,13 @@ import {
   BedrockRuntimeClient,
   InvokeModelCommand,
   InvokeModelWithResponseStreamCommand,
+  ResponseStream,
   type BedrockRuntimeClientConfig,
   type InvokeModelCommandInput,
   type InvokeModelWithResponseStreamCommandInput,
 } from "@aws-sdk/client-bedrock-runtime";
-
 import type {
+  BaseTool,
   ChatMessage,
   ChatResponse,
   ChatResponseChunk,
@@ -17,6 +18,8 @@ import type {
   LLMCompletionParamsNonStreaming,
   LLMCompletionParamsStreaming,
   LLMMetadata,
+  PartialToolCall,
+  ToolCall,
   ToolCallLLMMessageOptions,
 } from "llamaindex";
 import { ToolCallLLM, streamConverter, wrapLLMEvent } from "llamaindex";
@@ -24,14 +27,17 @@ import type {
   AnthropicNoneStreamingResponse,
   AnthropicTextContent,
   StreamEvent,
+  ToolBlock,
+  ToolChoice,
 } from "./types.js";
 import {
+  mapBaseToolsToAnthropicTools,
   mapChatMessagesToAnthropicMessages,
   mapMessageContentToMessageContentDetails,
   toUtf8,
 } from "./utils.js";
 
-export type BedrockAdditionalChatOptions = {};
+export type BedrockAdditionalChatOptions = { toolChoice: ToolChoice };
 
 export type BedrockChatParamsStreaming = LLMChatParamsStreaming<
   BedrockAdditionalChatOptions,
@@ -138,9 +144,39 @@ export const STREAMING_MODELS = new Set([
   BEDROCK_MODELS.MISTRAL_MIXTRAL_LARGE_2402,
 ]);
 
-abstract class Provider {
+export const TOOL_CALL_MODELS = [
+  BEDROCK_MODELS.ANTHROPIC_CLAUDE_3_SONNET,
+  BEDROCK_MODELS.ANTHROPIC_CLAUDE_3_HAIKU,
+  BEDROCK_MODELS.ANTHROPIC_CLAUDE_3_OPUS,
+  BEDROCK_MODELS.ANTHROPIC_CLAUDE_3_5_SONNET,
+];
+
+abstract class Provider<ProviderStreamEvent extends {} = {}> {
   abstract getTextFromResponse(response: Record<string, any>): string;
 
+  abstract getToolsFromResponse<T extends {} = {}>(
+    response: Record<string, any>,
+  ): T[];
+
+  getStreamingEventResponse(
+    response: Record<string, any>,
+  ): ProviderStreamEvent | undefined {
+    return response.chunk?.bytes
+      ? (JSON.parse(toUtf8(response.chunk?.bytes)) as ProviderStreamEvent)
+      : undefined;
+  }
+
+  async *reduceStream(
+    stream: AsyncIterable<ResponseStream>,
+  ): BedrockChatStreamResponse {
+    yield* streamConverter(stream, (response) => {
+      return {
+        delta: this.getTextFromStreamResponse(response),
+        raw: response,
+      };
+    });
+  }
+
   getTextFromStreamResponse(response: Record<string, any>): string {
     return this.getTextFromResponse(response);
   }
@@ -148,16 +184,27 @@ abstract class Provider {
   abstract getRequestBody<T extends ChatMessage>(
     metadata: LLMMetadata,
     messages: T[],
+    tools?: BaseTool[],
+    options?: BedrockAdditionalChatOptions,
   ): InvokeModelCommandInput | InvokeModelWithResponseStreamCommandInput;
 }
 
-class AnthropicProvider extends Provider {
+class AnthropicProvider extends Provider<StreamEvent> {
   getResultFromResponse(
     response: Record<string, any>,
   ): AnthropicNoneStreamingResponse {
     return JSON.parse(toUtf8(response.body));
   }
 
+  getToolsFromResponse<AnthropicToolContent>(
+    response: Record<string, any>,
+  ): AnthropicToolContent[] {
+    const result = this.getResultFromResponse(response);
+    return result.content
+      .filter((item) => item.type === "tool_use")
+      .map((item) => item as AnthropicToolContent);
+  }
+
   getTextFromResponse(response: Record<string, any>): string {
     const result = this.getResultFromResponse(response);
     return result.content
@@ -167,28 +214,101 @@ class AnthropicProvider extends Provider {
   }
 
   getTextFromStreamResponse(response: Record<string, any>): string {
-    const event: StreamEvent | undefined = response.chunk?.bytes
-      ? JSON.parse(toUtf8(response.chunk?.bytes))
-      : undefined;
-
-    if (event?.type === "content_block_delta") return event.delta.text;
+    const event = this.getStreamingEventResponse(response);
+    if (event?.type === "content_block_delta") {
+      if (event.delta.type === "text_delta") return event.delta.text;
+      if (event.delta.type === "input_json_delta")
+        return event.delta.partial_json;
+    }
     return "";
   }
 
-  getRequestBody<T extends ChatMessage>(
+  async *reduceStream(
+    stream: AsyncIterable<ResponseStream>,
+  ): BedrockChatStreamResponse {
+    let collecting = [];
+    let tool: ToolBlock | undefined = undefined;
+    // #TODO this should be broken down into a separate consumer
+    for await (const response of stream) {
+      const event = this.getStreamingEventResponse(response);
+      if (
+        event?.type === "content_block_start" &&
+        event.content_block.type === "tool_use"
+      ) {
+        tool = event.content_block;
+        continue;
+      }
+
+      if (
+        event?.type === "content_block_delta" &&
+        event.delta.type === "input_json_delta"
+      ) {
+        collecting.push(event.delta.partial_json);
+      }
+
+      let options: undefined | ToolCallLLMMessageOptions = undefined;
+      if (tool && collecting.length) {
+        const input = collecting.filter((item) => item).join("");
+        // We have all we need to parse the tool_use json
+        if (event?.type === "content_block_stop") {
+          options = {
+            toolCall: [
+              {
+                id: tool.id,
+                name: tool.name,
+                input: JSON.parse(input),
+              } as ToolCall,
+            ],
+          };
+          // reset the collection/tool
+          collecting = [];
+          tool = undefined;
+        } else {
+          options = {
+            toolCall: [
+              {
+                id: tool.id,
+                name: tool.name,
+                input,
+              } as PartialToolCall,
+            ],
+          };
+        }
+      }
+      const delta = this.getTextFromStreamResponse(response);
+      if (!delta && !options) continue;
+
+      yield {
+        delta,
+        options,
+        raw: response,
+      };
+    }
+  }
+
+  getRequestBody<T extends ChatMessage<ToolCallLLMMessageOptions>>(
     metadata: LLMMetadata,
     messages: T[],
+    tools?: BaseTool[],
+    options?: BedrockAdditionalChatOptions,
   ): InvokeModelCommandInput | InvokeModelWithResponseStreamCommandInput {
+    const extra: Record<string, unknown> = {};
+    if (options?.toolChoice) {
+      extra["tool_choice"] = options?.toolChoice;
+    }
+    const mapped = mapChatMessagesToAnthropicMessages(messages);
     return {
       modelId: metadata.model,
       contentType: "application/json",
       accept: "application/json",
       body: JSON.stringify({
         anthropic_version: "bedrock-2023-05-31",
-        messages: mapChatMessagesToAnthropicMessages(messages),
+        messages: mapped,
+        tools: mapBaseToolsToAnthropicTools(tools),
         max_tokens: metadata.maxTokens,
         temperature: metadata.temperature,
         top_p: metadata.topP,
+        ...extra,
       }),
     };
   }
@@ -256,7 +376,7 @@ export class Bedrock extends ToolCallLLM<BedrockAdditionalChatOptions> {
   }
 
   get supportToolCall(): boolean {
-    return false;
+    return TOOL_CALL_MODELS.includes(this.model);
   }
 
   get metadata(): LLMMetadata {
@@ -274,14 +394,24 @@ export class Bedrock extends ToolCallLLM<BedrockAdditionalChatOptions> {
   protected async nonStreamChat(
     params: BedrockChatParamsNonStreaming,
   ): Promise<BedrockChatNonStreamResponse> {
-    const input = this.provider.getRequestBody(this.metadata, params.messages);
+    const input = this.provider.getRequestBody(
+      this.metadata,
+      params.messages,
+      params.tools,
+      params.additionalChatOptions,
+    );
     const command = new InvokeModelCommand(input);
     const response = await this.client.send(command);
+    const tools = this.provider.getToolsFromResponse(response);
+    const options: ToolCallLLMMessageOptions = tools.length
+      ? { toolCall: tools }
+      : {};
     return {
       raw: response,
       message: {
-        content: this.provider.getTextFromResponse(response),
         role: "assistant",
+        content: this.provider.getTextFromResponse(response),
+        options,
       },
     };
   }
@@ -291,29 +421,30 @@ export class Bedrock extends ToolCallLLM<BedrockAdditionalChatOptions> {
   ): BedrockChatStreamResponse {
     if (!STREAMING_MODELS.has(this.model))
       throw new Error(`The model: ${this.model} does not support streaming`);
-    const input = this.provider.getRequestBody(this.metadata, params.messages);
+
+    const input = this.provider.getRequestBody(
+      this.metadata,
+      params.messages,
+      params.tools,
+      params.additionalChatOptions,
+    );
     const command = new InvokeModelWithResponseStreamCommand(input);
     const response = await this.client.send(command);
 
-    if (response.body)
-      yield* streamConverter(response.body, (response) => {
-        return {
-          delta: this.provider.getTextFromStreamResponse(response),
-          raw: response,
-        };
-      });
+    if (response.body) yield* this.provider.reduceStream(response.body);
   }
 
   chat(params: BedrockChatParamsStreaming): Promise<BedrockChatStreamResponse>;
   chat(
     params: BedrockChatParamsNonStreaming,
   ): Promise<BedrockChatNonStreamResponse>;
-
   @wrapLLMEvent
   async chat(
     params: BedrockChatParamsStreaming | BedrockChatParamsNonStreaming,
   ): Promise<BedrockChatStreamResponse | BedrockChatNonStreamResponse> {
-    if (params.stream) return this.streamChat(params);
+    if (params.stream) {
+      return this.streamChat(params);
+    }
     return this.nonStreamChat(params);
   }
 
diff --git a/packages/community/src/llm/bedrock/types.ts b/packages/community/src/llm/bedrock/types.ts
index 14124dcda..8a02d5db4 100644
--- a/packages/community/src/llm/bedrock/types.ts
+++ b/packages/community/src/llm/bedrock/types.ts
@@ -14,19 +14,33 @@ type Message = {
   usage: Usage;
 };
 
+export type ToolBlock = {
+  id: string;
+  input: unknown;
+  name: string;
+  type: "tool_use";
+};
+
+export type TextBlock = {
+  type: "text";
+  text: string;
+};
+
 type ContentBlockStart = {
   type: "content_block_start";
   index: number;
-  content_block: {
-    type: string;
-    text: string;
-  };
+  content_block: ToolBlock | TextBlock;
 };
 
-type Delta = {
-  type: string;
-  text: string;
-};
+type Delta =
+  | {
+      type: "text_delta";
+      text: string;
+    }
+  | {
+      type: "input_json_delta";
+      partial_json: string;
+    };
 
 type ContentBlockDelta = {
   type: "content_block_delta";
@@ -60,6 +74,11 @@ type MessageStop = {
   "amazon-bedrock-invocationMetrics": InvocationMetrics;
 };
 
+export type ToolChoice =
+  | { type: "any" }
+  | { type: "auto" }
+  | { type: "tool"; name: string };
+
 export type StreamEvent =
   | { type: "message_start"; message: Message }
   | ContentBlockStart
@@ -68,13 +87,30 @@ export type StreamEvent =
   | MessageDelta
   | MessageStop;
 
-export type AnthropicContent = AnthropicTextContent | AnthropicImageContent;
+export type AnthropicContent =
+  | AnthropicTextContent
+  | AnthropicImageContent
+  | AnthropicToolContent
+  | AnthropicToolResultContent;
 
 export type AnthropicTextContent = {
   type: "text";
   text: string;
 };
 
+export type AnthropicToolContent = {
+  type: "tool_use";
+  id: string;
+  name: string;
+  input: Record<string, unknown>;
+};
+
+export type AnthropicToolResultContent = {
+  type: "tool_result";
+  tool_use_id: string;
+  content: string;
+};
+
 export type AnthropicMediaTypes =
   | "image/jpeg"
   | "image/png"
diff --git a/packages/community/src/llm/bedrock/utils.ts b/packages/community/src/llm/bedrock/utils.ts
index f6d9beea9..64bbdda1e 100644
--- a/packages/community/src/llm/bedrock/utils.ts
+++ b/packages/community/src/llm/bedrock/utils.ts
@@ -1,7 +1,11 @@
 import type {
+  BaseTool,
   ChatMessage,
+  JSONObject,
   MessageContent,
   MessageContentDetail,
+  ToolCallLLMMessageOptions,
+  ToolMetadata,
 } from "llamaindex";
 import type {
   AnthropicContent,
@@ -68,11 +72,61 @@ export const mapMessageContentToAnthropicContent = <T extends MessageContent>(
   );
 };
 
-export const mapChatMessagesToAnthropicMessages = <T extends ChatMessage>(
+type AnthropicTool = {
+  name: string;
+  description: string;
+  input_schema: ToolMetadata["parameters"];
+};
+
+export const mapBaseToolsToAnthropicTools = (
+  tools?: BaseTool[],
+): AnthropicTool[] => {
+  if (!tools) return [];
+  return tools.map((tool: BaseTool) => {
+    const {
+      metadata: { parameters, ...options },
+    } = tool;
+    return {
+      ...options,
+      input_schema: parameters,
+    };
+  });
+};
+
+export const mapChatMessagesToAnthropicMessages = <
+  T extends ChatMessage<ToolCallLLMMessageOptions>,
+>(
   messages: T[],
 ): AnthropicMessage[] => {
   const mapped = messages
     .flatMap((msg: T): AnthropicMessage[] => {
+      if (msg.options && "toolCall" in msg.options) {
+        return [
+          {
+            role: "assistant",
+            content: msg.options.toolCall.map((call) => ({
+              type: "tool_use",
+              id: call.id,
+              name: call.name,
+              input: call.input as JSONObject,
+            })),
+          },
+        ];
+      }
+      if (msg.options && "toolResult" in msg.options) {
+        return [
+          {
+            role: "user",
+            content: [
+              {
+                type: "tool_result",
+                tool_use_id: msg.options.toolResult.id,
+                content: msg.options.toolResult.result,
+              },
+            ],
+          },
+        ];
+      }
       return mapMessageContentToMessageContentDetails(msg.content).map(
         (detail: MessageContentDetail): AnthropicMessage => {
           const content = mapMessageContentDetailToAnthropicContent(detail);
diff --git a/packages/llamaindex/src/agent/anthropic.ts b/packages/llamaindex/src/agent/anthropic.ts
index 1f917cf3e..8f17b360d 100644
--- a/packages/llamaindex/src/agent/anthropic.ts
+++ b/packages/llamaindex/src/agent/anthropic.ts
@@ -1,116 +1,38 @@
-import { EngineResponse } from "../EngineResponse.js";
 import { Settings } from "../Settings.js";
-import {
-  type ChatEngineParamsNonStreaming,
-  type ChatEngineParamsStreaming,
-} from "../engines/chat/index.js";
-import { stringifyJSONToMessageContent } from "../internal/utils.js";
+import type {
+  ChatEngineParamsNonStreaming,
+  ChatEngineParamsStreaming,
+  EngineResponse,
+} from "../index.edge.js";
 import { Anthropic } from "../llm/anthropic.js";
-import { ObjectRetriever } from "../objects/index.js";
-import type { BaseToolWithCall } from "../types.js";
-import { AgentRunner, AgentWorker, type AgentParamsBase } from "./base.js";
-import type { TaskHandler } from "./types.js";
-import { callTool } from "./utils.js";
+import { LLMAgent, LLMAgentWorker, type LLMAgentParams } from "./llm.js";
 
-type AnthropicParamsBase = AgentParamsBase<Anthropic>;
+export type AnthropicAgentParams = LLMAgentParams;
 
-type AnthropicParamsWithTools = AnthropicParamsBase & {
-  tools: BaseToolWithCall[];
-};
+export class AnthropicAgentWorker extends LLMAgentWorker {}
 
-type AnthropicParamsWithToolRetriever = AnthropicParamsBase & {
-  toolRetriever: ObjectRetriever<BaseToolWithCall>;
-};
-
-export type AnthropicAgentParams =
-  | AnthropicParamsWithTools
-  | AnthropicParamsWithToolRetriever;
-
-export class AnthropicAgentWorker extends AgentWorker<Anthropic> {
-  taskHandler = AnthropicAgent.taskHandler;
-}
-
-export class AnthropicAgent extends AgentRunner<Anthropic> {
+export class AnthropicAgent extends LLMAgent {
   constructor(params: AnthropicAgentParams) {
+    const llm =
+      params.llm ??
+      (Settings.llm instanceof Anthropic
+        ? (Settings.llm as Anthropic)
+        : new Anthropic());
     super({
-      llm:
-        params.llm ??
-        (Settings.llm instanceof Anthropic
-          ? (Settings.llm as Anthropic)
-          : new Anthropic()),
-      chatHistory: params.chatHistory ?? [],
-      systemPrompt: params.systemPrompt ?? null,
-      runner: new AnthropicAgentWorker(),
-      tools:
-        "tools" in params
-          ? params.tools
-          : params.toolRetriever.retrieve.bind(params.toolRetriever),
-      verbose: params.verbose ?? false,
+      ...params,
+      llm,
     });
   }
 
-  createStore = AgentRunner.defaultCreateStore;
-
   async chat(params: ChatEngineParamsNonStreaming): Promise<EngineResponse>;
   async chat(params: ChatEngineParamsStreaming): Promise<never>;
   override async chat(
     params: ChatEngineParamsNonStreaming | ChatEngineParamsStreaming,
   ) {
     if (params.stream) {
+      // Anthropic does support this, but looks like it's not supported in the LITS LLM
       throw new Error("Anthropic does not support streaming");
     }
     return super.chat(params);
   }
-
-  static taskHandler: TaskHandler<Anthropic> = async (step, enqueueOutput) => {
-    const { llm, getTools, stream } = step.context;
-    const lastMessage = step.context.store.messages.at(-1)!.content;
-    const tools = await getTools(lastMessage);
-    if (stream === true) {
-      throw new Error("Anthropic does not support streaming");
-    }
-    const response = await llm.chat({
-      stream,
-      tools,
-      messages: step.context.store.messages,
-    });
-    step.context.store.messages = [
-      ...step.context.store.messages,
-      response.message,
-    ];
-    const options = response.message.options ?? {};
-    enqueueOutput({
-      taskStep: step,
-      output: response,
-      isLast: !("toolCall" in options),
-    });
-    if ("toolCall" in options) {
-      const { toolCall } = options;
-      for (const call of toolCall) {
-        const targetTool = tools.find(
-          (tool) => tool.metadata.name === call.name,
-        );
-        const toolOutput = await callTool(
-          targetTool,
-          call,
-          step.context.logger,
-        );
-        step.context.store.toolOutputs.push(toolOutput);
-        step.context.store.messages = [
-          ...step.context.store.messages,
-          {
-            content: stringifyJSONToMessageContent(toolOutput.output),
-            role: "user",
-            options: {
-              toolResult: {
-                result: toolOutput.output,
-                isError: toolOutput.isError,
-                id: call.id,
-              },
-            },
-          },
-        ];
-      }
-    }
-  };
 }
diff --git a/packages/llamaindex/src/agent/base.ts b/packages/llamaindex/src/agent/base.ts
index 5775c8d33..1965dac7d 100644
--- a/packages/llamaindex/src/agent/base.ts
+++ b/packages/llamaindex/src/agent/base.ts
@@ -19,6 +19,7 @@ import type {
   TaskStep,
   TaskStepOutput,
 } from "./types.js";
+import { stepTools, stepToolsStreaming } from "./utils.js";
 
 export const MAX_TOOL_CALLS = 10;
 
@@ -214,6 +215,33 @@ export abstract class AgentRunner<
     return Object.create(null);
   }
 
+  static defaultTaskHandler: TaskHandler<LLM> = async (step, enqueueOutput) => {
+    const { llm, getTools, stream } = step.context;
+    const lastMessage = step.context.store.messages.at(-1)!.content;
+    const tools = await getTools(lastMessage);
+    const response = await llm.chat({
+      // @ts-expect-error
+      stream,
+      tools,
+      messages: [...step.context.store.messages],
+    });
+    if (!stream) {
+      await stepTools<LLM>({
+        response,
+        tools,
+        step,
+        enqueueOutput,
+      });
+    } else {
+      await stepToolsStreaming<LLM>({
+        response,
+        tools,
+        step,
+        enqueueOutput,
+      });
+    }
+  };
+
   protected constructor(
     params: AgentRunnerParams<AI, Store, AdditionalMessageOptions>,
   ) {
diff --git a/packages/llamaindex/src/agent/index.ts b/packages/llamaindex/src/agent/index.ts
index 18d6fbe95..feda11bd4 100644
--- a/packages/llamaindex/src/agent/index.ts
+++ b/packages/llamaindex/src/agent/index.ts
@@ -3,6 +3,8 @@ export {
   AnthropicAgentWorker,
   type AnthropicAgentParams,
 } from "./anthropic.js";
+export { AgentRunner, AgentWorker, type AgentParamsBase } from "./base.js";
+export { LLMAgent, LLMAgentWorker, type LLMAgentParams } from "./llm.js";
 export {
   OpenAIAgent,
   OpenAIAgentWorker,
@@ -13,6 +15,9 @@ export {
   ReActAgent,
   type ReACTAgentParams,
 } from "./react.js";
+export { type TaskHandler } from "./types.js";
+export { callTool, stepTools, stepToolsStreaming } from "./utils.js";
+
 // todo: ParallelAgent
 // todo: CustomAgent
 // todo: ReactMultiModal
diff --git a/packages/llamaindex/src/agent/llm.ts b/packages/llamaindex/src/agent/llm.ts
new file mode 100644
index 000000000..78b853649
--- /dev/null
+++ b/packages/llamaindex/src/agent/llm.ts
@@ -0,0 +1,45 @@
+import type { LLM } from "../llm/index.js";
+import { ObjectRetriever } from "../objects/index.js";
+import { Settings } from "../Settings.js";
+import type { BaseToolWithCall } from "../types.js";
+import { AgentRunner, AgentWorker, type AgentParamsBase } from "./base.js";
+
+type LLMParamsBase = AgentParamsBase<LLM>;
+
+type LLMParamsWithTools = LLMParamsBase & {
+  tools: BaseToolWithCall[];
+};
+
+type LLMParamsWithToolRetriever = LLMParamsBase & {
+  toolRetriever: ObjectRetriever<BaseToolWithCall>;
+};
+
+export type LLMAgentParams = LLMParamsWithTools | LLMParamsWithToolRetriever;
+
+export class LLMAgentWorker extends AgentWorker<LLM> {
+  taskHandler = AgentRunner.defaultTaskHandler;
+}
+
+export class LLMAgent extends AgentRunner<LLM> {
+  constructor(params: LLMAgentParams) {
+    const llm = params.llm ?? (Settings.llm ? (Settings.llm as LLM) : null);
+    if (!llm)
+      throw new Error(
+        "llm must be provided for either in params or Settings.llm",
+      );
+    super({
+      llm,
+      chatHistory: params.chatHistory ?? [],
+      systemPrompt: params.systemPrompt ?? null,
+      runner: new LLMAgentWorker(),
+      tools:
+        "tools" in params
+          ? params.tools
+          : params.toolRetriever.retrieve.bind(params.toolRetriever),
+      verbose: params.verbose ?? false,
+    });
+  }
+
+  createStore = AgentRunner.defaultCreateStore;
+  taskHandler = AgentRunner.defaultTaskHandler;
+}
diff --git a/packages/llamaindex/src/agent/openai.ts b/packages/llamaindex/src/agent/openai.ts
index e0ca14930..a85fb4c5a 100644
--- a/packages/llamaindex/src/agent/openai.ts
+++ b/packages/llamaindex/src/agent/openai.ts
@@ -1,183 +1,23 @@
-import { ReadableStream } from "@llamaindex/env";
 import { Settings } from "../Settings.js";
-import { stringifyJSONToMessageContent } from "../internal/utils.js";
-import type {
-  ChatResponseChunk,
-  PartialToolCall,
-  ToolCall,
-  ToolCallLLMMessageOptions,
-} from "../llm/index.js";
 import { OpenAI } from "../llm/openai.js";
-import { ObjectRetriever } from "../objects/index.js";
-import type { BaseToolWithCall } from "../types.js";
-import { AgentRunner, AgentWorker, type AgentParamsBase } from "./base.js";
-import type { TaskHandler } from "./types.js";
-import { callTool } from "./utils.js";
+import { LLMAgent, LLMAgentWorker, type LLMAgentParams } from "./llm.js";
 
-type OpenAIParamsBase = AgentParamsBase<OpenAI>;
+// This is likely not necessary anymore but leaving it here just incase it's in use elsewhere
 
-type OpenAIParamsWithTools = OpenAIParamsBase & {
-  tools: BaseToolWithCall[];
-};
+export type OpenAIAgentParams = LLMAgentParams;
 
-type OpenAIParamsWithToolRetriever = OpenAIParamsBase & {
-  toolRetriever: ObjectRetriever<BaseToolWithCall>;
-};
+export class OpenAIAgentWorker extends LLMAgentWorker {}
 
-export type OpenAIAgentParams =
-  | OpenAIParamsWithTools
-  | OpenAIParamsWithToolRetriever;
-
-export class OpenAIAgentWorker extends AgentWorker<OpenAI> {
-  taskHandler = OpenAIAgent.taskHandler;
-}
-
-export class OpenAIAgent extends AgentRunner<OpenAI> {
+export class OpenAIAgent extends LLMAgent {
   constructor(params: OpenAIAgentParams) {
+    const llm =
+      params.llm ??
+      (Settings.llm instanceof OpenAI
+        ? (Settings.llm as OpenAI)
+        : new OpenAI());
     super({
-      llm:
-        params.llm ??
-        (Settings.llm instanceof OpenAI
-          ? (Settings.llm as OpenAI)
-          : new OpenAI()),
-      chatHistory: params.chatHistory ?? [],
-      runner: new OpenAIAgentWorker(),
-      systemPrompt: params.systemPrompt ?? null,
-      tools:
-        "tools" in params
-          ? params.tools
-          : params.toolRetriever.retrieve.bind(params.toolRetriever),
-      verbose: params.verbose ?? false,
+      ...params,
+      llm,
     });
   }
-
-  createStore = AgentRunner.defaultCreateStore;
-
-  static taskHandler: TaskHandler<OpenAI> = async (step, enqueueOutput) => {
-    const { llm, stream, getTools } = step.context;
-    const lastMessage = step.context.store.messages.at(-1)!.content;
-    const tools = await getTools(lastMessage);
-    const response = await llm.chat({
-      // @ts-expect-error
-      stream,
-      tools,
-      messages: [...step.context.store.messages],
-    });
-    if (!stream) {
-      step.context.store.messages = [
-        ...step.context.store.messages,
-        response.message,
-      ];
-      const options = response.message.options ?? {};
-      enqueueOutput({
-        taskStep: step,
-        output: response,
-        isLast: !("toolCall" in options),
-      });
-      if ("toolCall" in options) {
-        const { toolCall } = options;
-        for (const call of toolCall) {
-          const targetTool = tools.find(
-            (tool) => tool.metadata.name === call.name,
-          );
-          const toolOutput = await callTool(
-            targetTool,
-            call,
-            step.context.logger,
-          );
-          step.context.store.toolOutputs.push(toolOutput);
-          step.context.store.messages = [
-            ...step.context.store.messages,
-            {
-              role: "user" as const,
-              content: stringifyJSONToMessageContent(toolOutput.output),
-              options: {
-                toolResult: {
-                  result: toolOutput.output,
-                  isError: toolOutput.isError,
-                  id: call.id,
-                },
-              },
-            },
-          ];
-        }
-      }
-    } else {
-      const responseChunkStream = new ReadableStream<
-        ChatResponseChunk<ToolCallLLMMessageOptions>
-      >({
-        async start(controller) {
-          for await (const chunk of response) {
-            controller.enqueue(chunk);
-          }
-          controller.close();
-        },
-      });
-      const [pipStream, finalStream] = responseChunkStream.tee();
-      const reader = pipStream.getReader();
-      const { value } = await reader.read();
-      reader.releaseLock();
-      if (value === undefined) {
-        throw new Error(
-          "first chunk value is undefined, this should not happen",
-        );
-      }
-      // check if first chunk has tool calls, if so, this is a function call
-      // otherwise, it's a regular message
-      const hasToolCall = !!(value.options && "toolCall" in value.options);
-      enqueueOutput({
-        taskStep: step,
-        output: finalStream,
-        isLast: !hasToolCall,
-      });
-
-      if (hasToolCall) {
-        // you need to consume the response to get the full toolCalls
-        const toolCalls = new Map<string, ToolCall | PartialToolCall>();
-        for await (const chunk of pipStream) {
-          if (chunk.options && "toolCall" in chunk.options) {
-            const toolCall = chunk.options.toolCall;
-            toolCall.forEach((toolCall) => {
-              toolCalls.set(toolCall.id, toolCall);
-            });
-          }
-        }
-        step.context.store.messages = [
-          ...step.context.store.messages,
-          {
-            role: "assistant" as const,
-            content: "",
-            options: {
-              toolCall: [...toolCalls.values()],
-            },
-          },
-        ];
-        for (const toolCall of toolCalls.values()) {
-          const targetTool = tools.find(
-            (tool) => tool.metadata.name === toolCall.name,
-          );
-          const toolOutput = await callTool(
-            targetTool,
-            toolCall,
-            step.context.logger,
-          );
-          step.context.store.messages = [
-            ...step.context.store.messages,
-            {
-              role: "user" as const,
-              content: stringifyJSONToMessageContent(toolOutput.output),
-              options: {
-                toolResult: {
-                  result: toolOutput.output,
-                  isError: toolOutput.isError,
-                  id: toolCall.id,
-                },
-              },
-            },
-          ];
-          step.context.store.toolOutputs.push(toolOutput);
-        }
-      }
-    }
-  };
 }
diff --git a/packages/llamaindex/src/agent/utils.ts b/packages/llamaindex/src/agent/utils.ts
index e5d8614fe..8a3df6402 100644
--- a/packages/llamaindex/src/agent/utils.ts
+++ b/packages/llamaindex/src/agent/utils.ts
@@ -1,15 +1,160 @@
 import { ReadableStream } from "@llamaindex/env";
 import type { Logger } from "../internal/logger.js";
 import { getCallbackManager } from "../internal/settings/CallbackManager.js";
-import { isAsyncIterable, prettifyError } from "../internal/utils.js";
+import {
+  isAsyncIterable,
+  prettifyError,
+  stringifyJSONToMessageContent,
+} from "../internal/utils.js";
 import type {
   ChatMessage,
+  ChatResponse,
   ChatResponseChunk,
+  LLM,
   PartialToolCall,
   TextChatMessage,
   ToolCall,
+  ToolCallLLMMessageOptions,
 } from "../llm/index.js";
 import type { BaseTool, JSONObject, JSONValue, ToolOutput } from "../types.js";
+import type { TaskHandler } from "./types.js";
+
+type StepToolsResponseParams<Model extends LLM> = {
+  response: ChatResponse<ToolCallLLMMessageOptions>;
+  tools: BaseTool[];
+  step: Parameters<TaskHandler<Model, {}, ToolCallLLMMessageOptions>>[0];
+  enqueueOutput: Parameters<
+    TaskHandler<Model, {}, ToolCallLLMMessageOptions>
+  >[1];
+};
+
+type StepToolsStreamingResponseParams<Model extends LLM> =
+  StepToolsResponseParams<Model> & {
+    response: AsyncIterable<ChatResponseChunk<ToolCallLLMMessageOptions>>;
+  };
+
+// #TODO stepTools and stepToolsStreaming should be moved to a better abstraction
+
+export async function stepToolsStreaming<Model extends LLM>({
+  response,
+  tools,
+  step,
+  enqueueOutput,
+}: StepToolsStreamingResponseParams<Model>) {
+  const responseChunkStream = new ReadableStream<
+    ChatResponseChunk<ToolCallLLMMessageOptions>
+  >({
+    async start(controller) {
+      for await (const chunk of response) {
+        controller.enqueue(chunk);
+      }
+      controller.close();
+    },
+  });
+  const [pipStream, finalStream] = responseChunkStream.tee();
+  const reader = pipStream.getReader();
+  const { value } = await reader.read();
+  reader.releaseLock();
+  if (value === undefined) {
+    throw new Error("first chunk value is undefined, this should not happen");
+  }
+  // check if first chunk has tool calls, if so, this is a function call
+  // otherwise, it's a regular message
+  const hasToolCall = !!(value.options && "toolCall" in value.options);
+  enqueueOutput({
+    taskStep: step,
+    output: finalStream,
+    isLast: !hasToolCall,
+  });
+
+  if (hasToolCall) {
+    // you need to consume the response to get the full toolCalls
+    const toolCalls = new Map<string, ToolCall | PartialToolCall>();
+    for await (const chunk of pipStream) {
+      if (chunk.options && "toolCall" in chunk.options) {
+        const toolCall = chunk.options.toolCall;
+        toolCall.forEach((toolCall) => {
+          toolCalls.set(toolCall.id, toolCall);
+        });
+      }
+    }
+    step.context.store.messages = [
+      ...step.context.store.messages,
+      {
+        role: "assistant" as const,
+        content: "",
+        options: {
+          toolCall: [...toolCalls.values()],
+        },
+      },
+    ];
+    for (const toolCall of toolCalls.values()) {
+      const targetTool = tools.find(
+        (tool) => tool.metadata.name === toolCall.name,
+      );
+      const toolOutput = await callTool(
+        targetTool,
+        toolCall,
+        step.context.logger,
+      );
+      step.context.store.messages = [
+        ...step.context.store.messages,
+        {
+          role: "user" as const,
+          content: stringifyJSONToMessageContent(toolOutput.output),
+          options: {
+            toolResult: {
+              result: toolOutput.output,
+              isError: toolOutput.isError,
+              id: toolCall.id,
+            },
+          },
+        },
+      ];
+      step.context.store.toolOutputs.push(toolOutput);
+    }
+  }
+}
+
+export async function stepTools<Model extends LLM>({
+  response,
+  tools,
+  step,
+  enqueueOutput,
+}: StepToolsResponseParams<Model>) {
+  step.context.store.messages = [
+    ...step.context.store.messages,
+    response.message,
+  ];
+  const options = response.message.options ?? {};
+  enqueueOutput({
+    taskStep: step,
+    output: response,
+    isLast: !("toolCall" in options),
+  });
+  if ("toolCall" in options) {
+    const { toolCall } = options;
+    for (const call of toolCall) {
+      const targetTool = tools.find((tool) => tool.metadata.name === call.name);
+      const toolOutput = await callTool(targetTool, call, step.context.logger);
+      step.context.store.toolOutputs.push(toolOutput);
+      step.context.store.messages = [
+        ...step.context.store.messages,
+        {
+          content: stringifyJSONToMessageContent(toolOutput.output),
+          role: "user",
+          options: {
+            toolResult: {
+              result: toolOutput.output,
+              isError: toolOutput.isError,
+              id: call.id,
+            },
+          },
+        },
+      ];
+    }
+  }
+}
 
 export async function callTool(
   tool: BaseTool | undefined,
diff --git a/packages/llamaindex/src/llm/index.ts b/packages/llamaindex/src/llm/index.ts
index 123fb4fb4..fe069cedb 100644
--- a/packages/llamaindex/src/llm/index.ts
+++ b/packages/llamaindex/src/llm/index.ts
@@ -7,7 +7,7 @@ export {
 export { ToolCallLLM } from "./base.js";
 export { FireworksLLM } from "./fireworks.js";
 export { Gemini, GeminiSession } from "./gemini/base.js";
-export { streamConverter, wrapLLMEvent } from "./utils.js";
+export { streamConverter, streamReducer, wrapLLMEvent } from "./utils.js";
 
 export {
   GEMINI_MODEL,
-- 
GitLab