Skip to content
Snippets Groups Projects
Unverified Commit f0704ec7 authored by Marcus Schiesser's avatar Marcus Schiesser Committed by GitHub
Browse files

Add streaming for OpenAI agents (#693)

parent 4fcbdf71
Branches
Tags llamaindex@0.2.3
No related merge requests found
---
"llamaindex": patch
---
Support streaming for OpenAI agent
...@@ -14,12 +14,15 @@ async function main() { ...@@ -14,12 +14,15 @@ async function main() {
// Chat with the agent // Chat with the agent
const response = await agent.chat({ const response = await agent.chat({
message: "Who was Goethe?", message: "Who was Goethe?",
stream: true,
}); });
console.log(response.response); for await (const chunk of response.response) {
process.stdout.write(chunk.response);
}
} }
(async function () { (async function () {
await main(); await main();
console.log("Done"); console.log("\nDone");
})(); })();
import { OpenAIAgent, WikipediaTool } from "llamaindex";
async function main() {
const wikipediaTool = new WikipediaTool();
// Create an OpenAIAgent with the function tools
const agent = new OpenAIAgent({
tools: [wikipediaTool],
verbose: true,
});
// Chat with the agent
const response = await agent.chat({
message: "Where is Ho Chi Minh City?",
});
// Print the response
console.log(response);
}
void main().then(() => {
console.log("Done");
});
...@@ -9,6 +9,7 @@ import type { ...@@ -9,6 +9,7 @@ import type {
ChatMessage, ChatMessage,
ChatResponse, ChatResponse,
ChatResponseChunk, ChatResponseChunk,
LLMChatParamsBase,
} from "../../llm/index.js"; } from "../../llm/index.js";
import { OpenAI } from "../../llm/index.js"; import { OpenAI } from "../../llm/index.js";
import { streamConverter, streamReducer } from "../../llm/utils.js"; import { streamConverter, streamReducer } from "../../llm/utils.js";
...@@ -166,8 +167,8 @@ export class OpenAIAgentWorker implements AgentWorker { ...@@ -166,8 +167,8 @@ export class OpenAIAgentWorker implements AgentWorker {
task: Task, task: Task,
openaiTools: { [key: string]: any }[], openaiTools: { [key: string]: any }[],
toolChoice: string | { [key: string]: any } = "auto", toolChoice: string | { [key: string]: any } = "auto",
): { [key: string]: any } { ): LLMChatParamsBase {
const llmChatKwargs: { [key: string]: any } = { const llmChatKwargs: LLMChatParamsBase = {
messages: this.getAllMessages(task), messages: this.getAllMessages(task),
}; };
...@@ -179,17 +180,10 @@ export class OpenAIAgentWorker implements AgentWorker { ...@@ -179,17 +180,10 @@ export class OpenAIAgentWorker implements AgentWorker {
return llmChatKwargs; return llmChatKwargs;
} }
/**
* Process message.
* @param task: task
* @param chatResponse: chat response
* @returns: agent chat response
*/
private _processMessage( private _processMessage(
task: Task, task: Task,
chatResponse: ChatResponse, aiMessage: ChatMessage,
): AgentChatResponse { ): AgentChatResponse {
const aiMessage = chatResponse.message;
task.extraState.newMemory.put(aiMessage); task.extraState.newMemory.put(aiMessage);
return new AgentChatResponse(aiMessage.content, task.extraState.sources); return new AgentChatResponse(aiMessage.content, task.extraState.sources);
...@@ -198,16 +192,33 @@ export class OpenAIAgentWorker implements AgentWorker { ...@@ -198,16 +192,33 @@ export class OpenAIAgentWorker implements AgentWorker {
private async _getStreamAiResponse( private async _getStreamAiResponse(
task: Task, task: Task,
llmChatKwargs: any, llmChatKwargs: any,
): Promise<StreamingAgentChatResponse> { ): Promise<StreamingAgentChatResponse | AgentChatResponse> {
const stream = await this.llm.chat({ const stream = await this.llm.chat({
stream: true, stream: true,
...llmChatKwargs, ...llmChatKwargs,
}); });
// read first chunk from stream to find out if we need to call tools
const iterator = stream[Symbol.asyncIterator]();
let { value } = await iterator.next();
let content = value.delta;
const hasToolCalls = value.additionalKwargs?.toolCalls.length > 0;
if (hasToolCalls) {
// consume stream until we have all the tool calls and return a non-streamed response
for await (value of stream) {
content += value.delta;
}
return this._processMessage(task, {
content,
role: "assistant",
additionalKwargs: value.additionalKwargs,
});
}
const iterator = streamConverter.bind(this)( const newStream = streamConverter.bind(this)(
streamReducer({ streamReducer({
stream, stream,
initialValue: "", initialValue: content,
reducer: (accumulator, part) => (accumulator += part.delta), reducer: (accumulator, part) => (accumulator += part.delta),
finished: (accumulator) => { finished: (accumulator) => {
task.extraState.newMemory.put({ task.extraState.newMemory.put({
...@@ -219,7 +230,7 @@ export class OpenAIAgentWorker implements AgentWorker { ...@@ -219,7 +230,7 @@ export class OpenAIAgentWorker implements AgentWorker {
(r: ChatResponseChunk) => new Response(r.delta), (r: ChatResponseChunk) => new Response(r.delta),
); );
return new StreamingAgentChatResponse(iterator, task.extraState.sources); return new StreamingAgentChatResponse(newStream, task.extraState.sources);
} }
/** /**
...@@ -240,7 +251,10 @@ export class OpenAIAgentWorker implements AgentWorker { ...@@ -240,7 +251,10 @@ export class OpenAIAgentWorker implements AgentWorker {
...llmChatKwargs, ...llmChatKwargs,
})) as unknown as ChatResponse; })) as unknown as ChatResponse;
return this._processMessage(task, chatResponse) as AgentChatResponse; return this._processMessage(
task,
chatResponse.message,
) as AgentChatResponse;
} else if (mode === ChatResponseMode.STREAM) { } else if (mode === ChatResponseMode.STREAM) {
return this._getStreamAiResponse(task, llmChatKwargs); return this._getStreamAiResponse(task, llmChatKwargs);
} }
......
...@@ -170,13 +170,13 @@ export class TaskStep implements ITaskStep { ...@@ -170,13 +170,13 @@ export class TaskStep implements ITaskStep {
* @param isLast: isLast * @param isLast: isLast
*/ */
export class TaskStepOutput { export class TaskStepOutput {
output: any; output: AgentChatResponse | StreamingAgentChatResponse;
taskStep: TaskStep; taskStep: TaskStep;
nextSteps: TaskStep[]; nextSteps: TaskStep[];
isLast: boolean; isLast: boolean;
constructor( constructor(
output: any, output: AgentChatResponse | StreamingAgentChatResponse,
taskStep: TaskStep, taskStep: TaskStep,
nextSteps: TaskStep[], nextSteps: TaskStep[],
isLast: boolean = false, isLast: boolean = false,
......
...@@ -336,7 +336,8 @@ export class OpenAI extends BaseLLM { ...@@ -336,7 +336,8 @@ export class OpenAI extends BaseLLM {
yield { yield {
// add tool calls to final chunk // add tool calls to final chunk
additionalKwargs: isDone ? { toolCalls: toolCalls } : undefined, additionalKwargs:
toolCalls.length > 0 ? { toolCalls: toolCalls } : undefined,
delta: choice.delta.content ?? "", delta: choice.delta.content ?? "",
}; };
} }
...@@ -355,16 +356,19 @@ function updateToolCalls( ...@@ -355,16 +356,19 @@ function updateToolCalls(
toolCall = toolCall =
toolCall ?? toolCall ??
({ function: { name: "", arguments: "" } } as MessageToolCall); ({ function: { name: "", arguments: "" } } as MessageToolCall);
toolCall.id = toolCall.id ?? toolCallDelta?.id;
toolCall.type = toolCall.type ?? toolCallDelta?.type;
if (toolCallDelta?.function?.arguments) { if (toolCallDelta?.function?.arguments) {
toolCall.function.arguments += toolCallDelta.function.arguments; toolCall.function.arguments += toolCallDelta.function.arguments;
} }
if (toolCallDelta?.function?.name) { if (toolCallDelta?.function?.name) {
toolCall.function.name += toolCallDelta.function.name; toolCall.function.name += toolCallDelta.function.name;
} }
return toolCall;
} }
if (toolCallDeltas) { if (toolCallDeltas) {
toolCallDeltas?.forEach((toolCall, i) => { toolCallDeltas?.forEach((toolCall, i) => {
augmentToolCall(toolCalls[i], toolCall); toolCalls[i] = augmentToolCall(toolCalls[i], toolCall);
}); });
} }
} }
...@@ -149,4 +149,5 @@ interface Function { ...@@ -149,4 +149,5 @@ interface Function {
export interface MessageToolCall { export interface MessageToolCall {
id: string; id: string;
function: Function; function: Function;
type: "function";
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment