From 7dce3d28d376a41d9893c73bdcc28340f3e234c2 Mon Sep 17 00:00:00 2001 From: Peron <peron_ap@icloud.com> Date: Tue, 2 Jul 2024 09:28:22 +0800 Subject: [PATCH] fix: disable External Filters for Gemini (#994) Co-authored-by: Alex Yang <himself65@outlook.com> --- .changeset/fluffy-knives-glow.md | 5 +++ packages/llamaindex/src/llm/gemini/base.ts | 11 ++++- packages/llamaindex/src/llm/gemini/utils.ts | 46 +++++++++++++++++--- packages/llamaindex/src/llm/gemini/vertex.ts | 16 +++++-- 4 files changed, 66 insertions(+), 12 deletions(-) create mode 100644 .changeset/fluffy-knives-glow.md diff --git a/.changeset/fluffy-knives-glow.md b/.changeset/fluffy-knives-glow.md new file mode 100644 index 000000000..c21cbd26f --- /dev/null +++ b/.changeset/fluffy-knives-glow.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +fix: disable External Filters for Gemini diff --git a/packages/llamaindex/src/llm/gemini/base.ts b/packages/llamaindex/src/llm/gemini/base.ts index 9d14e82c4..458f631cf 100644 --- a/packages/llamaindex/src/llm/gemini/base.ts +++ b/packages/llamaindex/src/llm/gemini/base.ts @@ -33,6 +33,7 @@ import { type IGeminiSession, } from "./types.js"; import { + DEFAULT_SAFETY_SETTINGS, GeminiHelper, getChatContext, getPartsText, @@ -87,7 +88,10 @@ export class GeminiSession implements IGeminiSession { } getGenerativeModel(metadata: GoogleModelParams): GoogleGenerativeModel { - return this.gemini.getGenerativeModel(metadata); + return this.gemini.getGenerativeModel({ + safetySettings: DEFAULT_SAFETY_SETTINGS, + ...metadata, + }); } getResponseText(response: EnhancedGenerateContentResponse): string { @@ -143,8 +147,9 @@ export class GeminiSessionStore { }> = []; private static getSessionId(options: GeminiSessionOptions): string { - if (options.backend === GEMINI_BACKENDS.GOOGLE) + if (options.backend === GEMINI_BACKENDS.GOOGLE) { return options?.apiKey ?? ""; + } return ""; } private static sessionMatched( @@ -223,6 +228,7 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> { ), }, ], + safetySettings: DEFAULT_SAFETY_SETTINGS, }); const { response } = await chat.sendMessage(context.message); const topCandidate = response.candidates![0]; @@ -258,6 +264,7 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> { ), }, ], + safetySettings: DEFAULT_SAFETY_SETTINGS, }); const result = await chat.sendMessageStream(context.message); yield* this.session.getChatStream(result); diff --git a/packages/llamaindex/src/llm/gemini/utils.ts b/packages/llamaindex/src/llm/gemini/utils.ts index fd423fff4..fb75cbeab 100644 --- a/packages/llamaindex/src/llm/gemini/utils.ts +++ b/packages/llamaindex/src/llm/gemini/utils.ts @@ -1,6 +1,9 @@ import { type FunctionCall, type Content as GeminiMessageContent, + HarmBlockThreshold, + HarmCategory, + type SafetySetting, } from "@google/generative-ai"; import { type GenerateContentResponse } from "@google-cloud/vertexai"; @@ -53,10 +56,13 @@ const getImageParts = ( const { mimeType, base64: data } = extractDataUrlComponents( message.image_url.url, ); - if (!mimeType || !ACCEPTED_IMAGE_MIME_TYPES.includes(mimeType)) + if (!mimeType || !ACCEPTED_IMAGE_MIME_TYPES.includes(mimeType)) { throw new Error( - `Gemini only accepts the following mimeTypes: ${ACCEPTED_IMAGE_MIME_TYPES.join("\n")}`, + `Gemini only accepts the following mimeTypes: ${ACCEPTED_IMAGE_MIME_TYPES.join( + "\n", + )}`, ); + } return { inlineData: { mimeType, @@ -65,10 +71,13 @@ const getImageParts = ( }; } const mimeType = getFileURLMimeType(message.image_url.url); - if (!mimeType || !ACCEPTED_IMAGE_MIME_TYPES.includes(mimeType)) + if (!mimeType || !ACCEPTED_IMAGE_MIME_TYPES.includes(mimeType)) { throw new Error( - `Gemini only accepts the following mimeTypes: ${ACCEPTED_IMAGE_MIME_TYPES.join("\n")}`, + `Gemini only accepts the following mimeTypes: ${ACCEPTED_IMAGE_MIME_TYPES.join( + "\n", + )}`, ); + } return { fileData: { mimeType, fileUri: message.image_url.url }, }; @@ -124,10 +133,11 @@ export const getChatContext = ( // 2. Parts that have empty text const fnMap = params.messages.reduce( (result, message) => { - if (message.options && "toolCall" in message.options) + if (message.options && "toolCall" in message.options) { message.options.toolCall.forEach((call) => { result[call.id] = call.name; }); + } return result; }, @@ -224,10 +234,11 @@ export class GeminiHelper { if (options && "toolResult" in options) { if (!fnMap) throw Error("fnMap must be set"); const name = fnMap[options.toolResult.id]; - if (!name) + if (!name) { throw Error( `Could not find the name for fn call with id ${options.toolResult.id}`, ); + } return [ { @@ -299,3 +310,26 @@ export function getFunctionCalls( return undefined; } } + +/** + * Safety settings to disable external filters + * Documentation: https://ai.google.dev/gemini-api/docs/safety-settings + */ +export const DEFAULT_SAFETY_SETTINGS: SafetySetting[] = [ + { + category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold: HarmBlockThreshold.BLOCK_NONE, + }, + { + category: HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold: HarmBlockThreshold.BLOCK_NONE, + }, + { + category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, + threshold: HarmBlockThreshold.BLOCK_NONE, + }, + { + category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + threshold: HarmBlockThreshold.BLOCK_NONE, + }, +]; diff --git a/packages/llamaindex/src/llm/gemini/vertex.ts b/packages/llamaindex/src/llm/gemini/vertex.ts index a24e4546e..398a9b1e8 100644 --- a/packages/llamaindex/src/llm/gemini/vertex.ts +++ b/packages/llamaindex/src/llm/gemini/vertex.ts @@ -1,8 +1,8 @@ import { + type GenerateContentResponse, VertexAI, GenerativeModel as VertexGenerativeModel, GenerativeModelPreview as VertexGenerativeModelPreview, - type GenerateContentResponse, type ModelParams as VertexModelParams, type StreamGenerateContentResult as VertexStreamGenerateContentResult, } from "@google-cloud/vertexai"; @@ -21,7 +21,7 @@ import type { ToolCallLLMMessageOptions, } from "../types.js"; import { streamConverter } from "../utils.js"; -import { getFunctionCalls, getText } from "./utils.js"; +import { DEFAULT_SAFETY_SETTINGS, getFunctionCalls, getText } from "./utils.js"; /* To use Google's Vertex AI backend, it doesn't use api key authentication. * @@ -59,8 +59,16 @@ export class GeminiVertexSession implements IGeminiSession { getGenerativeModel( metadata: VertexModelParams, ): VertexGenerativeModelPreview | VertexGenerativeModel { - if (this.preview) return this.vertex.preview.getGenerativeModel(metadata); - return this.vertex.getGenerativeModel(metadata); + if (this.preview) { + return this.vertex.preview.getGenerativeModel({ + safetySettings: DEFAULT_SAFETY_SETTINGS, + ...metadata, + }); + } + return this.vertex.getGenerativeModel({ + safetySettings: DEFAULT_SAFETY_SETTINGS, + ...metadata, + }); } getResponseText(response: GenerateContentResponse): string { -- GitLab