diff --git a/.changeset/long-cobras-play.md b/.changeset/long-cobras-play.md new file mode 100644 index 0000000000000000000000000000000000000000..6f736784a94d782e972142f3688882c3839f3fb3 --- /dev/null +++ b/.changeset/long-cobras-play.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +Support streaming for OpenAI agent diff --git a/examples/agent/wiki.ts b/examples/agent/wiki.ts index 739f834cba104eb1b9241dac4efa426c1d9a9052..db308db5748bfbda04ac27e3494a1e5e2ed15b58 100644 --- a/examples/agent/wiki.ts +++ b/examples/agent/wiki.ts @@ -14,12 +14,15 @@ async function main() { // Chat with the agent const response = await agent.chat({ message: "Who was Goethe?", + stream: true, }); - console.log(response.response); + for await (const chunk of response.response) { + process.stdout.write(chunk.response); + } } (async function () { await main(); - console.log("Done"); + console.log("\nDone"); })(); diff --git a/examples/agent/wikipedia-tool.ts b/examples/agent/wikipedia-tool.ts deleted file mode 100644 index 1a5e82fc116bf37b583b60187c1871b93fd051ff..0000000000000000000000000000000000000000 --- a/examples/agent/wikipedia-tool.ts +++ /dev/null @@ -1,23 +0,0 @@ -import { OpenAIAgent, WikipediaTool } from "llamaindex"; - -async function main() { - const wikipediaTool = new WikipediaTool(); - - // Create an OpenAIAgent with the function tools - const agent = new OpenAIAgent({ - tools: [wikipediaTool], - verbose: true, - }); - - // Chat with the agent - const response = await agent.chat({ - message: "Where is Ho Chi Minh City?", - }); - - // Print the response - console.log(response); -} - -void main().then(() => { - console.log("Done"); -}); diff --git a/packages/core/src/agent/openai/worker.ts b/packages/core/src/agent/openai/worker.ts index b4f80407b589b457aecb187d05c8338739605871..e978a0dab8d1636a1060a56d1b0fe29dfe93ac79 100644 --- a/packages/core/src/agent/openai/worker.ts +++ b/packages/core/src/agent/openai/worker.ts @@ -9,6 +9,7 @@ import type { ChatMessage, ChatResponse, ChatResponseChunk, + LLMChatParamsBase, } from "../../llm/index.js"; import { OpenAI } from "../../llm/index.js"; import { streamConverter, streamReducer } from "../../llm/utils.js"; @@ -166,8 +167,8 @@ export class OpenAIAgentWorker implements AgentWorker { task: Task, openaiTools: { [key: string]: any }[], toolChoice: string | { [key: string]: any } = "auto", - ): { [key: string]: any } { - const llmChatKwargs: { [key: string]: any } = { + ): LLMChatParamsBase { + const llmChatKwargs: LLMChatParamsBase = { messages: this.getAllMessages(task), }; @@ -179,17 +180,10 @@ export class OpenAIAgentWorker implements AgentWorker { return llmChatKwargs; } - /** - * Process message. - * @param task: task - * @param chatResponse: chat response - * @returns: agent chat response - */ private _processMessage( task: Task, - chatResponse: ChatResponse, + aiMessage: ChatMessage, ): AgentChatResponse { - const aiMessage = chatResponse.message; task.extraState.newMemory.put(aiMessage); return new AgentChatResponse(aiMessage.content, task.extraState.sources); @@ -198,16 +192,33 @@ export class OpenAIAgentWorker implements AgentWorker { private async _getStreamAiResponse( task: Task, llmChatKwargs: any, - ): Promise<StreamingAgentChatResponse> { + ): Promise<StreamingAgentChatResponse | AgentChatResponse> { const stream = await this.llm.chat({ stream: true, ...llmChatKwargs, }); + // read first chunk from stream to find out if we need to call tools + const iterator = stream[Symbol.asyncIterator](); + let { value } = await iterator.next(); + let content = value.delta; + const hasToolCalls = value.additionalKwargs?.toolCalls.length > 0; + + if (hasToolCalls) { + // consume stream until we have all the tool calls and return a non-streamed response + for await (value of stream) { + content += value.delta; + } + return this._processMessage(task, { + content, + role: "assistant", + additionalKwargs: value.additionalKwargs, + }); + } - const iterator = streamConverter.bind(this)( + const newStream = streamConverter.bind(this)( streamReducer({ stream, - initialValue: "", + initialValue: content, reducer: (accumulator, part) => (accumulator += part.delta), finished: (accumulator) => { task.extraState.newMemory.put({ @@ -219,7 +230,7 @@ export class OpenAIAgentWorker implements AgentWorker { (r: ChatResponseChunk) => new Response(r.delta), ); - return new StreamingAgentChatResponse(iterator, task.extraState.sources); + return new StreamingAgentChatResponse(newStream, task.extraState.sources); } /** @@ -240,7 +251,10 @@ export class OpenAIAgentWorker implements AgentWorker { ...llmChatKwargs, })) as unknown as ChatResponse; - return this._processMessage(task, chatResponse) as AgentChatResponse; + return this._processMessage( + task, + chatResponse.message, + ) as AgentChatResponse; } else if (mode === ChatResponseMode.STREAM) { return this._getStreamAiResponse(task, llmChatKwargs); } diff --git a/packages/core/src/agent/types.ts b/packages/core/src/agent/types.ts index 185c26b43ff59a2630a2e357f480fea72b4fa1b9..6be1f8688a020638336676ad020c073a51eafc88 100644 --- a/packages/core/src/agent/types.ts +++ b/packages/core/src/agent/types.ts @@ -170,13 +170,13 @@ export class TaskStep implements ITaskStep { * @param isLast: isLast */ export class TaskStepOutput { - output: any; + output: AgentChatResponse | StreamingAgentChatResponse; taskStep: TaskStep; nextSteps: TaskStep[]; isLast: boolean; constructor( - output: any, + output: AgentChatResponse | StreamingAgentChatResponse, taskStep: TaskStep, nextSteps: TaskStep[], isLast: boolean = false, diff --git a/packages/core/src/llm/open_ai.ts b/packages/core/src/llm/open_ai.ts index b9987a73fa758b41878b7e3a5abd006a358493d7..c5dff90e2a7a9ceebc71d04b5122fdaf9d4af022 100644 --- a/packages/core/src/llm/open_ai.ts +++ b/packages/core/src/llm/open_ai.ts @@ -336,7 +336,8 @@ export class OpenAI extends BaseLLM { yield { // add tool calls to final chunk - additionalKwargs: isDone ? { toolCalls: toolCalls } : undefined, + additionalKwargs: + toolCalls.length > 0 ? { toolCalls: toolCalls } : undefined, delta: choice.delta.content ?? "", }; } @@ -355,16 +356,19 @@ function updateToolCalls( toolCall = toolCall ?? ({ function: { name: "", arguments: "" } } as MessageToolCall); + toolCall.id = toolCall.id ?? toolCallDelta?.id; + toolCall.type = toolCall.type ?? toolCallDelta?.type; if (toolCallDelta?.function?.arguments) { toolCall.function.arguments += toolCallDelta.function.arguments; } if (toolCallDelta?.function?.name) { toolCall.function.name += toolCallDelta.function.name; } + return toolCall; } if (toolCallDeltas) { toolCallDeltas?.forEach((toolCall, i) => { - augmentToolCall(toolCalls[i], toolCall); + toolCalls[i] = augmentToolCall(toolCalls[i], toolCall); }); } } diff --git a/packages/core/src/llm/types.ts b/packages/core/src/llm/types.ts index 1131473d95105a687fcc01940a1cee55bbe5d211..c49a5e8b14bb5999b5229acf8208222463bedb41 100644 --- a/packages/core/src/llm/types.ts +++ b/packages/core/src/llm/types.ts @@ -149,4 +149,5 @@ interface Function { export interface MessageToolCall { id: string; function: Function; + type: "function"; }