From 7488d3c235e49f47061c4d4e700a2e0a3beedd17 Mon Sep 17 00:00:00 2001 From: Alex Yang <himself65@outlook.com> Date: Fri, 26 Apr 2024 18:13:05 -0500 Subject: [PATCH] fix: agent callback with step infomation (#774) --- packages/core/src/agent/anthropic.ts | 2 +- packages/core/src/agent/base.ts | 107 ++++-------------- packages/core/src/agent/openai.ts | 8 +- packages/core/src/agent/react.ts | 8 +- packages/core/src/agent/type.ts | 4 - packages/core/src/agent/types.ts | 99 ++++++++++++++++ .../core/src/callbacks/CallbackManager.ts | 2 +- packages/core/src/embeddings/index.ts | 1 - packages/core/src/index.ts | 5 + packages/core/src/internal/type.ts | 4 +- 10 files changed, 131 insertions(+), 109 deletions(-) delete mode 100644 packages/core/src/agent/type.ts create mode 100644 packages/core/src/agent/types.ts diff --git a/packages/core/src/agent/anthropic.ts b/packages/core/src/agent/anthropic.ts index 299344767..231795d89 100644 --- a/packages/core/src/agent/anthropic.ts +++ b/packages/core/src/agent/anthropic.ts @@ -13,8 +13,8 @@ import { AgentWorker, type AgentChatResponse, type AgentParamsBase, - type TaskHandler, } from "./base.js"; +import type { TaskHandler } from "./types.js"; import { callTool } from "./utils.js"; type AnthropicParamsBase = AgentParamsBase<Anthropic>; diff --git a/packages/core/src/agent/base.ts b/packages/core/src/agent/base.ts index 0a65937e1..019518a95 100644 --- a/packages/core/src/agent/base.ts +++ b/packages/core/src/agent/base.ts @@ -15,94 +15,17 @@ import type { MessageContent, } from "../llm/index.js"; import { extractText } from "../llm/utils.js"; -import type { BaseToolWithCall, ToolOutput, UUID } from "../types.js"; +import type { BaseToolWithCall, ToolOutput } from "../types.js"; +import type { + AgentTaskContext, + TaskHandler, + TaskStep, + TaskStepOutput, +} from "./types.js"; import { consumeAsyncIterable } from "./utils.js"; export const MAX_TOOL_CALLS = 10; -export type AgentTaskContext< - Model extends LLM, - Store extends object = {}, - AdditionalMessageOptions extends object = Model extends LLM< - object, - infer AdditionalMessageOptions - > - ? AdditionalMessageOptions - : never, -> = { - readonly stream: boolean; - readonly toolCallCount: number; - readonly llm: Model; - readonly getTools: ( - input: MessageContent, - ) => BaseToolWithCall[] | Promise<BaseToolWithCall[]>; - shouldContinue: ( - taskStep: Readonly<TaskStep<Model, Store, AdditionalMessageOptions>>, - ) => boolean; - store: { - toolOutputs: ToolOutput[]; - messages: ChatMessage<AdditionalMessageOptions>[]; - } & Store; -}; - -export type TaskStep< - Model extends LLM, - Store extends object = {}, - AdditionalMessageOptions extends object = Model extends LLM< - object, - infer AdditionalMessageOptions - > - ? AdditionalMessageOptions - : never, -> = { - id: UUID; - input: ChatMessage<AdditionalMessageOptions> | null; - context: AgentTaskContext<Model, Store, AdditionalMessageOptions>; - - // linked list - prevStep: TaskStep<Model, Store, AdditionalMessageOptions> | null; - nextSteps: Set<TaskStep<Model, Store, AdditionalMessageOptions>>; -}; - -export type TaskStepOutput< - Model extends LLM, - Store extends object = {}, - AdditionalMessageOptions extends object = Model extends LLM< - object, - infer AdditionalMessageOptions - > - ? AdditionalMessageOptions - : never, -> = - | { - taskStep: TaskStep<Model, Store, AdditionalMessageOptions>; - output: - | null - | ChatResponse<AdditionalMessageOptions> - | ReadableStream<ChatResponseChunk<AdditionalMessageOptions>>; - isLast: false; - } - | { - taskStep: TaskStep<Model, Store, AdditionalMessageOptions>; - output: - | ChatResponse<AdditionalMessageOptions> - | ReadableStream<ChatResponseChunk<AdditionalMessageOptions>>; - isLast: true; - }; - -export type TaskHandler< - Model extends LLM, - Store extends object = {}, - AdditionalMessageOptions extends object = Model extends LLM< - object, - infer AdditionalMessageOptions - > - ? AdditionalMessageOptions - : never, -> = ( - step: TaskStep<Model, Store, AdditionalMessageOptions>, -) => Promise<TaskStepOutput<Model, Store, AdditionalMessageOptions>>; - /** * @internal */ @@ -120,6 +43,7 @@ export async function* createTaskImpl< context: AgentTaskContext<Model, Store, AdditionalMessageOptions>, _input: ChatMessage<AdditionalMessageOptions>, ): AsyncGenerator<TaskStepOutput<Model, Store, AdditionalMessageOptions>> { + let isFirst = true; let isDone = false; let input: ChatMessage<AdditionalMessageOptions> | null = _input; let prevStep: TaskStep<Model, Store, AdditionalMessageOptions> | null = null; @@ -138,9 +62,14 @@ export async function* createTaskImpl< if (!step.context.shouldContinue(step)) { throw new Error("Tool call count exceeded limit"); } - getCallbackManager().dispatchEvent("agent-start", { - payload: {}, - }); + if (isFirst) { + getCallbackManager().dispatchEvent("agent-start", { + payload: { + startStep: step, + }, + }); + isFirst = false; + } const taskOutput = await handler(step); const { isLast, output, taskStep } = taskOutput; // do not consume last output @@ -163,7 +92,9 @@ export async function* createTaskImpl< if (isLast) { isDone = true; getCallbackManager().dispatchEvent("agent-end", { - payload: {}, + payload: { + endStep: step, + }, }); } prevStep = taskStep; diff --git a/packages/core/src/agent/openai.ts b/packages/core/src/agent/openai.ts index 89a707724..bf753ad8a 100644 --- a/packages/core/src/agent/openai.ts +++ b/packages/core/src/agent/openai.ts @@ -9,12 +9,8 @@ import type { import { OpenAI } from "../llm/openai.js"; import { ObjectRetriever } from "../objects/index.js"; import type { BaseToolWithCall } from "../types.js"; -import { - AgentRunner, - AgentWorker, - type AgentParamsBase, - type TaskHandler, -} from "./base.js"; +import { AgentRunner, AgentWorker, type AgentParamsBase } from "./base.js"; +import type { TaskHandler } from "./types.js"; import { callTool } from "./utils.js"; type OpenAIParamsBase = AgentParamsBase<OpenAI>; diff --git a/packages/core/src/agent/react.ts b/packages/core/src/agent/react.ts index 54ed4190d..651a30781 100644 --- a/packages/core/src/agent/react.ts +++ b/packages/core/src/agent/react.ts @@ -19,12 +19,8 @@ import type { JSONObject, JSONValue, } from "../types.js"; -import { - AgentRunner, - AgentWorker, - type AgentParamsBase, - type TaskHandler, -} from "./base.js"; +import { AgentRunner, AgentWorker, type AgentParamsBase } from "./base.js"; +import type { TaskHandler } from "./types.js"; import { callTool, consumeAsyncIterable, diff --git a/packages/core/src/agent/type.ts b/packages/core/src/agent/type.ts deleted file mode 100644 index 38d974cf0..000000000 --- a/packages/core/src/agent/type.ts +++ /dev/null @@ -1,4 +0,0 @@ -import type { BaseEvent } from "../internal/type.js"; - -export type AgentStartEvent = BaseEvent<{}>; -export type AgentEndEvent = BaseEvent<{}>; diff --git a/packages/core/src/agent/types.ts b/packages/core/src/agent/types.ts new file mode 100644 index 000000000..c49dac0c3 --- /dev/null +++ b/packages/core/src/agent/types.ts @@ -0,0 +1,99 @@ +import type { BaseEvent } from "../internal/type.js"; +import type { + ChatMessage, + ChatResponse, + ChatResponseChunk, + LLM, + MessageContent, +} from "../llm/types.js"; +import type { BaseToolWithCall, ToolOutput, UUID } from "../types.js"; + +export type AgentTaskContext< + Model extends LLM, + Store extends object = {}, + AdditionalMessageOptions extends object = Model extends LLM< + object, + infer AdditionalMessageOptions + > + ? AdditionalMessageOptions + : never, +> = { + readonly stream: boolean; + readonly toolCallCount: number; + readonly llm: Model; + readonly getTools: ( + input: MessageContent, + ) => BaseToolWithCall[] | Promise<BaseToolWithCall[]>; + shouldContinue: ( + taskStep: Readonly<TaskStep<Model, Store, AdditionalMessageOptions>>, + ) => boolean; + store: { + toolOutputs: ToolOutput[]; + messages: ChatMessage<AdditionalMessageOptions>[]; + } & Store; +}; + +export type TaskStep< + Model extends LLM = LLM, + Store extends object = {}, + AdditionalMessageOptions extends object = Model extends LLM< + object, + infer AdditionalMessageOptions + > + ? AdditionalMessageOptions + : never, +> = { + id: UUID; + input: ChatMessage<AdditionalMessageOptions> | null; + context: AgentTaskContext<Model, Store, AdditionalMessageOptions>; + + // linked list + prevStep: TaskStep<Model, Store, AdditionalMessageOptions> | null; + nextSteps: Set<TaskStep<Model, Store, AdditionalMessageOptions>>; +}; + +export type TaskStepOutput< + Model extends LLM, + Store extends object = {}, + AdditionalMessageOptions extends object = Model extends LLM< + object, + infer AdditionalMessageOptions + > + ? AdditionalMessageOptions + : never, +> = + | { + taskStep: TaskStep<Model, Store, AdditionalMessageOptions>; + output: + | null + | ChatResponse<AdditionalMessageOptions> + | ReadableStream<ChatResponseChunk<AdditionalMessageOptions>>; + isLast: false; + } + | { + taskStep: TaskStep<Model, Store, AdditionalMessageOptions>; + output: + | ChatResponse<AdditionalMessageOptions> + | ReadableStream<ChatResponseChunk<AdditionalMessageOptions>>; + isLast: true; + }; + +export type TaskHandler< + Model extends LLM, + Store extends object = {}, + AdditionalMessageOptions extends object = Model extends LLM< + object, + infer AdditionalMessageOptions + > + ? AdditionalMessageOptions + : never, +> = ( + step: TaskStep<Model, Store, AdditionalMessageOptions>, +) => Promise<TaskStepOutput<Model, Store, AdditionalMessageOptions>>; + +export type AgentStartEvent = BaseEvent<{ + startStep: TaskStep; +}>; +export type AgentEndEvent = BaseEvent<{ + endStep: TaskStep; +}>; diff --git a/packages/core/src/callbacks/CallbackManager.ts b/packages/core/src/callbacks/CallbackManager.ts index 81bdc10b9..4646e5254 100644 --- a/packages/core/src/callbacks/CallbackManager.ts +++ b/packages/core/src/callbacks/CallbackManager.ts @@ -1,7 +1,7 @@ import type { Anthropic } from "@anthropic-ai/sdk"; import { CustomEvent } from "@llamaindex/env"; import type { NodeWithScore } from "../Node.js"; -import type { AgentEndEvent, AgentStartEvent } from "../agent/type.js"; +import type { AgentEndEvent, AgentStartEvent } from "../agent/types.js"; import { EventCaller, getEventCaller, diff --git a/packages/core/src/embeddings/index.ts b/packages/core/src/embeddings/index.ts index 2b0b61468..af6492ec7 100644 --- a/packages/core/src/embeddings/index.ts +++ b/packages/core/src/embeddings/index.ts @@ -1,6 +1,5 @@ export * from "./ClipEmbedding.js"; export * from "./GeminiEmbedding.js"; -export * from "./HuggingFaceEmbedding.js"; export * from "./JinaAIEmbedding.js"; export * from "./MistralAIEmbedding.js"; export * from "./MultiModalEmbedding.js"; diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index a3bc326f9..20f043c56 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -1,5 +1,10 @@ export * from "./index.edge.js"; export * from "./readers/index.js"; export * from "./storage/index.js"; +// Exports modules that doesn't support non-node.js runtime // Ollama is only compatible with the Node.js runtime +export { + HuggingFaceEmbedding, + HuggingFaceEmbeddingModelType, +} from "./embeddings/HuggingFaceEmbedding.js"; export { Ollama, type OllamaParams } from "./llm/ollama.js"; diff --git a/packages/core/src/internal/type.ts b/packages/core/src/internal/type.ts index b93af22a0..8421d5d1f 100644 --- a/packages/core/src/internal/type.ts +++ b/packages/core/src/internal/type.ts @@ -1,5 +1,5 @@ -import { CustomEvent } from "@llamaindex/env"; +import type { CustomEvent } from "@llamaindex/env"; export type BaseEvent<Payload extends Record<string, unknown>> = CustomEvent<{ - payload: Payload; + payload: Readonly<Payload>; }>; -- GitLab