diff --git a/packages/core/e2e/fixtures/llm/open_ai.ts b/packages/core/e2e/fixtures/llm/open_ai.ts index 76d31480737daaf3211b6ff75d51889e79d6081b..045cc085570845e3049939f12401a875740cf6ab 100644 --- a/packages/core/e2e/fixtures/llm/open_ai.ts +++ b/packages/core/e2e/fixtures/llm/open_ai.ts @@ -10,7 +10,6 @@ import type { } from "llamaindex/llm/types"; import { extractText } from "llamaindex/llm/utils"; import { deepStrictEqual, strictEqual } from "node:assert"; -import { inspect } from "node:util"; import { llmCompleteMockStorage } from "../../node/utils.js"; export function getOpenAISession() { @@ -47,8 +46,6 @@ export class OpenAI implements LLM { if (llmCompleteMockStorage.llmEventStart.length > 0) { const chatMessage = llmCompleteMockStorage.llmEventStart.shift()!["messages"]; - console.log(inspect(params.messages, { depth: 1 })); - console.log(inspect(chatMessage, { depth: 1 })); strictEqual(params.messages.length, chatMessage.length); for (let i = 0; i < chatMessage.length; i++) { strictEqual(params.messages[i].role, chatMessage[i].role); diff --git a/packages/core/e2e/node/openai.e2e.ts b/packages/core/e2e/node/openai.e2e.ts index 4f853e5383bd07de39540d3d7f926f9dc97c73dd..5bd086915197cffd7b386abd8bd6d55dff3ea595 100644 --- a/packages/core/e2e/node/openai.e2e.ts +++ b/packages/core/e2e/node/openai.e2e.ts @@ -196,6 +196,40 @@ For questions about more specific sections, please use the vector_tool.`, ok(extractText(response.message.content).toLowerCase().includes("no")); }); +await test("agent with object function call", async (t) => { + await mockLLMEvent(t, "agent_with_object_function_call"); + await t.test("basic", async () => { + const agent = new OpenAIAgent({ + tools: [ + FunctionTool.from( + ({ location }: { location: string }) => ({ + location, + temperature: 72, + weather: "cloudy", + rain_prediction: 0.89, + }), + { + name: "get_weather", + description: "Get the weather", + parameters: { + type: "object", + properties: { + location: { type: "string" }, + }, + required: ["location"], + }, + }, + ), + ], + }); + const { response } = await agent.chat({ + message: "What is the weather in San Francisco?", + }); + consola.debug("response:", response.message.content); + ok(extractText(response.message.content).includes("72")); + }); +}); + await test("agent", async (t) => { await mockLLMEvent(t, "agent"); await t.test("chat", async () => { diff --git a/packages/core/e2e/node/snapshot/agent_with_object_function_call.snap b/packages/core/e2e/node/snapshot/agent_with_object_function_call.snap new file mode 100644 index 0000000000000000000000000000000000000000..1e3f61974e0c08ab3ff5223c2c5888b09a5d38b5 --- /dev/null +++ b/packages/core/e2e/node/snapshot/agent_with_object_function_call.snap @@ -0,0 +1,80 @@ +{ + "llmEventStart": [ + { + "id": "PRESERVE_0", + "messages": [ + { + "role": "user", + "content": "What is the weather in San Francisco?" + } + ] + }, + { + "id": "PRESERVE_1", + "messages": [ + { + "role": "user", + "content": "What is the weather in San Francisco?" + }, + { + "content": "", + "role": "assistant", + "options": { + "toolCall": { + "id": "call_lR2r0rpfqNX11jukJvEUdByv", + "name": "get_weather", + "input": "{\"location\":\"San Francisco\"}" + } + } + }, + { + "content": "{\n location: San Francisco,\n temperature: 72,\n weather: cloudy,\n rain_prediction: 0.89\n}", + "role": "user", + "options": { + "toolResult": { + "result": { + "location": "San Francisco", + "temperature": 72, + "weather": "cloudy", + "rain_prediction": 0.89 + }, + "isError": false, + "id": "call_lR2r0rpfqNX11jukJvEUdByv" + } + } + } + ] + } + ], + "llmEventEnd": [ + { + "id": "PRESERVE_0", + "response": { + "raw": null, + "message": { + "content": "", + "role": "assistant", + "options": { + "toolCall": { + "id": "call_lR2r0rpfqNX11jukJvEUdByv", + "name": "get_weather", + "input": "{\"location\":\"San Francisco\"}" + } + } + } + } + }, + { + "id": "PRESERVE_1", + "response": { + "raw": null, + "message": { + "content": "The weather in San Francisco is currently cloudy with a temperature of 72°F. There is a 89% chance of rain.", + "role": "assistant", + "options": {} + } + } + } + ], + "llmEventStream": [] +} \ No newline at end of file diff --git a/packages/core/e2e/node/snapshot/react-agent.snap b/packages/core/e2e/node/snapshot/react-agent.snap index 48884cc45adf87768e8c373699e4ba9056c606f9..290e2f8778c6cb90b72aedf190abbaf4fa70d512 100644 --- a/packages/core/e2e/node/snapshot/react-agent.snap +++ b/packages/core/e2e/node/snapshot/react-agent.snap @@ -26,7 +26,7 @@ }, { "role": "assistant", - "content": "Thought: I need to use a tool to help me answer the question.\nAction: getWeather\nInput: {\"city\":\"San Francisco\"}" + "content": "Thought: I need to use a tool to help me answer the question.\nAction: getWeather\nInput: {\n city: San Francisco\n}" }, { "role": "user", @@ -41,7 +41,7 @@ "response": { "raw": null, "message": { - "content": "Thought: I need to use a tool to help me answer the question. \nAction: getWeather\nAction Input: {\"city\": \"San Francisco\"}", + "content": "Thought: I need to use a tool to help me answer the question.\nAction: getWeather\nAction Input: {\"city\": \"San Francisco\"}", "role": "assistant", "options": {} } diff --git a/packages/core/src/agent/anthropic.ts b/packages/core/src/agent/anthropic.ts index 2d05759e4efd03debc90ee35c0334eaba1a414a4..299344767b9c2811d537d5d7f71685ec9a786ac4 100644 --- a/packages/core/src/agent/anthropic.ts +++ b/packages/core/src/agent/anthropic.ts @@ -3,6 +3,7 @@ import { type ChatEngineParamsNonStreaming, type ChatEngineParamsStreaming, } from "../engines/chat/index.js"; +import { stringifyJSONToMessageContent } from "../internal/utils.js"; import { Anthropic } from "../llm/anthropic.js"; import type { ToolCallLLMMessageOptions } from "../llm/index.js"; import { ObjectRetriever } from "../objects/index.js"; @@ -99,7 +100,7 @@ export class AnthropicAgent extends AgentRunner<Anthropic> { output: { raw: response.raw, message: { - content: toolOutput.output, + content: stringifyJSONToMessageContent(toolOutput.output), role: "user", options: { toolResult: { diff --git a/packages/core/src/agent/openai.ts b/packages/core/src/agent/openai.ts index c0d59b076f56072ee83f9bf17dd072b5ec8741ad..47c25325b584126a110e870bd3497a016959beb7 100644 --- a/packages/core/src/agent/openai.ts +++ b/packages/core/src/agent/openai.ts @@ -1,5 +1,6 @@ import { pipeline } from "@llamaindex/env"; import { Settings } from "../Settings.js"; +import { stringifyJSONToMessageContent } from "../internal/utils.js"; import type { ChatResponseChunk, ToolCall, @@ -85,7 +86,7 @@ export class OpenAIAgent extends AgentRunner<OpenAI> { output: { raw: response.raw, message: { - content: toolOutput.output, + content: stringifyJSONToMessageContent(toolOutput.output), role: "user", options: { toolResult: { @@ -165,7 +166,7 @@ export class OpenAIAgent extends AgentRunner<OpenAI> { ...step.context.store.messages, { role: "user" as const, - content: toolOutput.output, + content: stringifyJSONToMessageContent(toolOutput.output), options: { toolResult: { result: toolOutput.output, diff --git a/packages/core/src/agent/react.ts b/packages/core/src/agent/react.ts index 1201c93e2321fc844e1e62ed2056d728c26e5470..8c9ec759cd66bedf343a8591b345b6044eac30a6 100644 --- a/packages/core/src/agent/react.ts +++ b/packages/core/src/agent/react.ts @@ -1,7 +1,10 @@ import { pipeline, randomUUID } from "@llamaindex/env"; import { Settings } from "../Settings.js"; import { getReACTAgentSystemHeader } from "../internal/prompt/react.js"; -import { isAsyncIterable } from "../internal/utils.js"; +import { + isAsyncIterable, + stringifyJSONToMessageContent, +} from "../internal/utils.js"; import { type ChatMessage, type ChatResponse, @@ -10,7 +13,12 @@ import { } from "../llm/index.js"; import { extractText } from "../llm/utils.js"; import { ObjectRetriever } from "../objects/index.js"; -import type { BaseTool, BaseToolWithCall } from "../types.js"; +import type { + BaseTool, + BaseToolWithCall, + JSONObject, + JSONValue, +} from "../types.js"; import { AgentRunner, AgentWorker, @@ -43,14 +51,14 @@ type BaseReason = { type ObservationReason = BaseReason & { type: "observation"; - observation: string; + observation: JSONValue; }; type ActionReason = BaseReason & { type: "action"; thought: string; action: string; - input: Record<string, unknown>; + input: JSONObject; }; type ResponseReason = BaseReason & { @@ -64,9 +72,9 @@ type Reason = ObservationReason | ActionReason | ResponseReason; function reasonFormatter(reason: Reason): string | Promise<string> { switch (reason.type) { case "observation": - return `Observation: ${reason.observation}`; + return `Observation: ${stringifyJSONToMessageContent(reason.observation)}`; case "action": - return `Thought: ${reason.thought}\nAction: ${reason.action}\nInput: ${JSON.stringify( + return `Thought: ${reason.thought}\nAction: ${reason.action}\nInput: ${stringifyJSONToMessageContent( reason.input, )}`; case "response": { @@ -133,7 +141,7 @@ function extractToolUse( return [thought, action, actionInput]; } -function actionInputParser(jsonStr: string): Record<string, unknown> { +function actionInputParser(jsonStr: string): JSONObject { const processedString = jsonStr.replace(/(?<!\w)'|'(?!\w)/g, '"'); const pattern = /"(\w+)":\s*"([^"]*)"/g; const matches = [...processedString.matchAll(pattern)]; @@ -172,7 +180,7 @@ const reACTOutputParser: ReACTOutputParser = async ( const { content } = response; const [thought, action, input] = extractToolUse(content); const jsonStr = extractJsonStr(input); - let json: Record<string, unknown>; + let json: JSONObject; try { json = JSON.parse(jsonStr); } catch (e) { @@ -230,7 +238,7 @@ const reACTOutputParser: ReACTOutputParser = async ( case "action": { const [thought, action, input] = extractToolUse(content); const jsonStr = extractJsonStr(input); - let json: Record<string, unknown>; + let json: JSONObject; try { json = JSON.parse(jsonStr); } catch (e) { diff --git a/packages/core/src/agent/utils.ts b/packages/core/src/agent/utils.ts index 08a601edfdba9b2a31be406587c71ec27443148f..2cb6117ecdd9415b370125f46ebde703ca9afdba 100644 --- a/packages/core/src/agent/utils.ts +++ b/packages/core/src/agent/utils.ts @@ -6,7 +6,7 @@ import type { TextChatMessage, ToolCall, } from "../llm/index.js"; -import type { BaseTool, ToolOutput } from "../types.js"; +import type { BaseTool, JSONValue, ToolOutput } from "../types.js"; export async function callTool( tool: BaseTool | undefined, @@ -22,7 +22,7 @@ export async function callTool( }; } const call = tool.call; - let output: string; + let output: JSONValue; if (!call) { output = `Tool ${tool.metadata.name} (remote:${toolCall.name}) does not have a implementation.`; return { diff --git a/packages/core/src/internal/utils.ts b/packages/core/src/internal/utils.ts index d142bf1db887c6745c73490fa25550e3ea392f6f..6fe3e816671702c9e15820d8ef0a20c03744c2f0 100644 --- a/packages/core/src/internal/utils.ts +++ b/packages/core/src/internal/utils.ts @@ -1,3 +1,5 @@ +import type { JSONValue } from "../types.js"; + export const isAsyncIterable = ( obj: unknown, ): obj is AsyncIterable<unknown> => { @@ -18,3 +20,7 @@ export function prettifyError(error: unknown): string { return `${error}`; } } + +export function stringifyJSONToMessageContent(value: JSONValue): string { + return JSON.stringify(value, null, 2).replace(/"([^"]*)"/g, "$1"); +} diff --git a/packages/core/src/llm/open_ai.ts b/packages/core/src/llm/open_ai.ts index 78590ec89b7770451ded4d791fe1e5b0c67d4276..af450826c4a3ffc4c340b04b351d373d7c76da79 100644 --- a/packages/core/src/llm/open_ai.ts +++ b/packages/core/src/llm/open_ai.ts @@ -38,8 +38,8 @@ import type { LLMChatParamsStreaming, LLMMetadata, MessageType, + ToolCall, ToolCallLLMMessageOptions, - ToolCallOptions, } from "./types.js"; import { extractText, wrapLLMEvent } from "./utils.js"; @@ -394,23 +394,45 @@ export class OpenAI extends ToolCallLLM<OpenAIAdditionalChatOptions> { // TODO: add callback to streamConverter and use streamConverter here //Indices let idxCounter: number = 0; - let toolCallOptions: ToolCallOptions | null = null; + // this will be used to keep track of the current tool call, make sure input are valid json object. + let currentToolCall: + | (Omit<ToolCall, "input"> & { + input: string; + }) + | null = null; + const toolCallMap = new Map< + string, + Omit<ToolCall, "input"> & { + input: string; + } + >(); for await (const part of stream) { - if (!part.choices.length) continue; + if (part.choices.length === 0) continue; const choice = part.choices[0]; // skip parts that don't have any content if (!(choice.delta.content || choice.delta.tool_calls)) continue; + + let shouldEmitToolCall: ToolCall | null = null; + if ( + choice.delta.tool_calls?.[0].id && + currentToolCall && + choice.delta.tool_calls?.[0].id !== currentToolCall.id + ) { + shouldEmitToolCall = { + ...currentToolCall, + input: JSON.parse(currentToolCall.input), + }; + } if (choice.delta.tool_calls?.[0].id) { - toolCallOptions = { - toolCall: { - name: choice.delta.tool_calls[0].function!.name!, - id: choice.delta.tool_calls[0].id, - input: choice.delta.tool_calls[0].function!.arguments!, - }, + currentToolCall = { + name: choice.delta.tool_calls[0].function!.name!, + id: choice.delta.tool_calls[0].id, + input: choice.delta.tool_calls[0].function!.arguments!, }; + toolCallMap.set(choice.delta.tool_calls[0].id, currentToolCall); } else { if (choice.delta.tool_calls?.[0].function?.arguments) { - toolCallOptions!.toolCall.input += + currentToolCall!.input += choice.delta.tool_calls[0].function.arguments; } } @@ -423,12 +445,21 @@ export class OpenAI extends ToolCallLLM<OpenAIAdditionalChatOptions> { token: part, }); + if (isDone && currentToolCall) { + // for the last one, we need to emit the tool call + shouldEmitToolCall = { + ...currentToolCall, + input: JSON.parse(currentToolCall.input), + }; + } + yield { raw: part, - options: toolCallOptions ? toolCallOptions : {}, + options: shouldEmitToolCall ? shouldEmitToolCall : {}, delta: choice.delta.content ?? "", }; } + toolCallMap.clear(); return; } diff --git a/packages/core/src/llm/types.ts b/packages/core/src/llm/types.ts index 21126953635d96fed9fe75cf06fbee4aac5f3703..204be288f5fa242b3034fa5fdf5c2409ca47acb8 100644 --- a/packages/core/src/llm/types.ts +++ b/packages/core/src/llm/types.ts @@ -1,6 +1,6 @@ import type { Tokenizers } from "../GlobalsHelper.js"; import type { BaseEvent } from "../internal/type.js"; -import type { BaseTool, ToolOutput, UUID } from "../types.js"; +import type { BaseTool, JSONObject, ToolOutput, UUID } from "../types.js"; export type LLMStartEvent = BaseEvent<{ id: UUID; @@ -185,9 +185,7 @@ export type MessageContent = string | MessageContentDetail[]; export type ToolCall = { name: string; - // for now, claude-3-opus will give object, gpt-3/4 will give string - // todo: unify this to always be an object - input: unknown; + input: JSONObject; id: string; }; diff --git a/packages/core/src/tools/functionTool.ts b/packages/core/src/tools/functionTool.ts index cbd8a28e3bfdff4c9f0853b1050a31b90c75d363..128d429cd5fbe78073a7eb44df45b761a78125b9 100644 --- a/packages/core/src/tools/functionTool.ts +++ b/packages/core/src/tools/functionTool.ts @@ -1,7 +1,7 @@ import type { JSONSchemaType } from "ajv"; -import type { BaseTool, ToolMetadata } from "../types.js"; +import type { BaseTool, JSONValue, ToolMetadata } from "../types.js"; -export class FunctionTool<T, R extends string | Promise<string>> +export class FunctionTool<T, R extends JSONValue | Promise<JSONValue>> implements BaseTool<T> { constructor( @@ -10,10 +10,10 @@ export class FunctionTool<T, R extends string | Promise<string>> ) {} static from<T>( - fn: (input: T) => string | Promise<string>, + fn: (input: T) => JSONValue | Promise<JSONValue>, schema: ToolMetadata<JSONSchemaType<T>>, - ): FunctionTool<T, string | Promise<string>>; - static from<T, R extends string | Promise<string>>( + ): FunctionTool<T, JSONValue | Promise<JSONValue>>; + static from<T, R extends JSONValue | Promise<JSONValue>>( fn: (input: T) => R, schema: ToolMetadata<JSONSchemaType<T>>, ): FunctionTool<T, R> { diff --git a/packages/core/src/types.ts b/packages/core/src/types.ts index 17433e08a2ea02359817045d58ee59e4b4672d37..7e95ea70904da240955c0ddcd5cc2097c8ea012f 100644 --- a/packages/core/src/types.ts +++ b/packages/core/src/types.ts @@ -60,9 +60,9 @@ export interface BaseTool<Input = any> { * This could be undefined if the implementation is not provided, * which might be the case when communicating with a llm. * - * @return string - the output of the tool, should be string in any case for LLM input. + * @return {JSONValue | Promise<JSONValue>} The output of the tool. */ - call?: (input: Input) => string | Promise<string>; + call?: (input: Input) => JSONValue | Promise<JSONValue>; metadata: // if user input any, we cannot check the schema Input extends Known ? ToolMetadata<JSONSchemaType<Input>> : ToolMetadata; } @@ -104,9 +104,18 @@ export class QueryBundle { export type UUID = `${string}-${string}-${string}-${string}-${string}`; +export type JSONValue = string | number | boolean | JSONObject | JSONArray; + +export type JSONObject = { + [key: string]: JSONValue; +}; + +type JSONArray = Array<JSONValue>; + export type ToolOutput = { tool: BaseTool | undefined; - input: unknown; - output: string; + // all of existing function calling LLMs only support object input + input: JSONObject; + output: JSONValue; isError: boolean; };