From 1feb23bb83e7509a93fb0089e08b58b3e5c0fd5b Mon Sep 17 00:00:00 2001
From: Parham Saidi <parham@parha.me>
Date: Wed, 26 Jun 2024 19:49:11 +0200
Subject: [PATCH] feat: added Gemini tool calling support (#973)

---
 .changeset/fuzzy-tigers-accept.md            |   5 +
 apps/docs/docs/examples/agent_gemini.mdx     |   6 ++
 apps/docs/docs/modules/agent/index.md        |   4 +-
 examples/gemini/agent.ts                     |  65 +++++++++++
 packages/llamaindex/src/agent/utils.ts       |   9 ++
 packages/llamaindex/src/llm/gemini/base.ts   |  68 ++++++++++--
 packages/llamaindex/src/llm/gemini/types.ts  |  16 +++
 packages/llamaindex/src/llm/gemini/utils.ts  | 108 +++++++++++++++++--
 packages/llamaindex/src/llm/gemini/vertex.ts |  40 +++++--
 9 files changed, 298 insertions(+), 23 deletions(-)
 create mode 100644 .changeset/fuzzy-tigers-accept.md
 create mode 100644 apps/docs/docs/examples/agent_gemini.mdx
 create mode 100644 examples/gemini/agent.ts

diff --git a/.changeset/fuzzy-tigers-accept.md b/.changeset/fuzzy-tigers-accept.md
new file mode 100644
index 000000000..03e08ea44
--- /dev/null
+++ b/.changeset/fuzzy-tigers-accept.md
@@ -0,0 +1,5 @@
+---
+"llamaindex": patch
+---
+
+feat: Gemini tool calling for agent support
diff --git a/apps/docs/docs/examples/agent_gemini.mdx b/apps/docs/docs/examples/agent_gemini.mdx
new file mode 100644
index 000000000..7df6ecb53
--- /dev/null
+++ b/apps/docs/docs/examples/agent_gemini.mdx
@@ -0,0 +1,6 @@
+# Gemini Agent
+
+import CodeBlock from "@theme/CodeBlock";
+import CodeSourceGemini from "!raw-loader!../../../../examples/gemini/agent.ts";
+
+<CodeBlock language="ts">{CodeSourceGemini}</CodeBlock>
diff --git a/apps/docs/docs/modules/agent/index.md b/apps/docs/docs/modules/agent/index.md
index 1941d1f50..39121cfb4 100644
--- a/apps/docs/docs/modules/agent/index.md
+++ b/apps/docs/docs/modules/agent/index.md
@@ -12,12 +12,14 @@ An “agent” is an automated reasoning and decision engine. It takes in a user
 LlamaIndex.TS comes with a few built-in agents, but you can also create your own. The built-in agents include:
 
 - OpenAI Agent
-- Anthropic Agent
+- Anthropic Agent both via Anthropic and Bedrock (in `@llamaIndex/community`)
+- Gemini Agent
 - ReACT Agent
 
 ## Examples
 
 - [OpenAI Agent](../../examples/agent.mdx)
+- [Gemini Agent](../../examples/agent_gemini.mdx)
 
 ## Api References
 
diff --git a/examples/gemini/agent.ts b/examples/gemini/agent.ts
new file mode 100644
index 000000000..212f10da8
--- /dev/null
+++ b/examples/gemini/agent.ts
@@ -0,0 +1,65 @@
+import { FunctionTool, Gemini, GEMINI_MODEL, LLMAgent } from "llamaindex";
+
+const sumNumbers = FunctionTool.from(
+  ({ a, b }: { a: number; b: number }) => `${a + b}`,
+  {
+    name: "sumNumbers",
+    description: "Use this function to sum two numbers",
+    parameters: {
+      type: "object",
+      properties: {
+        a: {
+          type: "number",
+          description: "The first number",
+        },
+        b: {
+          type: "number",
+          description: "The second number",
+        },
+      },
+      required: ["a", "b"],
+    },
+  },
+);
+
+const divideNumbers = FunctionTool.from(
+  ({ a, b }: { a: number; b: number }) => `${a / b}`,
+  {
+    name: "divideNumbers",
+    description: "Use this function to divide two numbers",
+    parameters: {
+      type: "object",
+      properties: {
+        a: {
+          type: "number",
+          description: "The dividend a to divide",
+        },
+        b: {
+          type: "number",
+          description: "The divisor b to divide by",
+        },
+      },
+      required: ["a", "b"],
+    },
+  },
+);
+
+async function main() {
+  const gemini = new Gemini({
+    model: GEMINI_MODEL.GEMINI_PRO,
+  });
+  const agent = new LLMAgent({
+    llm: gemini,
+    tools: [sumNumbers, divideNumbers],
+  });
+
+  const response = await agent.chat({
+    message: "How much is 5 + 5? then divide by 2",
+  });
+
+  console.log(response.message);
+}
+
+void main().then(() => {
+  console.log("Done");
+});
diff --git a/packages/llamaindex/src/agent/utils.ts b/packages/llamaindex/src/agent/utils.ts
index 8a3df6402..32ba34139 100644
--- a/packages/llamaindex/src/agent/utils.ts
+++ b/packages/llamaindex/src/agent/utils.ts
@@ -61,6 +61,7 @@ export async function stepToolsStreaming<Model extends LLM>({
   // check if first chunk has tool calls, if so, this is a function call
   // otherwise, it's a regular message
   const hasToolCall = !!(value.options && "toolCall" in value.options);
+
   enqueueOutput({
     taskStep: step,
     output: finalStream,
@@ -78,6 +79,14 @@ export async function stepToolsStreaming<Model extends LLM>({
         });
       }
     }
+
+    // If there are toolCalls but they didn't get read into the stream, used for Gemini
+    if (!toolCalls.size && value.options && "toolCall" in value.options) {
+      value.options.toolCall.forEach((toolCall) => {
+        toolCalls.set(toolCall.id, toolCall);
+      });
+    }
+
     step.context.store.messages = [
       ...step.context.store.messages,
       {
diff --git a/packages/llamaindex/src/llm/gemini/base.ts b/packages/llamaindex/src/llm/gemini/base.ts
index 65491fd98..9d14e82c4 100644
--- a/packages/llamaindex/src/llm/gemini/base.ts
+++ b/packages/llamaindex/src/llm/gemini/base.ts
@@ -2,17 +2,20 @@ import {
   GoogleGenerativeAI,
   GenerativeModel as GoogleGenerativeModel,
   type EnhancedGenerateContentResponse,
+  type FunctionCall,
   type ModelParams as GoogleModelParams,
   type GenerateContentStreamResult as GoogleStreamGenerateContentResult,
 } from "@google/generative-ai";
 
-import { getEnv } from "@llamaindex/env";
+import { getEnv, randomUUID } from "@llamaindex/env";
 import { ToolCallLLM } from "../base.js";
 import type {
   CompletionResponse,
   LLMCompletionParamsNonStreaming,
   LLMCompletionParamsStreaming,
   LLMMetadata,
+  ToolCall,
+  ToolCallLLMMessageOptions,
 } from "../types.js";
 import { streamConverter, wrapLLMEvent } from "../utils.js";
 import {
@@ -29,7 +32,12 @@ import {
   type GoogleGeminiSessionOptions,
   type IGeminiSession,
 } from "./types.js";
-import { GeminiHelper, getChatContext, getPartsText } from "./utils.js";
+import {
+  GeminiHelper,
+  getChatContext,
+  getPartsText,
+  mapBaseToolToGeminiFunctionDeclaration,
+} from "./utils.js";
 
 export const GEMINI_MODEL_INFO_MAP: Record<GEMINI_MODEL, GeminiModelInfo> = {
   [GEMINI_MODEL.GEMINI_PRO]: { contextWindow: 30720 },
@@ -86,13 +94,33 @@ export class GeminiSession implements IGeminiSession {
     return response.text();
   }
 
+  getToolsFromResponse(
+    response: EnhancedGenerateContentResponse,
+  ): ToolCall[] | undefined {
+    return response.functionCalls()?.map(
+      (call: FunctionCall) =>
+        ({
+          name: call.name,
+          input: call.args,
+          id: randomUUID(),
+        }) as ToolCall,
+    );
+  }
+
   async *getChatStream(
     result: GoogleStreamGenerateContentResult,
   ): GeminiChatStreamResponse {
-    yield* streamConverter(result.stream, (response) => ({
-      delta: this.getResponseText(response),
-      raw: response,
-    }));
+    yield* streamConverter(result.stream, (response) => {
+      const tools = this.getToolsFromResponse(response);
+      const options: ToolCallLLMMessageOptions = tools?.length
+        ? { toolCall: tools }
+        : {};
+      return {
+        delta: this.getResponseText(response),
+        raw: response,
+        options,
+      };
+    });
   }
 
   getCompletionStream(
@@ -188,10 +216,22 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> {
     const client = this.session.getGenerativeModel(this.metadata);
     const chat = client.startChat({
       history: context.history,
+      tools: params.tools && [
+        {
+          functionDeclarations: params.tools.map(
+            mapBaseToolToGeminiFunctionDeclaration,
+          ),
+        },
+      ],
     });
     const { response } = await chat.sendMessage(context.message);
     const topCandidate = response.candidates![0];
 
+    const tools = this.session.getToolsFromResponse(response);
+    const options: ToolCallLLMMessageOptions = tools?.length
+      ? { toolCall: tools }
+      : {};
+
     return {
       raw: response,
       message: {
@@ -199,6 +239,7 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> {
         role: GeminiHelper.ROLES_FROM_GEMINI[
           topCandidate.content.role as GeminiMessageRole
         ],
+        options,
       },
     };
   }
@@ -210,6 +251,13 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> {
     const client = this.session.getGenerativeModel(this.metadata);
     const chat = client.startChat({
       history: context.history,
+      tools: params.tools && [
+        {
+          functionDeclarations: params.tools.map(
+            mapBaseToolToGeminiFunctionDeclaration,
+          ),
+        },
+      ],
     });
     const result = await chat.sendMessageStream(context.message);
     yield* this.session.getChatStream(result);
@@ -241,13 +289,17 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> {
 
     if (stream) {
       const result = await client.generateContentStream(
-        getPartsText(GeminiHelper.messageContentToGeminiParts(prompt)),
+        getPartsText(
+          GeminiHelper.messageContentToGeminiParts({ content: prompt }),
+        ),
       );
       return this.session.getCompletionStream(result);
     }
 
     const result = await client.generateContent(
-      getPartsText(GeminiHelper.messageContentToGeminiParts(prompt)),
+      getPartsText(
+        GeminiHelper.messageContentToGeminiParts({ content: prompt }),
+      ),
     );
     return {
       text: this.session.getResponseText(result.response),
diff --git a/packages/llamaindex/src/llm/gemini/types.ts b/packages/llamaindex/src/llm/gemini/types.ts
index 998ec12ea..f602ee83b 100644
--- a/packages/llamaindex/src/llm/gemini/types.ts
+++ b/packages/llamaindex/src/llm/gemini/types.ts
@@ -3,6 +3,8 @@ import {
   type EnhancedGenerateContentResponse,
   type Content as GeminiMessageContent,
   type FileDataPart as GoogleFileDataPart,
+  type FunctionDeclaration as GoogleFunctionDeclaration,
+  type FunctionDeclarationSchema as GoogleFunctionDeclarationSchema,
   type InlineDataPart as GoogleInlineFileDataPart,
   type ModelParams as GoogleModelParams,
   type Part as GooglePart,
@@ -14,6 +16,8 @@ import {
   GenerativeModelPreview as VertexGenerativeModelPreview,
   type GenerateContentResponse,
   type FileDataPart as VertexFileDataPart,
+  type FunctionDeclaration as VertexFunctionDeclaration,
+  type FunctionDeclarationSchema as VertexFunctionDeclarationSchema,
   type VertexInit,
   type InlineDataPart as VertexInlineFileDataPart,
   type ModelParams as VertexModelParams,
@@ -27,6 +31,7 @@ import type {
   CompletionResponse,
   LLMChatParamsNonStreaming,
   LLMChatParamsStreaming,
+  ToolCall,
   ToolCallLLMMessageOptions,
 } from "../types.js";
 
@@ -69,6 +74,14 @@ export type InlineDataPart =
 
 export type ModelParams = GoogleModelParams | VertexModelParams;
 
+export type FunctionDeclaration =
+  | VertexFunctionDeclaration
+  | GoogleFunctionDeclaration;
+
+export type FunctionDeclarationSchema =
+  | GoogleFunctionDeclarationSchema
+  | VertexFunctionDeclarationSchema;
+
 export type GenerativeModel =
   | VertexGenerativeModelPreview
   | VertexGenerativeModel
@@ -112,4 +125,7 @@ export interface IGeminiSession {
       | GoogleStreamGenerateContentResult
       | VertexStreamGenerateContentResult,
   ): GeminiChatStreamResponse;
+  getToolsFromResponse(
+    response: EnhancedGenerateContentResponse | GenerateContentResponse,
+  ): ToolCall[] | undefined;
 }
diff --git a/packages/llamaindex/src/llm/gemini/utils.ts b/packages/llamaindex/src/llm/gemini/utils.ts
index e20b5f065..fd423fff4 100644
--- a/packages/llamaindex/src/llm/gemini/utils.ts
+++ b/packages/llamaindex/src/llm/gemini/utils.ts
@@ -1,17 +1,23 @@
-import { type Content as GeminiMessageContent } from "@google/generative-ai";
+import {
+  type FunctionCall,
+  type Content as GeminiMessageContent,
+} from "@google/generative-ai";
 
 import { type GenerateContentResponse } from "@google-cloud/vertexai";
+import type { BaseTool } from "../../types.js";
 import type {
   ChatMessage,
-  MessageContent,
   MessageContentImageDetail,
   MessageContentTextDetail,
   MessageType,
+  ToolCallLLMMessageOptions,
 } from "../types.js";
 import { extractDataUrlComponents } from "../utils.js";
 import type {
   ChatContext,
   FileDataPart,
+  FunctionDeclaration,
+  FunctionDeclarationSchema,
   GeminiChatParamsNonStreaming,
   GeminiChatParamsStreaming,
   GeminiMessageRole,
@@ -104,7 +110,8 @@ export const cleanParts = (
         part.text?.trim() ||
         part.inlineData ||
         part.fileData ||
-        part.functionCall,
+        part.functionCall ||
+        part.functionResponse,
     ),
   };
 };
@@ -115,8 +122,21 @@ export const getChatContext = (
   // Gemini doesn't allow:
   // 1. Consecutive messages from the same role
   // 2. Parts that have empty text
+  const fnMap = params.messages.reduce(
+    (result, message) => {
+      if (message.options && "toolCall" in message.options)
+        message.options.toolCall.forEach((call) => {
+          result[call.id] = call.name;
+        });
+
+      return result;
+    },
+    {} as Record<string, string>,
+  );
   const messages = GeminiHelper.mergeNeighboringSameRoleMessages(
-    params.messages.map(GeminiHelper.chatMessageToGemini),
+    params.messages.map((message) =>
+      GeminiHelper.chatMessageToGemini(message, fnMap),
+    ),
   ).map(cleanParts);
 
   const history = messages.slice(0, -1);
@@ -127,6 +147,23 @@ export const getChatContext = (
   };
 };
 
+export const mapBaseToolToGeminiFunctionDeclaration = (
+  tool: BaseTool,
+): FunctionDeclaration => {
+  const parameters: FunctionDeclarationSchema = {
+    type: tool.metadata.parameters?.type.toUpperCase(),
+    properties: tool.metadata.parameters?.properties,
+    description: tool.metadata.parameters?.description,
+    required: tool.metadata.parameters?.required,
+  };
+
+  return {
+    name: tool.metadata.name,
+    description: tool.metadata.description,
+    parameters,
+  };
+};
+
 /**
  * Helper class providing utility functions for Gemini
  */
@@ -177,7 +214,40 @@ export class GeminiHelper {
       );
   }
 
-  public static messageContentToGeminiParts(content: MessageContent): Part[] {
+  public static messageContentToGeminiParts({
+    content,
+    options = undefined,
+    fnMap = undefined,
+  }: Pick<ChatMessage<ToolCallLLMMessageOptions>, "content" | "options"> & {
+    fnMap?: Record<string, string>;
+  }): Part[] {
+    if (options && "toolResult" in options) {
+      if (!fnMap) throw Error("fnMap must be set");
+      const name = fnMap[options.toolResult.id];
+      if (!name)
+        throw Error(
+          `Could not find the name for fn call with id ${options.toolResult.id}`,
+        );
+
+      return [
+        {
+          functionResponse: {
+            name,
+            response: {
+              result: options.toolResult.result,
+            },
+          },
+        },
+      ];
+    }
+    if (options && "toolCall" in options) {
+      return options.toolCall.map((call) => ({
+        functionCall: {
+          name: call.name,
+          args: call.input,
+        } as FunctionCall,
+      }));
+    }
     if (typeof content === "string") {
       return [{ text: content }];
     }
@@ -197,11 +267,35 @@ export class GeminiHelper {
   }
 
   public static chatMessageToGemini(
-    message: ChatMessage,
+    message: ChatMessage<ToolCallLLMMessageOptions>,
+    fnMap: Record<string, string>, // mapping of fn call id to fn call name
   ): GeminiMessageContent {
     return {
       role: GeminiHelper.ROLES_TO_GEMINI[message.role],
-      parts: GeminiHelper.messageContentToGeminiParts(message.content),
+      parts: GeminiHelper.messageContentToGeminiParts({ ...message, fnMap }),
     };
   }
 }
+
+/**
+ * Returns functionCall of first candidate.
+ * Taken from https://github.com/google-gemini/generative-ai-js/ to be used with
+ * vertexai as that library doesn't include it
+ */
+export function getFunctionCalls(
+  response: GenerateContentResponse,
+): FunctionCall[] | undefined {
+  const functionCalls: FunctionCall[] = [];
+  if (response.candidates?.[0].content?.parts) {
+    for (const part of response.candidates?.[0].content?.parts) {
+      if (part.functionCall) {
+        functionCalls.push(part.functionCall);
+      }
+    }
+  }
+  if (functionCalls.length > 0) {
+    return functionCalls;
+  } else {
+    return undefined;
+  }
+}
diff --git a/packages/llamaindex/src/llm/gemini/vertex.ts b/packages/llamaindex/src/llm/gemini/vertex.ts
index 43c7100da..a24e4546e 100644
--- a/packages/llamaindex/src/llm/gemini/vertex.ts
+++ b/packages/llamaindex/src/llm/gemini/vertex.ts
@@ -13,10 +13,15 @@ import type {
   VertexGeminiSessionOptions,
 } from "./types.js";
 
-import { getEnv } from "@llamaindex/env";
-import type { CompletionResponse } from "../types.js";
+import type { FunctionCall } from "@google/generative-ai";
+import { getEnv, randomUUID } from "@llamaindex/env";
+import type {
+  CompletionResponse,
+  ToolCall,
+  ToolCallLLMMessageOptions,
+} from "../types.js";
 import { streamConverter } from "../utils.js";
-import { getText } from "./utils.js";
+import { getFunctionCalls, getText } from "./utils.js";
 
 /* To use Google's Vertex AI backend, it doesn't use api key authentication.
  *
@@ -62,14 +67,35 @@ export class GeminiVertexSession implements IGeminiSession {
     return getText(response);
   }
 
+  getToolsFromResponse(
+    response: GenerateContentResponse,
+  ): ToolCall[] | undefined {
+    return getFunctionCalls(response)?.map(
+      (call: FunctionCall) =>
+        ({
+          name: call.name,
+          input: call.args,
+          id: randomUUID(),
+        }) as ToolCall,
+    );
+  }
+
   async *getChatStream(
     result: VertexStreamGenerateContentResult,
   ): GeminiChatStreamResponse {
-    yield* streamConverter(result.stream, (response) => ({
-      delta: this.getResponseText(response),
-      raw: response,
-    }));
+    yield* streamConverter(result.stream, (response) => {
+      const tools = this.getToolsFromResponse(response);
+      const options: ToolCallLLMMessageOptions = tools?.length
+        ? { toolCall: tools }
+        : {};
+      return {
+        delta: this.getResponseText(response),
+        raw: response,
+        options,
+      };
+    });
   }
+
   getCompletionStream(
     result: VertexStreamGenerateContentResult,
   ): AsyncIterable<CompletionResponse> {
-- 
GitLab