Skip to content
Snippets Groups Projects
Unverified Commit 98eebf72 authored by Jack Qian's avatar Jack Qian Committed by GitHub
Browse files

feat: add request options for gemini (#1733)

parent 5478ba88
No related branches found
No related tags found
No related merge requests found
---
"@llamaindex/google": patch
"@llamaindex/doc": patch
---
Add RequestOptions parameter passing to support Gemini proxy calls.
Add a usage example for the RequestOptions parameter.
......@@ -31,6 +31,20 @@ Settings.llm = new Gemini({
});
```
## Usage with Proxy
```ts
import { Gemini, GEMINI_MODEL } from "@llamaindex/google";
import { Settings } from "llamaindex";
Settings.llm = new Gemini({
model: GEMINI_MODEL.GEMINI_PRO,
requestOptions: {
baseUrl: <YOUR_PROXY_URL> // optional, but useful for custom endpoints
}
});
```
### Usage with Vertex AI
To use Gemini via Vertex AI you can use `GeminiVertexSession`.
......
......@@ -4,6 +4,7 @@ import {
type EnhancedGenerateContentResponse,
type FunctionCall,
type ModelParams as GoogleModelParams,
type RequestOptions as GoogleRequestOptions,
type GenerateContentStreamResult as GoogleStreamGenerateContentResult,
} from "@google/generative-ai";
......@@ -86,6 +87,7 @@ const DEFAULT_GEMINI_PARAMS = {
export type GeminiConfig = Partial<typeof DEFAULT_GEMINI_PARAMS> & {
session?: IGeminiSession;
requestOptions?: GoogleRequestOptions;
};
/**
......@@ -104,11 +106,17 @@ export class GeminiSession implements IGeminiSession {
this.gemini = new GoogleGenerativeAI(options.apiKey);
}
getGenerativeModel(metadata: GoogleModelParams): GoogleGenerativeModel {
return this.gemini.getGenerativeModel({
safetySettings: DEFAULT_SAFETY_SETTINGS,
...metadata,
});
getGenerativeModel(
metadata: GoogleModelParams,
requestOpts?: GoogleRequestOptions,
): GoogleGenerativeModel {
return this.gemini.getGenerativeModel(
{
safetySettings: DEFAULT_SAFETY_SETTINGS,
...metadata,
},
requestOpts,
);
}
getResponseText(response: EnhancedGenerateContentResponse): string {
......@@ -173,6 +181,9 @@ export class GeminiSessionStore {
o1: GeminiSessionOptions,
o2: GeminiSessionOptions,
): boolean {
// #TODO: check if the session is matched
// Q: should we check the requestOptions?
// A: wait for confirmation from author
return (
GeminiSessionStore.getSessionId(o1) ===
GeminiSessionStore.getSessionId(o2)
......@@ -205,6 +216,7 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> {
temperature: number;
topP: number;
maxTokens?: number | undefined;
#requestOptions?: GoogleRequestOptions | undefined;
session: IGeminiSession;
constructor(init?: GeminiConfig) {
......@@ -214,6 +226,7 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> {
this.topP = init?.topP ?? 1;
this.maxTokens = init?.maxTokens ?? undefined;
this.session = init?.session ?? GeminiSessionStore.get();
this.#requestOptions = init?.requestOptions ?? undefined;
}
get supportToolCall(): boolean {
......@@ -260,7 +273,10 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> {
params: GeminiChatParamsNonStreaming,
): Promise<GeminiChatNonStreamResponse> {
const context = getChatContext(params);
const client = this.session.getGenerativeModel(this.metadata);
const client = this.session.getGenerativeModel(
this.metadata,
this.#requestOptions,
);
const chat = client.startChat(this.createStartChatParams(params));
const { response } = await chat.sendMessage(context.message);
const topCandidate = response.candidates![0]!;
......@@ -286,7 +302,10 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> {
params: GeminiChatParamsStreaming,
): GeminiChatStreamResponse {
const context = getChatContext(params);
const client = this.session.getGenerativeModel(this.metadata);
const client = this.session.getGenerativeModel(
this.metadata,
this.#requestOptions,
);
const chat = client.startChat(this.createStartChatParams(params));
const result = await chat.sendMessageStream(context.message);
yield* this.session.getChatStream(result);
......@@ -314,8 +333,10 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> {
params: LLMCompletionParamsStreaming | LLMCompletionParamsNonStreaming,
): Promise<CompletionResponse | AsyncIterable<CompletionResponse>> {
const { prompt, stream } = params;
const client = this.session.getGenerativeModel(this.metadata);
const client = this.session.getGenerativeModel(
this.metadata,
this.#requestOptions,
);
if (stream) {
const result = await client.generateContentStream(
getPartsText(
......
......@@ -8,6 +8,7 @@ import {
type InlineDataPart as GoogleInlineFileDataPart,
type ModelParams as GoogleModelParams,
type Part as GooglePart,
type RequestOptions as GoogleRequestOptions,
type GenerateContentStreamResult as GoogleStreamGenerateContentResult,
} from "@google/generative-ai";
......@@ -35,6 +36,8 @@ import type {
ToolCallLLMMessageOptions,
} from "@llamaindex/core/llms";
export { type RequestOptions as GoogleRequestOptions } from "@google/generative-ai";
export enum GEMINI_BACKENDS {
GOOGLE = "google",
VERTEX = "vertex",
......@@ -121,7 +124,10 @@ export type GeminiChatNonStreamResponse =
ChatResponse<ToolCallLLMMessageOptions>;
export interface IGeminiSession {
getGenerativeModel(metadata: ModelParams): GenerativeModel;
getGenerativeModel(
metadata: ModelParams,
requestOptions?: GoogleRequestOptions,
): GenerativeModel;
getResponseText(
response: EnhancedGenerateContentResponse | GenerateContentResponse,
): string;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment