From 21af81085aeb049750942ac5f3b84775cb461693 Mon Sep 17 00:00:00 2001
From: Timothy Carambat <rambat1010@gmail.com>
Date: Mon, 13 Jan 2025 13:12:03 -0800
Subject: [PATCH] Add caching to Gemini /models (#2969)

rename file typo
---
 server/storage/models/.gitignore              |   3 +-
 .../{defaultModals.js => defaultModels.js}    |   0
 server/utils/AiProviders/gemini/index.js      | 107 ++++++++++++++++--
 3 files changed, 99 insertions(+), 11 deletions(-)
 rename server/utils/AiProviders/gemini/{defaultModals.js => defaultModels.js} (100%)

diff --git a/server/storage/models/.gitignore b/server/storage/models/.gitignore
index 7a8f66d8f..71c5e891a 100644
--- a/server/storage/models/.gitignore
+++ b/server/storage/models/.gitignore
@@ -4,4 +4,5 @@ downloaded/*
 openrouter
 apipie
 novita
-mixedbread-ai*
\ No newline at end of file
+mixedbread-ai*
+gemini
\ No newline at end of file
diff --git a/server/utils/AiProviders/gemini/defaultModals.js b/server/utils/AiProviders/gemini/defaultModels.js
similarity index 100%
rename from server/utils/AiProviders/gemini/defaultModals.js
rename to server/utils/AiProviders/gemini/defaultModels.js
diff --git a/server/utils/AiProviders/gemini/index.js b/server/utils/AiProviders/gemini/index.js
index 3554a51fc..9961c70d7 100644
--- a/server/utils/AiProviders/gemini/index.js
+++ b/server/utils/AiProviders/gemini/index.js
@@ -1,3 +1,5 @@
+const fs = require("fs");
+const path = require("path");
 const { NativeEmbedder } = require("../../EmbeddingEngines/native");
 const {
   LLMPerformanceMonitor,
@@ -7,7 +9,13 @@ const {
   clientAbortedHandler,
 } = require("../../helpers/chat/responses");
 const { MODEL_MAP } = require("../modelMap");
-const { defaultGeminiModels, v1BetaModels } = require("./defaultModals");
+const { defaultGeminiModels, v1BetaModels } = require("./defaultModels");
+const { safeJsonParse } = require("../../http");
+const cacheFolder = path.resolve(
+  process.env.STORAGE_DIR
+    ? path.resolve(process.env.STORAGE_DIR, "models", "gemini")
+    : path.resolve(__dirname, `../../../storage/models/gemini`)
+);
 
 class GeminiLLM {
   constructor(embedder = null, modelPreference = null) {
@@ -44,13 +52,33 @@ class GeminiLLM {
     this.embedder = embedder ?? new NativeEmbedder();
     this.defaultTemp = 0.7; // not used for Gemini
     this.safetyThreshold = this.#fetchSafetyThreshold();
-    this.#log(`Initialized with model: ${this.model}`);
+
+    if (!fs.existsSync(cacheFolder))
+      fs.mkdirSync(cacheFolder, { recursive: true });
+    this.cacheModelPath = path.resolve(cacheFolder, "models.json");
+    this.cacheAtPath = path.resolve(cacheFolder, ".cached_at");
+    this.#log(
+      `Initialized with model: ${this.model} (${this.promptWindowLimit()})`
+    );
   }
 
   #log(text, ...args) {
     console.log(`\x1b[32m[GeminiLLM]\x1b[0m ${text}`, ...args);
   }
 
+  // This checks if the .cached_at file has a timestamp that is more than 1Week (in millis)
+  // from the current date. If it is, then we will refetch the API so that all the models are up
+  // to date.
+  static cacheIsStale() {
+    const MAX_STALE = 6.048e8; // 1 Week in MS
+    if (!fs.existsSync(path.resolve(cacheFolder, ".cached_at"))) return true;
+    const now = Number(new Date());
+    const timestampMs = Number(
+      fs.readFileSync(path.resolve(cacheFolder, ".cached_at"))
+    );
+    return now - timestampMs > MAX_STALE;
+  }
+
   #appendContext(contextTexts = []) {
     if (!contextTexts || !contextTexts.length) return "";
     return (
@@ -103,11 +131,40 @@ class GeminiLLM {
   }
 
   static promptWindowLimit(modelName) {
-    return MODEL_MAP.gemini[modelName] ?? 30_720;
+    try {
+      const cacheModelPath = path.resolve(cacheFolder, "models.json");
+      if (!fs.existsSync(cacheModelPath))
+        return MODEL_MAP.gemini[modelName] ?? 30_720;
+
+      const models = safeJsonParse(fs.readFileSync(cacheModelPath));
+      const model = models.find((model) => model.id === modelName);
+      if (!model)
+        throw new Error(
+          "Model not found in cache - falling back to default model."
+        );
+      return model.contextWindow;
+    } catch (e) {
+      console.error(`GeminiLLM:promptWindowLimit`, e.message);
+      return MODEL_MAP.gemini[modelName] ?? 30_720;
+    }
   }
 
   promptWindowLimit() {
-    return MODEL_MAP.gemini[this.model] ?? 30_720;
+    try {
+      if (!fs.existsSync(this.cacheModelPath))
+        return MODEL_MAP.gemini[this.model] ?? 30_720;
+
+      const models = safeJsonParse(fs.readFileSync(this.cacheModelPath));
+      const model = models.find((model) => model.id === this.model);
+      if (!model)
+        throw new Error(
+          "Model not found in cache - falling back to default model."
+        );
+      return model.contextWindow;
+    } catch (e) {
+      console.error(`GeminiLLM:promptWindowLimit`, e.message);
+      return MODEL_MAP.gemini[this.model] ?? 30_720;
+    }
   }
 
   /**
@@ -118,14 +175,25 @@ class GeminiLLM {
    * @returns {Promise<[{id: string, name: string, contextWindow: number, experimental: boolean}]>} A promise that resolves to an array of Gemini models
    */
   static async fetchModels(apiKey, limit = 1_000, pageToken = null) {
+    if (!apiKey) return [];
+    if (fs.existsSync(cacheFolder) && !this.cacheIsStale()) {
+      console.log(
+        `\x1b[32m[GeminiLLM]\x1b[0m Using cached models API response.`
+      );
+      return safeJsonParse(
+        fs.readFileSync(path.resolve(cacheFolder, "models.json"))
+      );
+    }
+
     const url = new URL(
       "https://generativelanguage.googleapis.com/v1beta/models"
     );
     url.searchParams.set("pageSize", limit);
     url.searchParams.set("key", apiKey);
     if (pageToken) url.searchParams.set("pageToken", pageToken);
+    let success = false;
 
-    return fetch(url.toString(), {
+    const models = await fetch(url.toString(), {
       method: "GET",
       headers: { "Content-Type": "application/json" },
     })
@@ -134,8 +202,9 @@ class GeminiLLM {
         if (data.error) throw new Error(data.error.message);
         return data.models ?? [];
       })
-      .then((models) =>
-        models
+      .then((models) => {
+        success = true;
+        return models
           .filter(
             (model) => !model.displayName.toLowerCase().includes("tuning")
           )
@@ -149,12 +218,30 @@ class GeminiLLM {
               contextWindow: model.inputTokenLimit,
               experimental: model.name.includes("exp"),
             };
-          })
-      )
+          });
+      })
       .catch((e) => {
         console.error(`Gemini:getGeminiModels`, e.message);
+        success = false;
         return defaultGeminiModels;
       });
+
+    if (success) {
+      console.log(
+        `\x1b[32m[GeminiLLM]\x1b[0m Writing cached models API response to disk.`
+      );
+      if (!fs.existsSync(cacheFolder))
+        fs.mkdirSync(cacheFolder, { recursive: true });
+      fs.writeFileSync(
+        path.resolve(cacheFolder, "models.json"),
+        JSON.stringify(models)
+      );
+      fs.writeFileSync(
+        path.resolve(cacheFolder, ".cached_at"),
+        new Date().getTime().toString()
+      );
+    }
+    return models;
   }
 
   /**
@@ -164,7 +251,7 @@ class GeminiLLM {
    * @returns {Promise<boolean>} A promise that resolves to a boolean indicating if the model is valid
    */
   async isValidChatCompletionModel(modelName = "") {
-    const models = await this.fetchModels(true);
+    const models = await this.fetchModels(process.env.GEMINI_API_KEY);
     return models.some((model) => model.id === modelName);
   }
   /**
-- 
GitLab