From a87a4d1222e28719731d5fffa0112d39bd4bdf07 Mon Sep 17 00:00:00 2001 From: Parham Saidi <parham@parha.me> Date: Tue, 25 Jun 2024 19:51:40 +0200 Subject: [PATCH] feat: tool calling for Bedrock's Claude and General LLM Agent (#955) --- .changeset/brave-cherries-juggle.md | 6 + packages/community/src/llm/bedrock/base.ts | 179 ++++++++++++++++--- packages/community/src/llm/bedrock/types.ts | 54 +++++- packages/community/src/llm/bedrock/utils.ts | 56 +++++- packages/llamaindex/src/agent/anthropic.ts | 112 ++---------- packages/llamaindex/src/agent/base.ts | 28 +++ packages/llamaindex/src/agent/index.ts | 5 + packages/llamaindex/src/agent/llm.ts | 45 +++++ packages/llamaindex/src/agent/openai.ts | 184 ++------------------ packages/llamaindex/src/agent/utils.ts | 147 +++++++++++++++- packages/llamaindex/src/llm/index.ts | 2 +- 11 files changed, 515 insertions(+), 303 deletions(-) create mode 100644 .changeset/brave-cherries-juggle.md create mode 100644 packages/llamaindex/src/agent/llm.ts diff --git a/.changeset/brave-cherries-juggle.md b/.changeset/brave-cherries-juggle.md new file mode 100644 index 000000000..138a6162b --- /dev/null +++ b/.changeset/brave-cherries-juggle.md @@ -0,0 +1,6 @@ +--- +"llamaindex": patch +"@llamaindex/community": patch +--- + +feat: added tool support calling for Bedrock's Calude and general llm support for agents diff --git a/packages/community/src/llm/bedrock/base.ts b/packages/community/src/llm/bedrock/base.ts index 853f74d33..a3d5b3405 100644 --- a/packages/community/src/llm/bedrock/base.ts +++ b/packages/community/src/llm/bedrock/base.ts @@ -2,12 +2,13 @@ import { BedrockRuntimeClient, InvokeModelCommand, InvokeModelWithResponseStreamCommand, + ResponseStream, type BedrockRuntimeClientConfig, type InvokeModelCommandInput, type InvokeModelWithResponseStreamCommandInput, } from "@aws-sdk/client-bedrock-runtime"; - import type { + BaseTool, ChatMessage, ChatResponse, ChatResponseChunk, @@ -17,6 +18,8 @@ import type { LLMCompletionParamsNonStreaming, LLMCompletionParamsStreaming, LLMMetadata, + PartialToolCall, + ToolCall, ToolCallLLMMessageOptions, } from "llamaindex"; import { ToolCallLLM, streamConverter, wrapLLMEvent } from "llamaindex"; @@ -24,14 +27,17 @@ import type { AnthropicNoneStreamingResponse, AnthropicTextContent, StreamEvent, + ToolBlock, + ToolChoice, } from "./types.js"; import { + mapBaseToolsToAnthropicTools, mapChatMessagesToAnthropicMessages, mapMessageContentToMessageContentDetails, toUtf8, } from "./utils.js"; -export type BedrockAdditionalChatOptions = {}; +export type BedrockAdditionalChatOptions = { toolChoice: ToolChoice }; export type BedrockChatParamsStreaming = LLMChatParamsStreaming< BedrockAdditionalChatOptions, @@ -138,9 +144,39 @@ export const STREAMING_MODELS = new Set([ BEDROCK_MODELS.MISTRAL_MIXTRAL_LARGE_2402, ]); -abstract class Provider { +export const TOOL_CALL_MODELS = [ + BEDROCK_MODELS.ANTHROPIC_CLAUDE_3_SONNET, + BEDROCK_MODELS.ANTHROPIC_CLAUDE_3_HAIKU, + BEDROCK_MODELS.ANTHROPIC_CLAUDE_3_OPUS, + BEDROCK_MODELS.ANTHROPIC_CLAUDE_3_5_SONNET, +]; + +abstract class Provider<ProviderStreamEvent extends {} = {}> { abstract getTextFromResponse(response: Record<string, any>): string; + abstract getToolsFromResponse<T extends {} = {}>( + response: Record<string, any>, + ): T[]; + + getStreamingEventResponse( + response: Record<string, any>, + ): ProviderStreamEvent | undefined { + return response.chunk?.bytes + ? (JSON.parse(toUtf8(response.chunk?.bytes)) as ProviderStreamEvent) + : undefined; + } + + async *reduceStream( + stream: AsyncIterable<ResponseStream>, + ): BedrockChatStreamResponse { + yield* streamConverter(stream, (response) => { + return { + delta: this.getTextFromStreamResponse(response), + raw: response, + }; + }); + } + getTextFromStreamResponse(response: Record<string, any>): string { return this.getTextFromResponse(response); } @@ -148,16 +184,27 @@ abstract class Provider { abstract getRequestBody<T extends ChatMessage>( metadata: LLMMetadata, messages: T[], + tools?: BaseTool[], + options?: BedrockAdditionalChatOptions, ): InvokeModelCommandInput | InvokeModelWithResponseStreamCommandInput; } -class AnthropicProvider extends Provider { +class AnthropicProvider extends Provider<StreamEvent> { getResultFromResponse( response: Record<string, any>, ): AnthropicNoneStreamingResponse { return JSON.parse(toUtf8(response.body)); } + getToolsFromResponse<AnthropicToolContent>( + response: Record<string, any>, + ): AnthropicToolContent[] { + const result = this.getResultFromResponse(response); + return result.content + .filter((item) => item.type === "tool_use") + .map((item) => item as AnthropicToolContent); + } + getTextFromResponse(response: Record<string, any>): string { const result = this.getResultFromResponse(response); return result.content @@ -167,28 +214,101 @@ class AnthropicProvider extends Provider { } getTextFromStreamResponse(response: Record<string, any>): string { - const event: StreamEvent | undefined = response.chunk?.bytes - ? JSON.parse(toUtf8(response.chunk?.bytes)) - : undefined; - - if (event?.type === "content_block_delta") return event.delta.text; + const event = this.getStreamingEventResponse(response); + if (event?.type === "content_block_delta") { + if (event.delta.type === "text_delta") return event.delta.text; + if (event.delta.type === "input_json_delta") + return event.delta.partial_json; + } return ""; } - getRequestBody<T extends ChatMessage>( + async *reduceStream( + stream: AsyncIterable<ResponseStream>, + ): BedrockChatStreamResponse { + let collecting = []; + let tool: ToolBlock | undefined = undefined; + // #TODO this should be broken down into a separate consumer + for await (const response of stream) { + const event = this.getStreamingEventResponse(response); + if ( + event?.type === "content_block_start" && + event.content_block.type === "tool_use" + ) { + tool = event.content_block; + continue; + } + + if ( + event?.type === "content_block_delta" && + event.delta.type === "input_json_delta" + ) { + collecting.push(event.delta.partial_json); + } + + let options: undefined | ToolCallLLMMessageOptions = undefined; + if (tool && collecting.length) { + const input = collecting.filter((item) => item).join(""); + // We have all we need to parse the tool_use json + if (event?.type === "content_block_stop") { + options = { + toolCall: [ + { + id: tool.id, + name: tool.name, + input: JSON.parse(input), + } as ToolCall, + ], + }; + // reset the collection/tool + collecting = []; + tool = undefined; + } else { + options = { + toolCall: [ + { + id: tool.id, + name: tool.name, + input, + } as PartialToolCall, + ], + }; + } + } + const delta = this.getTextFromStreamResponse(response); + if (!delta && !options) continue; + + yield { + delta, + options, + raw: response, + }; + } + } + + getRequestBody<T extends ChatMessage<ToolCallLLMMessageOptions>>( metadata: LLMMetadata, messages: T[], + tools?: BaseTool[], + options?: BedrockAdditionalChatOptions, ): InvokeModelCommandInput | InvokeModelWithResponseStreamCommandInput { + const extra: Record<string, unknown> = {}; + if (options?.toolChoice) { + extra["tool_choice"] = options?.toolChoice; + } + const mapped = mapChatMessagesToAnthropicMessages(messages); return { modelId: metadata.model, contentType: "application/json", accept: "application/json", body: JSON.stringify({ anthropic_version: "bedrock-2023-05-31", - messages: mapChatMessagesToAnthropicMessages(messages), + messages: mapped, + tools: mapBaseToolsToAnthropicTools(tools), max_tokens: metadata.maxTokens, temperature: metadata.temperature, top_p: metadata.topP, + ...extra, }), }; } @@ -256,7 +376,7 @@ export class Bedrock extends ToolCallLLM<BedrockAdditionalChatOptions> { } get supportToolCall(): boolean { - return false; + return TOOL_CALL_MODELS.includes(this.model); } get metadata(): LLMMetadata { @@ -274,14 +394,24 @@ export class Bedrock extends ToolCallLLM<BedrockAdditionalChatOptions> { protected async nonStreamChat( params: BedrockChatParamsNonStreaming, ): Promise<BedrockChatNonStreamResponse> { - const input = this.provider.getRequestBody(this.metadata, params.messages); + const input = this.provider.getRequestBody( + this.metadata, + params.messages, + params.tools, + params.additionalChatOptions, + ); const command = new InvokeModelCommand(input); const response = await this.client.send(command); + const tools = this.provider.getToolsFromResponse(response); + const options: ToolCallLLMMessageOptions = tools.length + ? { toolCall: tools } + : {}; return { raw: response, message: { - content: this.provider.getTextFromResponse(response), role: "assistant", + content: this.provider.getTextFromResponse(response), + options, }, }; } @@ -291,29 +421,30 @@ export class Bedrock extends ToolCallLLM<BedrockAdditionalChatOptions> { ): BedrockChatStreamResponse { if (!STREAMING_MODELS.has(this.model)) throw new Error(`The model: ${this.model} does not support streaming`); - const input = this.provider.getRequestBody(this.metadata, params.messages); + + const input = this.provider.getRequestBody( + this.metadata, + params.messages, + params.tools, + params.additionalChatOptions, + ); const command = new InvokeModelWithResponseStreamCommand(input); const response = await this.client.send(command); - if (response.body) - yield* streamConverter(response.body, (response) => { - return { - delta: this.provider.getTextFromStreamResponse(response), - raw: response, - }; - }); + if (response.body) yield* this.provider.reduceStream(response.body); } chat(params: BedrockChatParamsStreaming): Promise<BedrockChatStreamResponse>; chat( params: BedrockChatParamsNonStreaming, ): Promise<BedrockChatNonStreamResponse>; - @wrapLLMEvent async chat( params: BedrockChatParamsStreaming | BedrockChatParamsNonStreaming, ): Promise<BedrockChatStreamResponse | BedrockChatNonStreamResponse> { - if (params.stream) return this.streamChat(params); + if (params.stream) { + return this.streamChat(params); + } return this.nonStreamChat(params); } diff --git a/packages/community/src/llm/bedrock/types.ts b/packages/community/src/llm/bedrock/types.ts index 14124dcda..8a02d5db4 100644 --- a/packages/community/src/llm/bedrock/types.ts +++ b/packages/community/src/llm/bedrock/types.ts @@ -14,19 +14,33 @@ type Message = { usage: Usage; }; +export type ToolBlock = { + id: string; + input: unknown; + name: string; + type: "tool_use"; +}; + +export type TextBlock = { + type: "text"; + text: string; +}; + type ContentBlockStart = { type: "content_block_start"; index: number; - content_block: { - type: string; - text: string; - }; + content_block: ToolBlock | TextBlock; }; -type Delta = { - type: string; - text: string; -}; +type Delta = + | { + type: "text_delta"; + text: string; + } + | { + type: "input_json_delta"; + partial_json: string; + }; type ContentBlockDelta = { type: "content_block_delta"; @@ -60,6 +74,11 @@ type MessageStop = { "amazon-bedrock-invocationMetrics": InvocationMetrics; }; +export type ToolChoice = + | { type: "any" } + | { type: "auto" } + | { type: "tool"; name: string }; + export type StreamEvent = | { type: "message_start"; message: Message } | ContentBlockStart @@ -68,13 +87,30 @@ export type StreamEvent = | MessageDelta | MessageStop; -export type AnthropicContent = AnthropicTextContent | AnthropicImageContent; +export type AnthropicContent = + | AnthropicTextContent + | AnthropicImageContent + | AnthropicToolContent + | AnthropicToolResultContent; export type AnthropicTextContent = { type: "text"; text: string; }; +export type AnthropicToolContent = { + type: "tool_use"; + id: string; + name: string; + input: Record<string, unknown>; +}; + +export type AnthropicToolResultContent = { + type: "tool_result"; + tool_use_id: string; + content: string; +}; + export type AnthropicMediaTypes = | "image/jpeg" | "image/png" diff --git a/packages/community/src/llm/bedrock/utils.ts b/packages/community/src/llm/bedrock/utils.ts index f6d9beea9..64bbdda1e 100644 --- a/packages/community/src/llm/bedrock/utils.ts +++ b/packages/community/src/llm/bedrock/utils.ts @@ -1,7 +1,11 @@ import type { + BaseTool, ChatMessage, + JSONObject, MessageContent, MessageContentDetail, + ToolCallLLMMessageOptions, + ToolMetadata, } from "llamaindex"; import type { AnthropicContent, @@ -68,11 +72,61 @@ export const mapMessageContentToAnthropicContent = <T extends MessageContent>( ); }; -export const mapChatMessagesToAnthropicMessages = <T extends ChatMessage>( +type AnthropicTool = { + name: string; + description: string; + input_schema: ToolMetadata["parameters"]; +}; + +export const mapBaseToolsToAnthropicTools = ( + tools?: BaseTool[], +): AnthropicTool[] => { + if (!tools) return []; + return tools.map((tool: BaseTool) => { + const { + metadata: { parameters, ...options }, + } = tool; + return { + ...options, + input_schema: parameters, + }; + }); +}; + +export const mapChatMessagesToAnthropicMessages = < + T extends ChatMessage<ToolCallLLMMessageOptions>, +>( messages: T[], ): AnthropicMessage[] => { const mapped = messages .flatMap((msg: T): AnthropicMessage[] => { + if (msg.options && "toolCall" in msg.options) { + return [ + { + role: "assistant", + content: msg.options.toolCall.map((call) => ({ + type: "tool_use", + id: call.id, + name: call.name, + input: call.input as JSONObject, + })), + }, + ]; + } + if (msg.options && "toolResult" in msg.options) { + return [ + { + role: "user", + content: [ + { + type: "tool_result", + tool_use_id: msg.options.toolResult.id, + content: msg.options.toolResult.result, + }, + ], + }, + ]; + } return mapMessageContentToMessageContentDetails(msg.content).map( (detail: MessageContentDetail): AnthropicMessage => { const content = mapMessageContentDetailToAnthropicContent(detail); diff --git a/packages/llamaindex/src/agent/anthropic.ts b/packages/llamaindex/src/agent/anthropic.ts index 1f917cf3e..8f17b360d 100644 --- a/packages/llamaindex/src/agent/anthropic.ts +++ b/packages/llamaindex/src/agent/anthropic.ts @@ -1,116 +1,38 @@ -import { EngineResponse } from "../EngineResponse.js"; import { Settings } from "../Settings.js"; -import { - type ChatEngineParamsNonStreaming, - type ChatEngineParamsStreaming, -} from "../engines/chat/index.js"; -import { stringifyJSONToMessageContent } from "../internal/utils.js"; +import type { + ChatEngineParamsNonStreaming, + ChatEngineParamsStreaming, + EngineResponse, +} from "../index.edge.js"; import { Anthropic } from "../llm/anthropic.js"; -import { ObjectRetriever } from "../objects/index.js"; -import type { BaseToolWithCall } from "../types.js"; -import { AgentRunner, AgentWorker, type AgentParamsBase } from "./base.js"; -import type { TaskHandler } from "./types.js"; -import { callTool } from "./utils.js"; +import { LLMAgent, LLMAgentWorker, type LLMAgentParams } from "./llm.js"; -type AnthropicParamsBase = AgentParamsBase<Anthropic>; +export type AnthropicAgentParams = LLMAgentParams; -type AnthropicParamsWithTools = AnthropicParamsBase & { - tools: BaseToolWithCall[]; -}; +export class AnthropicAgentWorker extends LLMAgentWorker {} -type AnthropicParamsWithToolRetriever = AnthropicParamsBase & { - toolRetriever: ObjectRetriever<BaseToolWithCall>; -}; - -export type AnthropicAgentParams = - | AnthropicParamsWithTools - | AnthropicParamsWithToolRetriever; - -export class AnthropicAgentWorker extends AgentWorker<Anthropic> { - taskHandler = AnthropicAgent.taskHandler; -} - -export class AnthropicAgent extends AgentRunner<Anthropic> { +export class AnthropicAgent extends LLMAgent { constructor(params: AnthropicAgentParams) { + const llm = + params.llm ?? + (Settings.llm instanceof Anthropic + ? (Settings.llm as Anthropic) + : new Anthropic()); super({ - llm: - params.llm ?? - (Settings.llm instanceof Anthropic - ? (Settings.llm as Anthropic) - : new Anthropic()), - chatHistory: params.chatHistory ?? [], - systemPrompt: params.systemPrompt ?? null, - runner: new AnthropicAgentWorker(), - tools: - "tools" in params - ? params.tools - : params.toolRetriever.retrieve.bind(params.toolRetriever), - verbose: params.verbose ?? false, + ...params, + llm, }); } - createStore = AgentRunner.defaultCreateStore; - async chat(params: ChatEngineParamsNonStreaming): Promise<EngineResponse>; async chat(params: ChatEngineParamsStreaming): Promise<never>; override async chat( params: ChatEngineParamsNonStreaming | ChatEngineParamsStreaming, ) { if (params.stream) { + // Anthropic does support this, but looks like it's not supported in the LITS LLM throw new Error("Anthropic does not support streaming"); } return super.chat(params); } - - static taskHandler: TaskHandler<Anthropic> = async (step, enqueueOutput) => { - const { llm, getTools, stream } = step.context; - const lastMessage = step.context.store.messages.at(-1)!.content; - const tools = await getTools(lastMessage); - if (stream === true) { - throw new Error("Anthropic does not support streaming"); - } - const response = await llm.chat({ - stream, - tools, - messages: step.context.store.messages, - }); - step.context.store.messages = [ - ...step.context.store.messages, - response.message, - ]; - const options = response.message.options ?? {}; - enqueueOutput({ - taskStep: step, - output: response, - isLast: !("toolCall" in options), - }); - if ("toolCall" in options) { - const { toolCall } = options; - for (const call of toolCall) { - const targetTool = tools.find( - (tool) => tool.metadata.name === call.name, - ); - const toolOutput = await callTool( - targetTool, - call, - step.context.logger, - ); - step.context.store.toolOutputs.push(toolOutput); - step.context.store.messages = [ - ...step.context.store.messages, - { - content: stringifyJSONToMessageContent(toolOutput.output), - role: "user", - options: { - toolResult: { - result: toolOutput.output, - isError: toolOutput.isError, - id: call.id, - }, - }, - }, - ]; - } - } - }; } diff --git a/packages/llamaindex/src/agent/base.ts b/packages/llamaindex/src/agent/base.ts index 5775c8d33..1965dac7d 100644 --- a/packages/llamaindex/src/agent/base.ts +++ b/packages/llamaindex/src/agent/base.ts @@ -19,6 +19,7 @@ import type { TaskStep, TaskStepOutput, } from "./types.js"; +import { stepTools, stepToolsStreaming } from "./utils.js"; export const MAX_TOOL_CALLS = 10; @@ -214,6 +215,33 @@ export abstract class AgentRunner< return Object.create(null); } + static defaultTaskHandler: TaskHandler<LLM> = async (step, enqueueOutput) => { + const { llm, getTools, stream } = step.context; + const lastMessage = step.context.store.messages.at(-1)!.content; + const tools = await getTools(lastMessage); + const response = await llm.chat({ + // @ts-expect-error + stream, + tools, + messages: [...step.context.store.messages], + }); + if (!stream) { + await stepTools<LLM>({ + response, + tools, + step, + enqueueOutput, + }); + } else { + await stepToolsStreaming<LLM>({ + response, + tools, + step, + enqueueOutput, + }); + } + }; + protected constructor( params: AgentRunnerParams<AI, Store, AdditionalMessageOptions>, ) { diff --git a/packages/llamaindex/src/agent/index.ts b/packages/llamaindex/src/agent/index.ts index 18d6fbe95..feda11bd4 100644 --- a/packages/llamaindex/src/agent/index.ts +++ b/packages/llamaindex/src/agent/index.ts @@ -3,6 +3,8 @@ export { AnthropicAgentWorker, type AnthropicAgentParams, } from "./anthropic.js"; +export { AgentRunner, AgentWorker, type AgentParamsBase } from "./base.js"; +export { LLMAgent, LLMAgentWorker, type LLMAgentParams } from "./llm.js"; export { OpenAIAgent, OpenAIAgentWorker, @@ -13,6 +15,9 @@ export { ReActAgent, type ReACTAgentParams, } from "./react.js"; +export { type TaskHandler } from "./types.js"; +export { callTool, stepTools, stepToolsStreaming } from "./utils.js"; + // todo: ParallelAgent // todo: CustomAgent // todo: ReactMultiModal diff --git a/packages/llamaindex/src/agent/llm.ts b/packages/llamaindex/src/agent/llm.ts new file mode 100644 index 000000000..78b853649 --- /dev/null +++ b/packages/llamaindex/src/agent/llm.ts @@ -0,0 +1,45 @@ +import type { LLM } from "../llm/index.js"; +import { ObjectRetriever } from "../objects/index.js"; +import { Settings } from "../Settings.js"; +import type { BaseToolWithCall } from "../types.js"; +import { AgentRunner, AgentWorker, type AgentParamsBase } from "./base.js"; + +type LLMParamsBase = AgentParamsBase<LLM>; + +type LLMParamsWithTools = LLMParamsBase & { + tools: BaseToolWithCall[]; +}; + +type LLMParamsWithToolRetriever = LLMParamsBase & { + toolRetriever: ObjectRetriever<BaseToolWithCall>; +}; + +export type LLMAgentParams = LLMParamsWithTools | LLMParamsWithToolRetriever; + +export class LLMAgentWorker extends AgentWorker<LLM> { + taskHandler = AgentRunner.defaultTaskHandler; +} + +export class LLMAgent extends AgentRunner<LLM> { + constructor(params: LLMAgentParams) { + const llm = params.llm ?? (Settings.llm ? (Settings.llm as LLM) : null); + if (!llm) + throw new Error( + "llm must be provided for either in params or Settings.llm", + ); + super({ + llm, + chatHistory: params.chatHistory ?? [], + systemPrompt: params.systemPrompt ?? null, + runner: new LLMAgentWorker(), + tools: + "tools" in params + ? params.tools + : params.toolRetriever.retrieve.bind(params.toolRetriever), + verbose: params.verbose ?? false, + }); + } + + createStore = AgentRunner.defaultCreateStore; + taskHandler = AgentRunner.defaultTaskHandler; +} diff --git a/packages/llamaindex/src/agent/openai.ts b/packages/llamaindex/src/agent/openai.ts index e0ca14930..a85fb4c5a 100644 --- a/packages/llamaindex/src/agent/openai.ts +++ b/packages/llamaindex/src/agent/openai.ts @@ -1,183 +1,23 @@ -import { ReadableStream } from "@llamaindex/env"; import { Settings } from "../Settings.js"; -import { stringifyJSONToMessageContent } from "../internal/utils.js"; -import type { - ChatResponseChunk, - PartialToolCall, - ToolCall, - ToolCallLLMMessageOptions, -} from "../llm/index.js"; import { OpenAI } from "../llm/openai.js"; -import { ObjectRetriever } from "../objects/index.js"; -import type { BaseToolWithCall } from "../types.js"; -import { AgentRunner, AgentWorker, type AgentParamsBase } from "./base.js"; -import type { TaskHandler } from "./types.js"; -import { callTool } from "./utils.js"; +import { LLMAgent, LLMAgentWorker, type LLMAgentParams } from "./llm.js"; -type OpenAIParamsBase = AgentParamsBase<OpenAI>; +// This is likely not necessary anymore but leaving it here just incase it's in use elsewhere -type OpenAIParamsWithTools = OpenAIParamsBase & { - tools: BaseToolWithCall[]; -}; +export type OpenAIAgentParams = LLMAgentParams; -type OpenAIParamsWithToolRetriever = OpenAIParamsBase & { - toolRetriever: ObjectRetriever<BaseToolWithCall>; -}; +export class OpenAIAgentWorker extends LLMAgentWorker {} -export type OpenAIAgentParams = - | OpenAIParamsWithTools - | OpenAIParamsWithToolRetriever; - -export class OpenAIAgentWorker extends AgentWorker<OpenAI> { - taskHandler = OpenAIAgent.taskHandler; -} - -export class OpenAIAgent extends AgentRunner<OpenAI> { +export class OpenAIAgent extends LLMAgent { constructor(params: OpenAIAgentParams) { + const llm = + params.llm ?? + (Settings.llm instanceof OpenAI + ? (Settings.llm as OpenAI) + : new OpenAI()); super({ - llm: - params.llm ?? - (Settings.llm instanceof OpenAI - ? (Settings.llm as OpenAI) - : new OpenAI()), - chatHistory: params.chatHistory ?? [], - runner: new OpenAIAgentWorker(), - systemPrompt: params.systemPrompt ?? null, - tools: - "tools" in params - ? params.tools - : params.toolRetriever.retrieve.bind(params.toolRetriever), - verbose: params.verbose ?? false, + ...params, + llm, }); } - - createStore = AgentRunner.defaultCreateStore; - - static taskHandler: TaskHandler<OpenAI> = async (step, enqueueOutput) => { - const { llm, stream, getTools } = step.context; - const lastMessage = step.context.store.messages.at(-1)!.content; - const tools = await getTools(lastMessage); - const response = await llm.chat({ - // @ts-expect-error - stream, - tools, - messages: [...step.context.store.messages], - }); - if (!stream) { - step.context.store.messages = [ - ...step.context.store.messages, - response.message, - ]; - const options = response.message.options ?? {}; - enqueueOutput({ - taskStep: step, - output: response, - isLast: !("toolCall" in options), - }); - if ("toolCall" in options) { - const { toolCall } = options; - for (const call of toolCall) { - const targetTool = tools.find( - (tool) => tool.metadata.name === call.name, - ); - const toolOutput = await callTool( - targetTool, - call, - step.context.logger, - ); - step.context.store.toolOutputs.push(toolOutput); - step.context.store.messages = [ - ...step.context.store.messages, - { - role: "user" as const, - content: stringifyJSONToMessageContent(toolOutput.output), - options: { - toolResult: { - result: toolOutput.output, - isError: toolOutput.isError, - id: call.id, - }, - }, - }, - ]; - } - } - } else { - const responseChunkStream = new ReadableStream< - ChatResponseChunk<ToolCallLLMMessageOptions> - >({ - async start(controller) { - for await (const chunk of response) { - controller.enqueue(chunk); - } - controller.close(); - }, - }); - const [pipStream, finalStream] = responseChunkStream.tee(); - const reader = pipStream.getReader(); - const { value } = await reader.read(); - reader.releaseLock(); - if (value === undefined) { - throw new Error( - "first chunk value is undefined, this should not happen", - ); - } - // check if first chunk has tool calls, if so, this is a function call - // otherwise, it's a regular message - const hasToolCall = !!(value.options && "toolCall" in value.options); - enqueueOutput({ - taskStep: step, - output: finalStream, - isLast: !hasToolCall, - }); - - if (hasToolCall) { - // you need to consume the response to get the full toolCalls - const toolCalls = new Map<string, ToolCall | PartialToolCall>(); - for await (const chunk of pipStream) { - if (chunk.options && "toolCall" in chunk.options) { - const toolCall = chunk.options.toolCall; - toolCall.forEach((toolCall) => { - toolCalls.set(toolCall.id, toolCall); - }); - } - } - step.context.store.messages = [ - ...step.context.store.messages, - { - role: "assistant" as const, - content: "", - options: { - toolCall: [...toolCalls.values()], - }, - }, - ]; - for (const toolCall of toolCalls.values()) { - const targetTool = tools.find( - (tool) => tool.metadata.name === toolCall.name, - ); - const toolOutput = await callTool( - targetTool, - toolCall, - step.context.logger, - ); - step.context.store.messages = [ - ...step.context.store.messages, - { - role: "user" as const, - content: stringifyJSONToMessageContent(toolOutput.output), - options: { - toolResult: { - result: toolOutput.output, - isError: toolOutput.isError, - id: toolCall.id, - }, - }, - }, - ]; - step.context.store.toolOutputs.push(toolOutput); - } - } - } - }; } diff --git a/packages/llamaindex/src/agent/utils.ts b/packages/llamaindex/src/agent/utils.ts index e5d8614fe..8a3df6402 100644 --- a/packages/llamaindex/src/agent/utils.ts +++ b/packages/llamaindex/src/agent/utils.ts @@ -1,15 +1,160 @@ import { ReadableStream } from "@llamaindex/env"; import type { Logger } from "../internal/logger.js"; import { getCallbackManager } from "../internal/settings/CallbackManager.js"; -import { isAsyncIterable, prettifyError } from "../internal/utils.js"; +import { + isAsyncIterable, + prettifyError, + stringifyJSONToMessageContent, +} from "../internal/utils.js"; import type { ChatMessage, + ChatResponse, ChatResponseChunk, + LLM, PartialToolCall, TextChatMessage, ToolCall, + ToolCallLLMMessageOptions, } from "../llm/index.js"; import type { BaseTool, JSONObject, JSONValue, ToolOutput } from "../types.js"; +import type { TaskHandler } from "./types.js"; + +type StepToolsResponseParams<Model extends LLM> = { + response: ChatResponse<ToolCallLLMMessageOptions>; + tools: BaseTool[]; + step: Parameters<TaskHandler<Model, {}, ToolCallLLMMessageOptions>>[0]; + enqueueOutput: Parameters< + TaskHandler<Model, {}, ToolCallLLMMessageOptions> + >[1]; +}; + +type StepToolsStreamingResponseParams<Model extends LLM> = + StepToolsResponseParams<Model> & { + response: AsyncIterable<ChatResponseChunk<ToolCallLLMMessageOptions>>; + }; + +// #TODO stepTools and stepToolsStreaming should be moved to a better abstraction + +export async function stepToolsStreaming<Model extends LLM>({ + response, + tools, + step, + enqueueOutput, +}: StepToolsStreamingResponseParams<Model>) { + const responseChunkStream = new ReadableStream< + ChatResponseChunk<ToolCallLLMMessageOptions> + >({ + async start(controller) { + for await (const chunk of response) { + controller.enqueue(chunk); + } + controller.close(); + }, + }); + const [pipStream, finalStream] = responseChunkStream.tee(); + const reader = pipStream.getReader(); + const { value } = await reader.read(); + reader.releaseLock(); + if (value === undefined) { + throw new Error("first chunk value is undefined, this should not happen"); + } + // check if first chunk has tool calls, if so, this is a function call + // otherwise, it's a regular message + const hasToolCall = !!(value.options && "toolCall" in value.options); + enqueueOutput({ + taskStep: step, + output: finalStream, + isLast: !hasToolCall, + }); + + if (hasToolCall) { + // you need to consume the response to get the full toolCalls + const toolCalls = new Map<string, ToolCall | PartialToolCall>(); + for await (const chunk of pipStream) { + if (chunk.options && "toolCall" in chunk.options) { + const toolCall = chunk.options.toolCall; + toolCall.forEach((toolCall) => { + toolCalls.set(toolCall.id, toolCall); + }); + } + } + step.context.store.messages = [ + ...step.context.store.messages, + { + role: "assistant" as const, + content: "", + options: { + toolCall: [...toolCalls.values()], + }, + }, + ]; + for (const toolCall of toolCalls.values()) { + const targetTool = tools.find( + (tool) => tool.metadata.name === toolCall.name, + ); + const toolOutput = await callTool( + targetTool, + toolCall, + step.context.logger, + ); + step.context.store.messages = [ + ...step.context.store.messages, + { + role: "user" as const, + content: stringifyJSONToMessageContent(toolOutput.output), + options: { + toolResult: { + result: toolOutput.output, + isError: toolOutput.isError, + id: toolCall.id, + }, + }, + }, + ]; + step.context.store.toolOutputs.push(toolOutput); + } + } +} + +export async function stepTools<Model extends LLM>({ + response, + tools, + step, + enqueueOutput, +}: StepToolsResponseParams<Model>) { + step.context.store.messages = [ + ...step.context.store.messages, + response.message, + ]; + const options = response.message.options ?? {}; + enqueueOutput({ + taskStep: step, + output: response, + isLast: !("toolCall" in options), + }); + if ("toolCall" in options) { + const { toolCall } = options; + for (const call of toolCall) { + const targetTool = tools.find((tool) => tool.metadata.name === call.name); + const toolOutput = await callTool(targetTool, call, step.context.logger); + step.context.store.toolOutputs.push(toolOutput); + step.context.store.messages = [ + ...step.context.store.messages, + { + content: stringifyJSONToMessageContent(toolOutput.output), + role: "user", + options: { + toolResult: { + result: toolOutput.output, + isError: toolOutput.isError, + id: call.id, + }, + }, + }, + ]; + } + } +} export async function callTool( tool: BaseTool | undefined, diff --git a/packages/llamaindex/src/llm/index.ts b/packages/llamaindex/src/llm/index.ts index 123fb4fb4..fe069cedb 100644 --- a/packages/llamaindex/src/llm/index.ts +++ b/packages/llamaindex/src/llm/index.ts @@ -7,7 +7,7 @@ export { export { ToolCallLLM } from "./base.js"; export { FireworksLLM } from "./fireworks.js"; export { Gemini, GeminiSession } from "./gemini/base.js"; -export { streamConverter, wrapLLMEvent } from "./utils.js"; +export { streamConverter, streamReducer, wrapLLMEvent } from "./utils.js"; export { GEMINI_MODEL, -- GitLab