diff --git a/.changeset/late-yaks-worry.md b/.changeset/late-yaks-worry.md new file mode 100644 index 0000000000000000000000000000000000000000..653108fd8f9183789207836e00260eaac3d2213d --- /dev/null +++ b/.changeset/late-yaks-worry.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +chore: refactor openai agent utils diff --git a/packages/core/src/agent/openai/base.ts b/packages/core/src/agent/openai/base.ts index 6581f5de024efe44c209b50f2225a232d2545fe3..5b916ebae63208e838da400cdf66e08e9b520354 100644 --- a/packages/core/src/agent/openai/base.ts +++ b/packages/core/src/agent/openai/base.ts @@ -15,6 +15,7 @@ type OpenAIAgentParams = { defaultToolChoice?: string; callbackManager?: CallbackManager; toolRetriever?: ObjectRetriever<BaseTool>; + systemPrompt?: string; }; /** @@ -33,7 +34,29 @@ export class OpenAIAgent extends AgentRunner { defaultToolChoice = "auto", callbackManager, toolRetriever, + systemPrompt, }: OpenAIAgentParams) { + prefixMessages = prefixMessages || []; + + llm = llm ?? new OpenAI({ model: "gpt-3.5-turbo-0613" }); + + if (systemPrompt) { + if (prefixMessages) { + throw new Error("Cannot provide both systemPrompt and prefixMessages"); + } + + prefixMessages = [ + { + content: systemPrompt, + role: "system", + }, + ]; + } + + if (!llm?.metadata.isFunctionCallingModel) { + throw new Error("LLM model must be a function-calling model"); + } + const stepEngine = new OpenAIAgentWorker({ tools, callbackManager, diff --git a/packages/core/src/agent/openai/worker.ts b/packages/core/src/agent/openai/worker.ts index 20176f51941fa14684b97127f24dcd58fa495463..33b6e470f40d72d512ef9df5f2366b46df41615e 100644 --- a/packages/core/src/agent/openai/worker.ts +++ b/packages/core/src/agent/openai/worker.ts @@ -88,9 +88,9 @@ type CallFunctionOutput = { * This class is responsible for running the agent. */ export class OpenAIAgentWorker implements AgentWorker { - private _llm: OpenAI; - private _verbose: boolean; - private _maxFunctionCalls: number; + private llm: OpenAI; + private verbose: boolean; + private maxFunctionCalls: number; public prefixMessages: ChatMessage[]; public callbackManager: CallbackManager | undefined; @@ -109,11 +109,11 @@ export class OpenAIAgentWorker implements AgentWorker { callbackManager, toolRetriever, }: OpenAIAgentWorkerParams) { - this._llm = llm ?? new OpenAI({ model: "gpt-3.5-turbo-0613" }); - this._verbose = verbose || false; - this._maxFunctionCalls = maxFunctionCalls; + this.llm = llm ?? new OpenAI({ model: "gpt-3.5-turbo-0613" }); + this.verbose = verbose || false; + this.maxFunctionCalls = maxFunctionCalls; this.prefixMessages = prefixMessages || []; - this.callbackManager = callbackManager || this._llm.callbackManager; + this.callbackManager = callbackManager || this.llm.callbackManager; if (tools.length > 0 && toolRetriever) { throw new Error("Cannot specify both tools and tool_retriever"); @@ -207,7 +207,7 @@ export class OpenAIAgentWorker implements AgentWorker { llmChatKwargs: any, ): Promise<AgentChatResponse> { if (mode === ChatResponseMode.WAIT) { - const chatResponse = (await this._llm.chat({ + const chatResponse = (await this.llm.chat({ stream: false, ...llmChatKwargs, })) as unknown as ChatResponse; @@ -236,7 +236,7 @@ export class OpenAIAgentWorker implements AgentWorker { throw new Error("Invalid tool_call object"); } - const functionMessage = await callFunction(tools, toolCall, this._verbose); + const functionMessage = await callFunction(tools, toolCall, this.verbose); const message = functionMessage[0]; const toolOutput = functionMessage[1]; @@ -282,7 +282,7 @@ export class OpenAIAgentWorker implements AgentWorker { toolCalls: OpenAIToolCall[] | null, nFunctionCalls: number, ): boolean { - if (nFunctionCalls > this._maxFunctionCalls) { + if (nFunctionCalls > this.maxFunctionCalls) { return false; } @@ -311,7 +311,7 @@ export class OpenAIAgentWorker implements AgentWorker { const tools = await this.getTools(task.input); if (step.input) { - addUserStepToMemory(step, task.extraState.newMemory, this._verbose); + addUserStepToMemory(step, task.extraState.newMemory, this.verbose); } const openaiTools = tools.map((tool) => diff --git a/packages/core/src/llm/LLM.ts b/packages/core/src/llm/LLM.ts index 777d24e6da903c54e42b3e4322c9f50b86a31384..bdba512d99c7249275dec3441b7ec7025e437be9 100644 --- a/packages/core/src/llm/LLM.ts +++ b/packages/core/src/llm/LLM.ts @@ -26,6 +26,7 @@ import { } from "./azure"; import { BaseLLM } from "./base"; import { OpenAISession, getOpenAISession } from "./open_ai"; +import { isFunctionCallingModel } from "./openai/utils"; import { PortkeySession, getPortkeySession } from "./portkey"; import { ReplicateSession } from "./replicate_ai"; import { @@ -166,6 +167,7 @@ export class OpenAI extends BaseLLM { maxTokens: this.maxTokens, contextWindow, tokenizer: Tokenizers.CL100K_BASE, + isFunctionCallingModel: isFunctionCallingModel(this.model), }; } diff --git a/packages/core/src/llm/openai/utils.ts b/packages/core/src/llm/openai/utils.ts new file mode 100644 index 0000000000000000000000000000000000000000..0521c9a410b80fe43c16ee19d7a457022cc8f7b0 --- /dev/null +++ b/packages/core/src/llm/openai/utils.ts @@ -0,0 +1,7 @@ +import { ALL_AVAILABLE_OPENAI_MODELS } from ".."; + +export const isFunctionCallingModel = (model: string): boolean => { + const isChatModel = Object.keys(ALL_AVAILABLE_OPENAI_MODELS).includes(model); + const isOld = model.includes("0314") || model.includes("0301"); + return isChatModel && !isOld; +}; diff --git a/packages/core/src/tests/llms/openai/utils.test.ts b/packages/core/src/tests/llms/openai/utils.test.ts new file mode 100644 index 0000000000000000000000000000000000000000..95eb3bc03eddf74653b538f5104a180d906c5b92 --- /dev/null +++ b/packages/core/src/tests/llms/openai/utils.test.ts @@ -0,0 +1,27 @@ +import { ALL_AVAILABLE_OPENAI_MODELS } from "../../../llm"; +import { isFunctionCallingModel } from "../../../llm/openai/utils"; + +describe("openai/utils", () => { + test("shouldn't be a old model", () => { + expect(isFunctionCallingModel("gpt-3.5-turbo")).toBe(true); + expect(isFunctionCallingModel("gpt-3.5-turbo-0314")).toBe(false); + expect(isFunctionCallingModel("gpt-3.5-turbo-0301")).toBe(false); + expect(isFunctionCallingModel("gpt-3.5-turbo-0314-0301")).toBe(false); + expect(isFunctionCallingModel("gpt-3.5-turbo-0314-0301-0314-0301")).toBe( + false, + ); + expect( + isFunctionCallingModel("gpt-3.5-turbo-0314-0301-0314-0301-0314-0301"), + ).toBe(false); + }); + + test("should be a open ai model", () => { + const models = Object.keys(ALL_AVAILABLE_OPENAI_MODELS).filter( + (model) => !model.includes("0314") && !model.includes("0301"), + ); + + models.forEach((model) => { + expect(isFunctionCallingModel(model)).toBe(true); + }); + }); +});