Skip to content
Snippets Groups Projects
Unverified Commit ee9f3f37 authored by Emanuel Ferreira's avatar Emanuel Ferreira Committed by GitHub
Browse files

refactor: openai agent and utils (#542)

parent f2053585
No related branches found
No related tags found
No related merge requests found
---
"llamaindex": patch
---
chore: refactor openai agent utils
......@@ -15,6 +15,7 @@ type OpenAIAgentParams = {
defaultToolChoice?: string;
callbackManager?: CallbackManager;
toolRetriever?: ObjectRetriever<BaseTool>;
systemPrompt?: string;
};
/**
......@@ -33,7 +34,29 @@ export class OpenAIAgent extends AgentRunner {
defaultToolChoice = "auto",
callbackManager,
toolRetriever,
systemPrompt,
}: OpenAIAgentParams) {
prefixMessages = prefixMessages || [];
llm = llm ?? new OpenAI({ model: "gpt-3.5-turbo-0613" });
if (systemPrompt) {
if (prefixMessages) {
throw new Error("Cannot provide both systemPrompt and prefixMessages");
}
prefixMessages = [
{
content: systemPrompt,
role: "system",
},
];
}
if (!llm?.metadata.isFunctionCallingModel) {
throw new Error("LLM model must be a function-calling model");
}
const stepEngine = new OpenAIAgentWorker({
tools,
callbackManager,
......
......@@ -88,9 +88,9 @@ type CallFunctionOutput = {
* This class is responsible for running the agent.
*/
export class OpenAIAgentWorker implements AgentWorker {
private _llm: OpenAI;
private _verbose: boolean;
private _maxFunctionCalls: number;
private llm: OpenAI;
private verbose: boolean;
private maxFunctionCalls: number;
public prefixMessages: ChatMessage[];
public callbackManager: CallbackManager | undefined;
......@@ -109,11 +109,11 @@ export class OpenAIAgentWorker implements AgentWorker {
callbackManager,
toolRetriever,
}: OpenAIAgentWorkerParams) {
this._llm = llm ?? new OpenAI({ model: "gpt-3.5-turbo-0613" });
this._verbose = verbose || false;
this._maxFunctionCalls = maxFunctionCalls;
this.llm = llm ?? new OpenAI({ model: "gpt-3.5-turbo-0613" });
this.verbose = verbose || false;
this.maxFunctionCalls = maxFunctionCalls;
this.prefixMessages = prefixMessages || [];
this.callbackManager = callbackManager || this._llm.callbackManager;
this.callbackManager = callbackManager || this.llm.callbackManager;
if (tools.length > 0 && toolRetriever) {
throw new Error("Cannot specify both tools and tool_retriever");
......@@ -207,7 +207,7 @@ export class OpenAIAgentWorker implements AgentWorker {
llmChatKwargs: any,
): Promise<AgentChatResponse> {
if (mode === ChatResponseMode.WAIT) {
const chatResponse = (await this._llm.chat({
const chatResponse = (await this.llm.chat({
stream: false,
...llmChatKwargs,
})) as unknown as ChatResponse;
......@@ -236,7 +236,7 @@ export class OpenAIAgentWorker implements AgentWorker {
throw new Error("Invalid tool_call object");
}
const functionMessage = await callFunction(tools, toolCall, this._verbose);
const functionMessage = await callFunction(tools, toolCall, this.verbose);
const message = functionMessage[0];
const toolOutput = functionMessage[1];
......@@ -282,7 +282,7 @@ export class OpenAIAgentWorker implements AgentWorker {
toolCalls: OpenAIToolCall[] | null,
nFunctionCalls: number,
): boolean {
if (nFunctionCalls > this._maxFunctionCalls) {
if (nFunctionCalls > this.maxFunctionCalls) {
return false;
}
......@@ -311,7 +311,7 @@ export class OpenAIAgentWorker implements AgentWorker {
const tools = await this.getTools(task.input);
if (step.input) {
addUserStepToMemory(step, task.extraState.newMemory, this._verbose);
addUserStepToMemory(step, task.extraState.newMemory, this.verbose);
}
const openaiTools = tools.map((tool) =>
......
......@@ -26,6 +26,7 @@ import {
} from "./azure";
import { BaseLLM } from "./base";
import { OpenAISession, getOpenAISession } from "./open_ai";
import { isFunctionCallingModel } from "./openai/utils";
import { PortkeySession, getPortkeySession } from "./portkey";
import { ReplicateSession } from "./replicate_ai";
import {
......@@ -166,6 +167,7 @@ export class OpenAI extends BaseLLM {
maxTokens: this.maxTokens,
contextWindow,
tokenizer: Tokenizers.CL100K_BASE,
isFunctionCallingModel: isFunctionCallingModel(this.model),
};
}
......
import { ALL_AVAILABLE_OPENAI_MODELS } from "..";
export const isFunctionCallingModel = (model: string): boolean => {
const isChatModel = Object.keys(ALL_AVAILABLE_OPENAI_MODELS).includes(model);
const isOld = model.includes("0314") || model.includes("0301");
return isChatModel && !isOld;
};
import { ALL_AVAILABLE_OPENAI_MODELS } from "../../../llm";
import { isFunctionCallingModel } from "../../../llm/openai/utils";
describe("openai/utils", () => {
test("shouldn't be a old model", () => {
expect(isFunctionCallingModel("gpt-3.5-turbo")).toBe(true);
expect(isFunctionCallingModel("gpt-3.5-turbo-0314")).toBe(false);
expect(isFunctionCallingModel("gpt-3.5-turbo-0301")).toBe(false);
expect(isFunctionCallingModel("gpt-3.5-turbo-0314-0301")).toBe(false);
expect(isFunctionCallingModel("gpt-3.5-turbo-0314-0301-0314-0301")).toBe(
false,
);
expect(
isFunctionCallingModel("gpt-3.5-turbo-0314-0301-0314-0301-0314-0301"),
).toBe(false);
});
test("should be a open ai model", () => {
const models = Object.keys(ALL_AVAILABLE_OPENAI_MODELS).filter(
(model) => !model.includes("0314") && !model.includes("0301"),
);
models.forEach((model) => {
expect(isFunctionCallingModel(model)).toBe(true);
});
});
});
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment