From 7488d3c235e49f47061c4d4e700a2e0a3beedd17 Mon Sep 17 00:00:00 2001
From: Alex Yang <himself65@outlook.com>
Date: Fri, 26 Apr 2024 18:13:05 -0500
Subject: [PATCH] fix: agent callback with step infomation (#774)

---
 packages/core/src/agent/anthropic.ts          |   2 +-
 packages/core/src/agent/base.ts               | 107 ++++--------------
 packages/core/src/agent/openai.ts             |   8 +-
 packages/core/src/agent/react.ts              |   8 +-
 packages/core/src/agent/type.ts               |   4 -
 packages/core/src/agent/types.ts              |  99 ++++++++++++++++
 .../core/src/callbacks/CallbackManager.ts     |   2 +-
 packages/core/src/embeddings/index.ts         |   1 -
 packages/core/src/index.ts                    |   5 +
 packages/core/src/internal/type.ts            |   4 +-
 10 files changed, 131 insertions(+), 109 deletions(-)
 delete mode 100644 packages/core/src/agent/type.ts
 create mode 100644 packages/core/src/agent/types.ts

diff --git a/packages/core/src/agent/anthropic.ts b/packages/core/src/agent/anthropic.ts
index 299344767..231795d89 100644
--- a/packages/core/src/agent/anthropic.ts
+++ b/packages/core/src/agent/anthropic.ts
@@ -13,8 +13,8 @@ import {
   AgentWorker,
   type AgentChatResponse,
   type AgentParamsBase,
-  type TaskHandler,
 } from "./base.js";
+import type { TaskHandler } from "./types.js";
 import { callTool } from "./utils.js";
 
 type AnthropicParamsBase = AgentParamsBase<Anthropic>;
diff --git a/packages/core/src/agent/base.ts b/packages/core/src/agent/base.ts
index 0a65937e1..019518a95 100644
--- a/packages/core/src/agent/base.ts
+++ b/packages/core/src/agent/base.ts
@@ -15,94 +15,17 @@ import type {
   MessageContent,
 } from "../llm/index.js";
 import { extractText } from "../llm/utils.js";
-import type { BaseToolWithCall, ToolOutput, UUID } from "../types.js";
+import type { BaseToolWithCall, ToolOutput } from "../types.js";
+import type {
+  AgentTaskContext,
+  TaskHandler,
+  TaskStep,
+  TaskStepOutput,
+} from "./types.js";
 import { consumeAsyncIterable } from "./utils.js";
 
 export const MAX_TOOL_CALLS = 10;
 
-export type AgentTaskContext<
-  Model extends LLM,
-  Store extends object = {},
-  AdditionalMessageOptions extends object = Model extends LLM<
-    object,
-    infer AdditionalMessageOptions
-  >
-    ? AdditionalMessageOptions
-    : never,
-> = {
-  readonly stream: boolean;
-  readonly toolCallCount: number;
-  readonly llm: Model;
-  readonly getTools: (
-    input: MessageContent,
-  ) => BaseToolWithCall[] | Promise<BaseToolWithCall[]>;
-  shouldContinue: (
-    taskStep: Readonly<TaskStep<Model, Store, AdditionalMessageOptions>>,
-  ) => boolean;
-  store: {
-    toolOutputs: ToolOutput[];
-    messages: ChatMessage<AdditionalMessageOptions>[];
-  } & Store;
-};
-
-export type TaskStep<
-  Model extends LLM,
-  Store extends object = {},
-  AdditionalMessageOptions extends object = Model extends LLM<
-    object,
-    infer AdditionalMessageOptions
-  >
-    ? AdditionalMessageOptions
-    : never,
-> = {
-  id: UUID;
-  input: ChatMessage<AdditionalMessageOptions> | null;
-  context: AgentTaskContext<Model, Store, AdditionalMessageOptions>;
-
-  // linked list
-  prevStep: TaskStep<Model, Store, AdditionalMessageOptions> | null;
-  nextSteps: Set<TaskStep<Model, Store, AdditionalMessageOptions>>;
-};
-
-export type TaskStepOutput<
-  Model extends LLM,
-  Store extends object = {},
-  AdditionalMessageOptions extends object = Model extends LLM<
-    object,
-    infer AdditionalMessageOptions
-  >
-    ? AdditionalMessageOptions
-    : never,
-> =
-  | {
-      taskStep: TaskStep<Model, Store, AdditionalMessageOptions>;
-      output:
-        | null
-        | ChatResponse<AdditionalMessageOptions>
-        | ReadableStream<ChatResponseChunk<AdditionalMessageOptions>>;
-      isLast: false;
-    }
-  | {
-      taskStep: TaskStep<Model, Store, AdditionalMessageOptions>;
-      output:
-        | ChatResponse<AdditionalMessageOptions>
-        | ReadableStream<ChatResponseChunk<AdditionalMessageOptions>>;
-      isLast: true;
-    };
-
-export type TaskHandler<
-  Model extends LLM,
-  Store extends object = {},
-  AdditionalMessageOptions extends object = Model extends LLM<
-    object,
-    infer AdditionalMessageOptions
-  >
-    ? AdditionalMessageOptions
-    : never,
-> = (
-  step: TaskStep<Model, Store, AdditionalMessageOptions>,
-) => Promise<TaskStepOutput<Model, Store, AdditionalMessageOptions>>;
-
 /**
  * @internal
  */
@@ -120,6 +43,7 @@ export async function* createTaskImpl<
   context: AgentTaskContext<Model, Store, AdditionalMessageOptions>,
   _input: ChatMessage<AdditionalMessageOptions>,
 ): AsyncGenerator<TaskStepOutput<Model, Store, AdditionalMessageOptions>> {
+  let isFirst = true;
   let isDone = false;
   let input: ChatMessage<AdditionalMessageOptions> | null = _input;
   let prevStep: TaskStep<Model, Store, AdditionalMessageOptions> | null = null;
@@ -138,9 +62,14 @@ export async function* createTaskImpl<
     if (!step.context.shouldContinue(step)) {
       throw new Error("Tool call count exceeded limit");
     }
-    getCallbackManager().dispatchEvent("agent-start", {
-      payload: {},
-    });
+    if (isFirst) {
+      getCallbackManager().dispatchEvent("agent-start", {
+        payload: {
+          startStep: step,
+        },
+      });
+      isFirst = false;
+    }
     const taskOutput = await handler(step);
     const { isLast, output, taskStep } = taskOutput;
     // do not consume last output
@@ -163,7 +92,9 @@ export async function* createTaskImpl<
     if (isLast) {
       isDone = true;
       getCallbackManager().dispatchEvent("agent-end", {
-        payload: {},
+        payload: {
+          endStep: step,
+        },
       });
     }
     prevStep = taskStep;
diff --git a/packages/core/src/agent/openai.ts b/packages/core/src/agent/openai.ts
index 89a707724..bf753ad8a 100644
--- a/packages/core/src/agent/openai.ts
+++ b/packages/core/src/agent/openai.ts
@@ -9,12 +9,8 @@ import type {
 import { OpenAI } from "../llm/openai.js";
 import { ObjectRetriever } from "../objects/index.js";
 import type { BaseToolWithCall } from "../types.js";
-import {
-  AgentRunner,
-  AgentWorker,
-  type AgentParamsBase,
-  type TaskHandler,
-} from "./base.js";
+import { AgentRunner, AgentWorker, type AgentParamsBase } from "./base.js";
+import type { TaskHandler } from "./types.js";
 import { callTool } from "./utils.js";
 
 type OpenAIParamsBase = AgentParamsBase<OpenAI>;
diff --git a/packages/core/src/agent/react.ts b/packages/core/src/agent/react.ts
index 54ed4190d..651a30781 100644
--- a/packages/core/src/agent/react.ts
+++ b/packages/core/src/agent/react.ts
@@ -19,12 +19,8 @@ import type {
   JSONObject,
   JSONValue,
 } from "../types.js";
-import {
-  AgentRunner,
-  AgentWorker,
-  type AgentParamsBase,
-  type TaskHandler,
-} from "./base.js";
+import { AgentRunner, AgentWorker, type AgentParamsBase } from "./base.js";
+import type { TaskHandler } from "./types.js";
 import {
   callTool,
   consumeAsyncIterable,
diff --git a/packages/core/src/agent/type.ts b/packages/core/src/agent/type.ts
deleted file mode 100644
index 38d974cf0..000000000
--- a/packages/core/src/agent/type.ts
+++ /dev/null
@@ -1,4 +0,0 @@
-import type { BaseEvent } from "../internal/type.js";
-
-export type AgentStartEvent = BaseEvent<{}>;
-export type AgentEndEvent = BaseEvent<{}>;
diff --git a/packages/core/src/agent/types.ts b/packages/core/src/agent/types.ts
new file mode 100644
index 000000000..c49dac0c3
--- /dev/null
+++ b/packages/core/src/agent/types.ts
@@ -0,0 +1,99 @@
+import type { BaseEvent } from "../internal/type.js";
+import type {
+  ChatMessage,
+  ChatResponse,
+  ChatResponseChunk,
+  LLM,
+  MessageContent,
+} from "../llm/types.js";
+import type { BaseToolWithCall, ToolOutput, UUID } from "../types.js";
+
+export type AgentTaskContext<
+  Model extends LLM,
+  Store extends object = {},
+  AdditionalMessageOptions extends object = Model extends LLM<
+    object,
+    infer AdditionalMessageOptions
+  >
+    ? AdditionalMessageOptions
+    : never,
+> = {
+  readonly stream: boolean;
+  readonly toolCallCount: number;
+  readonly llm: Model;
+  readonly getTools: (
+    input: MessageContent,
+  ) => BaseToolWithCall[] | Promise<BaseToolWithCall[]>;
+  shouldContinue: (
+    taskStep: Readonly<TaskStep<Model, Store, AdditionalMessageOptions>>,
+  ) => boolean;
+  store: {
+    toolOutputs: ToolOutput[];
+    messages: ChatMessage<AdditionalMessageOptions>[];
+  } & Store;
+};
+
+export type TaskStep<
+  Model extends LLM = LLM,
+  Store extends object = {},
+  AdditionalMessageOptions extends object = Model extends LLM<
+    object,
+    infer AdditionalMessageOptions
+  >
+    ? AdditionalMessageOptions
+    : never,
+> = {
+  id: UUID;
+  input: ChatMessage<AdditionalMessageOptions> | null;
+  context: AgentTaskContext<Model, Store, AdditionalMessageOptions>;
+
+  // linked list
+  prevStep: TaskStep<Model, Store, AdditionalMessageOptions> | null;
+  nextSteps: Set<TaskStep<Model, Store, AdditionalMessageOptions>>;
+};
+
+export type TaskStepOutput<
+  Model extends LLM,
+  Store extends object = {},
+  AdditionalMessageOptions extends object = Model extends LLM<
+    object,
+    infer AdditionalMessageOptions
+  >
+    ? AdditionalMessageOptions
+    : never,
+> =
+  | {
+      taskStep: TaskStep<Model, Store, AdditionalMessageOptions>;
+      output:
+        | null
+        | ChatResponse<AdditionalMessageOptions>
+        | ReadableStream<ChatResponseChunk<AdditionalMessageOptions>>;
+      isLast: false;
+    }
+  | {
+      taskStep: TaskStep<Model, Store, AdditionalMessageOptions>;
+      output:
+        | ChatResponse<AdditionalMessageOptions>
+        | ReadableStream<ChatResponseChunk<AdditionalMessageOptions>>;
+      isLast: true;
+    };
+
+export type TaskHandler<
+  Model extends LLM,
+  Store extends object = {},
+  AdditionalMessageOptions extends object = Model extends LLM<
+    object,
+    infer AdditionalMessageOptions
+  >
+    ? AdditionalMessageOptions
+    : never,
+> = (
+  step: TaskStep<Model, Store, AdditionalMessageOptions>,
+) => Promise<TaskStepOutput<Model, Store, AdditionalMessageOptions>>;
+
+export type AgentStartEvent = BaseEvent<{
+  startStep: TaskStep;
+}>;
+export type AgentEndEvent = BaseEvent<{
+  endStep: TaskStep;
+}>;
diff --git a/packages/core/src/callbacks/CallbackManager.ts b/packages/core/src/callbacks/CallbackManager.ts
index 81bdc10b9..4646e5254 100644
--- a/packages/core/src/callbacks/CallbackManager.ts
+++ b/packages/core/src/callbacks/CallbackManager.ts
@@ -1,7 +1,7 @@
 import type { Anthropic } from "@anthropic-ai/sdk";
 import { CustomEvent } from "@llamaindex/env";
 import type { NodeWithScore } from "../Node.js";
-import type { AgentEndEvent, AgentStartEvent } from "../agent/type.js";
+import type { AgentEndEvent, AgentStartEvent } from "../agent/types.js";
 import {
   EventCaller,
   getEventCaller,
diff --git a/packages/core/src/embeddings/index.ts b/packages/core/src/embeddings/index.ts
index 2b0b61468..af6492ec7 100644
--- a/packages/core/src/embeddings/index.ts
+++ b/packages/core/src/embeddings/index.ts
@@ -1,6 +1,5 @@
 export * from "./ClipEmbedding.js";
 export * from "./GeminiEmbedding.js";
-export * from "./HuggingFaceEmbedding.js";
 export * from "./JinaAIEmbedding.js";
 export * from "./MistralAIEmbedding.js";
 export * from "./MultiModalEmbedding.js";
diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts
index a3bc326f9..20f043c56 100644
--- a/packages/core/src/index.ts
+++ b/packages/core/src/index.ts
@@ -1,5 +1,10 @@
 export * from "./index.edge.js";
 export * from "./readers/index.js";
 export * from "./storage/index.js";
+// Exports modules that doesn't support non-node.js runtime
 // Ollama is only compatible with the Node.js runtime
+export {
+  HuggingFaceEmbedding,
+  HuggingFaceEmbeddingModelType,
+} from "./embeddings/HuggingFaceEmbedding.js";
 export { Ollama, type OllamaParams } from "./llm/ollama.js";
diff --git a/packages/core/src/internal/type.ts b/packages/core/src/internal/type.ts
index b93af22a0..8421d5d1f 100644
--- a/packages/core/src/internal/type.ts
+++ b/packages/core/src/internal/type.ts
@@ -1,5 +1,5 @@
-import { CustomEvent } from "@llamaindex/env";
+import type { CustomEvent } from "@llamaindex/env";
 
 export type BaseEvent<Payload extends Record<string, unknown>> = CustomEvent<{
-  payload: Payload;
+  payload: Readonly<Payload>;
 }>;
-- 
GitLab