From a3b44093c2554d5a12c35f4eca17247881af4e02 Mon Sep 17 00:00:00 2001 From: Marcus Schiesser <mail@marcusschiesser.de> Date: Wed, 10 Apr 2024 21:38:54 +0800 Subject: [PATCH] fix: agent streaming with new OpenAI models (#706) Co-authored-by: Alex Yang <himself65@outlook.com> --- .changeset/nasty-buttons-confess.md | 5 + .github/workflows/test.yml | 13 +- packages/core/e2e/node/openai.e2e.ts | 40 +++++- .../core/e2e/node/snapshot/gpt-4-turbo.snap | 135 ++++++++++++++++++ packages/core/e2e/node/utils.ts | 23 +-- packages/core/src/agent/openai/worker.ts | 93 +++++++----- packages/core/src/llm/open_ai.ts | 9 +- packages/env/package.json | 6 +- packages/env/src/index.polyfill.ts | 4 +- packages/env/src/index.ts | 3 +- pnpm-lock.yaml | 11 +- 11 files changed, 284 insertions(+), 58 deletions(-) create mode 100644 .changeset/nasty-buttons-confess.md create mode 100644 packages/core/e2e/node/snapshot/gpt-4-turbo.snap diff --git a/.changeset/nasty-buttons-confess.md b/.changeset/nasty-buttons-confess.md new file mode 100644 index 000000000..69797d01d --- /dev/null +++ b/.changeset/nasty-buttons-confess.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +Fix agent streaming with new OpenAI models diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f8ff42934..5e86880a5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,6 +1,12 @@ name: Run Tests -on: [push, pull_request] +on: + push: + branches: + - main + pull_request: + branches: + - main concurrency: group: ${{ github.workflow }}-${{ github.ref }} @@ -8,6 +14,11 @@ concurrency: jobs: e2e: + strategy: + fail-fast: false + matrix: + node-version: [18.x, 20.x, 21.x] + name: E2E on Node.js ${{ matrix.node-version }} runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 diff --git a/packages/core/e2e/node/openai.e2e.ts b/packages/core/e2e/node/openai.e2e.ts index b5762d3ff..883ab0c5c 100644 --- a/packages/core/e2e/node/openai.e2e.ts +++ b/packages/core/e2e/node/openai.e2e.ts @@ -11,11 +11,11 @@ import { type LLM, } from "llamaindex"; import { ok } from "node:assert"; -import { before, test } from "node:test"; +import { beforeEach, test } from "node:test"; import { mockLLMEvent } from "./utils.js"; let llm: LLM; -before(async () => { +beforeEach(async () => { Settings.llm = new OpenAI({ model: "gpt-3.5-turbo", }); @@ -54,6 +54,41 @@ test("llm", async (t) => { }); }); +test("gpt-4-turbo", async (t) => { + const llm = new OpenAI({ model: "gpt-4-turbo" }); + Settings.llm = llm; + await mockLLMEvent(t, "gpt-4-turbo"); + await t.test("agent", async () => { + const agent = new OpenAIAgent({ + llm, + tools: [ + { + call: async () => { + return "45 degrees and sunny in San Jose"; + }, + metadata: { + name: "Weather", + description: "Get the weather", + parameters: { + type: "object", + properties: { + location: { type: "string" }, + }, + required: ["location"], + }, + }, + }, + ], + }); + const { response } = await agent.chat({ + message: "What is the weather in San Jose?", + }); + consola.debug("response:", response); + ok(typeof response === "string"); + ok(response.includes("45")); + }); +}); + test("agent", async (t) => { await mockLLMEvent(t, "agent"); await t.test("chat", async () => { @@ -82,6 +117,7 @@ test("agent", async (t) => { }); consola.debug("response:", result.response); ok(typeof result.response === "string"); + ok(result.response.includes("35")); }); }); diff --git a/packages/core/e2e/node/snapshot/gpt-4-turbo.snap b/packages/core/e2e/node/snapshot/gpt-4-turbo.snap new file mode 100644 index 000000000..8a93f3d8d --- /dev/null +++ b/packages/core/e2e/node/snapshot/gpt-4-turbo.snap @@ -0,0 +1,135 @@ +{ + "llmEventStart": [ + { + "id": "3c5024e0-df1d-4a29-b491-9712324bd520", + "messages": [ + { + "content": "What is the weather in San Jose?", + "role": "user" + } + ] + }, + { + "id": "860b61c3-3c3a-4301-8200-9d6c0668cae5", + "messages": [ + { + "content": "What is the weather in San Jose?", + "role": "user" + }, + { + "content": "", + "role": "assistant", + "options": { + "toolCalls": [ + { + "id": "call_wlpohl1FXSCU9vV2CsjTPSWE", + "type": "function", + "function": { + "name": "Weather", + "arguments": "{\"location\":\"San Jose\"}" + } + } + ] + } + }, + { + "content": "45 degrees and sunny in San Jose", + "role": "tool", + "options": { + "name": "Weather", + "tool_call_id": "call_wlpohl1FXSCU9vV2CsjTPSWE" + } + } + ] + } + ], + "llmEventEnd": [ + { + "id": "3c5024e0-df1d-4a29-b491-9712324bd520", + "response": { + "raw": { + "id": "chatcmpl-9CQt20hfgKNlrbsbu47j40GzHzFUJ", + "object": "chat.completion", + "created": 1712750316, + "model": "gpt-3.5-turbo-0125", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "call_wlpohl1FXSCU9vV2CsjTPSWE", + "type": "function", + "function": { + "name": "Weather", + "arguments": "{\"location\":\"San Jose\"}" + } + } + ] + }, + "logprobs": null, + "finish_reason": "tool_calls" + } + ], + "usage": { + "prompt_tokens": 49, + "completion_tokens": 15, + "total_tokens": 64 + }, + "system_fingerprint": "fp_b28b39ffa8" + }, + "message": { + "content": "", + "role": "assistant", + "options": { + "toolCalls": [ + { + "id": "call_wlpohl1FXSCU9vV2CsjTPSWE", + "type": "function", + "function": { + "name": "Weather", + "arguments": "{\"location\":\"San Jose\"}" + } + } + ] + } + } + } + }, + { + "id": "860b61c3-3c3a-4301-8200-9d6c0668cae5", + "response": { + "raw": { + "id": "chatcmpl-9CQt2PPpt5qL8wl3lipBYJXLZXeQi", + "object": "chat.completion", + "created": 1712750316, + "model": "gpt-3.5-turbo-0125", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "The weather in San Jose is currently 45 degrees and sunny." + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 78, + "completion_tokens": 14, + "total_tokens": 92 + }, + "system_fingerprint": "fp_b28b39ffa8" + }, + "message": { + "content": "The weather in San Jose is currently 45 degrees and sunny.", + "role": "assistant", + "options": {} + } + } + } + ] +} \ No newline at end of file diff --git a/packages/core/e2e/node/utils.ts b/packages/core/e2e/node/utils.ts index d4272f69e..260e86e6b 100644 --- a/packages/core/e2e/node/utils.ts +++ b/packages/core/e2e/node/utils.ts @@ -35,15 +35,22 @@ export async function mockLLMEvent( await readFile(join(testRootDir, "snapshot", `${snapshotName}.snap`), { encoding: "utf-8", - }).then((data) => { - const result = JSON.parse(data) as MockStorage; - result["llmEventEnd"].forEach((event) => { - llmCompleteMockStorage.llmEventEnd.push(event); + }) + .then((data) => { + const result = JSON.parse(data) as MockStorage; + result["llmEventEnd"].forEach((event) => { + llmCompleteMockStorage.llmEventEnd.push(event); + }); + result["llmEventStart"].forEach((event) => { + llmCompleteMockStorage.llmEventStart.push(event); + }); + }) + .catch((error) => { + if (error.code === "ENOENT") { + console.warn("Snapshot file not found, will create a new one"); + return; + } }); - result["llmEventStart"].forEach((event) => { - llmCompleteMockStorage.llmEventStart.push(event); - }); - }); Settings.callbackManager.on("llm-start", captureLLMStart); Settings.callbackManager.on("llm-end", captureLLMEnd); diff --git a/packages/core/src/agent/openai/worker.ts b/packages/core/src/agent/openai/worker.ts index 404757e5b..8a229ffac 100644 --- a/packages/core/src/agent/openai/worker.ts +++ b/packages/core/src/agent/openai/worker.ts @@ -1,4 +1,4 @@ -import { randomUUID } from "@llamaindex/env"; +import { pipeline, randomUUID } from "@llamaindex/env"; import type { ChatCompletionToolChoiceOption } from "openai/resources/chat/completions"; import { Response } from "../../Response.js"; import { Settings } from "../../Settings.js"; @@ -14,12 +14,9 @@ import { type ChatResponseChunk, type LLMChatParamsBase, type OpenAIAdditionalChatOptions, + type OpenAIAdditionalMessageOptions, } from "../../llm/index.js"; -import { - extractText, - streamConverter, - streamReducer, -} from "../../llm/utils.js"; +import { extractText } from "../../llm/utils.js"; import { ChatMemoryBuffer } from "../../memory/ChatMemoryBuffer.js"; import type { ObjectRetriever } from "../../objects/base.js"; import type { ToolOutput } from "../../tools/types.js"; @@ -181,40 +178,70 @@ export class OpenAIAgentWorker stream: true, ...llmChatParams, }); - // read first chunk from stream to find out if we need to call tools - const iterator = stream[Symbol.asyncIterator](); - let { value } = await iterator.next(); - let content = value.delta; - const hasToolCalls = value.options?.toolCalls.length > 0; + + const responseChunkStream = new ReadableStream< + ChatResponseChunk<OpenAIAdditionalMessageOptions> + >({ + async start(controller) { + for await (const chunk of stream) { + controller.enqueue(chunk); + } + }, + }); + const [pipStream, finalStream] = responseChunkStream.tee(); + const { value } = await pipStream.getReader().read(); + if (value === undefined) { + throw new Error("first chunk value is undefined, this should not happen"); + } + // check if first chunk has tool calls, if so, this is a function call + // otherwise, it's a regular message + const hasToolCalls: boolean = + !!value.options?.toolCalls?.length && + value.options?.toolCalls?.length > 0; if (hasToolCalls) { - // consume stream until we have all the tool calls and return a non-streamed response - for await (value of stream) { - content += value.delta; - } return this._processMessage(task, { - content, + content: await pipeline(finalStream, async (iterator) => { + let content = ""; + for await (const value of iterator) { + content += value.delta; + } + return content; + }), role: "assistant", options: value.options, }); - } - - const newStream = streamConverter.bind(this)( - streamReducer({ - stream, - initialValue: content, - reducer: (accumulator, part) => (accumulator += part.delta), - finished: (accumulator) => { - task.extraState.newMemory.put({ - content: accumulator, - role: "assistant", - }); + } else { + let content = ""; + return pipeline( + finalStream.pipeThrough<Response>({ + readable: new ReadableStream({ + async start(controller) { + for await (const chunk of finalStream) { + controller.enqueue(new Response(chunk.delta)); + } + }, + }), + writable: new WritableStream({ + write(chunk) { + content += chunk.delta; + }, + close() { + task.extraState.newMemory.put({ + content, + role: "assistant", + }); + }, + }), + }), + async (iterator: AsyncIterable<Response>) => { + return new StreamingAgentChatResponse( + iterator, + task.extraState.sources, + ); }, - }), - (r: ChatResponseChunk) => new Response(r.delta), - ); - - return new StreamingAgentChatResponse(newStream, task.extraState.sources); + ); + } } private async _getAgentResponse( diff --git a/packages/core/src/llm/open_ai.ts b/packages/core/src/llm/open_ai.ts index 0c3b4fafb..6d6db4eb5 100644 --- a/packages/core/src/llm/open_ai.ts +++ b/packages/core/src/llm/open_ai.ts @@ -391,6 +391,8 @@ export class OpenAI extends BaseLLM< for await (const part of stream) { if (!part.choices.length) continue; const choice = part.choices[0]; + // skip parts that don't have any content + if (!(choice.delta.content || choice.delta.tool_calls)) continue; updateToolCalls(toolCalls, choice.delta.tool_calls); const isDone: boolean = choice.finish_reason !== null; @@ -444,8 +446,11 @@ function updateToolCalls( return toolCall; } if (toolCallDeltas) { - toolCallDeltas?.forEach((toolCall, i) => { - toolCalls[i] = augmentToolCall(toolCalls[i], toolCall); + toolCallDeltas?.forEach((toolCall) => { + toolCalls[toolCall.index] = augmentToolCall( + toolCalls[toolCall.index], + toolCall, + ); }); } } diff --git a/packages/env/package.json b/packages/env/package.json index 3b25dd73b..cfaca5b9a 100644 --- a/packages/env/package.json +++ b/packages/env/package.json @@ -56,8 +56,9 @@ "@aws-crypto/sha256-js": "^5.2.0", "@swc/cli": "^0.3.9", "@swc/core": "^1.4.2", + "concurrently": "^8.2.2", "pathe": "^1.1.2", - "concurrently": "^8.2.2" + "readable-stream": "^4.5.2" }, "dependencies": { "@types/lodash": "^4.14.202", @@ -66,6 +67,7 @@ }, "peerDependencies": { "@aws-crypto/sha256-js": "^5.2.0", - "pathe": "^1.1.2" + "pathe": "^1.1.2", + "readable-stream": "^4.5.2" } } diff --git a/packages/env/src/index.polyfill.ts b/packages/env/src/index.polyfill.ts index 0e2d49f86..f911741c9 100644 --- a/packages/env/src/index.polyfill.ts +++ b/packages/env/src/index.polyfill.ts @@ -1,8 +1,10 @@ import { Sha256 } from "@aws-crypto/sha256-js"; import pathe from "pathe"; import { InMemoryFileSystem, type CompleteFileSystem } from "./type.js"; +// @ts-expect-error +import { pipeline } from "readable-stream"; -export { pathe as path }; +export { pathe as path, pipeline }; export interface SHA256 { update(data: string | Uint8Array): void; diff --git a/packages/env/src/index.ts b/packages/env/src/index.ts index 2efc3833b..72333d095 100644 --- a/packages/env/src/index.ts +++ b/packages/env/src/index.ts @@ -3,6 +3,7 @@ import { createHash, randomUUID } from "node:crypto"; import fs from "node:fs/promises"; import { EOL } from "node:os"; import path from "node:path"; +import { pipeline } from "node:stream/promises"; import type { SHA256 } from "./index.polyfill.js"; import type { CompleteFileSystem } from "./type.js"; @@ -36,4 +37,4 @@ export const defaultFS: CompleteFileSystem = { export type * from "./type.js"; export { AsyncLocalStorage, CustomEvent, getEnv } from "./utils.js"; -export { EOL, ok, path, randomUUID }; +export { EOL, ok, path, pipeline, randomUUID }; diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 7597ea14b..d64e6b9d4 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -501,6 +501,9 @@ importers: pathe: specifier: ^1.1.2 version: 1.1.2 + readable-stream: + specifier: ^4.5.2 + version: 4.5.2 packages/eslint-config-custom: dependencies: @@ -5424,7 +5427,6 @@ packages: engines: {node: '>=6.5'} dependencies: event-target-shim: 5.0.1 - dev: false /accepts@1.3.8: resolution: {integrity: sha512-PYAthTa2m2VKxuvSD3DPC/Gy+U+sOA1LAuT8mkmRuvw+NACSaeXEQ+NHcVF7rONl6qcaxV3Uuemwawk+7+SJLw==} @@ -6104,8 +6106,6 @@ packages: dependencies: base64-js: 1.5.1 ieee754: 1.2.1 - dev: false - optional: true /busboy@1.6.0: resolution: {integrity: sha512-8SFQbg/0hQ9xy3UNTB0YEnsNBbWfhf7RtnzpL7TkBiTBRfrQ9Fxcnz7VJsleJpyp6rVLvXiuORqjlHi5q+PYuA==} @@ -8271,7 +8271,6 @@ packages: /event-target-shim@5.0.1: resolution: {integrity: sha512-i/2XbnSz/uxRCU6+NdVJgKWDTM427+MqYbkQzD321DuCQJUqOuJKIA0IM2+W2xtYHdKOmZ4dR6fExsd4SXL+WQ==} engines: {node: '>=6'} - dev: false /eventemitter3@4.0.7: resolution: {integrity: sha512-8guHBZCwKnFhYdHr2ysuRWErTwhoN2X8XELRlrRwpmfeY2jjuUN4taQMsULKUVo1K4DvZl+0pgfyoysHxvmvEw==} @@ -12923,8 +12922,6 @@ packages: resolution: {integrity: sha512-cdGef/drWFoydD1JsMzuFf8100nZl+GT+yacc2bEced5f9Rjk4z+WtFUTBu9PhOi9j/jfmBPu0mMEY4wIdAF8A==} engines: {node: '>= 0.6.0'} requiresBuild: true - dev: false - optional: true /prompts@2.4.2: resolution: {integrity: sha512-NxNv/kLguCA7p3jE8oL2aEBsrJWgAakBpgmgK6lpPWV+WuOmY6r2/zbAVnP+T8bQlA0nzHXSJSJW0Hq7ylaD2Q==} @@ -13378,8 +13375,6 @@ packages: events: 3.3.0 process: 0.11.10 string_decoder: 1.3.0 - dev: false - optional: true /readable-web-to-node-stream@3.0.2: resolution: {integrity: sha512-ePeK6cc1EcKLEhJFt/AebMCLL+GgSKhuygrZ/GLaKZYEecIgIECf4UaUuaByiGtzckwR4ain9VzUh95T1exYGw==} -- GitLab