From 21af81085aeb049750942ac5f3b84775cb461693 Mon Sep 17 00:00:00 2001 From: Timothy Carambat <rambat1010@gmail.com> Date: Mon, 13 Jan 2025 13:12:03 -0800 Subject: [PATCH] Add caching to Gemini /models (#2969) rename file typo --- server/storage/models/.gitignore | 3 +- .../{defaultModals.js => defaultModels.js} | 0 server/utils/AiProviders/gemini/index.js | 107 ++++++++++++++++-- 3 files changed, 99 insertions(+), 11 deletions(-) rename server/utils/AiProviders/gemini/{defaultModals.js => defaultModels.js} (100%) diff --git a/server/storage/models/.gitignore b/server/storage/models/.gitignore index 7a8f66d8f..71c5e891a 100644 --- a/server/storage/models/.gitignore +++ b/server/storage/models/.gitignore @@ -4,4 +4,5 @@ downloaded/* openrouter apipie novita -mixedbread-ai* \ No newline at end of file +mixedbread-ai* +gemini \ No newline at end of file diff --git a/server/utils/AiProviders/gemini/defaultModals.js b/server/utils/AiProviders/gemini/defaultModels.js similarity index 100% rename from server/utils/AiProviders/gemini/defaultModals.js rename to server/utils/AiProviders/gemini/defaultModels.js diff --git a/server/utils/AiProviders/gemini/index.js b/server/utils/AiProviders/gemini/index.js index 3554a51fc..9961c70d7 100644 --- a/server/utils/AiProviders/gemini/index.js +++ b/server/utils/AiProviders/gemini/index.js @@ -1,3 +1,5 @@ +const fs = require("fs"); +const path = require("path"); const { NativeEmbedder } = require("../../EmbeddingEngines/native"); const { LLMPerformanceMonitor, @@ -7,7 +9,13 @@ const { clientAbortedHandler, } = require("../../helpers/chat/responses"); const { MODEL_MAP } = require("../modelMap"); -const { defaultGeminiModels, v1BetaModels } = require("./defaultModals"); +const { defaultGeminiModels, v1BetaModels } = require("./defaultModels"); +const { safeJsonParse } = require("../../http"); +const cacheFolder = path.resolve( + process.env.STORAGE_DIR + ? path.resolve(process.env.STORAGE_DIR, "models", "gemini") + : path.resolve(__dirname, `../../../storage/models/gemini`) +); class GeminiLLM { constructor(embedder = null, modelPreference = null) { @@ -44,13 +52,33 @@ class GeminiLLM { this.embedder = embedder ?? new NativeEmbedder(); this.defaultTemp = 0.7; // not used for Gemini this.safetyThreshold = this.#fetchSafetyThreshold(); - this.#log(`Initialized with model: ${this.model}`); + + if (!fs.existsSync(cacheFolder)) + fs.mkdirSync(cacheFolder, { recursive: true }); + this.cacheModelPath = path.resolve(cacheFolder, "models.json"); + this.cacheAtPath = path.resolve(cacheFolder, ".cached_at"); + this.#log( + `Initialized with model: ${this.model} (${this.promptWindowLimit()})` + ); } #log(text, ...args) { console.log(`\x1b[32m[GeminiLLM]\x1b[0m ${text}`, ...args); } + // This checks if the .cached_at file has a timestamp that is more than 1Week (in millis) + // from the current date. If it is, then we will refetch the API so that all the models are up + // to date. + static cacheIsStale() { + const MAX_STALE = 6.048e8; // 1 Week in MS + if (!fs.existsSync(path.resolve(cacheFolder, ".cached_at"))) return true; + const now = Number(new Date()); + const timestampMs = Number( + fs.readFileSync(path.resolve(cacheFolder, ".cached_at")) + ); + return now - timestampMs > MAX_STALE; + } + #appendContext(contextTexts = []) { if (!contextTexts || !contextTexts.length) return ""; return ( @@ -103,11 +131,40 @@ class GeminiLLM { } static promptWindowLimit(modelName) { - return MODEL_MAP.gemini[modelName] ?? 30_720; + try { + const cacheModelPath = path.resolve(cacheFolder, "models.json"); + if (!fs.existsSync(cacheModelPath)) + return MODEL_MAP.gemini[modelName] ?? 30_720; + + const models = safeJsonParse(fs.readFileSync(cacheModelPath)); + const model = models.find((model) => model.id === modelName); + if (!model) + throw new Error( + "Model not found in cache - falling back to default model." + ); + return model.contextWindow; + } catch (e) { + console.error(`GeminiLLM:promptWindowLimit`, e.message); + return MODEL_MAP.gemini[modelName] ?? 30_720; + } } promptWindowLimit() { - return MODEL_MAP.gemini[this.model] ?? 30_720; + try { + if (!fs.existsSync(this.cacheModelPath)) + return MODEL_MAP.gemini[this.model] ?? 30_720; + + const models = safeJsonParse(fs.readFileSync(this.cacheModelPath)); + const model = models.find((model) => model.id === this.model); + if (!model) + throw new Error( + "Model not found in cache - falling back to default model." + ); + return model.contextWindow; + } catch (e) { + console.error(`GeminiLLM:promptWindowLimit`, e.message); + return MODEL_MAP.gemini[this.model] ?? 30_720; + } } /** @@ -118,14 +175,25 @@ class GeminiLLM { * @returns {Promise<[{id: string, name: string, contextWindow: number, experimental: boolean}]>} A promise that resolves to an array of Gemini models */ static async fetchModels(apiKey, limit = 1_000, pageToken = null) { + if (!apiKey) return []; + if (fs.existsSync(cacheFolder) && !this.cacheIsStale()) { + console.log( + `\x1b[32m[GeminiLLM]\x1b[0m Using cached models API response.` + ); + return safeJsonParse( + fs.readFileSync(path.resolve(cacheFolder, "models.json")) + ); + } + const url = new URL( "https://generativelanguage.googleapis.com/v1beta/models" ); url.searchParams.set("pageSize", limit); url.searchParams.set("key", apiKey); if (pageToken) url.searchParams.set("pageToken", pageToken); + let success = false; - return fetch(url.toString(), { + const models = await fetch(url.toString(), { method: "GET", headers: { "Content-Type": "application/json" }, }) @@ -134,8 +202,9 @@ class GeminiLLM { if (data.error) throw new Error(data.error.message); return data.models ?? []; }) - .then((models) => - models + .then((models) => { + success = true; + return models .filter( (model) => !model.displayName.toLowerCase().includes("tuning") ) @@ -149,12 +218,30 @@ class GeminiLLM { contextWindow: model.inputTokenLimit, experimental: model.name.includes("exp"), }; - }) - ) + }); + }) .catch((e) => { console.error(`Gemini:getGeminiModels`, e.message); + success = false; return defaultGeminiModels; }); + + if (success) { + console.log( + `\x1b[32m[GeminiLLM]\x1b[0m Writing cached models API response to disk.` + ); + if (!fs.existsSync(cacheFolder)) + fs.mkdirSync(cacheFolder, { recursive: true }); + fs.writeFileSync( + path.resolve(cacheFolder, "models.json"), + JSON.stringify(models) + ); + fs.writeFileSync( + path.resolve(cacheFolder, ".cached_at"), + new Date().getTime().toString() + ); + } + return models; } /** @@ -164,7 +251,7 @@ class GeminiLLM { * @returns {Promise<boolean>} A promise that resolves to a boolean indicating if the model is valid */ async isValidChatCompletionModel(modelName = "") { - const models = await this.fetchModels(true); + const models = await this.fetchModels(process.env.GEMINI_API_KEY); return models.some((model) => model.id === modelName); } /** -- GitLab