From 7dce3d28d376a41d9893c73bdcc28340f3e234c2 Mon Sep 17 00:00:00 2001
From: Peron <peron_ap@icloud.com>
Date: Tue, 2 Jul 2024 09:28:22 +0800
Subject: [PATCH] fix: disable External Filters for Gemini (#994)

Co-authored-by: Alex Yang <himself65@outlook.com>
---
 .changeset/fluffy-knives-glow.md             |  5 +++
 packages/llamaindex/src/llm/gemini/base.ts   | 11 ++++-
 packages/llamaindex/src/llm/gemini/utils.ts  | 46 +++++++++++++++++---
 packages/llamaindex/src/llm/gemini/vertex.ts | 16 +++++--
 4 files changed, 66 insertions(+), 12 deletions(-)
 create mode 100644 .changeset/fluffy-knives-glow.md

diff --git a/.changeset/fluffy-knives-glow.md b/.changeset/fluffy-knives-glow.md
new file mode 100644
index 000000000..c21cbd26f
--- /dev/null
+++ b/.changeset/fluffy-knives-glow.md
@@ -0,0 +1,5 @@
+---
+"llamaindex": patch
+---
+
+fix: disable External Filters for Gemini
diff --git a/packages/llamaindex/src/llm/gemini/base.ts b/packages/llamaindex/src/llm/gemini/base.ts
index 9d14e82c4..458f631cf 100644
--- a/packages/llamaindex/src/llm/gemini/base.ts
+++ b/packages/llamaindex/src/llm/gemini/base.ts
@@ -33,6 +33,7 @@ import {
   type IGeminiSession,
 } from "./types.js";
 import {
+  DEFAULT_SAFETY_SETTINGS,
   GeminiHelper,
   getChatContext,
   getPartsText,
@@ -87,7 +88,10 @@ export class GeminiSession implements IGeminiSession {
   }
 
   getGenerativeModel(metadata: GoogleModelParams): GoogleGenerativeModel {
-    return this.gemini.getGenerativeModel(metadata);
+    return this.gemini.getGenerativeModel({
+      safetySettings: DEFAULT_SAFETY_SETTINGS,
+      ...metadata,
+    });
   }
 
   getResponseText(response: EnhancedGenerateContentResponse): string {
@@ -143,8 +147,9 @@ export class GeminiSessionStore {
   }> = [];
 
   private static getSessionId(options: GeminiSessionOptions): string {
-    if (options.backend === GEMINI_BACKENDS.GOOGLE)
+    if (options.backend === GEMINI_BACKENDS.GOOGLE) {
       return options?.apiKey ?? "";
+    }
     return "";
   }
   private static sessionMatched(
@@ -223,6 +228,7 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> {
           ),
         },
       ],
+      safetySettings: DEFAULT_SAFETY_SETTINGS,
     });
     const { response } = await chat.sendMessage(context.message);
     const topCandidate = response.candidates![0];
@@ -258,6 +264,7 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> {
           ),
         },
       ],
+      safetySettings: DEFAULT_SAFETY_SETTINGS,
     });
     const result = await chat.sendMessageStream(context.message);
     yield* this.session.getChatStream(result);
diff --git a/packages/llamaindex/src/llm/gemini/utils.ts b/packages/llamaindex/src/llm/gemini/utils.ts
index fd423fff4..fb75cbeab 100644
--- a/packages/llamaindex/src/llm/gemini/utils.ts
+++ b/packages/llamaindex/src/llm/gemini/utils.ts
@@ -1,6 +1,9 @@
 import {
   type FunctionCall,
   type Content as GeminiMessageContent,
+  HarmBlockThreshold,
+  HarmCategory,
+  type SafetySetting,
 } from "@google/generative-ai";
 
 import { type GenerateContentResponse } from "@google-cloud/vertexai";
@@ -53,10 +56,13 @@ const getImageParts = (
     const { mimeType, base64: data } = extractDataUrlComponents(
       message.image_url.url,
     );
-    if (!mimeType || !ACCEPTED_IMAGE_MIME_TYPES.includes(mimeType))
+    if (!mimeType || !ACCEPTED_IMAGE_MIME_TYPES.includes(mimeType)) {
       throw new Error(
-        `Gemini only accepts the following mimeTypes: ${ACCEPTED_IMAGE_MIME_TYPES.join("\n")}`,
+        `Gemini only accepts the following mimeTypes: ${ACCEPTED_IMAGE_MIME_TYPES.join(
+          "\n",
+        )}`,
       );
+    }
     return {
       inlineData: {
         mimeType,
@@ -65,10 +71,13 @@ const getImageParts = (
     };
   }
   const mimeType = getFileURLMimeType(message.image_url.url);
-  if (!mimeType || !ACCEPTED_IMAGE_MIME_TYPES.includes(mimeType))
+  if (!mimeType || !ACCEPTED_IMAGE_MIME_TYPES.includes(mimeType)) {
     throw new Error(
-      `Gemini only accepts the following mimeTypes: ${ACCEPTED_IMAGE_MIME_TYPES.join("\n")}`,
+      `Gemini only accepts the following mimeTypes: ${ACCEPTED_IMAGE_MIME_TYPES.join(
+        "\n",
+      )}`,
     );
+  }
   return {
     fileData: { mimeType, fileUri: message.image_url.url },
   };
@@ -124,10 +133,11 @@ export const getChatContext = (
   // 2. Parts that have empty text
   const fnMap = params.messages.reduce(
     (result, message) => {
-      if (message.options && "toolCall" in message.options)
+      if (message.options && "toolCall" in message.options) {
         message.options.toolCall.forEach((call) => {
           result[call.id] = call.name;
         });
+      }
 
       return result;
     },
@@ -224,10 +234,11 @@ export class GeminiHelper {
     if (options && "toolResult" in options) {
       if (!fnMap) throw Error("fnMap must be set");
       const name = fnMap[options.toolResult.id];
-      if (!name)
+      if (!name) {
         throw Error(
           `Could not find the name for fn call with id ${options.toolResult.id}`,
         );
+      }
 
       return [
         {
@@ -299,3 +310,26 @@ export function getFunctionCalls(
     return undefined;
   }
 }
+
+/**
+ * Safety settings to disable external filters
+ * Documentation: https://ai.google.dev/gemini-api/docs/safety-settings
+ */
+export const DEFAULT_SAFETY_SETTINGS: SafetySetting[] = [
+  {
+    category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
+    threshold: HarmBlockThreshold.BLOCK_NONE,
+  },
+  {
+    category: HarmCategory.HARM_CATEGORY_HARASSMENT,
+    threshold: HarmBlockThreshold.BLOCK_NONE,
+  },
+  {
+    category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
+    threshold: HarmBlockThreshold.BLOCK_NONE,
+  },
+  {
+    category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
+    threshold: HarmBlockThreshold.BLOCK_NONE,
+  },
+];
diff --git a/packages/llamaindex/src/llm/gemini/vertex.ts b/packages/llamaindex/src/llm/gemini/vertex.ts
index a24e4546e..398a9b1e8 100644
--- a/packages/llamaindex/src/llm/gemini/vertex.ts
+++ b/packages/llamaindex/src/llm/gemini/vertex.ts
@@ -1,8 +1,8 @@
 import {
+  type GenerateContentResponse,
   VertexAI,
   GenerativeModel as VertexGenerativeModel,
   GenerativeModelPreview as VertexGenerativeModelPreview,
-  type GenerateContentResponse,
   type ModelParams as VertexModelParams,
   type StreamGenerateContentResult as VertexStreamGenerateContentResult,
 } from "@google-cloud/vertexai";
@@ -21,7 +21,7 @@ import type {
   ToolCallLLMMessageOptions,
 } from "../types.js";
 import { streamConverter } from "../utils.js";
-import { getFunctionCalls, getText } from "./utils.js";
+import { DEFAULT_SAFETY_SETTINGS, getFunctionCalls, getText } from "./utils.js";
 
 /* To use Google's Vertex AI backend, it doesn't use api key authentication.
  *
@@ -59,8 +59,16 @@ export class GeminiVertexSession implements IGeminiSession {
   getGenerativeModel(
     metadata: VertexModelParams,
   ): VertexGenerativeModelPreview | VertexGenerativeModel {
-    if (this.preview) return this.vertex.preview.getGenerativeModel(metadata);
-    return this.vertex.getGenerativeModel(metadata);
+    if (this.preview) {
+      return this.vertex.preview.getGenerativeModel({
+        safetySettings: DEFAULT_SAFETY_SETTINGS,
+        ...metadata,
+      });
+    }
+    return this.vertex.getGenerativeModel({
+      safetySettings: DEFAULT_SAFETY_SETTINGS,
+      ...metadata,
+    });
   }
 
   getResponseText(response: GenerateContentResponse): string {
-- 
GitLab