From b6ea2bf964632e82fb282abc19f95e42b3f62ef1 Mon Sep 17 00:00:00 2001
From: Thuc Pham <51660321+thucpn@users.noreply.github.com>
Date: Mon, 10 Feb 2025 12:57:45 +0700
Subject: [PATCH] fix(gemini): use function role for message contains
 tool-result (#1634)

---
 .changeset/itchy-tips-reflect.md       |  5 +++
 examples/gemini/agent.ts               | 42 ++++++++++++++++++++++++--
 packages/providers/google/src/types.ts |  2 +-
 packages/providers/google/src/utils.ts | 12 +++++++-
 4 files changed, 56 insertions(+), 5 deletions(-)
 create mode 100644 .changeset/itchy-tips-reflect.md

diff --git a/.changeset/itchy-tips-reflect.md b/.changeset/itchy-tips-reflect.md
new file mode 100644
index 000000000..90d198760
--- /dev/null
+++ b/.changeset/itchy-tips-reflect.md
@@ -0,0 +1,5 @@
+---
+"@llamaindex/google": patch
+---
+
+fix(gemini): use function role for message contains tool-result
diff --git a/examples/gemini/agent.ts b/examples/gemini/agent.ts
index 212f10da8..1099d89a1 100644
--- a/examples/gemini/agent.ts
+++ b/examples/gemini/agent.ts
@@ -1,4 +1,18 @@
-import { FunctionTool, Gemini, GEMINI_MODEL, LLMAgent } from "llamaindex";
+import {
+  FunctionTool,
+  Gemini,
+  GEMINI_MODEL,
+  LLMAgent,
+  Settings,
+} from "llamaindex";
+
+Settings.callbackManager.on("llm-tool-call", (event) => {
+  console.log(event.detail);
+});
+
+Settings.callbackManager.on("llm-tool-result", (event) => {
+  console.log(event.detail);
+});
 
 const sumNumbers = FunctionTool.from(
   ({ a, b }: { a: number; b: number }) => `${a + b}`,
@@ -44,17 +58,39 @@ const divideNumbers = FunctionTool.from(
   },
 );
 
+const subtractNumbers = FunctionTool.from(
+  ({ a, b }: { a: number; b: number }) => `${a - b}`,
+  {
+    name: "subtractNumbers",
+    description: "Use this function to subtract two numbers",
+    parameters: {
+      type: "object",
+      properties: {
+        a: {
+          type: "number",
+          description: "The number to subtract from",
+        },
+        b: {
+          type: "number",
+          description: "The number to subtract",
+        },
+      },
+      required: ["a", "b"],
+    },
+  },
+);
+
 async function main() {
   const gemini = new Gemini({
     model: GEMINI_MODEL.GEMINI_PRO,
   });
   const agent = new LLMAgent({
     llm: gemini,
-    tools: [sumNumbers, divideNumbers],
+    tools: [sumNumbers, divideNumbers, subtractNumbers],
   });
 
   const response = await agent.chat({
-    message: "How much is 5 + 5? then divide by 2",
+    message: "How much is 5 + 5? then divide by 2 then subtract 1",
   });
 
   console.log(response.message);
diff --git a/packages/providers/google/src/types.ts b/packages/providers/google/src/types.ts
index 6d90ec6d1..73d39fbf1 100644
--- a/packages/providers/google/src/types.ts
+++ b/packages/providers/google/src/types.ts
@@ -97,7 +97,7 @@ export type GenerativeModel =
 
 export type ChatContext = { message: Part[]; history: GeminiMessageContent[] };
 
-export type GeminiMessageRole = "user" | "model";
+export type GeminiMessageRole = "user" | "model" | "function";
 
 export type GeminiAdditionalChatOptions = object;
 
diff --git a/packages/providers/google/src/utils.ts b/packages/providers/google/src/utils.ts
index df0a44a92..a6e06253f 100644
--- a/packages/providers/google/src/utils.ts
+++ b/packages/providers/google/src/utils.ts
@@ -196,6 +196,7 @@ export class GeminiHelper {
   > = {
     user: "user",
     model: "assistant",
+    function: "user",
   };
 
   public static mergeNeighboringSameRoleMessages(
@@ -278,12 +279,21 @@ export class GeminiHelper {
     return parts;
   }
 
+  public static getGeminiMessageRole(
+    message: ChatMessage<ToolCallLLMMessageOptions>,
+  ): GeminiMessageRole {
+    if (message.options && "toolResult" in message.options) {
+      return "function";
+    }
+    return GeminiHelper.ROLES_TO_GEMINI[message.role];
+  }
+
   public static chatMessageToGemini(
     message: ChatMessage<ToolCallLLMMessageOptions>,
     fnMap: Record<string, string>, // mapping of fn call id to fn call name
   ): GeminiMessageContent {
     return {
-      role: GeminiHelper.ROLES_TO_GEMINI[message.role],
+      role: GeminiHelper.getGeminiMessageRole(message),
       parts: GeminiHelper.messageContentToGeminiParts({ ...message, fnMap }),
     };
   }
-- 
GitLab