From 61103b677bd3b29c634b549aa7847a120a77e3e8 Mon Sep 17 00:00:00 2001 From: Alex Yang <himself65@outlook.com> Date: Wed, 1 May 2024 19:26:06 -0500 Subject: [PATCH] fix: streaming for `Agent.createTask` (#788) --- .changeset/four-ears-nail.md | 6 ++ apps/docs/blog/2024-04-26-v0.3.0.md | 26 ++----- examples/agent/openai-task.ts | 87 +++++++++++++++++++++ examples/agent/step_wise_openai.ts | 97 ----------------------- packages/core/src/agent/anthropic.ts | 43 ++++------ packages/core/src/agent/base.ts | 112 +++++++++++++-------------- packages/core/src/agent/openai.ts | 59 ++++++-------- packages/core/src/agent/react.ts | 55 ++++++------- packages/core/src/agent/types.ts | 30 +++---- 9 files changed, 227 insertions(+), 288 deletions(-) create mode 100644 .changeset/four-ears-nail.md create mode 100644 examples/agent/openai-task.ts delete mode 100644 examples/agent/step_wise_openai.ts diff --git a/.changeset/four-ears-nail.md b/.changeset/four-ears-nail.md new file mode 100644 index 000000000..3825a41d9 --- /dev/null +++ b/.changeset/four-ears-nail.md @@ -0,0 +1,6 @@ +--- +"llamaindex": patch +"@llamaindex/core-e2e": patch +--- + +fix: streaming for `Agent.createTask` API diff --git a/apps/docs/blog/2024-04-26-v0.3.0.md b/apps/docs/blog/2024-04-26-v0.3.0.md index 6db63f00f..a1be2a4cc 100644 --- a/apps/docs/blog/2024-04-26-v0.3.0.md +++ b/apps/docs/blog/2024-04-26-v0.3.0.md @@ -72,12 +72,8 @@ export class MyAgent extends AgentRunner<MyLLM> { // create store is a function to create a store for each task, by default it only includes `messages` and `toolOutputs` createStore = AgentRunner.defaultCreateStore; - static taskHandler: TaskHandler<Anthropic> = async (step) => { - const { input } = step; + static taskHandler: TaskHandler<Anthropic> = async (step, enqueueOutput) => { const { llm, stream } = step.context; - if (input) { - step.context.store.messages = [...step.context.store.messages, input]; - } // initialize the input const response = await llm.chat({ stream, @@ -90,27 +86,21 @@ export class MyAgent extends AgentRunner<MyLLM> { ]; // your logic here to decide whether to continue the task const shouldContinue = Math.random(); /* <-- replace with your logic here */ + enqueueOutput({ + taskStep: step, + output: response, + isLast: !shouldContinue, + }); if (shouldContinue) { + const content = await someHeavyFunctionCall(); // if you want to continue the task, you can insert your new context for the next task step step.context.store.messages = [ ...step.context.store.messages, { - content: "INSERT MY NEW DATA", + content, role: "user", }, ]; - return { - taskStep: step, - output: response, - isLast: false, - }; - } else { - // if you want to end the task, you can return the response with `isLast: true` - return { - taskStep: step, - output: response, - isLast: true, - }; } }; } diff --git a/examples/agent/openai-task.ts b/examples/agent/openai-task.ts new file mode 100644 index 000000000..09e632278 --- /dev/null +++ b/examples/agent/openai-task.ts @@ -0,0 +1,87 @@ +import { ChatResponseChunk, FunctionTool, OpenAIAgent } from "llamaindex"; +import { ReadableStream } from "node:stream/web"; + +const functionTool = FunctionTool.from( + () => { + console.log("Getting user id..."); + return crypto.randomUUID(); + }, + { + name: "get_user_id", + description: "Get a random user id", + }, +); + +const functionTool2 = FunctionTool.from( + ({ userId }: { userId: string }) => { + console.log("Getting user info...", userId); + return `Name: Alex; Address: 1234 Main St, CA; User ID: ${userId}`; + }, + { + name: "get_user_info", + description: "Get user info", + parameters: { + type: "object", + properties: { + userId: { + type: "string", + description: "The user id", + }, + }, + required: ["userId"], + }, + }, +); + +const functionTool3 = FunctionTool.from( + ({ address }: { address: string }) => { + console.log("Getting weather...", address); + return `${address} is in a sunny location!`; + }, + { + name: "get_weather", + description: "Get the current weather for a location", + parameters: { + type: "object", + properties: { + address: { + type: "string", + description: "The address", + }, + }, + required: ["address"], + }, + }, +); + +async function main() { + // Create an OpenAIAgent with the function tools + const agent = new OpenAIAgent({ + tools: [functionTool, functionTool2, functionTool3], + }); + + const task = await agent.createTask( + "What is my current address weather based on my profile?", + true, + ); + + for await (const stepOutput of task) { + const stream = stepOutput.output as ReadableStream<ChatResponseChunk>; + if (stepOutput.isLast) { + for await (const chunk of stream) { + process.stdout.write(chunk.delta); + } + process.stdout.write("\n"); + } else { + // handing function call + console.log("handling function call..."); + for await (const chunk of stream) { + console.log("debug:", JSON.stringify(chunk.raw)); + } + } + } +} + +void main().then(() => { + console.log("Done"); +}); diff --git a/examples/agent/step_wise_openai.ts b/examples/agent/step_wise_openai.ts deleted file mode 100644 index 74c083400..000000000 --- a/examples/agent/step_wise_openai.ts +++ /dev/null @@ -1,97 +0,0 @@ -import { FunctionTool, OpenAIAgent } from "llamaindex"; -import { ReadableStream } from "node:stream/web"; - -// Define a function to sum two numbers -function sumNumbers({ a, b }: { a: number; b: number }) { - return `${a + b}`; -} - -// Define a function to divide two numbers -function divideNumbers({ a, b }: { a: number; b: number }) { - return `${a / b}`; -} - -// Define the parameters of the sum function as a JSON schema -const sumJSON = { - type: "object", - properties: { - a: { - type: "number", - description: "The first number", - }, - b: { - type: "number", - description: "The second number", - }, - }, - required: ["a", "b"], -} as const; - -const divideJSON = { - type: "object", - properties: { - a: { - type: "number", - description: "The dividend", - }, - b: { - type: "number", - description: "The divisor", - }, - }, - required: ["a", "b"], -} as const; - -async function main() { - // Create a function tool from the sum function - const functionTool = new FunctionTool(sumNumbers, { - name: "sumNumbers", - description: "Use this function to sum two numbers", - parameters: sumJSON, - }); - - // Create a function tool from the divide function - const functionTool2 = new FunctionTool(divideNumbers, { - name: "divideNumbers", - description: "Use this function to divide two numbers", - parameters: divideJSON, - }); - - // Create an OpenAIAgent with the function tools - const agent = new OpenAIAgent({ - tools: [functionTool, functionTool2], - }); - - // Create a task to sum and divide numbers - const task = await agent.createTask("How much is 5 + 5? then divide by 2"); - - let count = 0; - - for await (const stepOutput of task) { - console.log(`Runnning step ${count++}`); - console.log(`======== OUTPUT ==========`); - const output = stepOutput.output; - if (output instanceof ReadableStream) { - for await (const chunk of output) { - process.stdout.write(chunk.delta); - } - } else { - console.log(output); - } - console.log(`==========================`); - - if (stepOutput.isLast) { - if (stepOutput.output instanceof ReadableStream) { - for await (const chunk of stepOutput.output) { - process.stdout.write(chunk.delta); - } - } else { - console.log(stepOutput.output); - } - } - } -} - -void main().then(() => { - console.log("Done"); -}); diff --git a/packages/core/src/agent/anthropic.ts b/packages/core/src/agent/anthropic.ts index 231795d89..accc87fca 100644 --- a/packages/core/src/agent/anthropic.ts +++ b/packages/core/src/agent/anthropic.ts @@ -67,12 +67,8 @@ export class AnthropicAgent extends AgentRunner<Anthropic> { return super.chat(params); } - static taskHandler: TaskHandler<Anthropic> = async (step) => { - const { input } = step; + static taskHandler: TaskHandler<Anthropic> = async (step, enqueueOutput) => { const { llm, getTools, stream } = step.context; - if (input) { - step.context.store.messages = [...step.context.store.messages, input]; - } const lastMessage = step.context.store.messages.at(-1)!.content; const tools = await getTools(lastMessage); if (stream === true) { @@ -88,6 +84,11 @@ export class AnthropicAgent extends AgentRunner<Anthropic> { response.message, ]; const options = response.message.options ?? {}; + enqueueOutput({ + taskStep: step, + output: response, + isLast: !("toolCall" in options), + }); if ("toolCall" in options) { const { toolCall } = options; const targetTool = tools.find( @@ -95,30 +96,20 @@ export class AnthropicAgent extends AgentRunner<Anthropic> { ); const toolOutput = await callTool(targetTool, toolCall); step.context.store.toolOutputs.push(toolOutput); - return { - taskStep: step, - output: { - raw: response.raw, - message: { - content: stringifyJSONToMessageContent(toolOutput.output), - role: "user", - options: { - toolResult: { - result: toolOutput.output, - isError: toolOutput.isError, - id: toolCall.id, - }, + step.context.store.messages = [ + ...step.context.store.messages, + { + content: stringifyJSONToMessageContent(toolOutput.output), + role: "user", + options: { + toolResult: { + result: toolOutput.output, + isError: toolOutput.isError, + id: toolCall.id, }, }, }, - isLast: false, - }; - } else { - return { - taskStep: step, - output: response, - isLast: true, - }; + ]; } }; } diff --git a/packages/core/src/agent/base.ts b/packages/core/src/agent/base.ts index 57d6d9e42..e0e14f4f5 100644 --- a/packages/core/src/agent/base.ts +++ b/packages/core/src/agent/base.ts @@ -27,14 +27,10 @@ import type { TaskStep, TaskStepOutput, } from "./types.js"; -import { consumeAsyncIterable } from "./utils.js"; export const MAX_TOOL_CALLS = 10; -/** - * @internal - */ -export async function* createTaskImpl< +export function createTaskOutputStream< Model extends LLM, Store extends object = {}, AdditionalMessageOptions extends object = Model extends LLM< @@ -46,65 +42,60 @@ export async function* createTaskImpl< >( handler: TaskHandler<Model, Store, AdditionalMessageOptions>, context: AgentTaskContext<Model, Store, AdditionalMessageOptions>, - _input: ChatMessage<AdditionalMessageOptions>, -): AsyncGenerator<TaskStepOutput<Model, Store, AdditionalMessageOptions>> { - let isFirst = true; - let isDone = false; - let input: ChatMessage<AdditionalMessageOptions> | null = _input; - let prevStep: TaskStep<Model, Store, AdditionalMessageOptions> | null = null; - while (!isDone) { - const step: TaskStep<Model, Store, AdditionalMessageOptions> = { - id: randomUUID(), - input, - context, - prevStep, - nextSteps: new Set(), - }; - if (prevStep) { - prevStep.nextSteps.add(step); - } - const prevToolCallCount = step.context.toolCallCount; - if (!step.context.shouldContinue(step)) { - throw new Error("Tool call count exceeded limit"); - } - if (isFirst) { +): ReadableStream<TaskStepOutput<Model, Store, AdditionalMessageOptions>> { + const steps: TaskStep<Model, Store, AdditionalMessageOptions>[] = []; + return new ReadableStream< + TaskStepOutput<Model, Store, AdditionalMessageOptions> + >({ + pull: async (controller) => { + const step: TaskStep<Model, Store, AdditionalMessageOptions> = { + id: randomUUID(), + context, + prevStep: null, + nextSteps: new Set(), + }; + if (steps.length > 0) { + step.prevStep = steps[steps.length - 1]; + } + const taskOutputs: TaskStepOutput< + Model, + Store, + AdditionalMessageOptions + >[] = []; + steps.push(step); + const enqueueOutput = ( + output: TaskStepOutput<Model, Store, AdditionalMessageOptions>, + ) => { + taskOutputs.push(output); + controller.enqueue(output); + }; getCallbackManager().dispatchEvent("agent-start", { payload: { startStep: step, }, }); - isFirst = false; - } - const taskOutput = await handler(step); - const { isLast, output, taskStep } = taskOutput; - // do not consume last output - if (!isLast) { - if (output) { - input = isAsyncIterable(output) - ? await consumeAsyncIterable(output) - : output.message; - } else { - input = null; - } - } - context = { - ...taskStep.context, - store: { - ...taskStep.context.store, - }, - toolCallCount: prevToolCallCount + 1, - }; - if (isLast) { - isDone = true; - getCallbackManager().dispatchEvent("agent-end", { - payload: { - endStep: step, + + await handler(step, enqueueOutput); + // fixme: support multi-thread when there are multiple outputs + // todo: for now we pretend there is only one task output + const { isLast, taskStep } = taskOutputs[0]; + context = { + ...taskStep.context, + store: { + ...taskStep.context.store, }, - }); - } - prevStep = taskStep; - yield taskOutput; - } + toolCallCount: 1, + }; + if (isLast) { + getCallbackManager().dispatchEvent("agent-end", { + payload: { + endStep: step, + }, + }); + controller.close(); + } + }, + }); } export type AgentStreamChatResponse<Options extends object> = { @@ -170,15 +161,16 @@ export abstract class AgentWorker< query: string, context: AgentTaskContext<AI, Store, AdditionalMessageOptions>, ): ReadableStream<TaskStepOutput<AI, Store, AdditionalMessageOptions>> { - const taskGenerator = createTaskImpl(this.taskHandler, context, { + context.store.messages.push({ role: "user", content: query, }); + const taskOutputStream = createTaskOutputStream(this.taskHandler, context); return new ReadableStream< TaskStepOutput<AI, Store, AdditionalMessageOptions> >({ start: async (controller) => { - for await (const stepOutput of taskGenerator) { + for await (const stepOutput of taskOutputStream) { this.#taskSet.add(stepOutput.taskStep); controller.enqueue(stepOutput); if (stepOutput.isLast) { diff --git a/packages/core/src/agent/openai.ts b/packages/core/src/agent/openai.ts index 6e569c9a7..f7146665d 100644 --- a/packages/core/src/agent/openai.ts +++ b/packages/core/src/agent/openai.ts @@ -51,12 +51,8 @@ export class OpenAIAgent extends AgentRunner<OpenAI> { createStore = AgentRunner.defaultCreateStore; - static taskHandler: TaskHandler<OpenAI> = async (step) => { - const { input } = step; + static taskHandler: TaskHandler<OpenAI> = async (step, enqueueOutput) => { const { llm, stream, getTools } = step.context; - if (input) { - step.context.store.messages = [...step.context.store.messages, input]; - } const lastMessage = step.context.store.messages.at(-1)!.content; const tools = await getTools(lastMessage); const response = await llm.chat({ @@ -71,6 +67,11 @@ export class OpenAIAgent extends AgentRunner<OpenAI> { response.message, ]; const options = response.message.options ?? {}; + enqueueOutput({ + taskStep: step, + output: response, + isLast: !("toolCall" in options), + }); if ("toolCall" in options) { const { toolCall } = options; const targetTool = tools.find( @@ -78,30 +79,20 @@ export class OpenAIAgent extends AgentRunner<OpenAI> { ); const toolOutput = await callTool(targetTool, toolCall); step.context.store.toolOutputs.push(toolOutput); - return { - taskStep: step, - output: { - raw: response.raw, - message: { - content: stringifyJSONToMessageContent(toolOutput.output), - role: "user", - options: { - toolResult: { - result: toolOutput.output, - isError: toolOutput.isError, - id: toolCall.id, - }, + step.context.store.messages = [ + ...step.context.store.messages, + { + role: "user" as const, + content: stringifyJSONToMessageContent(toolOutput.output), + options: { + toolResult: { + result: toolOutput.output, + isError: toolOutput.isError, + id: toolCall.id, }, }, }, - isLast: false, - }; - } else { - return { - taskStep: step, - output: response, - isLast: true, - }; + ]; } } else { const responseChunkStream = new ReadableStream< @@ -126,6 +117,11 @@ export class OpenAIAgent extends AgentRunner<OpenAI> { // check if first chunk has tool calls, if so, this is a function call // otherwise, it's a regular message const hasToolCall = !!(value.options && "toolCall" in value.options); + enqueueOutput({ + taskStep: step, + output: finalStream, + isLast: !hasToolCall, + }); if (hasToolCall) { // you need to consume the response to get the full toolCalls @@ -175,17 +171,6 @@ export class OpenAIAgent extends AgentRunner<OpenAI> { ]; step.context.store.toolOutputs.push(toolOutput); } - return { - taskStep: step, - output: null, - isLast: false, - }; - } else { - return { - taskStep: step, - output: finalStream, - isLast: true, - }; } } }; diff --git a/packages/core/src/agent/react.ts b/packages/core/src/agent/react.ts index 651a30781..27b48d5b1 100644 --- a/packages/core/src/agent/react.ts +++ b/packages/core/src/agent/react.ts @@ -349,12 +349,11 @@ export class ReActAgent extends AgentRunner<LLM, ReACTAgentStore> { }; } - static taskHandler: TaskHandler<LLM, ReACTAgentStore> = async (step) => { + static taskHandler: TaskHandler<LLM, ReACTAgentStore> = async ( + step, + enqueueOutput, + ) => { const { llm, stream, getTools } = step.context; - const input = step.input; - if (input) { - step.context.store.messages.push(input); - } const lastMessage = step.context.store.messages.at(-1)!.content; const tools = await getTools(lastMessage); const messages = await chatFormatter( @@ -369,33 +368,25 @@ export class ReActAgent extends AgentRunner<LLM, ReACTAgentStore> { }); const reason = await reACTOutputParser(response); step.context.store.reasons = [...step.context.store.reasons, reason]; - if (reason.type === "response") { - return { - isLast: true, - output: response, - taskStep: step, - }; - } else { - if (reason.type === "action") { - const tool = tools.find((tool) => tool.metadata.name === reason.action); - const toolOutput = await callTool(tool, { - id: randomUUID(), - input: reason.input, - name: reason.action, - }); - step.context.store.reasons = [ - ...step.context.store.reasons, - { - type: "observation", - observation: toolOutput.output, - }, - ]; - } - return { - isLast: false, - output: null, - taskStep: step, - }; + enqueueOutput({ + taskStep: step, + output: response, + isLast: reason.type === "response", + }); + if (reason.type === "action") { + const tool = tools.find((tool) => tool.metadata.name === reason.action); + const toolOutput = await callTool(tool, { + id: randomUUID(), + input: reason.input, + name: reason.action, + }); + step.context.store.reasons = [ + ...step.context.store.reasons, + { + type: "observation", + observation: toolOutput.output, + }, + ]; } }; } diff --git a/packages/core/src/agent/types.ts b/packages/core/src/agent/types.ts index f523a047d..b3d48374c 100644 --- a/packages/core/src/agent/types.ts +++ b/packages/core/src/agent/types.ts @@ -45,7 +45,6 @@ export type TaskStep< : never, > = { id: UUID; - input: ChatMessage<AdditionalMessageOptions> | null; context: AgentTaskContext<Model, Store, AdditionalMessageOptions>; // linked list @@ -62,22 +61,14 @@ export type TaskStepOutput< > ? AdditionalMessageOptions : never, -> = - | { - taskStep: TaskStep<Model, Store, AdditionalMessageOptions>; - output: - | null - | ChatResponse<AdditionalMessageOptions> - | ReadableStream<ChatResponseChunk<AdditionalMessageOptions>>; - isLast: false; - } - | { - taskStep: TaskStep<Model, Store, AdditionalMessageOptions>; - output: - | ChatResponse<AdditionalMessageOptions> - | ReadableStream<ChatResponseChunk<AdditionalMessageOptions>>; - isLast: true; - }; +> = { + taskStep: TaskStep<Model, Store, AdditionalMessageOptions>; + // output shows the response to the user + output: + | ChatResponse<AdditionalMessageOptions> + | ReadableStream<ChatResponseChunk<AdditionalMessageOptions>>; + isLast: boolean; +}; export type TaskHandler< Model extends LLM, @@ -90,7 +81,10 @@ export type TaskHandler< : never, > = ( step: TaskStep<Model, Store, AdditionalMessageOptions>, -) => Promise<TaskStepOutput<Model, Store, AdditionalMessageOptions>>; + enqueueOutput: ( + taskOutput: TaskStepOutput<Model, Store, AdditionalMessageOptions>, + ) => void, +) => Promise<void>; export type AgentStartEvent = BaseEvent<{ startStep: TaskStep; -- GitLab