From a3b44093c2554d5a12c35f4eca17247881af4e02 Mon Sep 17 00:00:00 2001
From: Marcus Schiesser <mail@marcusschiesser.de>
Date: Wed, 10 Apr 2024 21:38:54 +0800
Subject: [PATCH] fix: agent streaming with new OpenAI models (#706)

Co-authored-by: Alex Yang <himself65@outlook.com>
---
 .changeset/nasty-buttons-confess.md           |   5 +
 .github/workflows/test.yml                    |  13 +-
 packages/core/e2e/node/openai.e2e.ts          |  40 +++++-
 .../core/e2e/node/snapshot/gpt-4-turbo.snap   | 135 ++++++++++++++++++
 packages/core/e2e/node/utils.ts               |  23 +--
 packages/core/src/agent/openai/worker.ts      |  93 +++++++-----
 packages/core/src/llm/open_ai.ts              |   9 +-
 packages/env/package.json                     |   6 +-
 packages/env/src/index.polyfill.ts            |   4 +-
 packages/env/src/index.ts                     |   3 +-
 pnpm-lock.yaml                                |  11 +-
 11 files changed, 284 insertions(+), 58 deletions(-)
 create mode 100644 .changeset/nasty-buttons-confess.md
 create mode 100644 packages/core/e2e/node/snapshot/gpt-4-turbo.snap

diff --git a/.changeset/nasty-buttons-confess.md b/.changeset/nasty-buttons-confess.md
new file mode 100644
index 000000000..69797d01d
--- /dev/null
+++ b/.changeset/nasty-buttons-confess.md
@@ -0,0 +1,5 @@
+---
+"llamaindex": patch
+---
+
+Fix agent streaming with new OpenAI models
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index f8ff42934..5e86880a5 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -1,6 +1,12 @@
 name: Run Tests
 
-on: [push, pull_request]
+on:
+  push:
+    branches:
+      - main
+  pull_request:
+    branches:
+      - main
 
 concurrency:
   group: ${{ github.workflow }}-${{ github.ref }}
@@ -8,6 +14,11 @@ concurrency:
 
 jobs:
   e2e:
+    strategy:
+      fail-fast: false
+      matrix:
+        node-version: [18.x, 20.x, 21.x]
+    name: E2E on Node.js ${{ matrix.node-version }}
     runs-on: ubuntu-latest
     steps:
       - uses: actions/checkout@v4
diff --git a/packages/core/e2e/node/openai.e2e.ts b/packages/core/e2e/node/openai.e2e.ts
index b5762d3ff..883ab0c5c 100644
--- a/packages/core/e2e/node/openai.e2e.ts
+++ b/packages/core/e2e/node/openai.e2e.ts
@@ -11,11 +11,11 @@ import {
   type LLM,
 } from "llamaindex";
 import { ok } from "node:assert";
-import { before, test } from "node:test";
+import { beforeEach, test } from "node:test";
 import { mockLLMEvent } from "./utils.js";
 
 let llm: LLM;
-before(async () => {
+beforeEach(async () => {
   Settings.llm = new OpenAI({
     model: "gpt-3.5-turbo",
   });
@@ -54,6 +54,41 @@ test("llm", async (t) => {
   });
 });
 
+test("gpt-4-turbo", async (t) => {
+  const llm = new OpenAI({ model: "gpt-4-turbo" });
+  Settings.llm = llm;
+  await mockLLMEvent(t, "gpt-4-turbo");
+  await t.test("agent", async () => {
+    const agent = new OpenAIAgent({
+      llm,
+      tools: [
+        {
+          call: async () => {
+            return "45 degrees and sunny in San Jose";
+          },
+          metadata: {
+            name: "Weather",
+            description: "Get the weather",
+            parameters: {
+              type: "object",
+              properties: {
+                location: { type: "string" },
+              },
+              required: ["location"],
+            },
+          },
+        },
+      ],
+    });
+    const { response } = await agent.chat({
+      message: "What is the weather in San Jose?",
+    });
+    consola.debug("response:", response);
+    ok(typeof response === "string");
+    ok(response.includes("45"));
+  });
+});
+
 test("agent", async (t) => {
   await mockLLMEvent(t, "agent");
   await t.test("chat", async () => {
@@ -82,6 +117,7 @@ test("agent", async (t) => {
     });
     consola.debug("response:", result.response);
     ok(typeof result.response === "string");
+    ok(result.response.includes("35"));
   });
 });
 
diff --git a/packages/core/e2e/node/snapshot/gpt-4-turbo.snap b/packages/core/e2e/node/snapshot/gpt-4-turbo.snap
new file mode 100644
index 000000000..8a93f3d8d
--- /dev/null
+++ b/packages/core/e2e/node/snapshot/gpt-4-turbo.snap
@@ -0,0 +1,135 @@
+{
+  "llmEventStart": [
+    {
+      "id": "3c5024e0-df1d-4a29-b491-9712324bd520",
+      "messages": [
+        {
+          "content": "What is the weather in San Jose?",
+          "role": "user"
+        }
+      ]
+    },
+    {
+      "id": "860b61c3-3c3a-4301-8200-9d6c0668cae5",
+      "messages": [
+        {
+          "content": "What is the weather in San Jose?",
+          "role": "user"
+        },
+        {
+          "content": "",
+          "role": "assistant",
+          "options": {
+            "toolCalls": [
+              {
+                "id": "call_wlpohl1FXSCU9vV2CsjTPSWE",
+                "type": "function",
+                "function": {
+                  "name": "Weather",
+                  "arguments": "{\"location\":\"San Jose\"}"
+                }
+              }
+            ]
+          }
+        },
+        {
+          "content": "45 degrees and sunny in San Jose",
+          "role": "tool",
+          "options": {
+            "name": "Weather",
+            "tool_call_id": "call_wlpohl1FXSCU9vV2CsjTPSWE"
+          }
+        }
+      ]
+    }
+  ],
+  "llmEventEnd": [
+    {
+      "id": "3c5024e0-df1d-4a29-b491-9712324bd520",
+      "response": {
+        "raw": {
+          "id": "chatcmpl-9CQt20hfgKNlrbsbu47j40GzHzFUJ",
+          "object": "chat.completion",
+          "created": 1712750316,
+          "model": "gpt-3.5-turbo-0125",
+          "choices": [
+            {
+              "index": 0,
+              "message": {
+                "role": "assistant",
+                "content": null,
+                "tool_calls": [
+                  {
+                    "id": "call_wlpohl1FXSCU9vV2CsjTPSWE",
+                    "type": "function",
+                    "function": {
+                      "name": "Weather",
+                      "arguments": "{\"location\":\"San Jose\"}"
+                    }
+                  }
+                ]
+              },
+              "logprobs": null,
+              "finish_reason": "tool_calls"
+            }
+          ],
+          "usage": {
+            "prompt_tokens": 49,
+            "completion_tokens": 15,
+            "total_tokens": 64
+          },
+          "system_fingerprint": "fp_b28b39ffa8"
+        },
+        "message": {
+          "content": "",
+          "role": "assistant",
+          "options": {
+            "toolCalls": [
+              {
+                "id": "call_wlpohl1FXSCU9vV2CsjTPSWE",
+                "type": "function",
+                "function": {
+                  "name": "Weather",
+                  "arguments": "{\"location\":\"San Jose\"}"
+                }
+              }
+            ]
+          }
+        }
+      }
+    },
+    {
+      "id": "860b61c3-3c3a-4301-8200-9d6c0668cae5",
+      "response": {
+        "raw": {
+          "id": "chatcmpl-9CQt2PPpt5qL8wl3lipBYJXLZXeQi",
+          "object": "chat.completion",
+          "created": 1712750316,
+          "model": "gpt-3.5-turbo-0125",
+          "choices": [
+            {
+              "index": 0,
+              "message": {
+                "role": "assistant",
+                "content": "The weather in San Jose is currently 45 degrees and sunny."
+              },
+              "logprobs": null,
+              "finish_reason": "stop"
+            }
+          ],
+          "usage": {
+            "prompt_tokens": 78,
+            "completion_tokens": 14,
+            "total_tokens": 92
+          },
+          "system_fingerprint": "fp_b28b39ffa8"
+        },
+        "message": {
+          "content": "The weather in San Jose is currently 45 degrees and sunny.",
+          "role": "assistant",
+          "options": {}
+        }
+      }
+    }
+  ]
+}
\ No newline at end of file
diff --git a/packages/core/e2e/node/utils.ts b/packages/core/e2e/node/utils.ts
index d4272f69e..260e86e6b 100644
--- a/packages/core/e2e/node/utils.ts
+++ b/packages/core/e2e/node/utils.ts
@@ -35,15 +35,22 @@ export async function mockLLMEvent(
 
   await readFile(join(testRootDir, "snapshot", `${snapshotName}.snap`), {
     encoding: "utf-8",
-  }).then((data) => {
-    const result = JSON.parse(data) as MockStorage;
-    result["llmEventEnd"].forEach((event) => {
-      llmCompleteMockStorage.llmEventEnd.push(event);
+  })
+    .then((data) => {
+      const result = JSON.parse(data) as MockStorage;
+      result["llmEventEnd"].forEach((event) => {
+        llmCompleteMockStorage.llmEventEnd.push(event);
+      });
+      result["llmEventStart"].forEach((event) => {
+        llmCompleteMockStorage.llmEventStart.push(event);
+      });
+    })
+    .catch((error) => {
+      if (error.code === "ENOENT") {
+        console.warn("Snapshot file not found, will create a new one");
+        return;
+      }
     });
-    result["llmEventStart"].forEach((event) => {
-      llmCompleteMockStorage.llmEventStart.push(event);
-    });
-  });
   Settings.callbackManager.on("llm-start", captureLLMStart);
   Settings.callbackManager.on("llm-end", captureLLMEnd);
 
diff --git a/packages/core/src/agent/openai/worker.ts b/packages/core/src/agent/openai/worker.ts
index 404757e5b..8a229ffac 100644
--- a/packages/core/src/agent/openai/worker.ts
+++ b/packages/core/src/agent/openai/worker.ts
@@ -1,4 +1,4 @@
-import { randomUUID } from "@llamaindex/env";
+import { pipeline, randomUUID } from "@llamaindex/env";
 import type { ChatCompletionToolChoiceOption } from "openai/resources/chat/completions";
 import { Response } from "../../Response.js";
 import { Settings } from "../../Settings.js";
@@ -14,12 +14,9 @@ import {
   type ChatResponseChunk,
   type LLMChatParamsBase,
   type OpenAIAdditionalChatOptions,
+  type OpenAIAdditionalMessageOptions,
 } from "../../llm/index.js";
-import {
-  extractText,
-  streamConverter,
-  streamReducer,
-} from "../../llm/utils.js";
+import { extractText } from "../../llm/utils.js";
 import { ChatMemoryBuffer } from "../../memory/ChatMemoryBuffer.js";
 import type { ObjectRetriever } from "../../objects/base.js";
 import type { ToolOutput } from "../../tools/types.js";
@@ -181,40 +178,70 @@ export class OpenAIAgentWorker
       stream: true,
       ...llmChatParams,
     });
-    // read first chunk from stream to find out if we need to call tools
-    const iterator = stream[Symbol.asyncIterator]();
-    let { value } = await iterator.next();
-    let content = value.delta;
-    const hasToolCalls = value.options?.toolCalls.length > 0;
+
+    const responseChunkStream = new ReadableStream<
+      ChatResponseChunk<OpenAIAdditionalMessageOptions>
+    >({
+      async start(controller) {
+        for await (const chunk of stream) {
+          controller.enqueue(chunk);
+        }
+      },
+    });
+    const [pipStream, finalStream] = responseChunkStream.tee();
+    const { value } = await pipStream.getReader().read();
+    if (value === undefined) {
+      throw new Error("first chunk value is undefined, this should not happen");
+    }
+    // check if first chunk has tool calls, if so, this is a function call
+    // otherwise, it's a regular message
+    const hasToolCalls: boolean =
+      !!value.options?.toolCalls?.length &&
+      value.options?.toolCalls?.length > 0;
 
     if (hasToolCalls) {
-      // consume stream until we have all the tool calls and return a non-streamed response
-      for await (value of stream) {
-        content += value.delta;
-      }
       return this._processMessage(task, {
-        content,
+        content: await pipeline(finalStream, async (iterator) => {
+          let content = "";
+          for await (const value of iterator) {
+            content += value.delta;
+          }
+          return content;
+        }),
         role: "assistant",
         options: value.options,
       });
-    }
-
-    const newStream = streamConverter.bind(this)(
-      streamReducer({
-        stream,
-        initialValue: content,
-        reducer: (accumulator, part) => (accumulator += part.delta),
-        finished: (accumulator) => {
-          task.extraState.newMemory.put({
-            content: accumulator,
-            role: "assistant",
-          });
+    } else {
+      let content = "";
+      return pipeline(
+        finalStream.pipeThrough<Response>({
+          readable: new ReadableStream({
+            async start(controller) {
+              for await (const chunk of finalStream) {
+                controller.enqueue(new Response(chunk.delta));
+              }
+            },
+          }),
+          writable: new WritableStream({
+            write(chunk) {
+              content += chunk.delta;
+            },
+            close() {
+              task.extraState.newMemory.put({
+                content,
+                role: "assistant",
+              });
+            },
+          }),
+        }),
+        async (iterator: AsyncIterable<Response>) => {
+          return new StreamingAgentChatResponse(
+            iterator,
+            task.extraState.sources,
+          );
         },
-      }),
-      (r: ChatResponseChunk) => new Response(r.delta),
-    );
-
-    return new StreamingAgentChatResponse(newStream, task.extraState.sources);
+      );
+    }
   }
 
   private async _getAgentResponse(
diff --git a/packages/core/src/llm/open_ai.ts b/packages/core/src/llm/open_ai.ts
index 0c3b4fafb..6d6db4eb5 100644
--- a/packages/core/src/llm/open_ai.ts
+++ b/packages/core/src/llm/open_ai.ts
@@ -391,6 +391,8 @@ export class OpenAI extends BaseLLM<
     for await (const part of stream) {
       if (!part.choices.length) continue;
       const choice = part.choices[0];
+      // skip parts that don't have any content
+      if (!(choice.delta.content || choice.delta.tool_calls)) continue;
       updateToolCalls(toolCalls, choice.delta.tool_calls);
 
       const isDone: boolean = choice.finish_reason !== null;
@@ -444,8 +446,11 @@ function updateToolCalls(
     return toolCall;
   }
   if (toolCallDeltas) {
-    toolCallDeltas?.forEach((toolCall, i) => {
-      toolCalls[i] = augmentToolCall(toolCalls[i], toolCall);
+    toolCallDeltas?.forEach((toolCall) => {
+      toolCalls[toolCall.index] = augmentToolCall(
+        toolCalls[toolCall.index],
+        toolCall,
+      );
     });
   }
 }
diff --git a/packages/env/package.json b/packages/env/package.json
index 3b25dd73b..cfaca5b9a 100644
--- a/packages/env/package.json
+++ b/packages/env/package.json
@@ -56,8 +56,9 @@
     "@aws-crypto/sha256-js": "^5.2.0",
     "@swc/cli": "^0.3.9",
     "@swc/core": "^1.4.2",
+    "concurrently": "^8.2.2",
     "pathe": "^1.1.2",
-    "concurrently": "^8.2.2"
+    "readable-stream": "^4.5.2"
   },
   "dependencies": {
     "@types/lodash": "^4.14.202",
@@ -66,6 +67,7 @@
   },
   "peerDependencies": {
     "@aws-crypto/sha256-js": "^5.2.0",
-    "pathe": "^1.1.2"
+    "pathe": "^1.1.2",
+    "readable-stream": "^4.5.2"
   }
 }
diff --git a/packages/env/src/index.polyfill.ts b/packages/env/src/index.polyfill.ts
index 0e2d49f86..f911741c9 100644
--- a/packages/env/src/index.polyfill.ts
+++ b/packages/env/src/index.polyfill.ts
@@ -1,8 +1,10 @@
 import { Sha256 } from "@aws-crypto/sha256-js";
 import pathe from "pathe";
 import { InMemoryFileSystem, type CompleteFileSystem } from "./type.js";
+// @ts-expect-error
+import { pipeline } from "readable-stream";
 
-export { pathe as path };
+export { pathe as path, pipeline };
 
 export interface SHA256 {
   update(data: string | Uint8Array): void;
diff --git a/packages/env/src/index.ts b/packages/env/src/index.ts
index 2efc3833b..72333d095 100644
--- a/packages/env/src/index.ts
+++ b/packages/env/src/index.ts
@@ -3,6 +3,7 @@ import { createHash, randomUUID } from "node:crypto";
 import fs from "node:fs/promises";
 import { EOL } from "node:os";
 import path from "node:path";
+import { pipeline } from "node:stream/promises";
 import type { SHA256 } from "./index.polyfill.js";
 import type { CompleteFileSystem } from "./type.js";
 
@@ -36,4 +37,4 @@ export const defaultFS: CompleteFileSystem = {
 
 export type * from "./type.js";
 export { AsyncLocalStorage, CustomEvent, getEnv } from "./utils.js";
-export { EOL, ok, path, randomUUID };
+export { EOL, ok, path, pipeline, randomUUID };
diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml
index 7597ea14b..d64e6b9d4 100644
--- a/pnpm-lock.yaml
+++ b/pnpm-lock.yaml
@@ -501,6 +501,9 @@ importers:
       pathe:
         specifier: ^1.1.2
         version: 1.1.2
+      readable-stream:
+        specifier: ^4.5.2
+        version: 4.5.2
 
   packages/eslint-config-custom:
     dependencies:
@@ -5424,7 +5427,6 @@ packages:
     engines: {node: '>=6.5'}
     dependencies:
       event-target-shim: 5.0.1
-    dev: false
 
   /accepts@1.3.8:
     resolution: {integrity: sha512-PYAthTa2m2VKxuvSD3DPC/Gy+U+sOA1LAuT8mkmRuvw+NACSaeXEQ+NHcVF7rONl6qcaxV3Uuemwawk+7+SJLw==}
@@ -6104,8 +6106,6 @@ packages:
     dependencies:
       base64-js: 1.5.1
       ieee754: 1.2.1
-    dev: false
-    optional: true
 
   /busboy@1.6.0:
     resolution: {integrity: sha512-8SFQbg/0hQ9xy3UNTB0YEnsNBbWfhf7RtnzpL7TkBiTBRfrQ9Fxcnz7VJsleJpyp6rVLvXiuORqjlHi5q+PYuA==}
@@ -8271,7 +8271,6 @@ packages:
   /event-target-shim@5.0.1:
     resolution: {integrity: sha512-i/2XbnSz/uxRCU6+NdVJgKWDTM427+MqYbkQzD321DuCQJUqOuJKIA0IM2+W2xtYHdKOmZ4dR6fExsd4SXL+WQ==}
     engines: {node: '>=6'}
-    dev: false
 
   /eventemitter3@4.0.7:
     resolution: {integrity: sha512-8guHBZCwKnFhYdHr2ysuRWErTwhoN2X8XELRlrRwpmfeY2jjuUN4taQMsULKUVo1K4DvZl+0pgfyoysHxvmvEw==}
@@ -12923,8 +12922,6 @@ packages:
     resolution: {integrity: sha512-cdGef/drWFoydD1JsMzuFf8100nZl+GT+yacc2bEced5f9Rjk4z+WtFUTBu9PhOi9j/jfmBPu0mMEY4wIdAF8A==}
     engines: {node: '>= 0.6.0'}
     requiresBuild: true
-    dev: false
-    optional: true
 
   /prompts@2.4.2:
     resolution: {integrity: sha512-NxNv/kLguCA7p3jE8oL2aEBsrJWgAakBpgmgK6lpPWV+WuOmY6r2/zbAVnP+T8bQlA0nzHXSJSJW0Hq7ylaD2Q==}
@@ -13378,8 +13375,6 @@ packages:
       events: 3.3.0
       process: 0.11.10
       string_decoder: 1.3.0
-    dev: false
-    optional: true
 
   /readable-web-to-node-stream@3.0.2:
     resolution: {integrity: sha512-ePeK6cc1EcKLEhJFt/AebMCLL+GgSKhuygrZ/GLaKZYEecIgIECf4UaUuaByiGtzckwR4ain9VzUh95T1exYGw==}
-- 
GitLab