diff --git a/.changeset/khaki-rivers-unite.md b/.changeset/khaki-rivers-unite.md new file mode 100644 index 0000000000000000000000000000000000000000..d09ba4b1276803a4846ae45ccd89f783b9f9f86b --- /dev/null +++ b/.changeset/khaki-rivers-unite.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +Add streaming to agents diff --git a/examples/agent/stream_openai_agent.ts b/examples/agent/stream_openai_agent.ts new file mode 100644 index 0000000000000000000000000000000000000000..3693381f32d7fc75c313a42c94f06ae69ed9316b --- /dev/null +++ b/examples/agent/stream_openai_agent.ts @@ -0,0 +1,77 @@ +import { FunctionTool, OpenAIAgent } from "llamaindex"; + +// Define a function to sum two numbers +function sumNumbers({ a, b }: { a: number; b: number }): number { + return a + b; +} + +// Define a function to divide two numbers +function divideNumbers({ a, b }: { a: number; b: number }): number { + return a / b; +} + +// Define the parameters of the sum function as a JSON schema +const sumJSON = { + type: "object", + properties: { + a: { + type: "number", + description: "The first number", + }, + b: { + type: "number", + description: "The second number", + }, + }, + required: ["a", "b"], +}; + +const divideJSON = { + type: "object", + properties: { + a: { + type: "number", + description: "The dividend", + }, + b: { + type: "number", + description: "The divisor", + }, + }, + required: ["a", "b"], +}; + +async function main() { + // Create a function tool from the sum function + const functionTool = new FunctionTool(sumNumbers, { + name: "sumNumbers", + description: "Use this function to sum two numbers", + parameters: sumJSON, + }); + + // Create a function tool from the divide function + const functionTool2 = new FunctionTool(divideNumbers, { + name: "divideNumbers", + description: "Use this function to divide two numbers", + parameters: divideJSON, + }); + + // Create an OpenAIAgent with the function tools + const agent = new OpenAIAgent({ + tools: [functionTool, functionTool2], + verbose: false, + }); + + const stream = await agent.chat({ + message: "Divide 16 by 2 then add 20", + stream: true, + }); + + for await (const chunk of stream.response) { + process.stdout.write(chunk.response); + } +} + +main().then(() => { + console.log("\nDone"); +}); diff --git a/packages/core/src/agent/openai/worker.ts b/packages/core/src/agent/openai/worker.ts index 84fd81bfa7fa33046188f4431e4125e54ac778ac..a4793cf82060e7fd2a6b40eb67548679782b03c9 100644 --- a/packages/core/src/agent/openai/worker.ts +++ b/packages/core/src/agent/openai/worker.ts @@ -1,10 +1,12 @@ // Assuming that the necessary interfaces and classes (like BaseTool, OpenAI, ChatMessage, CallbackManager, etc.) are defined elsewhere import { randomUUID } from "@llamaindex/env"; +import { Response } from "../../Response.js"; import type { CallbackManager } from "../../callbacks/CallbackManager.js"; import { AgentChatResponse, ChatResponseMode, + StreamingAgentChatResponse, } from "../../engines/chat/types.js"; import type { ChatMessage, @@ -12,6 +14,7 @@ import type { ChatResponseChunk, } from "../../llm/index.js"; import { OpenAI } from "../../llm/index.js"; +import { streamConverter, streamReducer } from "../../llm/utils.js"; import { ChatMemoryBuffer } from "../../memory/ChatMemoryBuffer.js"; import type { ObjectRetriever } from "../../objects/base.js"; import type { ToolOutput } from "../../tools/types.js"; @@ -192,13 +195,40 @@ export class OpenAIAgentWorker implements AgentWorker { private _processMessage( task: Task, chatResponse: ChatResponse, - ): AgentChatResponse | AsyncIterable<ChatResponseChunk> { + ): AgentChatResponse { const aiMessage = chatResponse.message; task.extraState.newMemory.put(aiMessage); return new AgentChatResponse(aiMessage.content, task.extraState.sources); } + private async _getStreamAiResponse( + task: Task, + llmChatKwargs: any, + ): Promise<StreamingAgentChatResponse> { + const stream = await this.llm.chat({ + stream: true, + ...llmChatKwargs, + }); + + const iterator = streamConverter( + streamReducer({ + stream, + initialValue: "", + reducer: (accumulator, part) => (accumulator += part.delta), + finished: (accumulator) => { + task.extraState.newMemory.put({ + content: accumulator, + role: "assistant", + }); + }, + }), + (r: ChatResponseChunk) => new Response(r.delta), + ); + + return new StreamingAgentChatResponse(iterator, task.extraState.sources); + } + /** * Get agent response. * @param task: task @@ -210,7 +240,7 @@ export class OpenAIAgentWorker implements AgentWorker { task: Task, mode: ChatResponseMode, llmChatKwargs: any, - ): Promise<AgentChatResponse> { + ): Promise<AgentChatResponse | StreamingAgentChatResponse> { if (mode === ChatResponseMode.WAIT) { const chatResponse = (await this.llm.chat({ stream: false, @@ -218,9 +248,11 @@ export class OpenAIAgentWorker implements AgentWorker { })) as unknown as ChatResponse; return this._processMessage(task, chatResponse) as AgentChatResponse; - } else { - throw new Error("Not implemented"); + } else if (mode === ChatResponseMode.STREAM) { + return this._getStreamAiResponse(task, llmChatKwargs); } + + throw new Error("Invalid mode"); } /** diff --git a/packages/core/src/agent/runner/base.ts b/packages/core/src/agent/runner/base.ts index 532390e0df92496880a428b76b61e55f905bdbe7..cce07b892d009021d1e20668bf311c56e4054aad 100644 --- a/packages/core/src/agent/runner/base.ts +++ b/packages/core/src/agent/runner/base.ts @@ -4,6 +4,7 @@ import type { ChatEngineAgentParams } from "../../engines/chat/index.js"; import { AgentChatResponse, ChatResponseMode, + StreamingAgentChatResponse, } from "../../engines/chat/index.js"; import type { ChatMessage, LLM } from "../../llm/index.js"; import { ChatMemoryBuffer } from "../../memory/ChatMemoryBuffer.js"; @@ -231,23 +232,26 @@ export class AgentRunner extends BaseAgentRunner { taskId: string, stepOutput: TaskStepOutput, kwargs?: any, - ): Promise<AgentChatResponse> { + ): Promise<AgentChatResponse | StreamingAgentChatResponse> { if (!stepOutput) { stepOutput = this.getCompletedSteps(taskId)[ this.getCompletedSteps(taskId).length - 1 ]; } + if (!stepOutput.isLast) { throw new Error( "finalizeResponse can only be called on the last step output", ); } - if (!(stepOutput.output instanceof AgentChatResponse)) { - throw new Error( - `When \`isLast\` is True, cur_step_output.output must be AGENT_CHAT_RESPONSE_TYPE: ${stepOutput.output}`, - ); + if (!(stepOutput.output instanceof StreamingAgentChatResponse)) { + if (!(stepOutput.output instanceof AgentChatResponse)) { + throw new Error( + `When \`isLast\` is True, cur_step_output.output must be AGENT_CHAT_RESPONSE_TYPE: ${stepOutput.output}`, + ); + } } this.agentWorker.finalizeTask(this.getTask(taskId), kwargs); @@ -262,20 +266,32 @@ export class AgentRunner extends BaseAgentRunner { protected async _chat({ message, toolChoice, - }: ChatEngineAgentParams & { mode: ChatResponseMode }) { + stream, + }: ChatEngineAgentParams): Promise<AgentChatResponse>; + protected async _chat({ + message, + toolChoice, + stream, + }: ChatEngineAgentParams & { + stream: true; + }): Promise<StreamingAgentChatResponse>; + protected async _chat({ + message, + toolChoice, + stream, + }: ChatEngineAgentParams): Promise< + AgentChatResponse | StreamingAgentChatResponse + > { const task = this.createTask(message as string); let resultOutput; + const mode = stream ? ChatResponseMode.STREAM : ChatResponseMode.WAIT; + while (true) { - const curStepOutput = await this._runStep( - task.taskId, - undefined, - ChatResponseMode.WAIT, - { - toolChoice, - }, - ); + const curStepOutput = await this._runStep(task.taskId, undefined, mode, { + toolChoice, + }); if (curStepOutput.isLast) { resultOutput = curStepOutput; @@ -299,7 +315,26 @@ export class AgentRunner extends BaseAgentRunner { message, chatHistory, toolChoice, - }: ChatEngineAgentParams): Promise<AgentChatResponse> { + stream, + }: ChatEngineAgentParams & { + stream?: false; + }): Promise<AgentChatResponse>; + public async chat({ + message, + chatHistory, + toolChoice, + stream, + }: ChatEngineAgentParams & { + stream: true; + }): Promise<StreamingAgentChatResponse>; + public async chat({ + message, + chatHistory, + toolChoice, + stream, + }: ChatEngineAgentParams): Promise< + AgentChatResponse | StreamingAgentChatResponse + > { if (!toolChoice) { toolChoice = this.defaultToolChoice; } @@ -308,7 +343,7 @@ export class AgentRunner extends BaseAgentRunner { message, chatHistory, toolChoice, - mode: ChatResponseMode.WAIT, + stream, }); return chatResponse; diff --git a/packages/core/src/agent/runner/types.ts b/packages/core/src/agent/runner/types.ts index 9fc4c5a8b645bff325b34146ac5fef5a083dcca5..c6fb3b8b6231aa9d2f33f700425c9d1c15a7ffb4 100644 --- a/packages/core/src/agent/runner/types.ts +++ b/packages/core/src/agent/runner/types.ts @@ -1,4 +1,7 @@ -import type { AgentChatResponse } from "../../engines/chat/index.js"; +import type { + AgentChatResponse, + StreamingAgentChatResponse, +} from "../../engines/chat/index.js"; import type { Task, TaskStep, TaskStepOutput } from "../types.js"; import { BaseAgent } from "../types.js"; @@ -57,7 +60,7 @@ export abstract class BaseAgentRunner extends BaseAgent { taskId: string, stepOutput: TaskStepOutput, kwargs?: any, - ): Promise<AgentChatResponse>; + ): Promise<AgentChatResponse | StreamingAgentChatResponse>; abstract undoStep(taskId: string): void; } diff --git a/packages/core/src/agent/types.ts b/packages/core/src/agent/types.ts index 8adc96340db5bb7f89e0d041bc64f388aa88d530..e277d7af0c7d6dab24dfbec980473eeeb3e393f4 100644 --- a/packages/core/src/agent/types.ts +++ b/packages/core/src/agent/types.ts @@ -1,6 +1,7 @@ import type { AgentChatResponse, ChatEngineAgentParams, + StreamingAgentChatResponse, } from "../engines/chat/index.js"; import type { QueryEngineParamsNonStreaming } from "../types.js"; @@ -12,11 +13,15 @@ export interface AgentWorker { } interface BaseChatEngine { - chat(params: ChatEngineAgentParams): Promise<AgentChatResponse>; + chat( + params: ChatEngineAgentParams, + ): Promise<AgentChatResponse | StreamingAgentChatResponse>; } interface BaseQueryEngine { - query(params: QueryEngineParamsNonStreaming): Promise<AgentChatResponse>; + query( + params: QueryEngineParamsNonStreaming, + ): Promise<AgentChatResponse | StreamingAgentChatResponse>; } /** @@ -31,7 +36,10 @@ export abstract class BaseAgent implements BaseChatEngine, BaseQueryEngine { return []; } - abstract chat(params: ChatEngineAgentParams): Promise<AgentChatResponse>; + abstract chat( + params: ChatEngineAgentParams, + ): Promise<AgentChatResponse | StreamingAgentChatResponse>; + abstract reset(): void; /** @@ -41,7 +49,7 @@ export abstract class BaseAgent implements BaseChatEngine, BaseQueryEngine { */ async query( params: QueryEngineParamsNonStreaming, - ): Promise<AgentChatResponse> { + ): Promise<AgentChatResponse | StreamingAgentChatResponse> { // Handle non-streaming query const agentResponse = await this.chat({ message: params.query, diff --git a/packages/core/src/engines/chat/types.ts b/packages/core/src/engines/chat/types.ts index 3a8478d0c95a804fb1f8d7e490aae51162334f0e..397d2cf51fa7d92ac91e415185fadc3605e72d63 100644 --- a/packages/core/src/engines/chat/types.ts +++ b/packages/core/src/engines/chat/types.ts @@ -27,6 +27,7 @@ export interface ChatEngineParamsNonStreaming extends ChatEngineParamsBase { export interface ChatEngineAgentParams extends ChatEngineParamsBase { toolChoice?: string | Record<string, any>; + stream?: boolean; } /** @@ -86,3 +87,20 @@ export class AgentChatResponse { return this.response ?? ""; } } + +export class StreamingAgentChatResponse { + response: AsyncIterable<Response>; + + sources: ToolOutput[]; + sourceNodes?: BaseNode[]; + + constructor( + response: AsyncIterable<Response>, + sources?: ToolOutput[], + sourceNodes?: BaseNode[], + ) { + this.response = response; + this.sources = sources ?? []; + this.sourceNodes = sourceNodes ?? []; + } +}