Skip to content
Snippets Groups Projects
Unverified Commit 2a0a899d authored by ANKIT VARSHNEY's avatar ANKIT VARSHNEY Committed by GitHub
Browse files

chore: added saftey setting as parameter for gemini (#1760)

parent 050cd534
No related branches found
No related tags found
No related merge requests found
---
"@llamaindex/google": patch
---
Added saftey setting parameter for gemini
...@@ -6,6 +6,7 @@ import { ...@@ -6,6 +6,7 @@ import {
type ModelParams as GoogleModelParams, type ModelParams as GoogleModelParams,
type RequestOptions as GoogleRequestOptions, type RequestOptions as GoogleRequestOptions,
type GenerateContentStreamResult as GoogleStreamGenerateContentResult, type GenerateContentStreamResult as GoogleStreamGenerateContentResult,
type SafetySetting,
} from "@google/generative-ai"; } from "@google/generative-ai";
import { wrapLLMEvent } from "@llamaindex/core/decorator"; import { wrapLLMEvent } from "@llamaindex/core/decorator";
...@@ -88,6 +89,7 @@ const DEFAULT_GEMINI_PARAMS = { ...@@ -88,6 +89,7 @@ const DEFAULT_GEMINI_PARAMS = {
export type GeminiConfig = Partial<typeof DEFAULT_GEMINI_PARAMS> & { export type GeminiConfig = Partial<typeof DEFAULT_GEMINI_PARAMS> & {
session?: IGeminiSession; session?: IGeminiSession;
requestOptions?: GoogleRequestOptions; requestOptions?: GoogleRequestOptions;
safetySettings?: SafetySetting[];
}; };
/** /**
...@@ -112,7 +114,7 @@ export class GeminiSession implements IGeminiSession { ...@@ -112,7 +114,7 @@ export class GeminiSession implements IGeminiSession {
): GoogleGenerativeModel { ): GoogleGenerativeModel {
return this.gemini.getGenerativeModel( return this.gemini.getGenerativeModel(
{ {
safetySettings: DEFAULT_SAFETY_SETTINGS, safetySettings: metadata.safetySettings ?? DEFAULT_SAFETY_SETTINGS,
...metadata, ...metadata,
}, },
requestOpts, requestOpts,
...@@ -218,6 +220,7 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> { ...@@ -218,6 +220,7 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> {
maxTokens?: number | undefined; maxTokens?: number | undefined;
#requestOptions?: GoogleRequestOptions | undefined; #requestOptions?: GoogleRequestOptions | undefined;
session: IGeminiSession; session: IGeminiSession;
safetySettings: SafetySetting[];
constructor(init?: GeminiConfig) { constructor(init?: GeminiConfig) {
super(); super();
...@@ -227,13 +230,14 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> { ...@@ -227,13 +230,14 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> {
this.maxTokens = init?.maxTokens ?? undefined; this.maxTokens = init?.maxTokens ?? undefined;
this.session = init?.session ?? GeminiSessionStore.get(); this.session = init?.session ?? GeminiSessionStore.get();
this.#requestOptions = init?.requestOptions ?? undefined; this.#requestOptions = init?.requestOptions ?? undefined;
this.safetySettings = init?.safetySettings ?? DEFAULT_SAFETY_SETTINGS;
} }
get supportToolCall(): boolean { get supportToolCall(): boolean {
return SUPPORT_TOOL_CALL_MODELS.includes(this.model); return SUPPORT_TOOL_CALL_MODELS.includes(this.model);
} }
get metadata(): LLMMetadata { get metadata(): LLMMetadata & { safetySettings: SafetySetting[] } {
return { return {
model: this.model, model: this.model,
temperature: this.temperature, temperature: this.temperature,
...@@ -242,6 +246,7 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> { ...@@ -242,6 +246,7 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> {
contextWindow: GEMINI_MODEL_INFO_MAP[this.model].contextWindow, contextWindow: GEMINI_MODEL_INFO_MAP[this.model].contextWindow,
tokenizer: undefined, tokenizer: undefined,
structuredOutput: false, structuredOutput: false,
safetySettings: this.safetySettings,
}; };
} }
...@@ -251,7 +256,7 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> { ...@@ -251,7 +256,7 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> {
const context = getChatContext(params); const context = getChatContext(params);
const common = { const common = {
history: context.history, history: context.history,
safetySettings: DEFAULT_SAFETY_SETTINGS, safetySettings: this.safetySettings,
}; };
return params.tools?.length return params.tools?.length
...@@ -265,7 +270,7 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> { ...@@ -265,7 +270,7 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> {
), ),
}, },
], ],
safetySettings: DEFAULT_SAFETY_SETTINGS, safetySettings: this.safetySettings,
} }
: common; : common;
} }
......
...@@ -59,14 +59,15 @@ export class GeminiVertexSession implements IGeminiSession { ...@@ -59,14 +59,15 @@ export class GeminiVertexSession implements IGeminiSession {
getGenerativeModel( getGenerativeModel(
metadata: VertexModelParams, metadata: VertexModelParams,
): VertexGenerativeModelPreview | VertexGenerativeModel { ): VertexGenerativeModelPreview | VertexGenerativeModel {
const safetySettings = metadata.safetySettings ?? DEFAULT_SAFETY_SETTINGS;
if (this.preview) { if (this.preview) {
return this.vertex.preview.getGenerativeModel({ return this.vertex.preview.getGenerativeModel({
safetySettings: DEFAULT_SAFETY_SETTINGS, safetySettings,
...metadata, ...metadata,
}); });
} }
return this.vertex.getGenerativeModel({ return this.vertex.getGenerativeModel({
safetySettings: DEFAULT_SAFETY_SETTINGS, safetySettings,
...metadata, ...metadata,
}); });
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment