From b6ea2bf964632e82fb282abc19f95e42b3f62ef1 Mon Sep 17 00:00:00 2001 From: Thuc Pham <51660321+thucpn@users.noreply.github.com> Date: Mon, 10 Feb 2025 12:57:45 +0700 Subject: [PATCH] fix(gemini): use function role for message contains tool-result (#1634) --- .changeset/itchy-tips-reflect.md | 5 +++ examples/gemini/agent.ts | 42 ++++++++++++++++++++++++-- packages/providers/google/src/types.ts | 2 +- packages/providers/google/src/utils.ts | 12 +++++++- 4 files changed, 56 insertions(+), 5 deletions(-) create mode 100644 .changeset/itchy-tips-reflect.md diff --git a/.changeset/itchy-tips-reflect.md b/.changeset/itchy-tips-reflect.md new file mode 100644 index 000000000..90d198760 --- /dev/null +++ b/.changeset/itchy-tips-reflect.md @@ -0,0 +1,5 @@ +--- +"@llamaindex/google": patch +--- + +fix(gemini): use function role for message contains tool-result diff --git a/examples/gemini/agent.ts b/examples/gemini/agent.ts index 212f10da8..1099d89a1 100644 --- a/examples/gemini/agent.ts +++ b/examples/gemini/agent.ts @@ -1,4 +1,18 @@ -import { FunctionTool, Gemini, GEMINI_MODEL, LLMAgent } from "llamaindex"; +import { + FunctionTool, + Gemini, + GEMINI_MODEL, + LLMAgent, + Settings, +} from "llamaindex"; + +Settings.callbackManager.on("llm-tool-call", (event) => { + console.log(event.detail); +}); + +Settings.callbackManager.on("llm-tool-result", (event) => { + console.log(event.detail); +}); const sumNumbers = FunctionTool.from( ({ a, b }: { a: number; b: number }) => `${a + b}`, @@ -44,17 +58,39 @@ const divideNumbers = FunctionTool.from( }, ); +const subtractNumbers = FunctionTool.from( + ({ a, b }: { a: number; b: number }) => `${a - b}`, + { + name: "subtractNumbers", + description: "Use this function to subtract two numbers", + parameters: { + type: "object", + properties: { + a: { + type: "number", + description: "The number to subtract from", + }, + b: { + type: "number", + description: "The number to subtract", + }, + }, + required: ["a", "b"], + }, + }, +); + async function main() { const gemini = new Gemini({ model: GEMINI_MODEL.GEMINI_PRO, }); const agent = new LLMAgent({ llm: gemini, - tools: [sumNumbers, divideNumbers], + tools: [sumNumbers, divideNumbers, subtractNumbers], }); const response = await agent.chat({ - message: "How much is 5 + 5? then divide by 2", + message: "How much is 5 + 5? then divide by 2 then subtract 1", }); console.log(response.message); diff --git a/packages/providers/google/src/types.ts b/packages/providers/google/src/types.ts index 6d90ec6d1..73d39fbf1 100644 --- a/packages/providers/google/src/types.ts +++ b/packages/providers/google/src/types.ts @@ -97,7 +97,7 @@ export type GenerativeModel = export type ChatContext = { message: Part[]; history: GeminiMessageContent[] }; -export type GeminiMessageRole = "user" | "model"; +export type GeminiMessageRole = "user" | "model" | "function"; export type GeminiAdditionalChatOptions = object; diff --git a/packages/providers/google/src/utils.ts b/packages/providers/google/src/utils.ts index df0a44a92..a6e06253f 100644 --- a/packages/providers/google/src/utils.ts +++ b/packages/providers/google/src/utils.ts @@ -196,6 +196,7 @@ export class GeminiHelper { > = { user: "user", model: "assistant", + function: "user", }; public static mergeNeighboringSameRoleMessages( @@ -278,12 +279,21 @@ export class GeminiHelper { return parts; } + public static getGeminiMessageRole( + message: ChatMessage<ToolCallLLMMessageOptions>, + ): GeminiMessageRole { + if (message.options && "toolResult" in message.options) { + return "function"; + } + return GeminiHelper.ROLES_TO_GEMINI[message.role]; + } + public static chatMessageToGemini( message: ChatMessage<ToolCallLLMMessageOptions>, fnMap: Record<string, string>, // mapping of fn call id to fn call name ): GeminiMessageContent { return { - role: GeminiHelper.ROLES_TO_GEMINI[message.role], + role: GeminiHelper.getGeminiMessageRole(message), parts: GeminiHelper.messageContentToGeminiParts({ ...message, fnMap }), }; } -- GitLab