From 130b7992a1f27c192c1b561dccb9b00947382354 Mon Sep 17 00:00:00 2001 From: Marcus Schiesser <mail@marcusschiesser.de> Date: Tue, 30 Apr 2024 11:51:22 +0800 Subject: [PATCH] refactor: clean gemini embedding (#781) --- examples/{gemini.ts => gemini/embedding.ts} | 4 ++-- .../core/src/embeddings/GeminiEmbedding.ts | 24 +++++++------------ packages/core/src/llm/gemini.ts | 6 ----- 3 files changed, 11 insertions(+), 23 deletions(-) rename examples/{gemini.ts => gemini/embedding.ts} (76%) diff --git a/examples/gemini.ts b/examples/gemini/embedding.ts similarity index 76% rename from examples/gemini.ts rename to examples/gemini/embedding.ts index 416a586f9..6ecc2692b 100644 --- a/examples/gemini.ts +++ b/examples/gemini/embedding.ts @@ -1,11 +1,11 @@ -import { GEMINI_MODEL, GeminiEmbedding } from "llamaindex"; +import { GEMINI_EMBEDDING_MODEL, GeminiEmbedding } from "llamaindex"; async function main() { if (!process.env.GOOGLE_API_KEY) { throw new Error("Please set the GOOGLE_API_KEY environment variable."); } const embedModel = new GeminiEmbedding({ - model: GEMINI_MODEL.GEMINI_PRO, + model: GEMINI_EMBEDDING_MODEL.EMBEDDING_001, }); const texts = ["hello", "world"]; const embeddings = await embedModel.getTextEmbeddingsBatch(texts); diff --git a/packages/core/src/embeddings/GeminiEmbedding.ts b/packages/core/src/embeddings/GeminiEmbedding.ts index fa5eb0cb5..1b4fda6bc 100644 --- a/packages/core/src/embeddings/GeminiEmbedding.ts +++ b/packages/core/src/embeddings/GeminiEmbedding.ts @@ -1,27 +1,21 @@ -import { - GEMINI_MODEL, - GeminiSessionStore, - type GeminiConfig, - type GeminiSession, -} from "../llm/gemini.js"; +import { GeminiSessionStore, type GeminiSession } from "../llm/gemini.js"; import { BaseEmbedding } from "./types.js"; +export enum GEMINI_EMBEDDING_MODEL { + EMBEDDING_001 = "embedding-001", + TEXT_EMBEDDING_004 = "text-embedding-004", +} + /** * GeminiEmbedding is an alias for Gemini that implements the BaseEmbedding interface. */ export class GeminiEmbedding extends BaseEmbedding { - model: GEMINI_MODEL; - temperature: number; - topP: number; - maxTokens?: number; + model: GEMINI_EMBEDDING_MODEL; session: GeminiSession; - constructor(init?: GeminiConfig) { + constructor(init?: Partial<GeminiEmbedding>) { super(); - this.model = init?.model ?? GEMINI_MODEL.GEMINI_PRO; - this.temperature = init?.temperature ?? 0.1; - this.topP = init?.topP ?? 1; - this.maxTokens = init?.maxTokens ?? undefined; + this.model = init?.model ?? GEMINI_EMBEDDING_MODEL.EMBEDDING_001; this.session = init?.session ?? GeminiSessionStore.get(); } diff --git a/packages/core/src/llm/gemini.ts b/packages/core/src/llm/gemini.ts index 7e498c0fc..413e4e2a3 100644 --- a/packages/core/src/llm/gemini.ts +++ b/packages/core/src/llm/gemini.ts @@ -32,8 +32,6 @@ type GeminiSessionOptions = { export enum GEMINI_MODEL { GEMINI_PRO = "gemini-pro", GEMINI_PRO_VISION = "gemini-pro-vision", - EMBEDDING_001 = "embedding-001", - AQA = "aqa", GEMINI_PRO_LATEST = "gemini-1.5-pro-latest", } @@ -44,16 +42,12 @@ export interface GeminiModelInfo { export const GEMINI_MODEL_INFO_MAP: Record<GEMINI_MODEL, GeminiModelInfo> = { [GEMINI_MODEL.GEMINI_PRO]: { contextWindow: 30720 }, [GEMINI_MODEL.GEMINI_PRO_VISION]: { contextWindow: 12288 }, - [GEMINI_MODEL.EMBEDDING_001]: { contextWindow: 2048 }, - [GEMINI_MODEL.AQA]: { contextWindow: 7168 }, [GEMINI_MODEL.GEMINI_PRO_LATEST]: { contextWindow: 10 ** 6 }, }; const SUPPORT_TOOL_CALL_MODELS: GEMINI_MODEL[] = [ GEMINI_MODEL.GEMINI_PRO, GEMINI_MODEL.GEMINI_PRO_VISION, - GEMINI_MODEL.EMBEDDING_001, - GEMINI_MODEL.AQA, ]; const DEFAULT_GEMINI_PARAMS = { -- GitLab