diff --git a/.changeset/heavy-ants-turn.md b/.changeset/heavy-ants-turn.md new file mode 100644 index 0000000000000000000000000000000000000000..63824c55b5b13038526ee808ae7ea2424edda436 --- /dev/null +++ b/.changeset/heavy-ants-turn.md @@ -0,0 +1,5 @@ +--- +"@llamaindex/google": patch +--- + +Added saftey setting parameter for gemini diff --git a/packages/providers/google/src/base.ts b/packages/providers/google/src/base.ts index ff67fb0251bb1b71d739d9a3851d65e19cb649d4..ea5eb40599307f5ec4cf3429cefde0799d39f261 100644 --- a/packages/providers/google/src/base.ts +++ b/packages/providers/google/src/base.ts @@ -6,6 +6,7 @@ import { type ModelParams as GoogleModelParams, type RequestOptions as GoogleRequestOptions, type GenerateContentStreamResult as GoogleStreamGenerateContentResult, + type SafetySetting, } from "@google/generative-ai"; import { wrapLLMEvent } from "@llamaindex/core/decorator"; @@ -88,6 +89,7 @@ const DEFAULT_GEMINI_PARAMS = { export type GeminiConfig = Partial<typeof DEFAULT_GEMINI_PARAMS> & { session?: IGeminiSession; requestOptions?: GoogleRequestOptions; + safetySettings?: SafetySetting[]; }; /** @@ -112,7 +114,7 @@ export class GeminiSession implements IGeminiSession { ): GoogleGenerativeModel { return this.gemini.getGenerativeModel( { - safetySettings: DEFAULT_SAFETY_SETTINGS, + safetySettings: metadata.safetySettings ?? DEFAULT_SAFETY_SETTINGS, ...metadata, }, requestOpts, @@ -218,6 +220,7 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> { maxTokens?: number | undefined; #requestOptions?: GoogleRequestOptions | undefined; session: IGeminiSession; + safetySettings: SafetySetting[]; constructor(init?: GeminiConfig) { super(); @@ -227,13 +230,14 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> { this.maxTokens = init?.maxTokens ?? undefined; this.session = init?.session ?? GeminiSessionStore.get(); this.#requestOptions = init?.requestOptions ?? undefined; + this.safetySettings = init?.safetySettings ?? DEFAULT_SAFETY_SETTINGS; } get supportToolCall(): boolean { return SUPPORT_TOOL_CALL_MODELS.includes(this.model); } - get metadata(): LLMMetadata { + get metadata(): LLMMetadata & { safetySettings: SafetySetting[] } { return { model: this.model, temperature: this.temperature, @@ -242,6 +246,7 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> { contextWindow: GEMINI_MODEL_INFO_MAP[this.model].contextWindow, tokenizer: undefined, structuredOutput: false, + safetySettings: this.safetySettings, }; } @@ -251,7 +256,7 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> { const context = getChatContext(params); const common = { history: context.history, - safetySettings: DEFAULT_SAFETY_SETTINGS, + safetySettings: this.safetySettings, }; return params.tools?.length @@ -265,7 +270,7 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> { ), }, ], - safetySettings: DEFAULT_SAFETY_SETTINGS, + safetySettings: this.safetySettings, } : common; } diff --git a/packages/providers/google/src/vertex.ts b/packages/providers/google/src/vertex.ts index a8906c0cb9e63fa76a3fab4d2bb44e5275827388..13c7ffd99add939742cf299fe97f662bd492b184 100644 --- a/packages/providers/google/src/vertex.ts +++ b/packages/providers/google/src/vertex.ts @@ -59,14 +59,15 @@ export class GeminiVertexSession implements IGeminiSession { getGenerativeModel( metadata: VertexModelParams, ): VertexGenerativeModelPreview | VertexGenerativeModel { + const safetySettings = metadata.safetySettings ?? DEFAULT_SAFETY_SETTINGS; if (this.preview) { return this.vertex.preview.getGenerativeModel({ - safetySettings: DEFAULT_SAFETY_SETTINGS, + safetySettings, ...metadata, }); } return this.vertex.getGenerativeModel({ - safetySettings: DEFAULT_SAFETY_SETTINGS, + safetySettings, ...metadata, }); }