diff --git a/examples/agent/multi_document_agent.ts b/examples/agent/multi_document_agent.ts index 0a70ecacb865c4eadbea5fff725ee1921ebe0653..d4d0407add2b7fd7958d6a8739e1dc0815f65f4d 100644 --- a/examples/agent/multi_document_agent.ts +++ b/examples/agent/multi_document_agent.ts @@ -86,7 +86,6 @@ async function main() { const agent = new OpenAIAgent({ tools: queryEngineTools, llm: new OpenAI({ model: "gpt-4" }), - verbose: true, }); documentAgents[title] = agent; @@ -126,7 +125,6 @@ async function main() { const topAgent = new OpenAIAgent({ toolRetriever: await objectIndex.asRetriever({}), llm: new OpenAI({ model: "gpt-4" }), - verbose: true, prefixMessages: [ { content: diff --git a/examples/agent/openai.ts b/examples/agent/openai.ts index 62b8e35cfc55f4022152d65a8f67d6506f0b5816..3e19011995ba117db71e5adcd0c892b38560005b 100644 --- a/examples/agent/openai.ts +++ b/examples/agent/openai.ts @@ -59,7 +59,6 @@ async function main() { // Create an OpenAIAgent with the function tools const agent = new OpenAIAgent({ tools: [functionTool, functionTool2], - verbose: true, }); // Chat with the agent diff --git a/examples/agent/query_openai_agent.ts b/examples/agent/query_openai_agent.ts index 69281f4b5c3ce0f368524382433bcd3cdbe67309..4521c9921cd0514874cf294c98af35959659ebd0 100644 --- a/examples/agent/query_openai_agent.ts +++ b/examples/agent/query_openai_agent.ts @@ -29,7 +29,6 @@ async function main() { // Create an OpenAIAgent with the function tools const agent = new OpenAIAgent({ tools: [queryEngineTool], - verbose: true, }); // Chat with the agent diff --git a/examples/agent/react_agent.ts b/examples/agent/react_agent.ts index bfd8b82219e74166e42bcc1f0926965c5cb970ec..a177b3835bbe26f19ef6783259fd12ad26d200f1 100644 --- a/examples/agent/react_agent.ts +++ b/examples/agent/react_agent.ts @@ -65,7 +65,6 @@ async function main() { const agent = new ReActAgent({ llm: anthropic, tools: [functionTool, functionTool2], - verbose: true, }); // Chat with the agent diff --git a/examples/agent/step_wise_openai.ts b/examples/agent/step_wise_openai.ts index 853b577fe1ac17bf793b8a77736b7fcf51b20728..f5c50db78ef5b42d5e758c62040612f3ead5262e 100644 --- a/examples/agent/step_wise_openai.ts +++ b/examples/agent/step_wise_openai.ts @@ -59,7 +59,6 @@ async function main() { // Create an OpenAIAgent with the function tools const agent = new OpenAIAgent({ tools: [functionTool, functionTool2], - verbose: true, }); // Create a task to sum and divide numbers diff --git a/examples/agent/step_wise_query_tool.ts b/examples/agent/step_wise_query_tool.ts index 4b8d4db19263b1b66941370f2cbfb6d6be73d988..b098f1a4e2689571c43850d9186b316509e00c3d 100644 --- a/examples/agent/step_wise_query_tool.ts +++ b/examples/agent/step_wise_query_tool.ts @@ -29,7 +29,6 @@ async function main() { // Create an OpenAIAgent with the function tools const agent = new OpenAIAgent({ tools: [queryEngineTool], - verbose: true, }); const task = agent.createTask("What was his salary?"); diff --git a/examples/agent/step_wise_react.ts b/examples/agent/step_wise_react.ts index 4e9280ceae4a137a729b2e288e645257acbcd5da..81f2f87b8568cfde8ac198cc2783d203870b2795 100644 --- a/examples/agent/step_wise_react.ts +++ b/examples/agent/step_wise_react.ts @@ -59,7 +59,6 @@ async function main() { // Create an OpenAIAgent with the function tools const agent = new ReActAgent({ tools: [functionTool, functionTool2], - verbose: true, }); const task = agent.createTask("Divide 16 by 2 then add 20"); diff --git a/examples/agent/stream_openai_agent.ts b/examples/agent/stream_openai_agent.ts index 2190370640dddc6c56a0a654964f59f7a8754e2d..3d942ad3ddc5253c276f2b503665e2a538aa6b6b 100644 --- a/examples/agent/stream_openai_agent.ts +++ b/examples/agent/stream_openai_agent.ts @@ -59,7 +59,6 @@ async function main() { // Create an OpenAIAgent with the function tools const agent = new OpenAIAgent({ tools: [functionTool, functionTool2], - verbose: false, }); const stream = await agent.chat({ diff --git a/examples/agent/wiki.ts b/examples/agent/wiki.ts index db308db5748bfbda04ac27e3494a1e5e2ed15b58..65c00a3fe540e6f5f1eee257d1e0b271f71b9583 100644 --- a/examples/agent/wiki.ts +++ b/examples/agent/wiki.ts @@ -8,7 +8,6 @@ async function main() { const agent = new OpenAIAgent({ llm, tools: [wikiTool], - verbose: true, }); // Chat with the agent diff --git a/examples/toolsStream.ts b/examples/toolsStream.ts index b59114dc99189b962500302fd25df1ed108a04b1..10e8400df776f17f13f8b2d670af885accb47bc1 100644 --- a/examples/toolsStream.ts +++ b/examples/toolsStream.ts @@ -1,9 +1,12 @@ -import { ChatResponseChunk, LLMChatParamsBase, OpenAI } from "llamaindex"; +import { ChatResponseChunk, OpenAI } from "llamaindex"; async function main() { const llm = new OpenAI({ model: "gpt-4-turbo-preview" }); - const args: LLMChatParamsBase = { + const args: Parameters<typeof llm.chat>[0] = { + additionalChatOptions: { + tool_choice: "auto", + }, messages: [ { content: "Who was Goethe?", @@ -12,8 +15,7 @@ async function main() { ], tools: [ { - type: "function", - function: { + metadata: { name: "wikipedia_tool", description: "A tool that uses a query engine to search Wikipedia.", parameters: { @@ -29,7 +31,6 @@ async function main() { }, }, ], - toolChoice: "auto", }; const stream = await llm.chat({ ...args, stream: true }); diff --git a/packages/core/src/Settings.ts b/packages/core/src/Settings.ts index e2b2f4c0eeca33610e6819dd8ef93f9e2d8309cb..bf30f3f4ddadf4c909915022de076367fa83ed65 100644 --- a/packages/core/src/Settings.ts +++ b/packages/core/src/Settings.ts @@ -5,7 +5,7 @@ import { OpenAI } from "./llm/open_ai.js"; import { PromptHelper } from "./PromptHelper.js"; import { SimpleNodeParser } from "./nodeParsers/SimpleNodeParser.js"; -import { AsyncLocalStorage } from "@llamaindex/env"; +import { AsyncLocalStorage, getEnv } from "@llamaindex/env"; import type { ServiceContext } from "./ServiceContext.js"; import type { BaseEmbedding } from "./embeddings/types.js"; import { @@ -52,6 +52,15 @@ class GlobalSettings implements Config { #chunkOverlapAsyncLocalStorage = new AsyncLocalStorage<number>(); #promptAsyncLocalStorage = new AsyncLocalStorage<PromptConfig>(); + get debug() { + const debug = getEnv("DEBUG"); + return ( + getEnv("NODE_ENV") === "development" && + Boolean(debug) && + debug?.includes("llamaindex") + ); + } + get llm(): LLM { if (this.#llm === null) { this.#llm = new OpenAI(); diff --git a/packages/core/src/agent/openai/base.ts b/packages/core/src/agent/openai/base.ts index dee13e94fa6254c398a62a1bc8ac5ac18e90d4d8..795792158d7b2273db70beaec780be5d410eb70a 100644 --- a/packages/core/src/agent/openai/base.ts +++ b/packages/core/src/agent/openai/base.ts @@ -10,7 +10,6 @@ type OpenAIAgentParams = { llm?: OpenAI; memory?: any; prefixMessages?: ChatMessage[]; - verbose?: boolean; maxFunctionCalls?: number; defaultToolChoice?: string; toolRetriever?: ObjectRetriever; @@ -28,7 +27,6 @@ export class OpenAIAgent extends AgentRunner { llm, memory, prefixMessages, - verbose, maxFunctionCalls = 5, defaultToolChoice = "auto", toolRetriever, @@ -59,7 +57,6 @@ export class OpenAIAgent extends AgentRunner { prefixMessages, maxFunctionCalls, toolRetriever, - verbose, }); super({ diff --git a/packages/core/src/agent/openai/utils.ts b/packages/core/src/agent/openai/utils.ts deleted file mode 100644 index 5f291f44a60c5d0dc63a68f6c4b6383a06dc8483..0000000000000000000000000000000000000000 --- a/packages/core/src/agent/openai/utils.ts +++ /dev/null @@ -1,27 +0,0 @@ -import type { ToolMetadata } from "../../types.js"; - -export type OpenAIFunction = { - type: "function"; - function: ToolMetadata; -}; - -type OpenAiTool = { - name: string; - description: string; - parameters: ToolMetadata["parameters"]; -}; - -export const toOpenAiTool = ({ - name, - description, - parameters, -}: OpenAiTool): OpenAIFunction => { - return { - type: "function", - function: { - name: name, - description: description, - parameters, - }, - }; -}; diff --git a/packages/core/src/agent/openai/worker.ts b/packages/core/src/agent/openai/worker.ts index e978a0dab8d1636a1060a56d1b0fe29dfe93ac79..2c48cc67c75be2023833a8bdbfde5804b81cb78e 100644 --- a/packages/core/src/agent/openai/worker.ts +++ b/packages/core/src/agent/openai/worker.ts @@ -1,17 +1,20 @@ import { randomUUID } from "@llamaindex/env"; +import type { ChatCompletionToolChoiceOption } from "openai/resources/chat/completions"; import { Response } from "../../Response.js"; +import { Settings } from "../../Settings.js"; import { AgentChatResponse, ChatResponseMode, StreamingAgentChatResponse, } from "../../engines/chat/types.js"; -import type { - ChatMessage, - ChatResponse, - ChatResponseChunk, - LLMChatParamsBase, +import { + OpenAI, + isFunctionCallingModel, + type ChatMessage, + type ChatResponseChunk, + type LLMChatParamsBase, + type OpenAIAdditionalChatOptions, } from "../../llm/index.js"; -import { OpenAI } from "../../llm/index.js"; import { streamConverter, streamReducer } from "../../llm/utils.js"; import { ChatMemoryBuffer } from "../../memory/ChatMemoryBuffer.js"; import type { ObjectRetriever } from "../../objects/base.js"; @@ -22,28 +25,17 @@ import type { AgentWorker, Task } from "../types.js"; import { TaskStep, TaskStepOutput } from "../types.js"; import { addUserStepToMemory, getFunctionByName } from "../utils.js"; import type { OpenAIToolCall } from "./types/chat.js"; -import { toOpenAiTool } from "./utils.js"; - -const DEFAULT_MAX_FUNCTION_CALLS = 5; -/** - * Call function. - * @param tools: tools - * @param toolCall: tool call - * @param verbose: verbose - * @returns: void - */ async function callFunction( tools: BaseTool[], toolCall: OpenAIToolCall, - verbose: boolean = false, ): Promise<[ChatMessage, ToolOutput]> { const id_ = toolCall.id; const functionCall = toolCall.function; const name = toolCall.function.name; const argumentsStr = toolCall.function.arguments; - if (verbose) { + if (Settings.debug) { console.log("=== Calling Function ==="); console.log(`Calling function: ${name} with args: ${argumentsStr}`); } @@ -55,7 +47,7 @@ async function callFunction( // Use default error message const output = await callToolWithErrorHandling(tool, argumentDict, null); - if (verbose) { + if (Settings.debug) { console.log(`Got output ${output}`); console.log("=========================="); } @@ -77,7 +69,6 @@ type OpenAIAgentWorkerParams = { tools?: BaseTool[]; llm?: OpenAI; prefixMessages?: ChatMessage[]; - verbose?: boolean; maxFunctionCalls?: number; toolRetriever?: ObjectRetriever; }; @@ -87,40 +78,40 @@ type CallFunctionOutput = { toolOutput: ToolOutput; }; -/** - * OpenAI agent worker. - * This class is responsible for running the agent. - */ -export class OpenAIAgentWorker implements AgentWorker { +export class OpenAIAgentWorker + implements AgentWorker<LLMChatParamsBase<OpenAIAdditionalChatOptions>> +{ private llm: OpenAI; - private verbose: boolean; - private maxFunctionCalls: number; + private maxFunctionCalls: number = 5; public prefixMessages: ChatMessage[]; private _getTools: (input: string) => Promise<BaseTool[]>; - /** - * Initialize. - */ constructor({ tools = [], llm, prefixMessages, - verbose, - maxFunctionCalls = DEFAULT_MAX_FUNCTION_CALLS, + maxFunctionCalls, toolRetriever, }: OpenAIAgentWorkerParams) { - this.llm = llm ?? new OpenAI({ model: "gpt-3.5-turbo-0613" }); - this.verbose = verbose || false; - this.maxFunctionCalls = maxFunctionCalls; + this.llm = + llm ?? isFunctionCallingModel(Settings.llm) + ? (Settings.llm as OpenAI) + : new OpenAI({ + model: "gpt-3.5-turbo-0613", + }); + if (maxFunctionCalls) { + this.maxFunctionCalls = maxFunctionCalls; + } this.prefixMessages = prefixMessages || []; - if (tools.length > 0 && toolRetriever) { + if (Array.isArray(tools) && tools.length > 0 && toolRetriever) { throw new Error("Cannot specify both tools and tool_retriever"); - } else if (tools.length > 0) { + } else if (Array.isArray(tools)) { this._getTools = async () => tools; } else if (toolRetriever) { + // fixme: this won't work, type mismatch this._getTools = async (message: string) => toolRetriever.retrieve(message); } else { @@ -128,11 +119,6 @@ export class OpenAIAgentWorker implements AgentWorker { } } - /** - * Get all messages. - * @param task: task - * @returns: messages - */ public getAllMessages(task: Task): ChatMessage[] { return [ ...this.prefixMessages, @@ -141,11 +127,6 @@ export class OpenAIAgentWorker implements AgentWorker { ]; } - /** - * Get latest tool calls. - * @param task: task - * @returns: tool calls - */ public getLatestToolCalls(task: Task): OpenAIToolCall[] | null { const chatHistory: ChatMessage[] = task.extraState.newMemory.getAll(); @@ -156,28 +137,23 @@ export class OpenAIAgentWorker implements AgentWorker { return chatHistory[chatHistory.length - 1].additionalKwargs?.toolCalls; } - /** - * - * @param task - * @param openaiTools - * @param toolChoice - * @returns - */ - private _getLlmChatKwargs( + private _getLlmChatParams( task: Task, - openaiTools: { [key: string]: any }[], - toolChoice: string | { [key: string]: any } = "auto", - ): LLMChatParamsBase { - const llmChatKwargs: LLMChatParamsBase = { + openaiTools: BaseTool[], + toolChoice: ChatCompletionToolChoiceOption = "auto", + ): LLMChatParamsBase<OpenAIAdditionalChatOptions> { + const llmChatParams = { messages: this.getAllMessages(task), - }; + tools: [] as BaseTool[], + additionalChatOptions: {} as OpenAIAdditionalChatOptions, + } satisfies LLMChatParamsBase<OpenAIAdditionalChatOptions>; if (openaiTools.length > 0) { - llmChatKwargs.tools = openaiTools; - llmChatKwargs.toolChoice = toolChoice; + llmChatParams.tools = openaiTools; + llmChatParams.additionalChatOptions.tool_choice = toolChoice; } - return llmChatKwargs; + return llmChatParams; } private _processMessage( @@ -191,11 +167,11 @@ export class OpenAIAgentWorker implements AgentWorker { private async _getStreamAiResponse( task: Task, - llmChatKwargs: any, + llmChatParams: LLMChatParamsBase<OpenAIAdditionalChatOptions>, ): Promise<StreamingAgentChatResponse | AgentChatResponse> { const stream = await this.llm.chat({ stream: true, - ...llmChatKwargs, + ...llmChatParams, }); // read first chunk from stream to find out if we need to call tools const iterator = stream[Symbol.asyncIterator](); @@ -233,43 +209,28 @@ export class OpenAIAgentWorker implements AgentWorker { return new StreamingAgentChatResponse(newStream, task.extraState.sources); } - /** - * Get agent response. - * @param task: task - * @param mode: mode - * @param llmChatKwargs: llm chat kwargs - * @returns: agent chat response - */ private async _getAgentResponse( task: Task, mode: ChatResponseMode, - llmChatKwargs: any, + llmChatParams: LLMChatParamsBase<OpenAIAdditionalChatOptions>, ): Promise<AgentChatResponse | StreamingAgentChatResponse> { if (mode === ChatResponseMode.WAIT) { - const chatResponse = (await this.llm.chat({ + const chatResponse = await this.llm.chat({ stream: false, - ...llmChatKwargs, - })) as unknown as ChatResponse; + ...llmChatParams, + }); return this._processMessage( task, chatResponse.message, ) as AgentChatResponse; } else if (mode === ChatResponseMode.STREAM) { - return this._getStreamAiResponse(task, llmChatKwargs); + return this._getStreamAiResponse(task, llmChatParams); } throw new Error("Invalid mode"); } - /** - * Call function. - * @param tools: tools - * @param toolCall: tool call - * @param memory: memory - * @param sources: sources - * @returns: void - */ async callFunction( tools: BaseTool[], toolCall: OpenAIToolCall, @@ -280,7 +241,7 @@ export class OpenAIAgentWorker implements AgentWorker { throw new Error("Invalid tool_call object"); } - const functionMessage = await callFunction(tools, toolCall, this.verbose); + const functionMessage = await callFunction(tools, toolCall); const message = functionMessage[0]; const toolOutput = functionMessage[1]; @@ -291,13 +252,7 @@ export class OpenAIAgentWorker implements AgentWorker { }; } - /** - * Initialize step. - * @param task: task - * @param kwargs: kwargs - * @returns: task step - */ - initializeStep(task: Task, kwargs?: any): TaskStep { + initializeStep(task: Task): TaskStep { const sources: ToolOutput[] = []; const newMemory = new ChatMemoryBuffer({ @@ -318,12 +273,6 @@ export class OpenAIAgentWorker implements AgentWorker { return new TaskStep(task.taskId, randomUUID(), task.input); } - /** - * Should continue. - * @param toolCalls: tool calls - * @param nFunctionCalls: number of function calls - * @returns: boolean - */ private _shouldContinue( toolCalls: OpenAIToolCall[] | null, nFunctionCalls: number, @@ -339,11 +288,6 @@ export class OpenAIAgentWorker implements AgentWorker { return true; } - /** - * Get tools. - * @param input: input - * @returns: tools - */ async getTools(input: string): Promise<BaseTool[]> { return this._getTools(input); } @@ -352,28 +296,20 @@ export class OpenAIAgentWorker implements AgentWorker { step: TaskStep, task: Task, mode: ChatResponseMode = ChatResponseMode.WAIT, - toolChoice: string | { [key: string]: any } = "auto", + toolChoice: ChatCompletionToolChoiceOption = "auto", ): Promise<TaskStepOutput> { const tools = await this.getTools(task.input); if (step.input) { - addUserStepToMemory(step, task.extraState.newMemory, this.verbose); + addUserStepToMemory(step, task.extraState.newMemory); } - const openaiTools = tools.map((tool) => - toOpenAiTool({ - name: tool.metadata.name, - description: tool.metadata.description, - parameters: tool.metadata.parameters, - }), - ); - - const llmChatKwargs = this._getLlmChatKwargs(task, openaiTools, toolChoice); + const llmChatParams = this._getLlmChatParams(task, tools, toolChoice); const agentChatResponse = await this._getAgentResponse( task, mode, - llmChatKwargs, + llmChatParams, ); const latestToolCalls = this.getLatestToolCalls(task) || []; @@ -406,45 +342,25 @@ export class OpenAIAgentWorker implements AgentWorker { return new TaskStepOutput(agentChatResponse, step, newSteps, isDone); } - /** - * Run step. - * @param step: step - * @param task: task - * @param kwargs: kwargs - * @returns: task step output - */ async runStep( step: TaskStep, task: Task, - kwargs?: any, + chatParams: LLMChatParamsBase<OpenAIAdditionalChatOptions>, ): Promise<TaskStepOutput> { - const toolChoice = kwargs?.toolChoice || "auto"; + const toolChoice = chatParams?.additionalChatOptions?.tool_choice ?? "auto"; return this._runStep(step, task, ChatResponseMode.WAIT, toolChoice); } - /** - * Stream step. - * @param step: step - * @param task: task - * @param kwargs: kwargs - * @returns: task step output - */ async streamStep( step: TaskStep, task: Task, - kwargs?: any, + chatParams: LLMChatParamsBase<OpenAIAdditionalChatOptions>, ): Promise<TaskStepOutput> { - const toolChoice = kwargs?.toolChoice || "auto"; + const toolChoice = chatParams?.additionalChatOptions?.tool_choice ?? "auto"; return this._runStep(step, task, ChatResponseMode.STREAM, toolChoice); } - /** - * Finalize task. - * @param task: task - * @param kwargs: kwargs - * @returns: void - */ - finalizeTask(task: Task, kwargs?: any): void { + finalizeTask(task: Task): void { task.memory.set(task.memory.get().concat(task.extraState.newMemory.get())); task.extraState.newMemory.reset(); } diff --git a/packages/core/src/agent/react/base.ts b/packages/core/src/agent/react/base.ts index 59cd0d9eb960875b524b4325e8066319c51b5d5b..57b0da518aafcc794188896d9a384056bcec69ad 100644 --- a/packages/core/src/agent/react/base.ts +++ b/packages/core/src/agent/react/base.ts @@ -9,7 +9,6 @@ type ReActAgentParams = { llm?: LLM; memory?: any; prefixMessages?: ChatMessage[]; - verbose?: boolean; maxInteractions?: number; defaultToolChoice?: string; toolRetriever?: ObjectRetriever; @@ -26,7 +25,6 @@ export class ReActAgent extends AgentRunner { llm, memory, prefixMessages, - verbose, maxInteractions = 10, defaultToolChoice = "auto", toolRetriever, @@ -36,7 +34,6 @@ export class ReActAgent extends AgentRunner { llm, maxInteractions, toolRetriever, - verbose, }); super({ diff --git a/packages/core/src/agent/react/worker.ts b/packages/core/src/agent/react/worker.ts index 36c9ec687e3747ec2ce834a7d096afa8d14d61a5..fcd1252d37c2903f9cbc0991ae051cf2a01343c8 100644 --- a/packages/core/src/agent/react/worker.ts +++ b/packages/core/src/agent/react/worker.ts @@ -1,7 +1,8 @@ import { randomUUID } from "@llamaindex/env"; +import type { ChatMessage } from "cohere-ai/api"; +import { Settings } from "../../Settings.js"; import { AgentChatResponse } from "../../engines/chat/index.js"; -import type { ChatResponse, LLM } from "../../llm/index.js"; -import { OpenAI } from "../../llm/index.js"; +import { type ChatResponse, type LLM } from "../../llm/index.js"; import { ChatMemoryBuffer } from "../../memory/ChatMemoryBuffer.js"; import type { ObjectRetriever } from "../../objects/base.js"; import { ToolOutput } from "../../tools/index.js"; @@ -16,28 +17,20 @@ import { ObservationReasoningStep, ResponseReasoningStep, } from "./types.js"; + type ReActAgentWorkerParams = { tools: BaseTool[]; llm?: LLM; maxInteractions?: number; reactChatFormatter?: ReActChatFormatter | undefined; outputParser?: ReActOutputParser | undefined; - verbose?: boolean | undefined; toolRetriever?: ObjectRetriever | undefined; }; -/** - * - * @param step - * @param memory - * @param currentReasoning - * @param verbose - */ function addUserStepToReasoning( step: TaskStep, memory: ChatMemoryBuffer, currentReasoning: BaseReasoningStep[], - verbose: boolean = false, ): void { if (step.stepState.isFirst) { memory.put({ @@ -50,18 +43,22 @@ function addUserStepToReasoning( observation: step.input ?? undefined, }); currentReasoning.push(reasoningStep); - if (verbose) { + if (Settings.debug) { console.log(`Added user message to memory: ${step.input}`); } } } +type ChatParams = { + messages: ChatMessage[]; + tools?: BaseTool[]; +}; + /** * ReAct agent worker. */ -export class ReActAgentWorker implements AgentWorker { +export class ReActAgentWorker implements AgentWorker<ChatParams> { llm: LLM; - verbose: boolean; maxInteractions: number = 10; reactChatFormatter: ReActChatFormatter; @@ -75,15 +72,13 @@ export class ReActAgentWorker implements AgentWorker { maxInteractions, reactChatFormatter, outputParser, - verbose, toolRetriever, }: ReActAgentWorkerParams) { - this.llm = llm ?? new OpenAI({ model: "gpt-3.5-turbo-0613" }); + this.llm = llm ?? Settings.llm; this.maxInteractions = maxInteractions ?? 10; this.reactChatFormatter = reactChatFormatter ?? new ReActChatFormatter(); this.outputParser = outputParser ?? new ReActOutputParser(); - this.verbose = verbose || false; if (tools.length > 0 && toolRetriever) { throw new Error("Cannot specify both tools and tool_retriever"); @@ -97,13 +92,7 @@ export class ReActAgentWorker implements AgentWorker { } } - /** - * Initialize a task step. - * @param task - task - * @param kwargs - keyword arguments - * @returns - task step - */ - initializeStep(task: Task, kwargs?: any): TaskStep { + initializeStep(task: Task): TaskStep { const sources: ToolOutput[] = []; const currentReasoning: BaseReasoningStep[] = []; const newMemory = new ChatMemoryBuffer({ @@ -126,12 +115,6 @@ export class ReActAgentWorker implements AgentWorker { }); } - /** - * Extract reasoning step from chat response. - * @param output - chat response - * @param isStreaming - whether the chat response is streaming - * @returns - [message content, reasoning steps, is done] - */ extractReasoningStep( output: ChatResponse, isStreaming: boolean, @@ -154,7 +137,7 @@ export class ReActAgentWorker implements AgentWorker { throw new Error(`Could not parse output: ${e}`); } - if (this.verbose) { + if (Settings.debug) { console.log(`${reasoningStep.getContent()}\n`); } @@ -177,14 +160,6 @@ export class ReActAgentWorker implements AgentWorker { return [messageContent, currentReasoning, false]; } - /** - * Process actions. - * @param task - task - * @param tools - tools - * @param output - chat response - * @param isStreaming - whether the chat response is streaming - * @returns - [reasoning steps, is done] - */ async _processActions( task: Task, tools: BaseTool[], @@ -235,19 +210,13 @@ export class ReActAgentWorker implements AgentWorker { currentReasoning.push(observationStep); - if (this.verbose) { + if (Settings.debug) { console.log(`${observationStep.getContent()}`); } return [currentReasoning, false]; } - /** - * Get response. - * @param currentReasoning - current reasoning steps - * @param sources - tool outputs - * @returns - agent chat response - */ _getResponse( currentReasoning: BaseReasoningStep[], sources: ToolOutput[], @@ -271,13 +240,6 @@ export class ReActAgentWorker implements AgentWorker { return new AgentChatResponse(responseStr, sources); } - /** - * Get task step response. - * @param agentResponse - agent chat response - * @param step - task step - * @param isDone - whether the task is done - * @returns - task step output - */ _getTaskStepResponse( agentResponse: AgentChatResponse, step: TaskStep, @@ -294,24 +256,12 @@ export class ReActAgentWorker implements AgentWorker { return new TaskStepOutput(agentResponse, step, newSteps, isDone); } - /** - * Run a task step. - * @param step - task step - * @param task - task - * @param kwargs - keyword arguments - * @returns - task step output - */ - async _runStep( - step: TaskStep, - task: Task, - kwargs?: any, - ): Promise<TaskStepOutput> { + async _runStep(step: TaskStep, task: Task): Promise<TaskStepOutput> { if (step.input) { addUserStepToReasoning( step, task.extraState.newMemory, task.extraState.currentReasoning, - this.verbose, ); } @@ -350,42 +300,15 @@ export class ReActAgentWorker implements AgentWorker { return this._getTaskStepResponse(agentResponse, step, isDone); } - /** - * Run a task step. - * @param step - task step - * @param task - task - * @param kwargs - keyword arguments - * @returns - task step output - */ - async runStep( - step: TaskStep, - task: Task, - kwargs?: any, - ): Promise<TaskStepOutput> { + async runStep(step: TaskStep, task: Task): Promise<TaskStepOutput> { return await this._runStep(step, task); } - /** - * Run a task step. - * @param step - task step - * @param task - task - * @param kwargs - keyword arguments - * @returns - task step output - */ - streamStep( - step: TaskStep, - task: Task, - kwargs?: any, - ): Promise<TaskStepOutput> { + streamStep(): Promise<TaskStepOutput> { throw new Error("Method not implemented."); } - /** - * Finalize a task. - * @param task - task - * @param kwargs - keyword arguments - */ - finalizeTask(task: Task, kwargs?: any): void { + finalizeTask(task: Task): void { task.memory.set(task.memory.get() + task.extraState.newMemory.get()); task.extraState.newMemory.reset(); } diff --git a/packages/core/src/agent/types.ts b/packages/core/src/agent/types.ts index 6be1f8688a020638336676ad020c073a51eafc88..d81ba8ebfa7c38c3ccbcd5916ced01ca1711417e 100644 --- a/packages/core/src/agent/types.ts +++ b/packages/core/src/agent/types.ts @@ -6,11 +6,19 @@ import type { import type { QueryEngineParamsNonStreaming } from "../types.js"; -export interface AgentWorker { - initializeStep(task: Task, kwargs?: any): TaskStep; - runStep(step: TaskStep, task: Task, kwargs?: any): Promise<TaskStepOutput>; - streamStep(step: TaskStep, task: Task, kwargs?: any): Promise<TaskStepOutput>; - finalizeTask(task: Task, kwargs?: any): void; +export interface AgentWorker<ExtraParams extends object = object> { + initializeStep(task: Task, params?: ExtraParams): TaskStep; + runStep( + step: TaskStep, + task: Task, + params?: ExtraParams, + ): Promise<TaskStepOutput>; + streamStep( + step: TaskStep, + task: Task, + params?: ExtraParams, + ): Promise<TaskStepOutput>; + finalizeTask(task: Task, params?: ExtraParams): void; } interface BaseChatEngine { diff --git a/packages/core/src/agent/utils.ts b/packages/core/src/agent/utils.ts index 858306856ee1b528a3d0a44ceb60230079692709..a22c882ec0286b4f281f0b7ebd858711d05e48c0 100644 --- a/packages/core/src/agent/utils.ts +++ b/packages/core/src/agent/utils.ts @@ -1,19 +1,12 @@ +import { Settings } from "../Settings.js"; import type { ChatMessage } from "../llm/index.js"; import type { ChatMemoryBuffer } from "../memory/ChatMemoryBuffer.js"; import type { BaseTool } from "../types.js"; import type { TaskStep } from "./types.js"; -/** - * Adds the user's input to the memory. - * - * @param step - The step to add to the memory. - * @param memory - The memory to add the step to. - * @param verbose - Whether to print debug messages. - */ export function addUserStepToMemory( step: TaskStep, memory: ChatMemoryBuffer, - verbose: boolean = false, ): void { if (!step.input) { return; @@ -26,26 +19,17 @@ export function addUserStepToMemory( memory.put(userMessage); - if (verbose) { + if (Settings.debug) { console.log(`Added user message to memory!: ${userMessage.content}`); } } -/** - * Get function by name. - * @param tools: tools - * @param name: name - * @returns: tool - */ export function getFunctionByName(tools: BaseTool[], name: string): BaseTool { - const nameToTool: { [key: string]: BaseTool } = {}; - tools.forEach((tool) => { - nameToTool[tool.metadata.name] = tool; - }); + const exist = tools.find((tool) => tool.metadata.name === name); - if (!(name in nameToTool)) { + if (!exist) { throw new Error(`Tool with name ${name} not found`); } - return nameToTool[name]; + return exist; } diff --git a/packages/core/src/llm/LLM.ts b/packages/core/src/llm/LLM.ts index 8b68f7466a36a0d47f13c96adbb5528cfba0270c..f88f7da15a95d7a6b8c28a22f2b2d93b02d008fa 100644 --- a/packages/core/src/llm/LLM.ts +++ b/packages/core/src/llm/LLM.ts @@ -484,11 +484,11 @@ export class Portkey extends BaseLLM { async chat( params: LLMChatParamsNonStreaming | LLMChatParamsStreaming, ): Promise<ChatResponse | AsyncIterable<ChatResponseChunk>> { - const { messages, stream, extraParams } = params; + const { messages, stream, additionalChatOptions } = params; if (stream) { - return this.streamChat(messages, extraParams); + return this.streamChat(messages, additionalChatOptions); } else { - const bodyParams = extraParams || {}; + const bodyParams = additionalChatOptions || {}; const response = await this.session.portkey.chatCompletions.create({ messages, ...bodyParams, diff --git a/packages/core/src/llm/base.ts b/packages/core/src/llm/base.ts index 04854c2049791f9555bfcd53fbc0e989e35c265b..d67cdbb5b73c82c8adb75ab5a101a1f70402bec2 100644 --- a/packages/core/src/llm/base.ts +++ b/packages/core/src/llm/base.ts @@ -11,7 +11,13 @@ import type { } from "./types.js"; import { streamConverter } from "./utils.js"; -export abstract class BaseLLM implements LLM { +export abstract class BaseLLM< + AdditionalChatOptions extends Record<string, unknown> = Record< + string, + unknown + >, +> implements LLM<AdditionalChatOptions> +{ abstract metadata: LLMMetadata; complete( @@ -42,7 +48,9 @@ export abstract class BaseLLM implements LLM { } abstract chat( - params: LLMChatParamsStreaming, + params: LLMChatParamsStreaming<AdditionalChatOptions>, ): Promise<AsyncIterable<ChatResponseChunk>>; - abstract chat(params: LLMChatParamsNonStreaming): Promise<ChatResponse>; + abstract chat( + params: LLMChatParamsNonStreaming<AdditionalChatOptions>, + ): Promise<ChatResponse>; } diff --git a/packages/core/src/llm/open_ai.ts b/packages/core/src/llm/open_ai.ts index c5dff90e2a7a9ceebc71d04b5122fdaf9d4af022..7805f4951207c3716dcabf003792b29911f6fc1c 100644 --- a/packages/core/src/llm/open_ai.ts +++ b/packages/core/src/llm/open_ai.ts @@ -7,10 +7,12 @@ import type { } from "openai"; import { OpenAI as OrigOpenAI } from "openai"; +import type { ChatCompletionTool } from "openai/resources/chat/completions"; import type { ChatCompletionMessageParam } from "openai/resources/index.js"; import { Tokenizers } from "../GlobalsHelper.js"; import { wrapEventCaller } from "../internal/context/EventCaller.js"; import { getCallbackManager } from "../internal/settings/CallbackManager.js"; +import type { BaseTool } from "../types.js"; import type { AzureOpenAIConfig } from "./azure.js"; import { getAzureBaseUrl, @@ -23,8 +25,10 @@ import type { ChatMessage, ChatResponse, ChatResponseChunk, + LLM, LLMChatParamsNonStreaming, LLMChatParamsStreaming, + LLMMetadata, MessageToolCall, MessageType, } from "./types.js"; @@ -116,32 +120,43 @@ export const ALL_AVAILABLE_OPENAI_MODELS = { ...GPT35_MODELS, }; -export const isFunctionCallingModel = (model: string): boolean => { +export function isFunctionCallingModel(llm: LLM): llm is OpenAI { + let model: string; + if (llm instanceof OpenAI) { + model = llm.model; + } else if ("model" in llm && typeof llm.model === "string") { + model = llm.model; + } else { + return false; + } const isChatModel = Object.keys(ALL_AVAILABLE_OPENAI_MODELS).includes(model); const isOld = model.includes("0314") || model.includes("0301"); return isChatModel && !isOld; +} + +export type OpenAIAdditionalChatOptions = Omit< + Partial<OpenAILLM.Chat.ChatCompletionCreateParams>, + | "max_tokens" + | "messages" + | "model" + | "temperature" + | "top_p" + | "stream" + | "tools" + | "toolChoice" +>; + +export type OpenAIAdditionalMetadata = { + isFunctionCallingModel: boolean; }; -/** - * OpenAI LLM implementation - */ -export class OpenAI extends BaseLLM { +export class OpenAI extends BaseLLM<OpenAIAdditionalChatOptions> { // Per completion OpenAI params model: keyof typeof ALL_AVAILABLE_OPENAI_MODELS | string; temperature: number; topP: number; maxTokens?: number; - additionalChatOptions?: Omit< - Partial<OpenAILLM.Chat.ChatCompletionCreateParams>, - | "max_tokens" - | "messages" - | "model" - | "temperature" - | "top_p" - | "stream" - | "tools" - | "toolChoice" - >; + additionalChatOptions?: OpenAIAdditionalChatOptions; // OpenAI session params apiKey?: string = undefined; @@ -206,7 +221,7 @@ export class OpenAI extends BaseLLM { } } - get metadata() { + get metadata(): LLMMetadata & OpenAIAdditionalMetadata { const contextWindow = ALL_AVAILABLE_OPENAI_MODELS[ this.model as keyof typeof ALL_AVAILABLE_OPENAI_MODELS @@ -218,7 +233,7 @@ export class OpenAI extends BaseLLM { maxTokens: this.maxTokens, contextWindow, tokenizer: Tokenizers.CL100K_BASE, - isFunctionCallingModel: isFunctionCallingModel(this.model), + isFunctionCallingModel: isFunctionCallingModel(this), }; } @@ -259,24 +274,27 @@ export class OpenAI extends BaseLLM { } chat( - params: LLMChatParamsStreaming, + params: LLMChatParamsStreaming<OpenAIAdditionalChatOptions>, ): Promise<AsyncIterable<ChatResponseChunk>>; - chat(params: LLMChatParamsNonStreaming): Promise<ChatResponse>; + chat( + params: LLMChatParamsNonStreaming<OpenAIAdditionalChatOptions>, + ): Promise<ChatResponse>; @wrapEventCaller @wrapLLMEvent async chat( - params: LLMChatParamsNonStreaming | LLMChatParamsStreaming, + params: + | LLMChatParamsNonStreaming<OpenAIAdditionalChatOptions> + | LLMChatParamsStreaming<OpenAIAdditionalChatOptions>, ): Promise<ChatResponse | AsyncIterable<ChatResponseChunk>> { - const { messages, stream, tools, toolChoice } = params; + const { messages, stream, tools, additionalChatOptions } = params; const baseRequestParams: OpenAILLM.Chat.ChatCompletionCreateParams = { model: this.model, temperature: this.temperature, max_tokens: this.maxTokens, - tools: tools, - tool_choice: toolChoice, + tools: tools?.map(OpenAI.toTool), messages: this.toOpenAIMessage(messages) as ChatCompletionMessageParam[], top_p: this.topP, - ...this.additionalChatOptions, + ...Object.assign({}, this.additionalChatOptions, additionalChatOptions), }; // Streaming @@ -343,6 +361,17 @@ export class OpenAI extends BaseLLM { } return; } + + static toTool(tool: BaseTool): ChatCompletionTool { + return { + type: "function", + function: { + name: tool.metadata.name, + description: tool.metadata.description, + parameters: tool.metadata.parameters, + }, + }; + } } function updateToolCalls( diff --git a/packages/core/src/llm/types.ts b/packages/core/src/llm/types.ts index c49a5e8b14bb5999b5229acf8208222463bedb41..f8ac3f77a9507dc4c20bf3f773baf301b80c31c0 100644 --- a/packages/core/src/llm/types.ts +++ b/packages/core/src/llm/types.ts @@ -1,4 +1,5 @@ import type { Tokenizers } from "../GlobalsHelper.js"; +import type { BaseTool } from "../types.js"; type LLMBaseEvent< Type extends string, @@ -30,24 +31,35 @@ declare module "llamaindex" { /** * @internal */ -export interface LLMChat { +export interface LLMChat< + ExtraParams extends Record<string, unknown> = Record<string, unknown>, +> { chat( - params: LLMChatParamsStreaming | LLMChatParamsNonStreaming, + params: + | LLMChatParamsStreaming<ExtraParams> + | LLMChatParamsNonStreaming<ExtraParams>, ): Promise<ChatResponse | AsyncIterable<ChatResponseChunk>>; } /** * Unified language model interface */ -export interface LLM extends LLMChat { +export interface LLM< + AdditionalChatOptions extends Record<string, unknown> = Record< + string, + unknown + >, +> extends LLMChat<AdditionalChatOptions> { metadata: LLMMetadata; /** * Get a chat response from the LLM */ chat( - params: LLMChatParamsStreaming, + params: LLMChatParamsStreaming<AdditionalChatOptions>, ): Promise<AsyncIterable<ChatResponseChunk>>; - chat(params: LLMChatParamsNonStreaming): Promise<ChatResponse>; + chat( + params: LLMChatParamsNonStreaming<AdditionalChatOptions>, + ): Promise<ChatResponse>; /** * Get a prompt completion from the LLM @@ -92,29 +104,43 @@ export interface CompletionResponse { raw?: Record<string, any>; } -export interface LLMMetadata { +export type LLMMetadata = { model: string; temperature: number; topP: number; maxTokens?: number; contextWindow: number; tokenizer: Tokenizers | undefined; -} - -export interface LLMChatParamsBase { +}; + +export interface LLMChatParamsBase< + AdditionalChatOptions extends Record<string, unknown> = Record< + string, + unknown + >, +> { messages: ChatMessage[]; - extraParams?: Record<string, any>; - tools?: any; - toolChoice?: any; - additionalKwargs?: Record<string, any>; + additionalChatOptions?: AdditionalChatOptions; + tools?: BaseTool[]; + additionalKwargs?: Record<string, unknown>; } -export interface LLMChatParamsStreaming extends LLMChatParamsBase { +export interface LLMChatParamsStreaming< + AdditionalChatOptions extends Record<string, unknown> = Record< + string, + unknown + >, +> extends LLMChatParamsBase<AdditionalChatOptions> { stream: true; } -export interface LLMChatParamsNonStreaming extends LLMChatParamsBase { - stream?: false | null; +export interface LLMChatParamsNonStreaming< + AdditionalChatOptions extends Record<string, unknown> = Record< + string, + unknown + >, +> extends LLMChatParamsBase<AdditionalChatOptions> { + stream?: false; } export interface LLMCompletionParamsBase { diff --git a/packages/core/tests/CallbackManager.test.ts b/packages/core/tests/CallbackManager.test.ts index 3e2f2749e50eb65e6f6cb1aeed2fd32825d9faec..6788c90472086d9106e4202468053519aec543d1 100644 --- a/packages/core/tests/CallbackManager.test.ts +++ b/packages/core/tests/CallbackManager.test.ts @@ -48,6 +48,7 @@ describe("CallbackManager: onLLMStream and onRetrieve", () => { const languageModel = new OpenAI({ model: "gpt-3.5-turbo", }); + mockLlmGeneration({ languageModel, callbackManager }); const embedModel = new OpenAIEmbedding(); diff --git a/packages/core/tests/MetadataExtractors.test.ts b/packages/core/tests/MetadataExtractors.test.ts index f9337b3b0c32abacb5f63f892d1d1a315e3540e8..e71dd9674370b6ce974c63747f5228cec00cb488 100644 --- a/packages/core/tests/MetadataExtractors.test.ts +++ b/packages/core/tests/MetadataExtractors.test.ts @@ -1,3 +1,4 @@ +import { Settings } from "llamaindex"; import { Document } from "llamaindex/Node"; import type { ServiceContext } from "llamaindex/ServiceContext"; import { serviceContextFromDefaults } from "llamaindex/ServiceContext"; @@ -25,6 +26,8 @@ describe("[MetadataExtractor]: Extractors should populate the metadata", () => { model: "gpt-3.5-turbo", }); + Settings.llm = languageModel; + mockLlmGeneration({ languageModel }); const embedModel = new OpenAIEmbedding(); diff --git a/packages/core/tests/agent/OpenAIAgent.test.ts b/packages/core/tests/agent/OpenAIAgent.test.ts index 8180464a3d38b667e531101a0b9f621bdb45477e..0189e0158ff27a5e2fb3fbea586213a6dcadb9cc 100644 --- a/packages/core/tests/agent/OpenAIAgent.test.ts +++ b/packages/core/tests/agent/OpenAIAgent.test.ts @@ -1,3 +1,4 @@ +import { Settings } from "llamaindex"; import { OpenAIAgent } from "llamaindex/agent/index"; import { OpenAI } from "llamaindex/llm/index"; import { FunctionTool } from "llamaindex/tools/index"; @@ -32,6 +33,8 @@ describe("OpenAIAgent", () => { model: "gpt-3.5-turbo", }); + Settings.llm = languageModel; + mockLlmToolCallGeneration({ languageModel, }); @@ -45,7 +48,6 @@ describe("OpenAIAgent", () => { openaiAgent = new OpenAIAgent({ tools: [sumFunctionTool], llm: languageModel, - verbose: false, }); }); diff --git a/packages/core/tests/agent/runner/AgentRunner.test.ts b/packages/core/tests/agent/runner/AgentRunner.test.ts index 95e9430831ae0f80cfb7f6ee4eeca6f58c1274a0..4b8b4af76785aec589a5c68e9dcc2fedc9fe6b4e 100644 --- a/packages/core/tests/agent/runner/AgentRunner.test.ts +++ b/packages/core/tests/agent/runner/AgentRunner.test.ts @@ -3,6 +3,7 @@ import { AgentRunner } from "llamaindex/agent/runner/base"; import { OpenAI } from "llamaindex/llm/open_ai"; import { beforeEach, describe, expect, it } from "vitest"; +import { Settings } from "llamaindex"; import { DEFAULT_LLM_TEXT_OUTPUT, mockLlmGeneration, @@ -12,20 +13,15 @@ describe("Agent Runner", () => { let agentRunner: AgentRunner; beforeEach(() => { - const languageModel = new OpenAI({ + Settings.llm = new OpenAI({ model: "gpt-3.5-turbo", }); - mockLlmGeneration({ - languageModel, - }); + mockLlmGeneration(); agentRunner = new AgentRunner({ - llm: languageModel, agentWorker: new OpenAIAgentWorker({ - llm: languageModel, tools: [], - verbose: false, }), }); }); diff --git a/packages/core/tests/utility/mockOpenAI.ts b/packages/core/tests/utility/mockOpenAI.ts index f90de391d354fde2e1b2d9730c03812a339ea1c0..ddb066bba86d1a1b75a577103bcc7242c472a5fa 100644 --- a/packages/core/tests/utility/mockOpenAI.ts +++ b/packages/core/tests/utility/mockOpenAI.ts @@ -1,6 +1,7 @@ +import { Settings } from "llamaindex"; import type { CallbackManager } from "llamaindex/callbacks/CallbackManager"; import type { OpenAIEmbedding } from "llamaindex/embeddings/index"; -import type { OpenAI } from "llamaindex/llm/open_ai"; +import { OpenAI } from "llamaindex/llm/open_ai"; import type { LLMChatParamsBase } from "llamaindex/llm/types"; import { vi } from "vitest"; @@ -10,9 +11,16 @@ export function mockLlmGeneration({ languageModel, callbackManager, }: { - languageModel: OpenAI; + languageModel?: OpenAI; callbackManager?: CallbackManager; -}) { +} = {}) { + callbackManager = callbackManager || Settings.callbackManager; + if (!languageModel && Settings.llm instanceof OpenAI) { + languageModel = Settings.llm; + } + if (!languageModel) { + return; + } vi.spyOn(languageModel, "chat").mockImplementation( async ({ messages }: LLMChatParamsBase) => { const text = DEFAULT_LLM_TEXT_OUTPUT; diff --git a/packages/core/tests/vitest.setup.ts b/packages/core/tests/vitest.setup.ts index ec1583acd20485819c186b1980ea2cfa7a98ce08..97e4d04e57b70e776902927769c2f1ac9b8e4328 100644 --- a/packages/core/tests/vitest.setup.ts +++ b/packages/core/tests/vitest.setup.ts @@ -12,7 +12,7 @@ globalThis.fetch = function fetch(...args: Parameters<typeof originalFetch>) { } } const parsedUrl = new URL(url); - if (parsedUrl.hostname.includes("api.openai.com")) { + if (parsedUrl.hostname.includes("openai.com")) { // todo: mock api using https://mswjs.io throw new Error( "Make sure to return a mock response for OpenAI API requests in your test.",