From 2a0a899d6681bad8aa270dc9d406253438862311 Mon Sep 17 00:00:00 2001
From: ANKIT VARSHNEY <132201033+AVtheking@users.noreply.github.com>
Date: Tue, 18 Mar 2025 21:35:47 +0530
Subject: [PATCH] chore: added saftey setting as parameter for gemini (#1760)

---
 .changeset/heavy-ants-turn.md           |  5 +++++
 packages/providers/google/src/base.ts   | 13 +++++++++----
 packages/providers/google/src/vertex.ts |  5 +++--
 3 files changed, 17 insertions(+), 6 deletions(-)
 create mode 100644 .changeset/heavy-ants-turn.md

diff --git a/.changeset/heavy-ants-turn.md b/.changeset/heavy-ants-turn.md
new file mode 100644
index 000000000..63824c55b
--- /dev/null
+++ b/.changeset/heavy-ants-turn.md
@@ -0,0 +1,5 @@
+---
+"@llamaindex/google": patch
+---
+
+Added saftey setting parameter for gemini
diff --git a/packages/providers/google/src/base.ts b/packages/providers/google/src/base.ts
index ff67fb025..ea5eb4059 100644
--- a/packages/providers/google/src/base.ts
+++ b/packages/providers/google/src/base.ts
@@ -6,6 +6,7 @@ import {
   type ModelParams as GoogleModelParams,
   type RequestOptions as GoogleRequestOptions,
   type GenerateContentStreamResult as GoogleStreamGenerateContentResult,
+  type SafetySetting,
 } from "@google/generative-ai";
 
 import { wrapLLMEvent } from "@llamaindex/core/decorator";
@@ -88,6 +89,7 @@ const DEFAULT_GEMINI_PARAMS = {
 export type GeminiConfig = Partial<typeof DEFAULT_GEMINI_PARAMS> & {
   session?: IGeminiSession;
   requestOptions?: GoogleRequestOptions;
+  safetySettings?: SafetySetting[];
 };
 
 /**
@@ -112,7 +114,7 @@ export class GeminiSession implements IGeminiSession {
   ): GoogleGenerativeModel {
     return this.gemini.getGenerativeModel(
       {
-        safetySettings: DEFAULT_SAFETY_SETTINGS,
+        safetySettings: metadata.safetySettings ?? DEFAULT_SAFETY_SETTINGS,
         ...metadata,
       },
       requestOpts,
@@ -218,6 +220,7 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> {
   maxTokens?: number | undefined;
   #requestOptions?: GoogleRequestOptions | undefined;
   session: IGeminiSession;
+  safetySettings: SafetySetting[];
 
   constructor(init?: GeminiConfig) {
     super();
@@ -227,13 +230,14 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> {
     this.maxTokens = init?.maxTokens ?? undefined;
     this.session = init?.session ?? GeminiSessionStore.get();
     this.#requestOptions = init?.requestOptions ?? undefined;
+    this.safetySettings = init?.safetySettings ?? DEFAULT_SAFETY_SETTINGS;
   }
 
   get supportToolCall(): boolean {
     return SUPPORT_TOOL_CALL_MODELS.includes(this.model);
   }
 
-  get metadata(): LLMMetadata {
+  get metadata(): LLMMetadata & { safetySettings: SafetySetting[] } {
     return {
       model: this.model,
       temperature: this.temperature,
@@ -242,6 +246,7 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> {
       contextWindow: GEMINI_MODEL_INFO_MAP[this.model].contextWindow,
       tokenizer: undefined,
       structuredOutput: false,
+      safetySettings: this.safetySettings,
     };
   }
 
@@ -251,7 +256,7 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> {
     const context = getChatContext(params);
     const common = {
       history: context.history,
-      safetySettings: DEFAULT_SAFETY_SETTINGS,
+      safetySettings: this.safetySettings,
     };
 
     return params.tools?.length
@@ -265,7 +270,7 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> {
               ),
             },
           ],
-          safetySettings: DEFAULT_SAFETY_SETTINGS,
+          safetySettings: this.safetySettings,
         }
       : common;
   }
diff --git a/packages/providers/google/src/vertex.ts b/packages/providers/google/src/vertex.ts
index a8906c0cb..13c7ffd99 100644
--- a/packages/providers/google/src/vertex.ts
+++ b/packages/providers/google/src/vertex.ts
@@ -59,14 +59,15 @@ export class GeminiVertexSession implements IGeminiSession {
   getGenerativeModel(
     metadata: VertexModelParams,
   ): VertexGenerativeModelPreview | VertexGenerativeModel {
+    const safetySettings = metadata.safetySettings ?? DEFAULT_SAFETY_SETTINGS;
     if (this.preview) {
       return this.vertex.preview.getGenerativeModel({
-        safetySettings: DEFAULT_SAFETY_SETTINGS,
+        safetySettings,
         ...metadata,
       });
     }
     return this.vertex.getGenerativeModel({
-      safetySettings: DEFAULT_SAFETY_SETTINGS,
+      safetySettings,
       ...metadata,
     });
   }
-- 
GitLab