From dbb5bd9f234030ac43446e74567e9ea32938ed42 Mon Sep 17 00:00:00 2001 From: Alex Yang <himself65@outlook.com> Date: Tue, 12 Nov 2024 12:46:57 -0800 Subject: [PATCH] feat: allow `tool_choice` for OpenAIAgent (#1472) --- packages/core/src/agent/base.ts | 85 +++++++++++++++++++---- packages/core/src/agent/llm.ts | 52 ++++++++++++-- packages/core/src/agent/types.ts | 39 +++++++++-- packages/core/src/chat-engine/base.ts | 4 ++ packages/providers/anthropic/src/agent.ts | 2 +- packages/providers/ollama/src/agent.ts | 2 +- packages/providers/openai/package.json | 3 +- packages/providers/openai/src/agent.ts | 11 ++- pnpm-lock.yaml | 3 - 9 files changed, 166 insertions(+), 35 deletions(-) diff --git a/packages/core/src/agent/base.ts b/packages/core/src/agent/base.ts index 15227f18a..23ed97959 100644 --- a/packages/core/src/agent/base.ts +++ b/packages/core/src/agent/base.ts @@ -106,11 +106,17 @@ export type AgentRunnerParams< > ? AdditionalMessageOptions : never, + AdditionalChatOptions extends object = object, > = { llm: AI; chatHistory: ChatMessage<AdditionalMessageOptions>[]; systemPrompt: MessageContent | null; - runner: AgentWorker<AI, Store, AdditionalMessageOptions>; + runner: AgentWorker< + AI, + Store, + AdditionalMessageOptions, + AdditionalChatOptions + >; tools: | BaseToolWithCall[] | ((query: MessageContent) => Promise<BaseToolWithCall[]>); @@ -125,6 +131,7 @@ export type AgentParamsBase< > ? AdditionalMessageOptions : never, + AdditionalChatOptions extends object = object, > = | { llm?: AI; @@ -132,6 +139,7 @@ export type AgentParamsBase< systemPrompt?: MessageContent; verbose?: boolean; tools: BaseToolWithCall[]; + additionalChatOptions?: AdditionalChatOptions; } | { llm?: AI; @@ -139,6 +147,7 @@ export type AgentParamsBase< systemPrompt?: MessageContent; verbose?: boolean; toolRetriever: ObjectRetriever<BaseToolWithCall>; + additionalChatOptions?: AdditionalChatOptions; }; /** @@ -153,21 +162,36 @@ export abstract class AgentWorker< > ? AdditionalMessageOptions : never, + AdditionalChatOptions extends object = object, > { - #taskSet = new Set<TaskStep<AI, Store, AdditionalMessageOptions>>(); - abstract taskHandler: TaskHandler<AI, Store, AdditionalMessageOptions>; + #taskSet = new Set< + TaskStep<AI, Store, AdditionalMessageOptions, AdditionalChatOptions> + >(); + abstract taskHandler: TaskHandler< + AI, + Store, + AdditionalMessageOptions, + AdditionalChatOptions + >; public createTask( query: MessageContent, - context: AgentTaskContext<AI, Store, AdditionalMessageOptions>, - ): ReadableStream<TaskStepOutput<AI, Store, AdditionalMessageOptions>> { + context: AgentTaskContext< + AI, + Store, + AdditionalMessageOptions, + AdditionalChatOptions + >, + ): ReadableStream< + TaskStepOutput<AI, Store, AdditionalMessageOptions, AdditionalChatOptions> + > { context.store.messages.push({ role: "user", content: query, }); const taskOutputStream = createTaskOutputStream(this.taskHandler, context); return new ReadableStream< - TaskStepOutput<AI, Store, AdditionalMessageOptions> + TaskStepOutput<AI, Store, AdditionalMessageOptions, AdditionalChatOptions> >({ start: async (controller) => { for await (const stepOutput of taskOutputStream) { @@ -176,7 +200,8 @@ export abstract class AgentWorker< let currentStep: TaskStep< AI, Store, - AdditionalMessageOptions + AdditionalMessageOptions, + AdditionalChatOptions > | null = stepOutput.taskStep; while (currentStep) { this.#taskSet.delete(currentStep); @@ -227,6 +252,7 @@ export abstract class AgentRunner< > ? AdditionalMessageOptions : never, + AdditionalChatOptions extends object = object, > extends BaseChatEngine { readonly #llm: AI; readonly #tools: @@ -234,7 +260,12 @@ export abstract class AgentRunner< | ((query: MessageContent) => Promise<BaseToolWithCall[]>); readonly #systemPrompt: MessageContent | null = null; #chatHistory: ChatMessage<AdditionalMessageOptions>[]; - readonly #runner: AgentWorker<AI, Store, AdditionalMessageOptions>; + readonly #runner: AgentWorker< + AI, + Store, + AdditionalMessageOptions, + AdditionalChatOptions + >; readonly #verbose: boolean; // create extra store @@ -245,7 +276,7 @@ export abstract class AgentRunner< } static defaultTaskHandler: TaskHandler<LLM> = async (step, enqueueOutput) => { - const { llm, getTools, stream } = step.context; + const { llm, getTools, stream, additionalChatOptions } = step.context; const lastMessage = step.context.store.messages.at(-1)!.content; const tools = await getTools(lastMessage); if (!stream) { @@ -253,8 +284,9 @@ export abstract class AgentRunner< stream, tools, messages: [...step.context.store.messages], + additionalChatOptions, }); - await stepTools<LLM>({ + await stepTools({ response, tools, step, @@ -265,6 +297,7 @@ export abstract class AgentRunner< stream, tools, messages: [...step.context.store.messages], + additionalChatOptions, }); await stepToolsStreaming<LLM>({ response, @@ -276,7 +309,12 @@ export abstract class AgentRunner< }; protected constructor( - params: AgentRunnerParams<AI, Store, AdditionalMessageOptions>, + params: AgentRunnerParams< + AI, + Store, + AdditionalMessageOptions, + AdditionalChatOptions + >, ) { super(); const { llm, chatHistory, systemPrompt, runner, tools, verbose } = params; @@ -330,6 +368,7 @@ export abstract class AgentRunner< stream: boolean = false, verbose: boolean | undefined = undefined, chatHistory?: ChatMessage<AdditionalMessageOptions>[], + additionalChatOptions?: AdditionalChatOptions, ) { const initialMessages = [...(chatHistory ?? this.#chatHistory)]; if (this.#systemPrompt !== null) { @@ -348,6 +387,7 @@ export abstract class AgentRunner< stream, toolCallCount: 0, llm: this.#llm, + additionalChatOptions: additionalChatOptions ?? {}, getTools: (message) => this.getTools(message), store: { ...this.createStore(), @@ -365,13 +405,29 @@ export abstract class AgentRunner< }); } - async chat(params: NonStreamingChatEngineParams): Promise<EngineResponse>; async chat( - params: StreamingChatEngineParams, + params: NonStreamingChatEngineParams< + AdditionalMessageOptions, + AdditionalChatOptions + >, + ): Promise<EngineResponse>; + async chat( + params: StreamingChatEngineParams< + AdditionalMessageOptions, + AdditionalChatOptions + >, ): Promise<ReadableStream<EngineResponse>>; @wrapEventCaller async chat( - params: NonStreamingChatEngineParams | StreamingChatEngineParams, + params: + | NonStreamingChatEngineParams< + AdditionalMessageOptions, + AdditionalChatOptions + > + | StreamingChatEngineParams< + AdditionalMessageOptions, + AdditionalChatOptions + >, ): Promise<EngineResponse | ReadableStream<EngineResponse>> { let chatHistory: ChatMessage<AdditionalMessageOptions>[] = []; @@ -388,6 +444,7 @@ export abstract class AgentRunner< !!params.stream, false, chatHistory, + params.chatOptions, ); for await (const stepOutput of task) { // update chat history for each round diff --git a/packages/core/src/agent/llm.ts b/packages/core/src/agent/llm.ts index 5050ee2a8..a04604e21 100644 --- a/packages/core/src/agent/llm.ts +++ b/packages/core/src/agent/llm.ts @@ -4,24 +4,66 @@ import { ObjectRetriever } from "../objects"; import { AgentRunner, AgentWorker, type AgentParamsBase } from "./base.js"; import { validateAgentParams } from "./utils.js"; -type LLMParamsBase = AgentParamsBase<LLM>; +type LLMParamsBase< + AI extends LLM, + AdditionalMessageOptions extends object = AI extends LLM< + object, + infer AdditionalMessageOptions + > + ? AdditionalMessageOptions + : never, + AdditionalChatOptions extends object = object, +> = AgentParamsBase<AI, AdditionalMessageOptions, AdditionalChatOptions>; -type LLMParamsWithTools = LLMParamsBase & { +type LLMParamsWithTools< + AI extends LLM, + AdditionalMessageOptions extends object = AI extends LLM< + object, + infer AdditionalMessageOptions + > + ? AdditionalMessageOptions + : never, + AdditionalChatOptions extends object = object, +> = LLMParamsBase<AI, AdditionalMessageOptions, AdditionalChatOptions> & { tools: BaseToolWithCall[]; }; -type LLMParamsWithToolRetriever = LLMParamsBase & { +type LLMParamsWithToolRetriever< + AI extends LLM, + AdditionalMessageOptions extends object = AI extends LLM< + object, + infer AdditionalMessageOptions + > + ? AdditionalMessageOptions + : never, + AdditionalChatOptions extends object = object, +> = LLMParamsBase<AI, AdditionalMessageOptions, AdditionalChatOptions> & { toolRetriever: ObjectRetriever<BaseToolWithCall>; }; -export type LLMAgentParams = LLMParamsWithTools | LLMParamsWithToolRetriever; +export type LLMAgentParams< + AI extends LLM, + AdditionalMessageOptions extends object = AI extends LLM< + object, + infer AdditionalMessageOptions + > + ? AdditionalMessageOptions + : never, + AdditionalChatOptions extends object = object, +> = + | LLMParamsWithTools<AI, AdditionalMessageOptions, AdditionalChatOptions> + | LLMParamsWithToolRetriever< + AI, + AdditionalMessageOptions, + AdditionalChatOptions + >; export class LLMAgentWorker extends AgentWorker<LLM> { taskHandler = AgentRunner.defaultTaskHandler; } export class LLMAgent extends AgentRunner<LLM> { - constructor(params: LLMAgentParams) { + constructor(params: LLMAgentParams<LLM>) { validateAgentParams(params); const llm = params.llm ?? (Settings.llm ? (Settings.llm as LLM) : null); if (!llm) diff --git a/packages/core/src/agent/types.ts b/packages/core/src/agent/types.ts index c6e7a78d1..d5063c5e1 100644 --- a/packages/core/src/agent/types.ts +++ b/packages/core/src/agent/types.ts @@ -19,6 +19,7 @@ export type AgentTaskContext< > ? AdditionalMessageOptions : never, + AdditionalChatOptions extends object = object, > = { readonly stream: boolean; readonly toolCallCount: number; @@ -26,6 +27,7 @@ export type AgentTaskContext< readonly getTools: ( input: MessageContent, ) => BaseToolWithCall[] | Promise<BaseToolWithCall[]>; + readonly additionalChatOptions: Partial<AdditionalChatOptions>; shouldContinue: ( taskStep: Readonly<TaskStep<Model, Store, AdditionalMessageOptions>>, ) => boolean; @@ -45,13 +47,26 @@ export type TaskStep< > ? AdditionalMessageOptions : never, + AdditionalChatOptions extends object = object, > = { id: UUID; - context: AgentTaskContext<Model, Store, AdditionalMessageOptions>; + context: AgentTaskContext< + Model, + Store, + AdditionalMessageOptions, + AdditionalChatOptions + >; // linked list - prevStep: TaskStep<Model, Store, AdditionalMessageOptions> | null; - nextSteps: Set<TaskStep<Model, Store, AdditionalMessageOptions>>; + prevStep: TaskStep< + Model, + Store, + AdditionalMessageOptions, + AdditionalChatOptions + > | null; + nextSteps: Set< + TaskStep<Model, Store, AdditionalMessageOptions, AdditionalChatOptions> + >; }; export type TaskStepOutput< @@ -63,8 +78,14 @@ export type TaskStepOutput< > ? AdditionalMessageOptions : never, + AdditionalChatOptions extends object = object, > = { - taskStep: TaskStep<Model, Store, AdditionalMessageOptions>; + taskStep: TaskStep< + Model, + Store, + AdditionalMessageOptions, + AdditionalChatOptions + >; // output shows the response to the user output: | ChatResponse<AdditionalMessageOptions> @@ -81,10 +102,16 @@ export type TaskHandler< > ? AdditionalMessageOptions : never, + AdditionalChatOptions extends object = object, > = ( - step: TaskStep<Model, Store, AdditionalMessageOptions>, + step: TaskStep<Model, Store, AdditionalMessageOptions, AdditionalChatOptions>, enqueueOutput: ( - taskOutput: TaskStepOutput<Model, Store, AdditionalMessageOptions>, + taskOutput: TaskStepOutput< + Model, + Store, + AdditionalMessageOptions, + AdditionalChatOptions + >, ) => void, ) => Promise<void>; diff --git a/packages/core/src/chat-engine/base.ts b/packages/core/src/chat-engine/base.ts index b4bd4cf3b..77bc73500 100644 --- a/packages/core/src/chat-engine/base.ts +++ b/packages/core/src/chat-engine/base.ts @@ -16,14 +16,18 @@ export interface BaseChatEngineParams< export interface StreamingChatEngineParams< AdditionalMessageOptions extends object = object, + AdditionalChatOptions extends object = object, > extends BaseChatEngineParams<AdditionalMessageOptions> { stream: true; + chatOptions?: AdditionalChatOptions; } export interface NonStreamingChatEngineParams< AdditionalMessageOptions extends object = object, + AdditionalChatOptions extends object = object, > extends BaseChatEngineParams<AdditionalMessageOptions> { stream?: false; + chatOptions?: AdditionalChatOptions; } export abstract class BaseChatEngine { diff --git a/packages/providers/anthropic/src/agent.ts b/packages/providers/anthropic/src/agent.ts index 18c32a5a6..da2472f49 100644 --- a/packages/providers/anthropic/src/agent.ts +++ b/packages/providers/anthropic/src/agent.ts @@ -11,7 +11,7 @@ import { Settings } from "@llamaindex/core/global"; import type { EngineResponse } from "@llamaindex/core/schema"; import { Anthropic } from "./llm.js"; -export type AnthropicAgentParams = LLMAgentParams; +export type AnthropicAgentParams = LLMAgentParams<Anthropic>; export class AnthropicAgentWorker extends LLMAgentWorker {} diff --git a/packages/providers/ollama/src/agent.ts b/packages/providers/ollama/src/agent.ts index 69dd75a68..a3dc739a0 100644 --- a/packages/providers/ollama/src/agent.ts +++ b/packages/providers/ollama/src/agent.ts @@ -8,7 +8,7 @@ import { Ollama } from "./llm"; // This is likely not necessary anymore but leaving it here just incase it's in use elsewhere -export type OllamaAgentParams = LLMAgentParams & { +export type OllamaAgentParams = LLMAgentParams<Ollama> & { model?: string; }; diff --git a/packages/providers/openai/package.json b/packages/providers/openai/package.json index e794cc142..f830b57cb 100644 --- a/packages/providers/openai/package.json +++ b/packages/providers/openai/package.json @@ -35,7 +35,6 @@ "dependencies": { "@llamaindex/core": "workspace:*", "@llamaindex/env": "workspace:*", - "openai": "^4.68.1", - "remeda": "^2.12.0" + "openai": "^4.68.1" } } diff --git a/packages/providers/openai/src/agent.ts b/packages/providers/openai/src/agent.ts index 36c6ad66c..64b4eb028 100644 --- a/packages/providers/openai/src/agent.ts +++ b/packages/providers/openai/src/agent.ts @@ -4,11 +4,16 @@ import { type LLMAgentParams, } from "@llamaindex/core/agent"; import { Settings } from "@llamaindex/core/global"; -import { OpenAI } from "./llm"; +import type { ToolCallLLMMessageOptions } from "@llamaindex/core/llms"; +import { OpenAI, type OpenAIAdditionalChatOptions } from "./llm"; -// This is likely not necessary anymore but leaving it here just incase it's in use elsewhere +// This is likely not necessary anymore but leaving it here just in case it's in use elsewhere -export type OpenAIAgentParams = LLMAgentParams; +export type OpenAIAgentParams = LLMAgentParams< + OpenAI, + ToolCallLLMMessageOptions, + OpenAIAdditionalChatOptions +>; export class OpenAIAgentWorker extends LLMAgentWorker {} diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 0c007f066..d4bd1f9a7 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -1310,9 +1310,6 @@ importers: openai: specifier: ^4.68.1 version: 4.69.0(encoding@0.1.13)(zod@3.23.8) - remeda: - specifier: ^2.12.0 - version: 2.16.0 devDependencies: bunchee: specifier: 5.6.1 -- GitLab