diff --git a/.changeset/itchy-tips-reflect.md b/.changeset/itchy-tips-reflect.md new file mode 100644 index 0000000000000000000000000000000000000000..90d1987609cb1d00a25e0d27b333e0dad45604db --- /dev/null +++ b/.changeset/itchy-tips-reflect.md @@ -0,0 +1,5 @@ +--- +"@llamaindex/google": patch +--- + +fix(gemini): use function role for message contains tool-result diff --git a/examples/gemini/agent.ts b/examples/gemini/agent.ts index 212f10da8e03c72c0a1554432fc5114fe2d4ba08..1099d89a19bca89a4428fc8c15977d2c9b1ec06c 100644 --- a/examples/gemini/agent.ts +++ b/examples/gemini/agent.ts @@ -1,4 +1,18 @@ -import { FunctionTool, Gemini, GEMINI_MODEL, LLMAgent } from "llamaindex"; +import { + FunctionTool, + Gemini, + GEMINI_MODEL, + LLMAgent, + Settings, +} from "llamaindex"; + +Settings.callbackManager.on("llm-tool-call", (event) => { + console.log(event.detail); +}); + +Settings.callbackManager.on("llm-tool-result", (event) => { + console.log(event.detail); +}); const sumNumbers = FunctionTool.from( ({ a, b }: { a: number; b: number }) => `${a + b}`, @@ -44,17 +58,39 @@ const divideNumbers = FunctionTool.from( }, ); +const subtractNumbers = FunctionTool.from( + ({ a, b }: { a: number; b: number }) => `${a - b}`, + { + name: "subtractNumbers", + description: "Use this function to subtract two numbers", + parameters: { + type: "object", + properties: { + a: { + type: "number", + description: "The number to subtract from", + }, + b: { + type: "number", + description: "The number to subtract", + }, + }, + required: ["a", "b"], + }, + }, +); + async function main() { const gemini = new Gemini({ model: GEMINI_MODEL.GEMINI_PRO, }); const agent = new LLMAgent({ llm: gemini, - tools: [sumNumbers, divideNumbers], + tools: [sumNumbers, divideNumbers, subtractNumbers], }); const response = await agent.chat({ - message: "How much is 5 + 5? then divide by 2", + message: "How much is 5 + 5? then divide by 2 then subtract 1", }); console.log(response.message); diff --git a/packages/providers/google/src/types.ts b/packages/providers/google/src/types.ts index 6d90ec6d17c1b02ae373de24a510abc17156b769..73d39fbf1884858f9683543b223eb0157de3f516 100644 --- a/packages/providers/google/src/types.ts +++ b/packages/providers/google/src/types.ts @@ -97,7 +97,7 @@ export type GenerativeModel = export type ChatContext = { message: Part[]; history: GeminiMessageContent[] }; -export type GeminiMessageRole = "user" | "model"; +export type GeminiMessageRole = "user" | "model" | "function"; export type GeminiAdditionalChatOptions = object; diff --git a/packages/providers/google/src/utils.ts b/packages/providers/google/src/utils.ts index df0a44a925d9e06f8024f5038dd22b9386fade67..a6e06253f7541f2a6d02276ba6137c3ec380386d 100644 --- a/packages/providers/google/src/utils.ts +++ b/packages/providers/google/src/utils.ts @@ -196,6 +196,7 @@ export class GeminiHelper { > = { user: "user", model: "assistant", + function: "user", }; public static mergeNeighboringSameRoleMessages( @@ -278,12 +279,21 @@ export class GeminiHelper { return parts; } + public static getGeminiMessageRole( + message: ChatMessage<ToolCallLLMMessageOptions>, + ): GeminiMessageRole { + if (message.options && "toolResult" in message.options) { + return "function"; + } + return GeminiHelper.ROLES_TO_GEMINI[message.role]; + } + public static chatMessageToGemini( message: ChatMessage<ToolCallLLMMessageOptions>, fnMap: Record<string, string>, // mapping of fn call id to fn call name ): GeminiMessageContent { return { - role: GeminiHelper.ROLES_TO_GEMINI[message.role], + role: GeminiHelper.getGeminiMessageRole(message), parts: GeminiHelper.messageContentToGeminiParts({ ...message, fnMap }), }; }