diff --git a/packages/core/src/agent/base.ts b/packages/core/src/agent/base.ts index 15227f18adc34279bee24bf155d61c5b592e6335..23ed9795965ecc69f9db120ed3df4b2b5b207de2 100644 --- a/packages/core/src/agent/base.ts +++ b/packages/core/src/agent/base.ts @@ -106,11 +106,17 @@ export type AgentRunnerParams< > ? AdditionalMessageOptions : never, + AdditionalChatOptions extends object = object, > = { llm: AI; chatHistory: ChatMessage<AdditionalMessageOptions>[]; systemPrompt: MessageContent | null; - runner: AgentWorker<AI, Store, AdditionalMessageOptions>; + runner: AgentWorker< + AI, + Store, + AdditionalMessageOptions, + AdditionalChatOptions + >; tools: | BaseToolWithCall[] | ((query: MessageContent) => Promise<BaseToolWithCall[]>); @@ -125,6 +131,7 @@ export type AgentParamsBase< > ? AdditionalMessageOptions : never, + AdditionalChatOptions extends object = object, > = | { llm?: AI; @@ -132,6 +139,7 @@ export type AgentParamsBase< systemPrompt?: MessageContent; verbose?: boolean; tools: BaseToolWithCall[]; + additionalChatOptions?: AdditionalChatOptions; } | { llm?: AI; @@ -139,6 +147,7 @@ export type AgentParamsBase< systemPrompt?: MessageContent; verbose?: boolean; toolRetriever: ObjectRetriever<BaseToolWithCall>; + additionalChatOptions?: AdditionalChatOptions; }; /** @@ -153,21 +162,36 @@ export abstract class AgentWorker< > ? AdditionalMessageOptions : never, + AdditionalChatOptions extends object = object, > { - #taskSet = new Set<TaskStep<AI, Store, AdditionalMessageOptions>>(); - abstract taskHandler: TaskHandler<AI, Store, AdditionalMessageOptions>; + #taskSet = new Set< + TaskStep<AI, Store, AdditionalMessageOptions, AdditionalChatOptions> + >(); + abstract taskHandler: TaskHandler< + AI, + Store, + AdditionalMessageOptions, + AdditionalChatOptions + >; public createTask( query: MessageContent, - context: AgentTaskContext<AI, Store, AdditionalMessageOptions>, - ): ReadableStream<TaskStepOutput<AI, Store, AdditionalMessageOptions>> { + context: AgentTaskContext< + AI, + Store, + AdditionalMessageOptions, + AdditionalChatOptions + >, + ): ReadableStream< + TaskStepOutput<AI, Store, AdditionalMessageOptions, AdditionalChatOptions> + > { context.store.messages.push({ role: "user", content: query, }); const taskOutputStream = createTaskOutputStream(this.taskHandler, context); return new ReadableStream< - TaskStepOutput<AI, Store, AdditionalMessageOptions> + TaskStepOutput<AI, Store, AdditionalMessageOptions, AdditionalChatOptions> >({ start: async (controller) => { for await (const stepOutput of taskOutputStream) { @@ -176,7 +200,8 @@ export abstract class AgentWorker< let currentStep: TaskStep< AI, Store, - AdditionalMessageOptions + AdditionalMessageOptions, + AdditionalChatOptions > | null = stepOutput.taskStep; while (currentStep) { this.#taskSet.delete(currentStep); @@ -227,6 +252,7 @@ export abstract class AgentRunner< > ? AdditionalMessageOptions : never, + AdditionalChatOptions extends object = object, > extends BaseChatEngine { readonly #llm: AI; readonly #tools: @@ -234,7 +260,12 @@ export abstract class AgentRunner< | ((query: MessageContent) => Promise<BaseToolWithCall[]>); readonly #systemPrompt: MessageContent | null = null; #chatHistory: ChatMessage<AdditionalMessageOptions>[]; - readonly #runner: AgentWorker<AI, Store, AdditionalMessageOptions>; + readonly #runner: AgentWorker< + AI, + Store, + AdditionalMessageOptions, + AdditionalChatOptions + >; readonly #verbose: boolean; // create extra store @@ -245,7 +276,7 @@ export abstract class AgentRunner< } static defaultTaskHandler: TaskHandler<LLM> = async (step, enqueueOutput) => { - const { llm, getTools, stream } = step.context; + const { llm, getTools, stream, additionalChatOptions } = step.context; const lastMessage = step.context.store.messages.at(-1)!.content; const tools = await getTools(lastMessage); if (!stream) { @@ -253,8 +284,9 @@ export abstract class AgentRunner< stream, tools, messages: [...step.context.store.messages], + additionalChatOptions, }); - await stepTools<LLM>({ + await stepTools({ response, tools, step, @@ -265,6 +297,7 @@ export abstract class AgentRunner< stream, tools, messages: [...step.context.store.messages], + additionalChatOptions, }); await stepToolsStreaming<LLM>({ response, @@ -276,7 +309,12 @@ export abstract class AgentRunner< }; protected constructor( - params: AgentRunnerParams<AI, Store, AdditionalMessageOptions>, + params: AgentRunnerParams< + AI, + Store, + AdditionalMessageOptions, + AdditionalChatOptions + >, ) { super(); const { llm, chatHistory, systemPrompt, runner, tools, verbose } = params; @@ -330,6 +368,7 @@ export abstract class AgentRunner< stream: boolean = false, verbose: boolean | undefined = undefined, chatHistory?: ChatMessage<AdditionalMessageOptions>[], + additionalChatOptions?: AdditionalChatOptions, ) { const initialMessages = [...(chatHistory ?? this.#chatHistory)]; if (this.#systemPrompt !== null) { @@ -348,6 +387,7 @@ export abstract class AgentRunner< stream, toolCallCount: 0, llm: this.#llm, + additionalChatOptions: additionalChatOptions ?? {}, getTools: (message) => this.getTools(message), store: { ...this.createStore(), @@ -365,13 +405,29 @@ export abstract class AgentRunner< }); } - async chat(params: NonStreamingChatEngineParams): Promise<EngineResponse>; async chat( - params: StreamingChatEngineParams, + params: NonStreamingChatEngineParams< + AdditionalMessageOptions, + AdditionalChatOptions + >, + ): Promise<EngineResponse>; + async chat( + params: StreamingChatEngineParams< + AdditionalMessageOptions, + AdditionalChatOptions + >, ): Promise<ReadableStream<EngineResponse>>; @wrapEventCaller async chat( - params: NonStreamingChatEngineParams | StreamingChatEngineParams, + params: + | NonStreamingChatEngineParams< + AdditionalMessageOptions, + AdditionalChatOptions + > + | StreamingChatEngineParams< + AdditionalMessageOptions, + AdditionalChatOptions + >, ): Promise<EngineResponse | ReadableStream<EngineResponse>> { let chatHistory: ChatMessage<AdditionalMessageOptions>[] = []; @@ -388,6 +444,7 @@ export abstract class AgentRunner< !!params.stream, false, chatHistory, + params.chatOptions, ); for await (const stepOutput of task) { // update chat history for each round diff --git a/packages/core/src/agent/llm.ts b/packages/core/src/agent/llm.ts index 5050ee2a8fb180b3d63372f72a91bec3826489b1..a04604e2175adcd35b3c3edec67533aa1bd8fc90 100644 --- a/packages/core/src/agent/llm.ts +++ b/packages/core/src/agent/llm.ts @@ -4,24 +4,66 @@ import { ObjectRetriever } from "../objects"; import { AgentRunner, AgentWorker, type AgentParamsBase } from "./base.js"; import { validateAgentParams } from "./utils.js"; -type LLMParamsBase = AgentParamsBase<LLM>; +type LLMParamsBase< + AI extends LLM, + AdditionalMessageOptions extends object = AI extends LLM< + object, + infer AdditionalMessageOptions + > + ? AdditionalMessageOptions + : never, + AdditionalChatOptions extends object = object, +> = AgentParamsBase<AI, AdditionalMessageOptions, AdditionalChatOptions>; -type LLMParamsWithTools = LLMParamsBase & { +type LLMParamsWithTools< + AI extends LLM, + AdditionalMessageOptions extends object = AI extends LLM< + object, + infer AdditionalMessageOptions + > + ? AdditionalMessageOptions + : never, + AdditionalChatOptions extends object = object, +> = LLMParamsBase<AI, AdditionalMessageOptions, AdditionalChatOptions> & { tools: BaseToolWithCall[]; }; -type LLMParamsWithToolRetriever = LLMParamsBase & { +type LLMParamsWithToolRetriever< + AI extends LLM, + AdditionalMessageOptions extends object = AI extends LLM< + object, + infer AdditionalMessageOptions + > + ? AdditionalMessageOptions + : never, + AdditionalChatOptions extends object = object, +> = LLMParamsBase<AI, AdditionalMessageOptions, AdditionalChatOptions> & { toolRetriever: ObjectRetriever<BaseToolWithCall>; }; -export type LLMAgentParams = LLMParamsWithTools | LLMParamsWithToolRetriever; +export type LLMAgentParams< + AI extends LLM, + AdditionalMessageOptions extends object = AI extends LLM< + object, + infer AdditionalMessageOptions + > + ? AdditionalMessageOptions + : never, + AdditionalChatOptions extends object = object, +> = + | LLMParamsWithTools<AI, AdditionalMessageOptions, AdditionalChatOptions> + | LLMParamsWithToolRetriever< + AI, + AdditionalMessageOptions, + AdditionalChatOptions + >; export class LLMAgentWorker extends AgentWorker<LLM> { taskHandler = AgentRunner.defaultTaskHandler; } export class LLMAgent extends AgentRunner<LLM> { - constructor(params: LLMAgentParams) { + constructor(params: LLMAgentParams<LLM>) { validateAgentParams(params); const llm = params.llm ?? (Settings.llm ? (Settings.llm as LLM) : null); if (!llm) diff --git a/packages/core/src/agent/types.ts b/packages/core/src/agent/types.ts index c6e7a78d149865f2232650925c63619a2f8c7a28..d5063c5e1684e1cc5bdfde98ad81c46a3118eef1 100644 --- a/packages/core/src/agent/types.ts +++ b/packages/core/src/agent/types.ts @@ -19,6 +19,7 @@ export type AgentTaskContext< > ? AdditionalMessageOptions : never, + AdditionalChatOptions extends object = object, > = { readonly stream: boolean; readonly toolCallCount: number; @@ -26,6 +27,7 @@ export type AgentTaskContext< readonly getTools: ( input: MessageContent, ) => BaseToolWithCall[] | Promise<BaseToolWithCall[]>; + readonly additionalChatOptions: Partial<AdditionalChatOptions>; shouldContinue: ( taskStep: Readonly<TaskStep<Model, Store, AdditionalMessageOptions>>, ) => boolean; @@ -45,13 +47,26 @@ export type TaskStep< > ? AdditionalMessageOptions : never, + AdditionalChatOptions extends object = object, > = { id: UUID; - context: AgentTaskContext<Model, Store, AdditionalMessageOptions>; + context: AgentTaskContext< + Model, + Store, + AdditionalMessageOptions, + AdditionalChatOptions + >; // linked list - prevStep: TaskStep<Model, Store, AdditionalMessageOptions> | null; - nextSteps: Set<TaskStep<Model, Store, AdditionalMessageOptions>>; + prevStep: TaskStep< + Model, + Store, + AdditionalMessageOptions, + AdditionalChatOptions + > | null; + nextSteps: Set< + TaskStep<Model, Store, AdditionalMessageOptions, AdditionalChatOptions> + >; }; export type TaskStepOutput< @@ -63,8 +78,14 @@ export type TaskStepOutput< > ? AdditionalMessageOptions : never, + AdditionalChatOptions extends object = object, > = { - taskStep: TaskStep<Model, Store, AdditionalMessageOptions>; + taskStep: TaskStep< + Model, + Store, + AdditionalMessageOptions, + AdditionalChatOptions + >; // output shows the response to the user output: | ChatResponse<AdditionalMessageOptions> @@ -81,10 +102,16 @@ export type TaskHandler< > ? AdditionalMessageOptions : never, + AdditionalChatOptions extends object = object, > = ( - step: TaskStep<Model, Store, AdditionalMessageOptions>, + step: TaskStep<Model, Store, AdditionalMessageOptions, AdditionalChatOptions>, enqueueOutput: ( - taskOutput: TaskStepOutput<Model, Store, AdditionalMessageOptions>, + taskOutput: TaskStepOutput< + Model, + Store, + AdditionalMessageOptions, + AdditionalChatOptions + >, ) => void, ) => Promise<void>; diff --git a/packages/core/src/chat-engine/base.ts b/packages/core/src/chat-engine/base.ts index b4bd4cf3b1a4cba3ddccd7f33417222bcce15d7b..77bc735001dd20e7bb04c7237576f90c1c4f72ab 100644 --- a/packages/core/src/chat-engine/base.ts +++ b/packages/core/src/chat-engine/base.ts @@ -16,14 +16,18 @@ export interface BaseChatEngineParams< export interface StreamingChatEngineParams< AdditionalMessageOptions extends object = object, + AdditionalChatOptions extends object = object, > extends BaseChatEngineParams<AdditionalMessageOptions> { stream: true; + chatOptions?: AdditionalChatOptions; } export interface NonStreamingChatEngineParams< AdditionalMessageOptions extends object = object, + AdditionalChatOptions extends object = object, > extends BaseChatEngineParams<AdditionalMessageOptions> { stream?: false; + chatOptions?: AdditionalChatOptions; } export abstract class BaseChatEngine { diff --git a/packages/providers/anthropic/src/agent.ts b/packages/providers/anthropic/src/agent.ts index 18c32a5a6600eb62fac4b8c439b97138850232b2..da2472f49425fcef885b3e977dde9770b09e3bfe 100644 --- a/packages/providers/anthropic/src/agent.ts +++ b/packages/providers/anthropic/src/agent.ts @@ -11,7 +11,7 @@ import { Settings } from "@llamaindex/core/global"; import type { EngineResponse } from "@llamaindex/core/schema"; import { Anthropic } from "./llm.js"; -export type AnthropicAgentParams = LLMAgentParams; +export type AnthropicAgentParams = LLMAgentParams<Anthropic>; export class AnthropicAgentWorker extends LLMAgentWorker {} diff --git a/packages/providers/ollama/src/agent.ts b/packages/providers/ollama/src/agent.ts index 69dd75a68c4b0f5301461c6ca7f9760590833746..a3dc739a05008df4e443767cc5617d0867eb76d5 100644 --- a/packages/providers/ollama/src/agent.ts +++ b/packages/providers/ollama/src/agent.ts @@ -8,7 +8,7 @@ import { Ollama } from "./llm"; // This is likely not necessary anymore but leaving it here just incase it's in use elsewhere -export type OllamaAgentParams = LLMAgentParams & { +export type OllamaAgentParams = LLMAgentParams<Ollama> & { model?: string; }; diff --git a/packages/providers/openai/package.json b/packages/providers/openai/package.json index e794cc142a89e284a320a91c9ca4cffe99513168..f830b57cb18c44ec7a3c53fb3eb141b0e3c10bdc 100644 --- a/packages/providers/openai/package.json +++ b/packages/providers/openai/package.json @@ -35,7 +35,6 @@ "dependencies": { "@llamaindex/core": "workspace:*", "@llamaindex/env": "workspace:*", - "openai": "^4.68.1", - "remeda": "^2.12.0" + "openai": "^4.68.1" } } diff --git a/packages/providers/openai/src/agent.ts b/packages/providers/openai/src/agent.ts index 36c6ad66c7f057d93fb3524e7cd7b0a727e44de2..64b4eb028cac75ec3b36ce209792db6c00c29aa5 100644 --- a/packages/providers/openai/src/agent.ts +++ b/packages/providers/openai/src/agent.ts @@ -4,11 +4,16 @@ import { type LLMAgentParams, } from "@llamaindex/core/agent"; import { Settings } from "@llamaindex/core/global"; -import { OpenAI } from "./llm"; +import type { ToolCallLLMMessageOptions } from "@llamaindex/core/llms"; +import { OpenAI, type OpenAIAdditionalChatOptions } from "./llm"; -// This is likely not necessary anymore but leaving it here just incase it's in use elsewhere +// This is likely not necessary anymore but leaving it here just in case it's in use elsewhere -export type OpenAIAgentParams = LLMAgentParams; +export type OpenAIAgentParams = LLMAgentParams< + OpenAI, + ToolCallLLMMessageOptions, + OpenAIAdditionalChatOptions +>; export class OpenAIAgentWorker extends LLMAgentWorker {} diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 0c007f066d3cf817cb25906bc300e547f6eb3e21..d4bd1f9a7c6dc67ee3b70bf33b7b103dd3e28553 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -1310,9 +1310,6 @@ importers: openai: specifier: ^4.68.1 version: 4.69.0(encoding@0.1.13)(zod@3.23.8) - remeda: - specifier: ^2.12.0 - version: 2.16.0 devDependencies: bunchee: specifier: 5.6.1