From f0704ec705b2d22d217beb03ba737cb372971c34 Mon Sep 17 00:00:00 2001 From: Marcus Schiesser <mail@marcusschiesser.de> Date: Fri, 5 Apr 2024 13:53:26 +0800 Subject: [PATCH] Add streaming for OpenAI agents (#693) --- .changeset/long-cobras-play.md | 5 +++ examples/agent/wiki.ts | 7 ++-- examples/agent/wikipedia-tool.ts | 23 ------------- packages/core/src/agent/openai/worker.ts | 44 ++++++++++++++++-------- packages/core/src/agent/types.ts | 4 +-- packages/core/src/llm/open_ai.ts | 8 +++-- packages/core/src/llm/types.ts | 1 + 7 files changed, 48 insertions(+), 44 deletions(-) create mode 100644 .changeset/long-cobras-play.md delete mode 100644 examples/agent/wikipedia-tool.ts diff --git a/.changeset/long-cobras-play.md b/.changeset/long-cobras-play.md new file mode 100644 index 000000000..6f736784a --- /dev/null +++ b/.changeset/long-cobras-play.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +Support streaming for OpenAI agent diff --git a/examples/agent/wiki.ts b/examples/agent/wiki.ts index 739f834cb..db308db57 100644 --- a/examples/agent/wiki.ts +++ b/examples/agent/wiki.ts @@ -14,12 +14,15 @@ async function main() { // Chat with the agent const response = await agent.chat({ message: "Who was Goethe?", + stream: true, }); - console.log(response.response); + for await (const chunk of response.response) { + process.stdout.write(chunk.response); + } } (async function () { await main(); - console.log("Done"); + console.log("\nDone"); })(); diff --git a/examples/agent/wikipedia-tool.ts b/examples/agent/wikipedia-tool.ts deleted file mode 100644 index 1a5e82fc1..000000000 --- a/examples/agent/wikipedia-tool.ts +++ /dev/null @@ -1,23 +0,0 @@ -import { OpenAIAgent, WikipediaTool } from "llamaindex"; - -async function main() { - const wikipediaTool = new WikipediaTool(); - - // Create an OpenAIAgent with the function tools - const agent = new OpenAIAgent({ - tools: [wikipediaTool], - verbose: true, - }); - - // Chat with the agent - const response = await agent.chat({ - message: "Where is Ho Chi Minh City?", - }); - - // Print the response - console.log(response); -} - -void main().then(() => { - console.log("Done"); -}); diff --git a/packages/core/src/agent/openai/worker.ts b/packages/core/src/agent/openai/worker.ts index b4f80407b..e978a0dab 100644 --- a/packages/core/src/agent/openai/worker.ts +++ b/packages/core/src/agent/openai/worker.ts @@ -9,6 +9,7 @@ import type { ChatMessage, ChatResponse, ChatResponseChunk, + LLMChatParamsBase, } from "../../llm/index.js"; import { OpenAI } from "../../llm/index.js"; import { streamConverter, streamReducer } from "../../llm/utils.js"; @@ -166,8 +167,8 @@ export class OpenAIAgentWorker implements AgentWorker { task: Task, openaiTools: { [key: string]: any }[], toolChoice: string | { [key: string]: any } = "auto", - ): { [key: string]: any } { - const llmChatKwargs: { [key: string]: any } = { + ): LLMChatParamsBase { + const llmChatKwargs: LLMChatParamsBase = { messages: this.getAllMessages(task), }; @@ -179,17 +180,10 @@ export class OpenAIAgentWorker implements AgentWorker { return llmChatKwargs; } - /** - * Process message. - * @param task: task - * @param chatResponse: chat response - * @returns: agent chat response - */ private _processMessage( task: Task, - chatResponse: ChatResponse, + aiMessage: ChatMessage, ): AgentChatResponse { - const aiMessage = chatResponse.message; task.extraState.newMemory.put(aiMessage); return new AgentChatResponse(aiMessage.content, task.extraState.sources); @@ -198,16 +192,33 @@ export class OpenAIAgentWorker implements AgentWorker { private async _getStreamAiResponse( task: Task, llmChatKwargs: any, - ): Promise<StreamingAgentChatResponse> { + ): Promise<StreamingAgentChatResponse | AgentChatResponse> { const stream = await this.llm.chat({ stream: true, ...llmChatKwargs, }); + // read first chunk from stream to find out if we need to call tools + const iterator = stream[Symbol.asyncIterator](); + let { value } = await iterator.next(); + let content = value.delta; + const hasToolCalls = value.additionalKwargs?.toolCalls.length > 0; + + if (hasToolCalls) { + // consume stream until we have all the tool calls and return a non-streamed response + for await (value of stream) { + content += value.delta; + } + return this._processMessage(task, { + content, + role: "assistant", + additionalKwargs: value.additionalKwargs, + }); + } - const iterator = streamConverter.bind(this)( + const newStream = streamConverter.bind(this)( streamReducer({ stream, - initialValue: "", + initialValue: content, reducer: (accumulator, part) => (accumulator += part.delta), finished: (accumulator) => { task.extraState.newMemory.put({ @@ -219,7 +230,7 @@ export class OpenAIAgentWorker implements AgentWorker { (r: ChatResponseChunk) => new Response(r.delta), ); - return new StreamingAgentChatResponse(iterator, task.extraState.sources); + return new StreamingAgentChatResponse(newStream, task.extraState.sources); } /** @@ -240,7 +251,10 @@ export class OpenAIAgentWorker implements AgentWorker { ...llmChatKwargs, })) as unknown as ChatResponse; - return this._processMessage(task, chatResponse) as AgentChatResponse; + return this._processMessage( + task, + chatResponse.message, + ) as AgentChatResponse; } else if (mode === ChatResponseMode.STREAM) { return this._getStreamAiResponse(task, llmChatKwargs); } diff --git a/packages/core/src/agent/types.ts b/packages/core/src/agent/types.ts index 185c26b43..6be1f8688 100644 --- a/packages/core/src/agent/types.ts +++ b/packages/core/src/agent/types.ts @@ -170,13 +170,13 @@ export class TaskStep implements ITaskStep { * @param isLast: isLast */ export class TaskStepOutput { - output: any; + output: AgentChatResponse | StreamingAgentChatResponse; taskStep: TaskStep; nextSteps: TaskStep[]; isLast: boolean; constructor( - output: any, + output: AgentChatResponse | StreamingAgentChatResponse, taskStep: TaskStep, nextSteps: TaskStep[], isLast: boolean = false, diff --git a/packages/core/src/llm/open_ai.ts b/packages/core/src/llm/open_ai.ts index b9987a73f..c5dff90e2 100644 --- a/packages/core/src/llm/open_ai.ts +++ b/packages/core/src/llm/open_ai.ts @@ -336,7 +336,8 @@ export class OpenAI extends BaseLLM { yield { // add tool calls to final chunk - additionalKwargs: isDone ? { toolCalls: toolCalls } : undefined, + additionalKwargs: + toolCalls.length > 0 ? { toolCalls: toolCalls } : undefined, delta: choice.delta.content ?? "", }; } @@ -355,16 +356,19 @@ function updateToolCalls( toolCall = toolCall ?? ({ function: { name: "", arguments: "" } } as MessageToolCall); + toolCall.id = toolCall.id ?? toolCallDelta?.id; + toolCall.type = toolCall.type ?? toolCallDelta?.type; if (toolCallDelta?.function?.arguments) { toolCall.function.arguments += toolCallDelta.function.arguments; } if (toolCallDelta?.function?.name) { toolCall.function.name += toolCallDelta.function.name; } + return toolCall; } if (toolCallDeltas) { toolCallDeltas?.forEach((toolCall, i) => { - augmentToolCall(toolCalls[i], toolCall); + toolCalls[i] = augmentToolCall(toolCalls[i], toolCall); }); } } diff --git a/packages/core/src/llm/types.ts b/packages/core/src/llm/types.ts index 1131473d9..c49a5e8b1 100644 --- a/packages/core/src/llm/types.ts +++ b/packages/core/src/llm/types.ts @@ -149,4 +149,5 @@ interface Function { export interface MessageToolCall { id: string; function: Function; + type: "function"; } -- GitLab