From dbb5bd9f234030ac43446e74567e9ea32938ed42 Mon Sep 17 00:00:00 2001
From: Alex Yang <himself65@outlook.com>
Date: Tue, 12 Nov 2024 12:46:57 -0800
Subject: [PATCH] feat: allow `tool_choice` for OpenAIAgent (#1472)

---
 packages/core/src/agent/base.ts           | 85 +++++++++++++++++++----
 packages/core/src/agent/llm.ts            | 52 ++++++++++++--
 packages/core/src/agent/types.ts          | 39 +++++++++--
 packages/core/src/chat-engine/base.ts     |  4 ++
 packages/providers/anthropic/src/agent.ts |  2 +-
 packages/providers/ollama/src/agent.ts    |  2 +-
 packages/providers/openai/package.json    |  3 +-
 packages/providers/openai/src/agent.ts    | 11 ++-
 pnpm-lock.yaml                            |  3 -
 9 files changed, 166 insertions(+), 35 deletions(-)

diff --git a/packages/core/src/agent/base.ts b/packages/core/src/agent/base.ts
index 15227f18a..23ed97959 100644
--- a/packages/core/src/agent/base.ts
+++ b/packages/core/src/agent/base.ts
@@ -106,11 +106,17 @@ export type AgentRunnerParams<
   >
     ? AdditionalMessageOptions
     : never,
+  AdditionalChatOptions extends object = object,
 > = {
   llm: AI;
   chatHistory: ChatMessage<AdditionalMessageOptions>[];
   systemPrompt: MessageContent | null;
-  runner: AgentWorker<AI, Store, AdditionalMessageOptions>;
+  runner: AgentWorker<
+    AI,
+    Store,
+    AdditionalMessageOptions,
+    AdditionalChatOptions
+  >;
   tools:
     | BaseToolWithCall[]
     | ((query: MessageContent) => Promise<BaseToolWithCall[]>);
@@ -125,6 +131,7 @@ export type AgentParamsBase<
   >
     ? AdditionalMessageOptions
     : never,
+  AdditionalChatOptions extends object = object,
 > =
   | {
       llm?: AI;
@@ -132,6 +139,7 @@ export type AgentParamsBase<
       systemPrompt?: MessageContent;
       verbose?: boolean;
       tools: BaseToolWithCall[];
+      additionalChatOptions?: AdditionalChatOptions;
     }
   | {
       llm?: AI;
@@ -139,6 +147,7 @@ export type AgentParamsBase<
       systemPrompt?: MessageContent;
       verbose?: boolean;
       toolRetriever: ObjectRetriever<BaseToolWithCall>;
+      additionalChatOptions?: AdditionalChatOptions;
     };
 
 /**
@@ -153,21 +162,36 @@ export abstract class AgentWorker<
   >
     ? AdditionalMessageOptions
     : never,
+  AdditionalChatOptions extends object = object,
 > {
-  #taskSet = new Set<TaskStep<AI, Store, AdditionalMessageOptions>>();
-  abstract taskHandler: TaskHandler<AI, Store, AdditionalMessageOptions>;
+  #taskSet = new Set<
+    TaskStep<AI, Store, AdditionalMessageOptions, AdditionalChatOptions>
+  >();
+  abstract taskHandler: TaskHandler<
+    AI,
+    Store,
+    AdditionalMessageOptions,
+    AdditionalChatOptions
+  >;
 
   public createTask(
     query: MessageContent,
-    context: AgentTaskContext<AI, Store, AdditionalMessageOptions>,
-  ): ReadableStream<TaskStepOutput<AI, Store, AdditionalMessageOptions>> {
+    context: AgentTaskContext<
+      AI,
+      Store,
+      AdditionalMessageOptions,
+      AdditionalChatOptions
+    >,
+  ): ReadableStream<
+    TaskStepOutput<AI, Store, AdditionalMessageOptions, AdditionalChatOptions>
+  > {
     context.store.messages.push({
       role: "user",
       content: query,
     });
     const taskOutputStream = createTaskOutputStream(this.taskHandler, context);
     return new ReadableStream<
-      TaskStepOutput<AI, Store, AdditionalMessageOptions>
+      TaskStepOutput<AI, Store, AdditionalMessageOptions, AdditionalChatOptions>
     >({
       start: async (controller) => {
         for await (const stepOutput of taskOutputStream) {
@@ -176,7 +200,8 @@ export abstract class AgentWorker<
             let currentStep: TaskStep<
               AI,
               Store,
-              AdditionalMessageOptions
+              AdditionalMessageOptions,
+              AdditionalChatOptions
             > | null = stepOutput.taskStep;
             while (currentStep) {
               this.#taskSet.delete(currentStep);
@@ -227,6 +252,7 @@ export abstract class AgentRunner<
   >
     ? AdditionalMessageOptions
     : never,
+  AdditionalChatOptions extends object = object,
 > extends BaseChatEngine {
   readonly #llm: AI;
   readonly #tools:
@@ -234,7 +260,12 @@ export abstract class AgentRunner<
     | ((query: MessageContent) => Promise<BaseToolWithCall[]>);
   readonly #systemPrompt: MessageContent | null = null;
   #chatHistory: ChatMessage<AdditionalMessageOptions>[];
-  readonly #runner: AgentWorker<AI, Store, AdditionalMessageOptions>;
+  readonly #runner: AgentWorker<
+    AI,
+    Store,
+    AdditionalMessageOptions,
+    AdditionalChatOptions
+  >;
   readonly #verbose: boolean;
 
   // create extra store
@@ -245,7 +276,7 @@ export abstract class AgentRunner<
   }
 
   static defaultTaskHandler: TaskHandler<LLM> = async (step, enqueueOutput) => {
-    const { llm, getTools, stream } = step.context;
+    const { llm, getTools, stream, additionalChatOptions } = step.context;
     const lastMessage = step.context.store.messages.at(-1)!.content;
     const tools = await getTools(lastMessage);
     if (!stream) {
@@ -253,8 +284,9 @@ export abstract class AgentRunner<
         stream,
         tools,
         messages: [...step.context.store.messages],
+        additionalChatOptions,
       });
-      await stepTools<LLM>({
+      await stepTools({
         response,
         tools,
         step,
@@ -265,6 +297,7 @@ export abstract class AgentRunner<
         stream,
         tools,
         messages: [...step.context.store.messages],
+        additionalChatOptions,
       });
       await stepToolsStreaming<LLM>({
         response,
@@ -276,7 +309,12 @@ export abstract class AgentRunner<
   };
 
   protected constructor(
-    params: AgentRunnerParams<AI, Store, AdditionalMessageOptions>,
+    params: AgentRunnerParams<
+      AI,
+      Store,
+      AdditionalMessageOptions,
+      AdditionalChatOptions
+    >,
   ) {
     super();
     const { llm, chatHistory, systemPrompt, runner, tools, verbose } = params;
@@ -330,6 +368,7 @@ export abstract class AgentRunner<
     stream: boolean = false,
     verbose: boolean | undefined = undefined,
     chatHistory?: ChatMessage<AdditionalMessageOptions>[],
+    additionalChatOptions?: AdditionalChatOptions,
   ) {
     const initialMessages = [...(chatHistory ?? this.#chatHistory)];
     if (this.#systemPrompt !== null) {
@@ -348,6 +387,7 @@ export abstract class AgentRunner<
       stream,
       toolCallCount: 0,
       llm: this.#llm,
+      additionalChatOptions: additionalChatOptions ?? {},
       getTools: (message) => this.getTools(message),
       store: {
         ...this.createStore(),
@@ -365,13 +405,29 @@ export abstract class AgentRunner<
     });
   }
 
-  async chat(params: NonStreamingChatEngineParams): Promise<EngineResponse>;
   async chat(
-    params: StreamingChatEngineParams,
+    params: NonStreamingChatEngineParams<
+      AdditionalMessageOptions,
+      AdditionalChatOptions
+    >,
+  ): Promise<EngineResponse>;
+  async chat(
+    params: StreamingChatEngineParams<
+      AdditionalMessageOptions,
+      AdditionalChatOptions
+    >,
   ): Promise<ReadableStream<EngineResponse>>;
   @wrapEventCaller
   async chat(
-    params: NonStreamingChatEngineParams | StreamingChatEngineParams,
+    params:
+      | NonStreamingChatEngineParams<
+          AdditionalMessageOptions,
+          AdditionalChatOptions
+        >
+      | StreamingChatEngineParams<
+          AdditionalMessageOptions,
+          AdditionalChatOptions
+        >,
   ): Promise<EngineResponse | ReadableStream<EngineResponse>> {
     let chatHistory: ChatMessage<AdditionalMessageOptions>[] = [];
 
@@ -388,6 +444,7 @@ export abstract class AgentRunner<
       !!params.stream,
       false,
       chatHistory,
+      params.chatOptions,
     );
     for await (const stepOutput of task) {
       // update chat history for each round
diff --git a/packages/core/src/agent/llm.ts b/packages/core/src/agent/llm.ts
index 5050ee2a8..a04604e21 100644
--- a/packages/core/src/agent/llm.ts
+++ b/packages/core/src/agent/llm.ts
@@ -4,24 +4,66 @@ import { ObjectRetriever } from "../objects";
 import { AgentRunner, AgentWorker, type AgentParamsBase } from "./base.js";
 import { validateAgentParams } from "./utils.js";
 
-type LLMParamsBase = AgentParamsBase<LLM>;
+type LLMParamsBase<
+  AI extends LLM,
+  AdditionalMessageOptions extends object = AI extends LLM<
+    object,
+    infer AdditionalMessageOptions
+  >
+    ? AdditionalMessageOptions
+    : never,
+  AdditionalChatOptions extends object = object,
+> = AgentParamsBase<AI, AdditionalMessageOptions, AdditionalChatOptions>;
 
-type LLMParamsWithTools = LLMParamsBase & {
+type LLMParamsWithTools<
+  AI extends LLM,
+  AdditionalMessageOptions extends object = AI extends LLM<
+    object,
+    infer AdditionalMessageOptions
+  >
+    ? AdditionalMessageOptions
+    : never,
+  AdditionalChatOptions extends object = object,
+> = LLMParamsBase<AI, AdditionalMessageOptions, AdditionalChatOptions> & {
   tools: BaseToolWithCall[];
 };
 
-type LLMParamsWithToolRetriever = LLMParamsBase & {
+type LLMParamsWithToolRetriever<
+  AI extends LLM,
+  AdditionalMessageOptions extends object = AI extends LLM<
+    object,
+    infer AdditionalMessageOptions
+  >
+    ? AdditionalMessageOptions
+    : never,
+  AdditionalChatOptions extends object = object,
+> = LLMParamsBase<AI, AdditionalMessageOptions, AdditionalChatOptions> & {
   toolRetriever: ObjectRetriever<BaseToolWithCall>;
 };
 
-export type LLMAgentParams = LLMParamsWithTools | LLMParamsWithToolRetriever;
+export type LLMAgentParams<
+  AI extends LLM,
+  AdditionalMessageOptions extends object = AI extends LLM<
+    object,
+    infer AdditionalMessageOptions
+  >
+    ? AdditionalMessageOptions
+    : never,
+  AdditionalChatOptions extends object = object,
+> =
+  | LLMParamsWithTools<AI, AdditionalMessageOptions, AdditionalChatOptions>
+  | LLMParamsWithToolRetriever<
+      AI,
+      AdditionalMessageOptions,
+      AdditionalChatOptions
+    >;
 
 export class LLMAgentWorker extends AgentWorker<LLM> {
   taskHandler = AgentRunner.defaultTaskHandler;
 }
 
 export class LLMAgent extends AgentRunner<LLM> {
-  constructor(params: LLMAgentParams) {
+  constructor(params: LLMAgentParams<LLM>) {
     validateAgentParams(params);
     const llm = params.llm ?? (Settings.llm ? (Settings.llm as LLM) : null);
     if (!llm)
diff --git a/packages/core/src/agent/types.ts b/packages/core/src/agent/types.ts
index c6e7a78d1..d5063c5e1 100644
--- a/packages/core/src/agent/types.ts
+++ b/packages/core/src/agent/types.ts
@@ -19,6 +19,7 @@ export type AgentTaskContext<
   >
     ? AdditionalMessageOptions
     : never,
+  AdditionalChatOptions extends object = object,
 > = {
   readonly stream: boolean;
   readonly toolCallCount: number;
@@ -26,6 +27,7 @@ export type AgentTaskContext<
   readonly getTools: (
     input: MessageContent,
   ) => BaseToolWithCall[] | Promise<BaseToolWithCall[]>;
+  readonly additionalChatOptions: Partial<AdditionalChatOptions>;
   shouldContinue: (
     taskStep: Readonly<TaskStep<Model, Store, AdditionalMessageOptions>>,
   ) => boolean;
@@ -45,13 +47,26 @@ export type TaskStep<
   >
     ? AdditionalMessageOptions
     : never,
+  AdditionalChatOptions extends object = object,
 > = {
   id: UUID;
-  context: AgentTaskContext<Model, Store, AdditionalMessageOptions>;
+  context: AgentTaskContext<
+    Model,
+    Store,
+    AdditionalMessageOptions,
+    AdditionalChatOptions
+  >;
 
   // linked list
-  prevStep: TaskStep<Model, Store, AdditionalMessageOptions> | null;
-  nextSteps: Set<TaskStep<Model, Store, AdditionalMessageOptions>>;
+  prevStep: TaskStep<
+    Model,
+    Store,
+    AdditionalMessageOptions,
+    AdditionalChatOptions
+  > | null;
+  nextSteps: Set<
+    TaskStep<Model, Store, AdditionalMessageOptions, AdditionalChatOptions>
+  >;
 };
 
 export type TaskStepOutput<
@@ -63,8 +78,14 @@ export type TaskStepOutput<
   >
     ? AdditionalMessageOptions
     : never,
+  AdditionalChatOptions extends object = object,
 > = {
-  taskStep: TaskStep<Model, Store, AdditionalMessageOptions>;
+  taskStep: TaskStep<
+    Model,
+    Store,
+    AdditionalMessageOptions,
+    AdditionalChatOptions
+  >;
   // output shows the response to the user
   output:
     | ChatResponse<AdditionalMessageOptions>
@@ -81,10 +102,16 @@ export type TaskHandler<
   >
     ? AdditionalMessageOptions
     : never,
+  AdditionalChatOptions extends object = object,
 > = (
-  step: TaskStep<Model, Store, AdditionalMessageOptions>,
+  step: TaskStep<Model, Store, AdditionalMessageOptions, AdditionalChatOptions>,
   enqueueOutput: (
-    taskOutput: TaskStepOutput<Model, Store, AdditionalMessageOptions>,
+    taskOutput: TaskStepOutput<
+      Model,
+      Store,
+      AdditionalMessageOptions,
+      AdditionalChatOptions
+    >,
   ) => void,
 ) => Promise<void>;
 
diff --git a/packages/core/src/chat-engine/base.ts b/packages/core/src/chat-engine/base.ts
index b4bd4cf3b..77bc73500 100644
--- a/packages/core/src/chat-engine/base.ts
+++ b/packages/core/src/chat-engine/base.ts
@@ -16,14 +16,18 @@ export interface BaseChatEngineParams<
 
 export interface StreamingChatEngineParams<
   AdditionalMessageOptions extends object = object,
+  AdditionalChatOptions extends object = object,
 > extends BaseChatEngineParams<AdditionalMessageOptions> {
   stream: true;
+  chatOptions?: AdditionalChatOptions;
 }
 
 export interface NonStreamingChatEngineParams<
   AdditionalMessageOptions extends object = object,
+  AdditionalChatOptions extends object = object,
 > extends BaseChatEngineParams<AdditionalMessageOptions> {
   stream?: false;
+  chatOptions?: AdditionalChatOptions;
 }
 
 export abstract class BaseChatEngine {
diff --git a/packages/providers/anthropic/src/agent.ts b/packages/providers/anthropic/src/agent.ts
index 18c32a5a6..da2472f49 100644
--- a/packages/providers/anthropic/src/agent.ts
+++ b/packages/providers/anthropic/src/agent.ts
@@ -11,7 +11,7 @@ import { Settings } from "@llamaindex/core/global";
 import type { EngineResponse } from "@llamaindex/core/schema";
 import { Anthropic } from "./llm.js";
 
-export type AnthropicAgentParams = LLMAgentParams;
+export type AnthropicAgentParams = LLMAgentParams<Anthropic>;
 
 export class AnthropicAgentWorker extends LLMAgentWorker {}
 
diff --git a/packages/providers/ollama/src/agent.ts b/packages/providers/ollama/src/agent.ts
index 69dd75a68..a3dc739a0 100644
--- a/packages/providers/ollama/src/agent.ts
+++ b/packages/providers/ollama/src/agent.ts
@@ -8,7 +8,7 @@ import { Ollama } from "./llm";
 
 // This is likely not necessary anymore but leaving it here just incase it's in use elsewhere
 
-export type OllamaAgentParams = LLMAgentParams & {
+export type OllamaAgentParams = LLMAgentParams<Ollama> & {
   model?: string;
 };
 
diff --git a/packages/providers/openai/package.json b/packages/providers/openai/package.json
index e794cc142..f830b57cb 100644
--- a/packages/providers/openai/package.json
+++ b/packages/providers/openai/package.json
@@ -35,7 +35,6 @@
   "dependencies": {
     "@llamaindex/core": "workspace:*",
     "@llamaindex/env": "workspace:*",
-    "openai": "^4.68.1",
-    "remeda": "^2.12.0"
+    "openai": "^4.68.1"
   }
 }
diff --git a/packages/providers/openai/src/agent.ts b/packages/providers/openai/src/agent.ts
index 36c6ad66c..64b4eb028 100644
--- a/packages/providers/openai/src/agent.ts
+++ b/packages/providers/openai/src/agent.ts
@@ -4,11 +4,16 @@ import {
   type LLMAgentParams,
 } from "@llamaindex/core/agent";
 import { Settings } from "@llamaindex/core/global";
-import { OpenAI } from "./llm";
+import type { ToolCallLLMMessageOptions } from "@llamaindex/core/llms";
+import { OpenAI, type OpenAIAdditionalChatOptions } from "./llm";
 
-// This is likely not necessary anymore but leaving it here just incase it's in use elsewhere
+// This is likely not necessary anymore but leaving it here just in case it's in use elsewhere
 
-export type OpenAIAgentParams = LLMAgentParams;
+export type OpenAIAgentParams = LLMAgentParams<
+  OpenAI,
+  ToolCallLLMMessageOptions,
+  OpenAIAdditionalChatOptions
+>;
 
 export class OpenAIAgentWorker extends LLMAgentWorker {}
 
diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml
index 0c007f066..d4bd1f9a7 100644
--- a/pnpm-lock.yaml
+++ b/pnpm-lock.yaml
@@ -1310,9 +1310,6 @@ importers:
       openai:
         specifier: ^4.68.1
         version: 4.69.0(encoding@0.1.13)(zod@3.23.8)
-      remeda:
-        specifier: ^2.12.0
-        version: 2.16.0
     devDependencies:
       bunchee:
         specifier: 5.6.1
-- 
GitLab