From 61103b677bd3b29c634b549aa7847a120a77e3e8 Mon Sep 17 00:00:00 2001
From: Alex Yang <himself65@outlook.com>
Date: Wed, 1 May 2024 19:26:06 -0500
Subject: [PATCH] fix: streaming for `Agent.createTask` (#788)

---
 .changeset/four-ears-nail.md         |   6 ++
 apps/docs/blog/2024-04-26-v0.3.0.md  |  26 ++-----
 examples/agent/openai-task.ts        |  87 +++++++++++++++++++++
 examples/agent/step_wise_openai.ts   |  97 -----------------------
 packages/core/src/agent/anthropic.ts |  43 ++++------
 packages/core/src/agent/base.ts      | 112 +++++++++++++--------------
 packages/core/src/agent/openai.ts    |  59 ++++++--------
 packages/core/src/agent/react.ts     |  55 ++++++-------
 packages/core/src/agent/types.ts     |  30 +++----
 9 files changed, 227 insertions(+), 288 deletions(-)
 create mode 100644 .changeset/four-ears-nail.md
 create mode 100644 examples/agent/openai-task.ts
 delete mode 100644 examples/agent/step_wise_openai.ts

diff --git a/.changeset/four-ears-nail.md b/.changeset/four-ears-nail.md
new file mode 100644
index 000000000..3825a41d9
--- /dev/null
+++ b/.changeset/four-ears-nail.md
@@ -0,0 +1,6 @@
+---
+"llamaindex": patch
+"@llamaindex/core-e2e": patch
+---
+
+fix: streaming for `Agent.createTask` API
diff --git a/apps/docs/blog/2024-04-26-v0.3.0.md b/apps/docs/blog/2024-04-26-v0.3.0.md
index 6db63f00f..a1be2a4cc 100644
--- a/apps/docs/blog/2024-04-26-v0.3.0.md
+++ b/apps/docs/blog/2024-04-26-v0.3.0.md
@@ -72,12 +72,8 @@ export class MyAgent extends AgentRunner<MyLLM> {
   // create store is a function to create a store for each task, by default it only includes `messages` and `toolOutputs`
   createStore = AgentRunner.defaultCreateStore;
 
-  static taskHandler: TaskHandler<Anthropic> = async (step) => {
-    const { input } = step;
+  static taskHandler: TaskHandler<Anthropic> = async (step, enqueueOutput) => {
     const { llm, stream } = step.context;
-    if (input) {
-      step.context.store.messages = [...step.context.store.messages, input];
-    }
     // initialize the input
     const response = await llm.chat({
       stream,
@@ -90,27 +86,21 @@ export class MyAgent extends AgentRunner<MyLLM> {
     ];
     // your logic here to decide whether to continue the task
     const shouldContinue = Math.random(); /* <-- replace with your logic here */
+    enqueueOutput({
+      taskStep: step,
+      output: response,
+      isLast: !shouldContinue,
+    });
     if (shouldContinue) {
+      const content = await someHeavyFunctionCall();
       // if you want to continue the task, you can insert your new context for the next task step
       step.context.store.messages = [
         ...step.context.store.messages,
         {
-          content: "INSERT MY NEW DATA",
+          content,
           role: "user",
         },
       ];
-      return {
-        taskStep: step,
-        output: response,
-        isLast: false,
-      };
-    } else {
-      // if you want to end the task, you can return the response with `isLast: true`
-      return {
-        taskStep: step,
-        output: response,
-        isLast: true,
-      };
     }
   };
 }
diff --git a/examples/agent/openai-task.ts b/examples/agent/openai-task.ts
new file mode 100644
index 000000000..09e632278
--- /dev/null
+++ b/examples/agent/openai-task.ts
@@ -0,0 +1,87 @@
+import { ChatResponseChunk, FunctionTool, OpenAIAgent } from "llamaindex";
+import { ReadableStream } from "node:stream/web";
+
+const functionTool = FunctionTool.from(
+  () => {
+    console.log("Getting user id...");
+    return crypto.randomUUID();
+  },
+  {
+    name: "get_user_id",
+    description: "Get a random user id",
+  },
+);
+
+const functionTool2 = FunctionTool.from(
+  ({ userId }: { userId: string }) => {
+    console.log("Getting user info...", userId);
+    return `Name: Alex; Address: 1234 Main St, CA; User ID: ${userId}`;
+  },
+  {
+    name: "get_user_info",
+    description: "Get user info",
+    parameters: {
+      type: "object",
+      properties: {
+        userId: {
+          type: "string",
+          description: "The user id",
+        },
+      },
+      required: ["userId"],
+    },
+  },
+);
+
+const functionTool3 = FunctionTool.from(
+  ({ address }: { address: string }) => {
+    console.log("Getting weather...", address);
+    return `${address} is in a sunny location!`;
+  },
+  {
+    name: "get_weather",
+    description: "Get the current weather for a location",
+    parameters: {
+      type: "object",
+      properties: {
+        address: {
+          type: "string",
+          description: "The address",
+        },
+      },
+      required: ["address"],
+    },
+  },
+);
+
+async function main() {
+  // Create an OpenAIAgent with the function tools
+  const agent = new OpenAIAgent({
+    tools: [functionTool, functionTool2, functionTool3],
+  });
+
+  const task = await agent.createTask(
+    "What is my current address weather based on my profile?",
+    true,
+  );
+
+  for await (const stepOutput of task) {
+    const stream = stepOutput.output as ReadableStream<ChatResponseChunk>;
+    if (stepOutput.isLast) {
+      for await (const chunk of stream) {
+        process.stdout.write(chunk.delta);
+      }
+      process.stdout.write("\n");
+    } else {
+      // handing function call
+      console.log("handling function call...");
+      for await (const chunk of stream) {
+        console.log("debug:", JSON.stringify(chunk.raw));
+      }
+    }
+  }
+}
+
+void main().then(() => {
+  console.log("Done");
+});
diff --git a/examples/agent/step_wise_openai.ts b/examples/agent/step_wise_openai.ts
deleted file mode 100644
index 74c083400..000000000
--- a/examples/agent/step_wise_openai.ts
+++ /dev/null
@@ -1,97 +0,0 @@
-import { FunctionTool, OpenAIAgent } from "llamaindex";
-import { ReadableStream } from "node:stream/web";
-
-// Define a function to sum two numbers
-function sumNumbers({ a, b }: { a: number; b: number }) {
-  return `${a + b}`;
-}
-
-// Define a function to divide two numbers
-function divideNumbers({ a, b }: { a: number; b: number }) {
-  return `${a / b}`;
-}
-
-// Define the parameters of the sum function as a JSON schema
-const sumJSON = {
-  type: "object",
-  properties: {
-    a: {
-      type: "number",
-      description: "The first number",
-    },
-    b: {
-      type: "number",
-      description: "The second number",
-    },
-  },
-  required: ["a", "b"],
-} as const;
-
-const divideJSON = {
-  type: "object",
-  properties: {
-    a: {
-      type: "number",
-      description: "The dividend",
-    },
-    b: {
-      type: "number",
-      description: "The divisor",
-    },
-  },
-  required: ["a", "b"],
-} as const;
-
-async function main() {
-  // Create a function tool from the sum function
-  const functionTool = new FunctionTool(sumNumbers, {
-    name: "sumNumbers",
-    description: "Use this function to sum two numbers",
-    parameters: sumJSON,
-  });
-
-  // Create a function tool from the divide function
-  const functionTool2 = new FunctionTool(divideNumbers, {
-    name: "divideNumbers",
-    description: "Use this function to divide two numbers",
-    parameters: divideJSON,
-  });
-
-  // Create an OpenAIAgent with the function tools
-  const agent = new OpenAIAgent({
-    tools: [functionTool, functionTool2],
-  });
-
-  // Create a task to sum and divide numbers
-  const task = await agent.createTask("How much is 5 + 5? then divide by 2");
-
-  let count = 0;
-
-  for await (const stepOutput of task) {
-    console.log(`Runnning step ${count++}`);
-    console.log(`======== OUTPUT ==========`);
-    const output = stepOutput.output;
-    if (output instanceof ReadableStream) {
-      for await (const chunk of output) {
-        process.stdout.write(chunk.delta);
-      }
-    } else {
-      console.log(output);
-    }
-    console.log(`==========================`);
-
-    if (stepOutput.isLast) {
-      if (stepOutput.output instanceof ReadableStream) {
-        for await (const chunk of stepOutput.output) {
-          process.stdout.write(chunk.delta);
-        }
-      } else {
-        console.log(stepOutput.output);
-      }
-    }
-  }
-}
-
-void main().then(() => {
-  console.log("Done");
-});
diff --git a/packages/core/src/agent/anthropic.ts b/packages/core/src/agent/anthropic.ts
index 231795d89..accc87fca 100644
--- a/packages/core/src/agent/anthropic.ts
+++ b/packages/core/src/agent/anthropic.ts
@@ -67,12 +67,8 @@ export class AnthropicAgent extends AgentRunner<Anthropic> {
     return super.chat(params);
   }
 
-  static taskHandler: TaskHandler<Anthropic> = async (step) => {
-    const { input } = step;
+  static taskHandler: TaskHandler<Anthropic> = async (step, enqueueOutput) => {
     const { llm, getTools, stream } = step.context;
-    if (input) {
-      step.context.store.messages = [...step.context.store.messages, input];
-    }
     const lastMessage = step.context.store.messages.at(-1)!.content;
     const tools = await getTools(lastMessage);
     if (stream === true) {
@@ -88,6 +84,11 @@ export class AnthropicAgent extends AgentRunner<Anthropic> {
       response.message,
     ];
     const options = response.message.options ?? {};
+    enqueueOutput({
+      taskStep: step,
+      output: response,
+      isLast: !("toolCall" in options),
+    });
     if ("toolCall" in options) {
       const { toolCall } = options;
       const targetTool = tools.find(
@@ -95,30 +96,20 @@ export class AnthropicAgent extends AgentRunner<Anthropic> {
       );
       const toolOutput = await callTool(targetTool, toolCall);
       step.context.store.toolOutputs.push(toolOutput);
-      return {
-        taskStep: step,
-        output: {
-          raw: response.raw,
-          message: {
-            content: stringifyJSONToMessageContent(toolOutput.output),
-            role: "user",
-            options: {
-              toolResult: {
-                result: toolOutput.output,
-                isError: toolOutput.isError,
-                id: toolCall.id,
-              },
+      step.context.store.messages = [
+        ...step.context.store.messages,
+        {
+          content: stringifyJSONToMessageContent(toolOutput.output),
+          role: "user",
+          options: {
+            toolResult: {
+              result: toolOutput.output,
+              isError: toolOutput.isError,
+              id: toolCall.id,
             },
           },
         },
-        isLast: false,
-      };
-    } else {
-      return {
-        taskStep: step,
-        output: response,
-        isLast: true,
-      };
+      ];
     }
   };
 }
diff --git a/packages/core/src/agent/base.ts b/packages/core/src/agent/base.ts
index 57d6d9e42..e0e14f4f5 100644
--- a/packages/core/src/agent/base.ts
+++ b/packages/core/src/agent/base.ts
@@ -27,14 +27,10 @@ import type {
   TaskStep,
   TaskStepOutput,
 } from "./types.js";
-import { consumeAsyncIterable } from "./utils.js";
 
 export const MAX_TOOL_CALLS = 10;
 
-/**
- * @internal
- */
-export async function* createTaskImpl<
+export function createTaskOutputStream<
   Model extends LLM,
   Store extends object = {},
   AdditionalMessageOptions extends object = Model extends LLM<
@@ -46,65 +42,60 @@ export async function* createTaskImpl<
 >(
   handler: TaskHandler<Model, Store, AdditionalMessageOptions>,
   context: AgentTaskContext<Model, Store, AdditionalMessageOptions>,
-  _input: ChatMessage<AdditionalMessageOptions>,
-): AsyncGenerator<TaskStepOutput<Model, Store, AdditionalMessageOptions>> {
-  let isFirst = true;
-  let isDone = false;
-  let input: ChatMessage<AdditionalMessageOptions> | null = _input;
-  let prevStep: TaskStep<Model, Store, AdditionalMessageOptions> | null = null;
-  while (!isDone) {
-    const step: TaskStep<Model, Store, AdditionalMessageOptions> = {
-      id: randomUUID(),
-      input,
-      context,
-      prevStep,
-      nextSteps: new Set(),
-    };
-    if (prevStep) {
-      prevStep.nextSteps.add(step);
-    }
-    const prevToolCallCount = step.context.toolCallCount;
-    if (!step.context.shouldContinue(step)) {
-      throw new Error("Tool call count exceeded limit");
-    }
-    if (isFirst) {
+): ReadableStream<TaskStepOutput<Model, Store, AdditionalMessageOptions>> {
+  const steps: TaskStep<Model, Store, AdditionalMessageOptions>[] = [];
+  return new ReadableStream<
+    TaskStepOutput<Model, Store, AdditionalMessageOptions>
+  >({
+    pull: async (controller) => {
+      const step: TaskStep<Model, Store, AdditionalMessageOptions> = {
+        id: randomUUID(),
+        context,
+        prevStep: null,
+        nextSteps: new Set(),
+      };
+      if (steps.length > 0) {
+        step.prevStep = steps[steps.length - 1];
+      }
+      const taskOutputs: TaskStepOutput<
+        Model,
+        Store,
+        AdditionalMessageOptions
+      >[] = [];
+      steps.push(step);
+      const enqueueOutput = (
+        output: TaskStepOutput<Model, Store, AdditionalMessageOptions>,
+      ) => {
+        taskOutputs.push(output);
+        controller.enqueue(output);
+      };
       getCallbackManager().dispatchEvent("agent-start", {
         payload: {
           startStep: step,
         },
       });
-      isFirst = false;
-    }
-    const taskOutput = await handler(step);
-    const { isLast, output, taskStep } = taskOutput;
-    // do not consume last output
-    if (!isLast) {
-      if (output) {
-        input = isAsyncIterable(output)
-          ? await consumeAsyncIterable(output)
-          : output.message;
-      } else {
-        input = null;
-      }
-    }
-    context = {
-      ...taskStep.context,
-      store: {
-        ...taskStep.context.store,
-      },
-      toolCallCount: prevToolCallCount + 1,
-    };
-    if (isLast) {
-      isDone = true;
-      getCallbackManager().dispatchEvent("agent-end", {
-        payload: {
-          endStep: step,
+
+      await handler(step, enqueueOutput);
+      // fixme: support multi-thread when there are multiple outputs
+      // todo: for now we pretend there is only one task output
+      const { isLast, taskStep } = taskOutputs[0];
+      context = {
+        ...taskStep.context,
+        store: {
+          ...taskStep.context.store,
         },
-      });
-    }
-    prevStep = taskStep;
-    yield taskOutput;
-  }
+        toolCallCount: 1,
+      };
+      if (isLast) {
+        getCallbackManager().dispatchEvent("agent-end", {
+          payload: {
+            endStep: step,
+          },
+        });
+        controller.close();
+      }
+    },
+  });
 }
 
 export type AgentStreamChatResponse<Options extends object> = {
@@ -170,15 +161,16 @@ export abstract class AgentWorker<
     query: string,
     context: AgentTaskContext<AI, Store, AdditionalMessageOptions>,
   ): ReadableStream<TaskStepOutput<AI, Store, AdditionalMessageOptions>> {
-    const taskGenerator = createTaskImpl(this.taskHandler, context, {
+    context.store.messages.push({
       role: "user",
       content: query,
     });
+    const taskOutputStream = createTaskOutputStream(this.taskHandler, context);
     return new ReadableStream<
       TaskStepOutput<AI, Store, AdditionalMessageOptions>
     >({
       start: async (controller) => {
-        for await (const stepOutput of taskGenerator) {
+        for await (const stepOutput of taskOutputStream) {
           this.#taskSet.add(stepOutput.taskStep);
           controller.enqueue(stepOutput);
           if (stepOutput.isLast) {
diff --git a/packages/core/src/agent/openai.ts b/packages/core/src/agent/openai.ts
index 6e569c9a7..f7146665d 100644
--- a/packages/core/src/agent/openai.ts
+++ b/packages/core/src/agent/openai.ts
@@ -51,12 +51,8 @@ export class OpenAIAgent extends AgentRunner<OpenAI> {
 
   createStore = AgentRunner.defaultCreateStore;
 
-  static taskHandler: TaskHandler<OpenAI> = async (step) => {
-    const { input } = step;
+  static taskHandler: TaskHandler<OpenAI> = async (step, enqueueOutput) => {
     const { llm, stream, getTools } = step.context;
-    if (input) {
-      step.context.store.messages = [...step.context.store.messages, input];
-    }
     const lastMessage = step.context.store.messages.at(-1)!.content;
     const tools = await getTools(lastMessage);
     const response = await llm.chat({
@@ -71,6 +67,11 @@ export class OpenAIAgent extends AgentRunner<OpenAI> {
         response.message,
       ];
       const options = response.message.options ?? {};
+      enqueueOutput({
+        taskStep: step,
+        output: response,
+        isLast: !("toolCall" in options),
+      });
       if ("toolCall" in options) {
         const { toolCall } = options;
         const targetTool = tools.find(
@@ -78,30 +79,20 @@ export class OpenAIAgent extends AgentRunner<OpenAI> {
         );
         const toolOutput = await callTool(targetTool, toolCall);
         step.context.store.toolOutputs.push(toolOutput);
-        return {
-          taskStep: step,
-          output: {
-            raw: response.raw,
-            message: {
-              content: stringifyJSONToMessageContent(toolOutput.output),
-              role: "user",
-              options: {
-                toolResult: {
-                  result: toolOutput.output,
-                  isError: toolOutput.isError,
-                  id: toolCall.id,
-                },
+        step.context.store.messages = [
+          ...step.context.store.messages,
+          {
+            role: "user" as const,
+            content: stringifyJSONToMessageContent(toolOutput.output),
+            options: {
+              toolResult: {
+                result: toolOutput.output,
+                isError: toolOutput.isError,
+                id: toolCall.id,
               },
             },
           },
-          isLast: false,
-        };
-      } else {
-        return {
-          taskStep: step,
-          output: response,
-          isLast: true,
-        };
+        ];
       }
     } else {
       const responseChunkStream = new ReadableStream<
@@ -126,6 +117,11 @@ export class OpenAIAgent extends AgentRunner<OpenAI> {
       // check if first chunk has tool calls, if so, this is a function call
       // otherwise, it's a regular message
       const hasToolCall = !!(value.options && "toolCall" in value.options);
+      enqueueOutput({
+        taskStep: step,
+        output: finalStream,
+        isLast: !hasToolCall,
+      });
 
       if (hasToolCall) {
         // you need to consume the response to get the full toolCalls
@@ -175,17 +171,6 @@ export class OpenAIAgent extends AgentRunner<OpenAI> {
           ];
           step.context.store.toolOutputs.push(toolOutput);
         }
-        return {
-          taskStep: step,
-          output: null,
-          isLast: false,
-        };
-      } else {
-        return {
-          taskStep: step,
-          output: finalStream,
-          isLast: true,
-        };
       }
     }
   };
diff --git a/packages/core/src/agent/react.ts b/packages/core/src/agent/react.ts
index 651a30781..27b48d5b1 100644
--- a/packages/core/src/agent/react.ts
+++ b/packages/core/src/agent/react.ts
@@ -349,12 +349,11 @@ export class ReActAgent extends AgentRunner<LLM, ReACTAgentStore> {
     };
   }
 
-  static taskHandler: TaskHandler<LLM, ReACTAgentStore> = async (step) => {
+  static taskHandler: TaskHandler<LLM, ReACTAgentStore> = async (
+    step,
+    enqueueOutput,
+  ) => {
     const { llm, stream, getTools } = step.context;
-    const input = step.input;
-    if (input) {
-      step.context.store.messages.push(input);
-    }
     const lastMessage = step.context.store.messages.at(-1)!.content;
     const tools = await getTools(lastMessage);
     const messages = await chatFormatter(
@@ -369,33 +368,25 @@ export class ReActAgent extends AgentRunner<LLM, ReACTAgentStore> {
     });
     const reason = await reACTOutputParser(response);
     step.context.store.reasons = [...step.context.store.reasons, reason];
-    if (reason.type === "response") {
-      return {
-        isLast: true,
-        output: response,
-        taskStep: step,
-      };
-    } else {
-      if (reason.type === "action") {
-        const tool = tools.find((tool) => tool.metadata.name === reason.action);
-        const toolOutput = await callTool(tool, {
-          id: randomUUID(),
-          input: reason.input,
-          name: reason.action,
-        });
-        step.context.store.reasons = [
-          ...step.context.store.reasons,
-          {
-            type: "observation",
-            observation: toolOutput.output,
-          },
-        ];
-      }
-      return {
-        isLast: false,
-        output: null,
-        taskStep: step,
-      };
+    enqueueOutput({
+      taskStep: step,
+      output: response,
+      isLast: reason.type === "response",
+    });
+    if (reason.type === "action") {
+      const tool = tools.find((tool) => tool.metadata.name === reason.action);
+      const toolOutput = await callTool(tool, {
+        id: randomUUID(),
+        input: reason.input,
+        name: reason.action,
+      });
+      step.context.store.reasons = [
+        ...step.context.store.reasons,
+        {
+          type: "observation",
+          observation: toolOutput.output,
+        },
+      ];
     }
   };
 }
diff --git a/packages/core/src/agent/types.ts b/packages/core/src/agent/types.ts
index f523a047d..b3d48374c 100644
--- a/packages/core/src/agent/types.ts
+++ b/packages/core/src/agent/types.ts
@@ -45,7 +45,6 @@ export type TaskStep<
     : never,
 > = {
   id: UUID;
-  input: ChatMessage<AdditionalMessageOptions> | null;
   context: AgentTaskContext<Model, Store, AdditionalMessageOptions>;
 
   // linked list
@@ -62,22 +61,14 @@ export type TaskStepOutput<
   >
     ? AdditionalMessageOptions
     : never,
-> =
-  | {
-      taskStep: TaskStep<Model, Store, AdditionalMessageOptions>;
-      output:
-        | null
-        | ChatResponse<AdditionalMessageOptions>
-        | ReadableStream<ChatResponseChunk<AdditionalMessageOptions>>;
-      isLast: false;
-    }
-  | {
-      taskStep: TaskStep<Model, Store, AdditionalMessageOptions>;
-      output:
-        | ChatResponse<AdditionalMessageOptions>
-        | ReadableStream<ChatResponseChunk<AdditionalMessageOptions>>;
-      isLast: true;
-    };
+> = {
+  taskStep: TaskStep<Model, Store, AdditionalMessageOptions>;
+  // output shows the response to the user
+  output:
+    | ChatResponse<AdditionalMessageOptions>
+    | ReadableStream<ChatResponseChunk<AdditionalMessageOptions>>;
+  isLast: boolean;
+};
 
 export type TaskHandler<
   Model extends LLM,
@@ -90,7 +81,10 @@ export type TaskHandler<
     : never,
 > = (
   step: TaskStep<Model, Store, AdditionalMessageOptions>,
-) => Promise<TaskStepOutput<Model, Store, AdditionalMessageOptions>>;
+  enqueueOutput: (
+    taskOutput: TaskStepOutput<Model, Store, AdditionalMessageOptions>,
+  ) => void,
+) => Promise<void>;
 
 export type AgentStartEvent = BaseEvent<{
   startStep: TaskStep;
-- 
GitLab