Skip to content
Snippets Groups Projects
Unverified Commit 21af8108 authored by Timothy Carambat's avatar Timothy Carambat Committed by GitHub
Browse files

Add caching to Gemini /models (#2969)

rename file typo
parent 665e8e5b
No related branches found
No related tags found
No related merge requests found
...@@ -4,4 +4,5 @@ downloaded/* ...@@ -4,4 +4,5 @@ downloaded/*
openrouter openrouter
apipie apipie
novita novita
mixedbread-ai* mixedbread-ai*
\ No newline at end of file gemini
\ No newline at end of file
const fs = require("fs");
const path = require("path");
const { NativeEmbedder } = require("../../EmbeddingEngines/native"); const { NativeEmbedder } = require("../../EmbeddingEngines/native");
const { const {
LLMPerformanceMonitor, LLMPerformanceMonitor,
...@@ -7,7 +9,13 @@ const { ...@@ -7,7 +9,13 @@ const {
clientAbortedHandler, clientAbortedHandler,
} = require("../../helpers/chat/responses"); } = require("../../helpers/chat/responses");
const { MODEL_MAP } = require("../modelMap"); const { MODEL_MAP } = require("../modelMap");
const { defaultGeminiModels, v1BetaModels } = require("./defaultModals"); const { defaultGeminiModels, v1BetaModels } = require("./defaultModels");
const { safeJsonParse } = require("../../http");
const cacheFolder = path.resolve(
process.env.STORAGE_DIR
? path.resolve(process.env.STORAGE_DIR, "models", "gemini")
: path.resolve(__dirname, `../../../storage/models/gemini`)
);
class GeminiLLM { class GeminiLLM {
constructor(embedder = null, modelPreference = null) { constructor(embedder = null, modelPreference = null) {
...@@ -44,13 +52,33 @@ class GeminiLLM { ...@@ -44,13 +52,33 @@ class GeminiLLM {
this.embedder = embedder ?? new NativeEmbedder(); this.embedder = embedder ?? new NativeEmbedder();
this.defaultTemp = 0.7; // not used for Gemini this.defaultTemp = 0.7; // not used for Gemini
this.safetyThreshold = this.#fetchSafetyThreshold(); this.safetyThreshold = this.#fetchSafetyThreshold();
this.#log(`Initialized with model: ${this.model}`);
if (!fs.existsSync(cacheFolder))
fs.mkdirSync(cacheFolder, { recursive: true });
this.cacheModelPath = path.resolve(cacheFolder, "models.json");
this.cacheAtPath = path.resolve(cacheFolder, ".cached_at");
this.#log(
`Initialized with model: ${this.model} (${this.promptWindowLimit()})`
);
} }
#log(text, ...args) { #log(text, ...args) {
console.log(`\x1b[32m[GeminiLLM]\x1b[0m ${text}`, ...args); console.log(`\x1b[32m[GeminiLLM]\x1b[0m ${text}`, ...args);
} }
// This checks if the .cached_at file has a timestamp that is more than 1Week (in millis)
// from the current date. If it is, then we will refetch the API so that all the models are up
// to date.
static cacheIsStale() {
const MAX_STALE = 6.048e8; // 1 Week in MS
if (!fs.existsSync(path.resolve(cacheFolder, ".cached_at"))) return true;
const now = Number(new Date());
const timestampMs = Number(
fs.readFileSync(path.resolve(cacheFolder, ".cached_at"))
);
return now - timestampMs > MAX_STALE;
}
#appendContext(contextTexts = []) { #appendContext(contextTexts = []) {
if (!contextTexts || !contextTexts.length) return ""; if (!contextTexts || !contextTexts.length) return "";
return ( return (
...@@ -103,11 +131,40 @@ class GeminiLLM { ...@@ -103,11 +131,40 @@ class GeminiLLM {
} }
static promptWindowLimit(modelName) { static promptWindowLimit(modelName) {
return MODEL_MAP.gemini[modelName] ?? 30_720; try {
const cacheModelPath = path.resolve(cacheFolder, "models.json");
if (!fs.existsSync(cacheModelPath))
return MODEL_MAP.gemini[modelName] ?? 30_720;
const models = safeJsonParse(fs.readFileSync(cacheModelPath));
const model = models.find((model) => model.id === modelName);
if (!model)
throw new Error(
"Model not found in cache - falling back to default model."
);
return model.contextWindow;
} catch (e) {
console.error(`GeminiLLM:promptWindowLimit`, e.message);
return MODEL_MAP.gemini[modelName] ?? 30_720;
}
} }
promptWindowLimit() { promptWindowLimit() {
return MODEL_MAP.gemini[this.model] ?? 30_720; try {
if (!fs.existsSync(this.cacheModelPath))
return MODEL_MAP.gemini[this.model] ?? 30_720;
const models = safeJsonParse(fs.readFileSync(this.cacheModelPath));
const model = models.find((model) => model.id === this.model);
if (!model)
throw new Error(
"Model not found in cache - falling back to default model."
);
return model.contextWindow;
} catch (e) {
console.error(`GeminiLLM:promptWindowLimit`, e.message);
return MODEL_MAP.gemini[this.model] ?? 30_720;
}
} }
/** /**
...@@ -118,14 +175,25 @@ class GeminiLLM { ...@@ -118,14 +175,25 @@ class GeminiLLM {
* @returns {Promise<[{id: string, name: string, contextWindow: number, experimental: boolean}]>} A promise that resolves to an array of Gemini models * @returns {Promise<[{id: string, name: string, contextWindow: number, experimental: boolean}]>} A promise that resolves to an array of Gemini models
*/ */
static async fetchModels(apiKey, limit = 1_000, pageToken = null) { static async fetchModels(apiKey, limit = 1_000, pageToken = null) {
if (!apiKey) return [];
if (fs.existsSync(cacheFolder) && !this.cacheIsStale()) {
console.log(
`\x1b[32m[GeminiLLM]\x1b[0m Using cached models API response.`
);
return safeJsonParse(
fs.readFileSync(path.resolve(cacheFolder, "models.json"))
);
}
const url = new URL( const url = new URL(
"https://generativelanguage.googleapis.com/v1beta/models" "https://generativelanguage.googleapis.com/v1beta/models"
); );
url.searchParams.set("pageSize", limit); url.searchParams.set("pageSize", limit);
url.searchParams.set("key", apiKey); url.searchParams.set("key", apiKey);
if (pageToken) url.searchParams.set("pageToken", pageToken); if (pageToken) url.searchParams.set("pageToken", pageToken);
let success = false;
return fetch(url.toString(), { const models = await fetch(url.toString(), {
method: "GET", method: "GET",
headers: { "Content-Type": "application/json" }, headers: { "Content-Type": "application/json" },
}) })
...@@ -134,8 +202,9 @@ class GeminiLLM { ...@@ -134,8 +202,9 @@ class GeminiLLM {
if (data.error) throw new Error(data.error.message); if (data.error) throw new Error(data.error.message);
return data.models ?? []; return data.models ?? [];
}) })
.then((models) => .then((models) => {
models success = true;
return models
.filter( .filter(
(model) => !model.displayName.toLowerCase().includes("tuning") (model) => !model.displayName.toLowerCase().includes("tuning")
) )
...@@ -149,12 +218,30 @@ class GeminiLLM { ...@@ -149,12 +218,30 @@ class GeminiLLM {
contextWindow: model.inputTokenLimit, contextWindow: model.inputTokenLimit,
experimental: model.name.includes("exp"), experimental: model.name.includes("exp"),
}; };
}) });
) })
.catch((e) => { .catch((e) => {
console.error(`Gemini:getGeminiModels`, e.message); console.error(`Gemini:getGeminiModels`, e.message);
success = false;
return defaultGeminiModels; return defaultGeminiModels;
}); });
if (success) {
console.log(
`\x1b[32m[GeminiLLM]\x1b[0m Writing cached models API response to disk.`
);
if (!fs.existsSync(cacheFolder))
fs.mkdirSync(cacheFolder, { recursive: true });
fs.writeFileSync(
path.resolve(cacheFolder, "models.json"),
JSON.stringify(models)
);
fs.writeFileSync(
path.resolve(cacheFolder, ".cached_at"),
new Date().getTime().toString()
);
}
return models;
} }
/** /**
...@@ -164,7 +251,7 @@ class GeminiLLM { ...@@ -164,7 +251,7 @@ class GeminiLLM {
* @returns {Promise<boolean>} A promise that resolves to a boolean indicating if the model is valid * @returns {Promise<boolean>} A promise that resolves to a boolean indicating if the model is valid
*/ */
async isValidChatCompletionModel(modelName = "") { async isValidChatCompletionModel(modelName = "") {
const models = await this.fetchModels(true); const models = await this.fetchModels(process.env.GEMINI_API_KEY);
return models.some((model) => model.id === modelName); return models.some((model) => model.id === modelName);
} }
/** /**
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment