From 2a0a899d6681bad8aa270dc9d406253438862311 Mon Sep 17 00:00:00 2001 From: ANKIT VARSHNEY <132201033+AVtheking@users.noreply.github.com> Date: Tue, 18 Mar 2025 21:35:47 +0530 Subject: [PATCH] chore: added saftey setting as parameter for gemini (#1760) --- .changeset/heavy-ants-turn.md | 5 +++++ packages/providers/google/src/base.ts | 13 +++++++++---- packages/providers/google/src/vertex.ts | 5 +++-- 3 files changed, 17 insertions(+), 6 deletions(-) create mode 100644 .changeset/heavy-ants-turn.md diff --git a/.changeset/heavy-ants-turn.md b/.changeset/heavy-ants-turn.md new file mode 100644 index 000000000..63824c55b --- /dev/null +++ b/.changeset/heavy-ants-turn.md @@ -0,0 +1,5 @@ +--- +"@llamaindex/google": patch +--- + +Added saftey setting parameter for gemini diff --git a/packages/providers/google/src/base.ts b/packages/providers/google/src/base.ts index ff67fb025..ea5eb4059 100644 --- a/packages/providers/google/src/base.ts +++ b/packages/providers/google/src/base.ts @@ -6,6 +6,7 @@ import { type ModelParams as GoogleModelParams, type RequestOptions as GoogleRequestOptions, type GenerateContentStreamResult as GoogleStreamGenerateContentResult, + type SafetySetting, } from "@google/generative-ai"; import { wrapLLMEvent } from "@llamaindex/core/decorator"; @@ -88,6 +89,7 @@ const DEFAULT_GEMINI_PARAMS = { export type GeminiConfig = Partial<typeof DEFAULT_GEMINI_PARAMS> & { session?: IGeminiSession; requestOptions?: GoogleRequestOptions; + safetySettings?: SafetySetting[]; }; /** @@ -112,7 +114,7 @@ export class GeminiSession implements IGeminiSession { ): GoogleGenerativeModel { return this.gemini.getGenerativeModel( { - safetySettings: DEFAULT_SAFETY_SETTINGS, + safetySettings: metadata.safetySettings ?? DEFAULT_SAFETY_SETTINGS, ...metadata, }, requestOpts, @@ -218,6 +220,7 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> { maxTokens?: number | undefined; #requestOptions?: GoogleRequestOptions | undefined; session: IGeminiSession; + safetySettings: SafetySetting[]; constructor(init?: GeminiConfig) { super(); @@ -227,13 +230,14 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> { this.maxTokens = init?.maxTokens ?? undefined; this.session = init?.session ?? GeminiSessionStore.get(); this.#requestOptions = init?.requestOptions ?? undefined; + this.safetySettings = init?.safetySettings ?? DEFAULT_SAFETY_SETTINGS; } get supportToolCall(): boolean { return SUPPORT_TOOL_CALL_MODELS.includes(this.model); } - get metadata(): LLMMetadata { + get metadata(): LLMMetadata & { safetySettings: SafetySetting[] } { return { model: this.model, temperature: this.temperature, @@ -242,6 +246,7 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> { contextWindow: GEMINI_MODEL_INFO_MAP[this.model].contextWindow, tokenizer: undefined, structuredOutput: false, + safetySettings: this.safetySettings, }; } @@ -251,7 +256,7 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> { const context = getChatContext(params); const common = { history: context.history, - safetySettings: DEFAULT_SAFETY_SETTINGS, + safetySettings: this.safetySettings, }; return params.tools?.length @@ -265,7 +270,7 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> { ), }, ], - safetySettings: DEFAULT_SAFETY_SETTINGS, + safetySettings: this.safetySettings, } : common; } diff --git a/packages/providers/google/src/vertex.ts b/packages/providers/google/src/vertex.ts index a8906c0cb..13c7ffd99 100644 --- a/packages/providers/google/src/vertex.ts +++ b/packages/providers/google/src/vertex.ts @@ -59,14 +59,15 @@ export class GeminiVertexSession implements IGeminiSession { getGenerativeModel( metadata: VertexModelParams, ): VertexGenerativeModelPreview | VertexGenerativeModel { + const safetySettings = metadata.safetySettings ?? DEFAULT_SAFETY_SETTINGS; if (this.preview) { return this.vertex.preview.getGenerativeModel({ - safetySettings: DEFAULT_SAFETY_SETTINGS, + safetySettings, ...metadata, }); } return this.vertex.getGenerativeModel({ - safetySettings: DEFAULT_SAFETY_SETTINGS, + safetySettings, ...metadata, }); } -- GitLab