From 1feb23bb83e7509a93fb0089e08b58b3e5c0fd5b Mon Sep 17 00:00:00 2001 From: Parham Saidi <parham@parha.me> Date: Wed, 26 Jun 2024 19:49:11 +0200 Subject: [PATCH] feat: added Gemini tool calling support (#973) --- .changeset/fuzzy-tigers-accept.md | 5 + apps/docs/docs/examples/agent_gemini.mdx | 6 ++ apps/docs/docs/modules/agent/index.md | 4 +- examples/gemini/agent.ts | 65 +++++++++++ packages/llamaindex/src/agent/utils.ts | 9 ++ packages/llamaindex/src/llm/gemini/base.ts | 68 ++++++++++-- packages/llamaindex/src/llm/gemini/types.ts | 16 +++ packages/llamaindex/src/llm/gemini/utils.ts | 108 +++++++++++++++++-- packages/llamaindex/src/llm/gemini/vertex.ts | 40 +++++-- 9 files changed, 298 insertions(+), 23 deletions(-) create mode 100644 .changeset/fuzzy-tigers-accept.md create mode 100644 apps/docs/docs/examples/agent_gemini.mdx create mode 100644 examples/gemini/agent.ts diff --git a/.changeset/fuzzy-tigers-accept.md b/.changeset/fuzzy-tigers-accept.md new file mode 100644 index 000000000..03e08ea44 --- /dev/null +++ b/.changeset/fuzzy-tigers-accept.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +feat: Gemini tool calling for agent support diff --git a/apps/docs/docs/examples/agent_gemini.mdx b/apps/docs/docs/examples/agent_gemini.mdx new file mode 100644 index 000000000..7df6ecb53 --- /dev/null +++ b/apps/docs/docs/examples/agent_gemini.mdx @@ -0,0 +1,6 @@ +# Gemini Agent + +import CodeBlock from "@theme/CodeBlock"; +import CodeSourceGemini from "!raw-loader!../../../../examples/gemini/agent.ts"; + +<CodeBlock language="ts">{CodeSourceGemini}</CodeBlock> diff --git a/apps/docs/docs/modules/agent/index.md b/apps/docs/docs/modules/agent/index.md index 1941d1f50..39121cfb4 100644 --- a/apps/docs/docs/modules/agent/index.md +++ b/apps/docs/docs/modules/agent/index.md @@ -12,12 +12,14 @@ An “agent” is an automated reasoning and decision engine. It takes in a user LlamaIndex.TS comes with a few built-in agents, but you can also create your own. The built-in agents include: - OpenAI Agent -- Anthropic Agent +- Anthropic Agent both via Anthropic and Bedrock (in `@llamaIndex/community`) +- Gemini Agent - ReACT Agent ## Examples - [OpenAI Agent](../../examples/agent.mdx) +- [Gemini Agent](../../examples/agent_gemini.mdx) ## Api References diff --git a/examples/gemini/agent.ts b/examples/gemini/agent.ts new file mode 100644 index 000000000..212f10da8 --- /dev/null +++ b/examples/gemini/agent.ts @@ -0,0 +1,65 @@ +import { FunctionTool, Gemini, GEMINI_MODEL, LLMAgent } from "llamaindex"; + +const sumNumbers = FunctionTool.from( + ({ a, b }: { a: number; b: number }) => `${a + b}`, + { + name: "sumNumbers", + description: "Use this function to sum two numbers", + parameters: { + type: "object", + properties: { + a: { + type: "number", + description: "The first number", + }, + b: { + type: "number", + description: "The second number", + }, + }, + required: ["a", "b"], + }, + }, +); + +const divideNumbers = FunctionTool.from( + ({ a, b }: { a: number; b: number }) => `${a / b}`, + { + name: "divideNumbers", + description: "Use this function to divide two numbers", + parameters: { + type: "object", + properties: { + a: { + type: "number", + description: "The dividend a to divide", + }, + b: { + type: "number", + description: "The divisor b to divide by", + }, + }, + required: ["a", "b"], + }, + }, +); + +async function main() { + const gemini = new Gemini({ + model: GEMINI_MODEL.GEMINI_PRO, + }); + const agent = new LLMAgent({ + llm: gemini, + tools: [sumNumbers, divideNumbers], + }); + + const response = await agent.chat({ + message: "How much is 5 + 5? then divide by 2", + }); + + console.log(response.message); +} + +void main().then(() => { + console.log("Done"); +}); diff --git a/packages/llamaindex/src/agent/utils.ts b/packages/llamaindex/src/agent/utils.ts index 8a3df6402..32ba34139 100644 --- a/packages/llamaindex/src/agent/utils.ts +++ b/packages/llamaindex/src/agent/utils.ts @@ -61,6 +61,7 @@ export async function stepToolsStreaming<Model extends LLM>({ // check if first chunk has tool calls, if so, this is a function call // otherwise, it's a regular message const hasToolCall = !!(value.options && "toolCall" in value.options); + enqueueOutput({ taskStep: step, output: finalStream, @@ -78,6 +79,14 @@ export async function stepToolsStreaming<Model extends LLM>({ }); } } + + // If there are toolCalls but they didn't get read into the stream, used for Gemini + if (!toolCalls.size && value.options && "toolCall" in value.options) { + value.options.toolCall.forEach((toolCall) => { + toolCalls.set(toolCall.id, toolCall); + }); + } + step.context.store.messages = [ ...step.context.store.messages, { diff --git a/packages/llamaindex/src/llm/gemini/base.ts b/packages/llamaindex/src/llm/gemini/base.ts index 65491fd98..9d14e82c4 100644 --- a/packages/llamaindex/src/llm/gemini/base.ts +++ b/packages/llamaindex/src/llm/gemini/base.ts @@ -2,17 +2,20 @@ import { GoogleGenerativeAI, GenerativeModel as GoogleGenerativeModel, type EnhancedGenerateContentResponse, + type FunctionCall, type ModelParams as GoogleModelParams, type GenerateContentStreamResult as GoogleStreamGenerateContentResult, } from "@google/generative-ai"; -import { getEnv } from "@llamaindex/env"; +import { getEnv, randomUUID } from "@llamaindex/env"; import { ToolCallLLM } from "../base.js"; import type { CompletionResponse, LLMCompletionParamsNonStreaming, LLMCompletionParamsStreaming, LLMMetadata, + ToolCall, + ToolCallLLMMessageOptions, } from "../types.js"; import { streamConverter, wrapLLMEvent } from "../utils.js"; import { @@ -29,7 +32,12 @@ import { type GoogleGeminiSessionOptions, type IGeminiSession, } from "./types.js"; -import { GeminiHelper, getChatContext, getPartsText } from "./utils.js"; +import { + GeminiHelper, + getChatContext, + getPartsText, + mapBaseToolToGeminiFunctionDeclaration, +} from "./utils.js"; export const GEMINI_MODEL_INFO_MAP: Record<GEMINI_MODEL, GeminiModelInfo> = { [GEMINI_MODEL.GEMINI_PRO]: { contextWindow: 30720 }, @@ -86,13 +94,33 @@ export class GeminiSession implements IGeminiSession { return response.text(); } + getToolsFromResponse( + response: EnhancedGenerateContentResponse, + ): ToolCall[] | undefined { + return response.functionCalls()?.map( + (call: FunctionCall) => + ({ + name: call.name, + input: call.args, + id: randomUUID(), + }) as ToolCall, + ); + } + async *getChatStream( result: GoogleStreamGenerateContentResult, ): GeminiChatStreamResponse { - yield* streamConverter(result.stream, (response) => ({ - delta: this.getResponseText(response), - raw: response, - })); + yield* streamConverter(result.stream, (response) => { + const tools = this.getToolsFromResponse(response); + const options: ToolCallLLMMessageOptions = tools?.length + ? { toolCall: tools } + : {}; + return { + delta: this.getResponseText(response), + raw: response, + options, + }; + }); } getCompletionStream( @@ -188,10 +216,22 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> { const client = this.session.getGenerativeModel(this.metadata); const chat = client.startChat({ history: context.history, + tools: params.tools && [ + { + functionDeclarations: params.tools.map( + mapBaseToolToGeminiFunctionDeclaration, + ), + }, + ], }); const { response } = await chat.sendMessage(context.message); const topCandidate = response.candidates![0]; + const tools = this.session.getToolsFromResponse(response); + const options: ToolCallLLMMessageOptions = tools?.length + ? { toolCall: tools } + : {}; + return { raw: response, message: { @@ -199,6 +239,7 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> { role: GeminiHelper.ROLES_FROM_GEMINI[ topCandidate.content.role as GeminiMessageRole ], + options, }, }; } @@ -210,6 +251,13 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> { const client = this.session.getGenerativeModel(this.metadata); const chat = client.startChat({ history: context.history, + tools: params.tools && [ + { + functionDeclarations: params.tools.map( + mapBaseToolToGeminiFunctionDeclaration, + ), + }, + ], }); const result = await chat.sendMessageStream(context.message); yield* this.session.getChatStream(result); @@ -241,13 +289,17 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> { if (stream) { const result = await client.generateContentStream( - getPartsText(GeminiHelper.messageContentToGeminiParts(prompt)), + getPartsText( + GeminiHelper.messageContentToGeminiParts({ content: prompt }), + ), ); return this.session.getCompletionStream(result); } const result = await client.generateContent( - getPartsText(GeminiHelper.messageContentToGeminiParts(prompt)), + getPartsText( + GeminiHelper.messageContentToGeminiParts({ content: prompt }), + ), ); return { text: this.session.getResponseText(result.response), diff --git a/packages/llamaindex/src/llm/gemini/types.ts b/packages/llamaindex/src/llm/gemini/types.ts index 998ec12ea..f602ee83b 100644 --- a/packages/llamaindex/src/llm/gemini/types.ts +++ b/packages/llamaindex/src/llm/gemini/types.ts @@ -3,6 +3,8 @@ import { type EnhancedGenerateContentResponse, type Content as GeminiMessageContent, type FileDataPart as GoogleFileDataPart, + type FunctionDeclaration as GoogleFunctionDeclaration, + type FunctionDeclarationSchema as GoogleFunctionDeclarationSchema, type InlineDataPart as GoogleInlineFileDataPart, type ModelParams as GoogleModelParams, type Part as GooglePart, @@ -14,6 +16,8 @@ import { GenerativeModelPreview as VertexGenerativeModelPreview, type GenerateContentResponse, type FileDataPart as VertexFileDataPart, + type FunctionDeclaration as VertexFunctionDeclaration, + type FunctionDeclarationSchema as VertexFunctionDeclarationSchema, type VertexInit, type InlineDataPart as VertexInlineFileDataPart, type ModelParams as VertexModelParams, @@ -27,6 +31,7 @@ import type { CompletionResponse, LLMChatParamsNonStreaming, LLMChatParamsStreaming, + ToolCall, ToolCallLLMMessageOptions, } from "../types.js"; @@ -69,6 +74,14 @@ export type InlineDataPart = export type ModelParams = GoogleModelParams | VertexModelParams; +export type FunctionDeclaration = + | VertexFunctionDeclaration + | GoogleFunctionDeclaration; + +export type FunctionDeclarationSchema = + | GoogleFunctionDeclarationSchema + | VertexFunctionDeclarationSchema; + export type GenerativeModel = | VertexGenerativeModelPreview | VertexGenerativeModel @@ -112,4 +125,7 @@ export interface IGeminiSession { | GoogleStreamGenerateContentResult | VertexStreamGenerateContentResult, ): GeminiChatStreamResponse; + getToolsFromResponse( + response: EnhancedGenerateContentResponse | GenerateContentResponse, + ): ToolCall[] | undefined; } diff --git a/packages/llamaindex/src/llm/gemini/utils.ts b/packages/llamaindex/src/llm/gemini/utils.ts index e20b5f065..fd423fff4 100644 --- a/packages/llamaindex/src/llm/gemini/utils.ts +++ b/packages/llamaindex/src/llm/gemini/utils.ts @@ -1,17 +1,23 @@ -import { type Content as GeminiMessageContent } from "@google/generative-ai"; +import { + type FunctionCall, + type Content as GeminiMessageContent, +} from "@google/generative-ai"; import { type GenerateContentResponse } from "@google-cloud/vertexai"; +import type { BaseTool } from "../../types.js"; import type { ChatMessage, - MessageContent, MessageContentImageDetail, MessageContentTextDetail, MessageType, + ToolCallLLMMessageOptions, } from "../types.js"; import { extractDataUrlComponents } from "../utils.js"; import type { ChatContext, FileDataPart, + FunctionDeclaration, + FunctionDeclarationSchema, GeminiChatParamsNonStreaming, GeminiChatParamsStreaming, GeminiMessageRole, @@ -104,7 +110,8 @@ export const cleanParts = ( part.text?.trim() || part.inlineData || part.fileData || - part.functionCall, + part.functionCall || + part.functionResponse, ), }; }; @@ -115,8 +122,21 @@ export const getChatContext = ( // Gemini doesn't allow: // 1. Consecutive messages from the same role // 2. Parts that have empty text + const fnMap = params.messages.reduce( + (result, message) => { + if (message.options && "toolCall" in message.options) + message.options.toolCall.forEach((call) => { + result[call.id] = call.name; + }); + + return result; + }, + {} as Record<string, string>, + ); const messages = GeminiHelper.mergeNeighboringSameRoleMessages( - params.messages.map(GeminiHelper.chatMessageToGemini), + params.messages.map((message) => + GeminiHelper.chatMessageToGemini(message, fnMap), + ), ).map(cleanParts); const history = messages.slice(0, -1); @@ -127,6 +147,23 @@ export const getChatContext = ( }; }; +export const mapBaseToolToGeminiFunctionDeclaration = ( + tool: BaseTool, +): FunctionDeclaration => { + const parameters: FunctionDeclarationSchema = { + type: tool.metadata.parameters?.type.toUpperCase(), + properties: tool.metadata.parameters?.properties, + description: tool.metadata.parameters?.description, + required: tool.metadata.parameters?.required, + }; + + return { + name: tool.metadata.name, + description: tool.metadata.description, + parameters, + }; +}; + /** * Helper class providing utility functions for Gemini */ @@ -177,7 +214,40 @@ export class GeminiHelper { ); } - public static messageContentToGeminiParts(content: MessageContent): Part[] { + public static messageContentToGeminiParts({ + content, + options = undefined, + fnMap = undefined, + }: Pick<ChatMessage<ToolCallLLMMessageOptions>, "content" | "options"> & { + fnMap?: Record<string, string>; + }): Part[] { + if (options && "toolResult" in options) { + if (!fnMap) throw Error("fnMap must be set"); + const name = fnMap[options.toolResult.id]; + if (!name) + throw Error( + `Could not find the name for fn call with id ${options.toolResult.id}`, + ); + + return [ + { + functionResponse: { + name, + response: { + result: options.toolResult.result, + }, + }, + }, + ]; + } + if (options && "toolCall" in options) { + return options.toolCall.map((call) => ({ + functionCall: { + name: call.name, + args: call.input, + } as FunctionCall, + })); + } if (typeof content === "string") { return [{ text: content }]; } @@ -197,11 +267,35 @@ export class GeminiHelper { } public static chatMessageToGemini( - message: ChatMessage, + message: ChatMessage<ToolCallLLMMessageOptions>, + fnMap: Record<string, string>, // mapping of fn call id to fn call name ): GeminiMessageContent { return { role: GeminiHelper.ROLES_TO_GEMINI[message.role], - parts: GeminiHelper.messageContentToGeminiParts(message.content), + parts: GeminiHelper.messageContentToGeminiParts({ ...message, fnMap }), }; } } + +/** + * Returns functionCall of first candidate. + * Taken from https://github.com/google-gemini/generative-ai-js/ to be used with + * vertexai as that library doesn't include it + */ +export function getFunctionCalls( + response: GenerateContentResponse, +): FunctionCall[] | undefined { + const functionCalls: FunctionCall[] = []; + if (response.candidates?.[0].content?.parts) { + for (const part of response.candidates?.[0].content?.parts) { + if (part.functionCall) { + functionCalls.push(part.functionCall); + } + } + } + if (functionCalls.length > 0) { + return functionCalls; + } else { + return undefined; + } +} diff --git a/packages/llamaindex/src/llm/gemini/vertex.ts b/packages/llamaindex/src/llm/gemini/vertex.ts index 43c7100da..a24e4546e 100644 --- a/packages/llamaindex/src/llm/gemini/vertex.ts +++ b/packages/llamaindex/src/llm/gemini/vertex.ts @@ -13,10 +13,15 @@ import type { VertexGeminiSessionOptions, } from "./types.js"; -import { getEnv } from "@llamaindex/env"; -import type { CompletionResponse } from "../types.js"; +import type { FunctionCall } from "@google/generative-ai"; +import { getEnv, randomUUID } from "@llamaindex/env"; +import type { + CompletionResponse, + ToolCall, + ToolCallLLMMessageOptions, +} from "../types.js"; import { streamConverter } from "../utils.js"; -import { getText } from "./utils.js"; +import { getFunctionCalls, getText } from "./utils.js"; /* To use Google's Vertex AI backend, it doesn't use api key authentication. * @@ -62,14 +67,35 @@ export class GeminiVertexSession implements IGeminiSession { return getText(response); } + getToolsFromResponse( + response: GenerateContentResponse, + ): ToolCall[] | undefined { + return getFunctionCalls(response)?.map( + (call: FunctionCall) => + ({ + name: call.name, + input: call.args, + id: randomUUID(), + }) as ToolCall, + ); + } + async *getChatStream( result: VertexStreamGenerateContentResult, ): GeminiChatStreamResponse { - yield* streamConverter(result.stream, (response) => ({ - delta: this.getResponseText(response), - raw: response, - })); + yield* streamConverter(result.stream, (response) => { + const tools = this.getToolsFromResponse(response); + const options: ToolCallLLMMessageOptions = tools?.length + ? { toolCall: tools } + : {}; + return { + delta: this.getResponseText(response), + raw: response, + options, + }; + }); } + getCompletionStream( result: VertexStreamGenerateContentResult, ): AsyncIterable<CompletionResponse> { -- GitLab