Skip to content
Snippets Groups Projects
Unverified Commit f0704ec7 authored by Marcus Schiesser's avatar Marcus Schiesser Committed by GitHub
Browse files

Add streaming for OpenAI agents (#693)

parent 4fcbdf71
No related branches found
No related tags found
No related merge requests found
---
"llamaindex": patch
---
Support streaming for OpenAI agent
......@@ -14,12 +14,15 @@ async function main() {
// Chat with the agent
const response = await agent.chat({
message: "Who was Goethe?",
stream: true,
});
console.log(response.response);
for await (const chunk of response.response) {
process.stdout.write(chunk.response);
}
}
(async function () {
await main();
console.log("Done");
console.log("\nDone");
})();
import { OpenAIAgent, WikipediaTool } from "llamaindex";
async function main() {
const wikipediaTool = new WikipediaTool();
// Create an OpenAIAgent with the function tools
const agent = new OpenAIAgent({
tools: [wikipediaTool],
verbose: true,
});
// Chat with the agent
const response = await agent.chat({
message: "Where is Ho Chi Minh City?",
});
// Print the response
console.log(response);
}
void main().then(() => {
console.log("Done");
});
......@@ -9,6 +9,7 @@ import type {
ChatMessage,
ChatResponse,
ChatResponseChunk,
LLMChatParamsBase,
} from "../../llm/index.js";
import { OpenAI } from "../../llm/index.js";
import { streamConverter, streamReducer } from "../../llm/utils.js";
......@@ -166,8 +167,8 @@ export class OpenAIAgentWorker implements AgentWorker {
task: Task,
openaiTools: { [key: string]: any }[],
toolChoice: string | { [key: string]: any } = "auto",
): { [key: string]: any } {
const llmChatKwargs: { [key: string]: any } = {
): LLMChatParamsBase {
const llmChatKwargs: LLMChatParamsBase = {
messages: this.getAllMessages(task),
};
......@@ -179,17 +180,10 @@ export class OpenAIAgentWorker implements AgentWorker {
return llmChatKwargs;
}
/**
* Process message.
* @param task: task
* @param chatResponse: chat response
* @returns: agent chat response
*/
private _processMessage(
task: Task,
chatResponse: ChatResponse,
aiMessage: ChatMessage,
): AgentChatResponse {
const aiMessage = chatResponse.message;
task.extraState.newMemory.put(aiMessage);
return new AgentChatResponse(aiMessage.content, task.extraState.sources);
......@@ -198,16 +192,33 @@ export class OpenAIAgentWorker implements AgentWorker {
private async _getStreamAiResponse(
task: Task,
llmChatKwargs: any,
): Promise<StreamingAgentChatResponse> {
): Promise<StreamingAgentChatResponse | AgentChatResponse> {
const stream = await this.llm.chat({
stream: true,
...llmChatKwargs,
});
// read first chunk from stream to find out if we need to call tools
const iterator = stream[Symbol.asyncIterator]();
let { value } = await iterator.next();
let content = value.delta;
const hasToolCalls = value.additionalKwargs?.toolCalls.length > 0;
if (hasToolCalls) {
// consume stream until we have all the tool calls and return a non-streamed response
for await (value of stream) {
content += value.delta;
}
return this._processMessage(task, {
content,
role: "assistant",
additionalKwargs: value.additionalKwargs,
});
}
const iterator = streamConverter.bind(this)(
const newStream = streamConverter.bind(this)(
streamReducer({
stream,
initialValue: "",
initialValue: content,
reducer: (accumulator, part) => (accumulator += part.delta),
finished: (accumulator) => {
task.extraState.newMemory.put({
......@@ -219,7 +230,7 @@ export class OpenAIAgentWorker implements AgentWorker {
(r: ChatResponseChunk) => new Response(r.delta),
);
return new StreamingAgentChatResponse(iterator, task.extraState.sources);
return new StreamingAgentChatResponse(newStream, task.extraState.sources);
}
/**
......@@ -240,7 +251,10 @@ export class OpenAIAgentWorker implements AgentWorker {
...llmChatKwargs,
})) as unknown as ChatResponse;
return this._processMessage(task, chatResponse) as AgentChatResponse;
return this._processMessage(
task,
chatResponse.message,
) as AgentChatResponse;
} else if (mode === ChatResponseMode.STREAM) {
return this._getStreamAiResponse(task, llmChatKwargs);
}
......
......@@ -170,13 +170,13 @@ export class TaskStep implements ITaskStep {
* @param isLast: isLast
*/
export class TaskStepOutput {
output: any;
output: AgentChatResponse | StreamingAgentChatResponse;
taskStep: TaskStep;
nextSteps: TaskStep[];
isLast: boolean;
constructor(
output: any,
output: AgentChatResponse | StreamingAgentChatResponse,
taskStep: TaskStep,
nextSteps: TaskStep[],
isLast: boolean = false,
......
......@@ -336,7 +336,8 @@ export class OpenAI extends BaseLLM {
yield {
// add tool calls to final chunk
additionalKwargs: isDone ? { toolCalls: toolCalls } : undefined,
additionalKwargs:
toolCalls.length > 0 ? { toolCalls: toolCalls } : undefined,
delta: choice.delta.content ?? "",
};
}
......@@ -355,16 +356,19 @@ function updateToolCalls(
toolCall =
toolCall ??
({ function: { name: "", arguments: "" } } as MessageToolCall);
toolCall.id = toolCall.id ?? toolCallDelta?.id;
toolCall.type = toolCall.type ?? toolCallDelta?.type;
if (toolCallDelta?.function?.arguments) {
toolCall.function.arguments += toolCallDelta.function.arguments;
}
if (toolCallDelta?.function?.name) {
toolCall.function.name += toolCallDelta.function.name;
}
return toolCall;
}
if (toolCallDeltas) {
toolCallDeltas?.forEach((toolCall, i) => {
augmentToolCall(toolCalls[i], toolCall);
toolCalls[i] = augmentToolCall(toolCalls[i], toolCall);
});
}
}
......@@ -149,4 +149,5 @@ interface Function {
export interface MessageToolCall {
id: string;
function: Function;
type: "function";
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment