From 130b7992a1f27c192c1b561dccb9b00947382354 Mon Sep 17 00:00:00 2001
From: Marcus Schiesser <mail@marcusschiesser.de>
Date: Tue, 30 Apr 2024 11:51:22 +0800
Subject: [PATCH] refactor: clean gemini embedding (#781)

---
 examples/{gemini.ts => gemini/embedding.ts}   |  4 ++--
 .../core/src/embeddings/GeminiEmbedding.ts    | 24 +++++++------------
 packages/core/src/llm/gemini.ts               |  6 -----
 3 files changed, 11 insertions(+), 23 deletions(-)
 rename examples/{gemini.ts => gemini/embedding.ts} (76%)

diff --git a/examples/gemini.ts b/examples/gemini/embedding.ts
similarity index 76%
rename from examples/gemini.ts
rename to examples/gemini/embedding.ts
index 416a586f9..6ecc2692b 100644
--- a/examples/gemini.ts
+++ b/examples/gemini/embedding.ts
@@ -1,11 +1,11 @@
-import { GEMINI_MODEL, GeminiEmbedding } from "llamaindex";
+import { GEMINI_EMBEDDING_MODEL, GeminiEmbedding } from "llamaindex";
 
 async function main() {
   if (!process.env.GOOGLE_API_KEY) {
     throw new Error("Please set the GOOGLE_API_KEY environment variable.");
   }
   const embedModel = new GeminiEmbedding({
-    model: GEMINI_MODEL.GEMINI_PRO,
+    model: GEMINI_EMBEDDING_MODEL.EMBEDDING_001,
   });
   const texts = ["hello", "world"];
   const embeddings = await embedModel.getTextEmbeddingsBatch(texts);
diff --git a/packages/core/src/embeddings/GeminiEmbedding.ts b/packages/core/src/embeddings/GeminiEmbedding.ts
index fa5eb0cb5..1b4fda6bc 100644
--- a/packages/core/src/embeddings/GeminiEmbedding.ts
+++ b/packages/core/src/embeddings/GeminiEmbedding.ts
@@ -1,27 +1,21 @@
-import {
-  GEMINI_MODEL,
-  GeminiSessionStore,
-  type GeminiConfig,
-  type GeminiSession,
-} from "../llm/gemini.js";
+import { GeminiSessionStore, type GeminiSession } from "../llm/gemini.js";
 import { BaseEmbedding } from "./types.js";
 
+export enum GEMINI_EMBEDDING_MODEL {
+  EMBEDDING_001 = "embedding-001",
+  TEXT_EMBEDDING_004 = "text-embedding-004",
+}
+
 /**
  * GeminiEmbedding is an alias for Gemini that implements the BaseEmbedding interface.
  */
 export class GeminiEmbedding extends BaseEmbedding {
-  model: GEMINI_MODEL;
-  temperature: number;
-  topP: number;
-  maxTokens?: number;
+  model: GEMINI_EMBEDDING_MODEL;
   session: GeminiSession;
 
-  constructor(init?: GeminiConfig) {
+  constructor(init?: Partial<GeminiEmbedding>) {
     super();
-    this.model = init?.model ?? GEMINI_MODEL.GEMINI_PRO;
-    this.temperature = init?.temperature ?? 0.1;
-    this.topP = init?.topP ?? 1;
-    this.maxTokens = init?.maxTokens ?? undefined;
+    this.model = init?.model ?? GEMINI_EMBEDDING_MODEL.EMBEDDING_001;
     this.session = init?.session ?? GeminiSessionStore.get();
   }
 
diff --git a/packages/core/src/llm/gemini.ts b/packages/core/src/llm/gemini.ts
index 7e498c0fc..413e4e2a3 100644
--- a/packages/core/src/llm/gemini.ts
+++ b/packages/core/src/llm/gemini.ts
@@ -32,8 +32,6 @@ type GeminiSessionOptions = {
 export enum GEMINI_MODEL {
   GEMINI_PRO = "gemini-pro",
   GEMINI_PRO_VISION = "gemini-pro-vision",
-  EMBEDDING_001 = "embedding-001",
-  AQA = "aqa",
   GEMINI_PRO_LATEST = "gemini-1.5-pro-latest",
 }
 
@@ -44,16 +42,12 @@ export interface GeminiModelInfo {
 export const GEMINI_MODEL_INFO_MAP: Record<GEMINI_MODEL, GeminiModelInfo> = {
   [GEMINI_MODEL.GEMINI_PRO]: { contextWindow: 30720 },
   [GEMINI_MODEL.GEMINI_PRO_VISION]: { contextWindow: 12288 },
-  [GEMINI_MODEL.EMBEDDING_001]: { contextWindow: 2048 },
-  [GEMINI_MODEL.AQA]: { contextWindow: 7168 },
   [GEMINI_MODEL.GEMINI_PRO_LATEST]: { contextWindow: 10 ** 6 },
 };
 
 const SUPPORT_TOOL_CALL_MODELS: GEMINI_MODEL[] = [
   GEMINI_MODEL.GEMINI_PRO,
   GEMINI_MODEL.GEMINI_PRO_VISION,
-  GEMINI_MODEL.EMBEDDING_001,
-  GEMINI_MODEL.AQA,
 ];
 
 const DEFAULT_GEMINI_PARAMS = {
-- 
GitLab