From 99f2c25b1ccefd73740d1325a635eda82ca05a76 Mon Sep 17 00:00:00 2001 From: Timothy Carambat <rambat1010@gmail.com> Date: Thu, 15 Aug 2024 12:13:28 -0700 Subject: [PATCH] Agent Context window + context window refactor. (#2126) * Enable agent context windows to be accurate per provider:model * Refactor model mapping to external file Add token count to document length instead of char-count refernce promptWindowLimit from AIProvider in central location * remove unused imports --- server/utils/AiProviders/anthropic/index.js | 24 ++---- server/utils/AiProviders/azureOpenAi/index.js | 6 ++ server/utils/AiProviders/bedrock/index.js | 7 ++ server/utils/AiProviders/cohere/index.js | 22 ++---- server/utils/AiProviders/gemini/index.js | 20 ++--- .../utils/AiProviders/genericOpenAi/index.js | 7 ++ server/utils/AiProviders/groq/index.js | 20 ++--- server/utils/AiProviders/huggingface/index.js | 7 ++ server/utils/AiProviders/koboldCPP/index.js | 7 ++ server/utils/AiProviders/liteLLM/index.js | 7 ++ server/utils/AiProviders/lmStudio/index.js | 7 ++ server/utils/AiProviders/localAi/index.js | 7 ++ server/utils/AiProviders/mistral/index.js | 4 + server/utils/AiProviders/modelMap.js | 55 +++++++++++++ server/utils/AiProviders/native/index.js | 7 ++ server/utils/AiProviders/ollama/index.js | 7 ++ server/utils/AiProviders/openAi/index.js | 26 ++----- server/utils/AiProviders/openRouter/index.js | 11 +++ server/utils/AiProviders/perplexity/index.js | 5 ++ .../utils/AiProviders/textGenWebUI/index.js | 7 ++ server/utils/AiProviders/togetherAi/index.js | 5 ++ .../utils/agents/aibitat/plugins/summarize.js | 6 +- .../agents/aibitat/plugins/web-scraping.js | 6 +- .../agents/aibitat/providers/ai-provider.js | 20 ++--- server/utils/helpers/index.js | 78 +++++++++++++++++++ 25 files changed, 284 insertions(+), 94 deletions(-) create mode 100644 server/utils/AiProviders/modelMap.js diff --git a/server/utils/AiProviders/anthropic/index.js b/server/utils/AiProviders/anthropic/index.js index 5702fc839..386e84a53 100644 --- a/server/utils/AiProviders/anthropic/index.js +++ b/server/utils/AiProviders/anthropic/index.js @@ -4,6 +4,7 @@ const { clientAbortedHandler, } = require("../../helpers/chat/responses"); const { NativeEmbedder } = require("../../EmbeddingEngines/native"); +const { MODEL_MAP } = require("../modelMap"); class AnthropicLLM { constructor(embedder = null, modelPreference = null) { @@ -32,25 +33,12 @@ class AnthropicLLM { return "streamGetChatCompletion" in this; } + static promptWindowLimit(modelName) { + return MODEL_MAP.anthropic[modelName] ?? 100_000; + } + promptWindowLimit() { - switch (this.model) { - case "claude-instant-1.2": - return 100_000; - case "claude-2.0": - return 100_000; - case "claude-2.1": - return 200_000; - case "claude-3-opus-20240229": - return 200_000; - case "claude-3-sonnet-20240229": - return 200_000; - case "claude-3-haiku-20240307": - return 200_000; - case "claude-3-5-sonnet-20240620": - return 200_000; - default: - return 100_000; // assume a claude-instant-1.2 model - } + return MODEL_MAP.anthropic[this.model] ?? 100_000; } isValidChatCompletionModel(modelName = "") { diff --git a/server/utils/AiProviders/azureOpenAi/index.js b/server/utils/AiProviders/azureOpenAi/index.js index 231d9c04c..feb6f0a1b 100644 --- a/server/utils/AiProviders/azureOpenAi/index.js +++ b/server/utils/AiProviders/azureOpenAi/index.js @@ -43,6 +43,12 @@ class AzureOpenAiLLM { return "streamGetChatCompletion" in this; } + static promptWindowLimit(_modelName) { + return !!process.env.AZURE_OPENAI_TOKEN_LIMIT + ? Number(process.env.AZURE_OPENAI_TOKEN_LIMIT) + : 4096; + } + // Sure the user selected a proper value for the token limit // could be any of these https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-models // and if undefined - assume it is the lowest end. diff --git a/server/utils/AiProviders/bedrock/index.js b/server/utils/AiProviders/bedrock/index.js index f579c0331..ebff7ea29 100644 --- a/server/utils/AiProviders/bedrock/index.js +++ b/server/utils/AiProviders/bedrock/index.js @@ -82,6 +82,13 @@ class AWSBedrockLLM { return "streamGetChatCompletion" in this; } + static promptWindowLimit(_modelName) { + const limit = process.env.AWS_BEDROCK_LLM_MODEL_TOKEN_LIMIT || 8191; + if (!limit || isNaN(Number(limit))) + throw new Error("No valid token context limit was set."); + return Number(limit); + } + // Ensure the user set a value for the token limit // and if undefined - assume 4096 window. promptWindowLimit() { diff --git a/server/utils/AiProviders/cohere/index.js b/server/utils/AiProviders/cohere/index.js index f57d2bc5c..f61d43f62 100644 --- a/server/utils/AiProviders/cohere/index.js +++ b/server/utils/AiProviders/cohere/index.js @@ -1,6 +1,7 @@ const { v4 } = require("uuid"); const { writeResponseChunk } = require("../../helpers/chat/responses"); const { NativeEmbedder } = require("../../EmbeddingEngines/native"); +const { MODEL_MAP } = require("../modelMap"); class CohereLLM { constructor(embedder = null) { @@ -58,23 +59,12 @@ class CohereLLM { return "streamGetChatCompletion" in this; } + static promptWindowLimit(modelName) { + return MODEL_MAP.cohere[modelName] ?? 4_096; + } + promptWindowLimit() { - switch (this.model) { - case "command-r": - return 128_000; - case "command-r-plus": - return 128_000; - case "command": - return 4_096; - case "command-light": - return 4_096; - case "command-nightly": - return 8_192; - case "command-light-nightly": - return 8_192; - default: - return 4_096; - } + return MODEL_MAP.cohere[this.model] ?? 4_096; } async isValidChatCompletionModel(model = "") { diff --git a/server/utils/AiProviders/gemini/index.js b/server/utils/AiProviders/gemini/index.js index f29d73e35..7acc924cc 100644 --- a/server/utils/AiProviders/gemini/index.js +++ b/server/utils/AiProviders/gemini/index.js @@ -3,6 +3,7 @@ const { writeResponseChunk, clientAbortedHandler, } = require("../../helpers/chat/responses"); +const { MODEL_MAP } = require("../modelMap"); class GeminiLLM { constructor(embedder = null, modelPreference = null) { @@ -89,21 +90,12 @@ class GeminiLLM { return "streamGetChatCompletion" in this; } + static promptWindowLimit(modelName) { + return MODEL_MAP.gemini[modelName] ?? 30_720; + } + promptWindowLimit() { - switch (this.model) { - case "gemini-pro": - return 30_720; - case "gemini-1.0-pro": - return 30_720; - case "gemini-1.5-flash-latest": - return 1_048_576; - case "gemini-1.5-pro-latest": - return 2_097_152; - case "gemini-1.5-pro-exp-0801": - return 2_097_152; - default: - return 30_720; // assume a gemini-pro model - } + return MODEL_MAP.gemini[this.model] ?? 30_720; } isValidChatCompletionModel(modelName = "") { diff --git a/server/utils/AiProviders/genericOpenAi/index.js b/server/utils/AiProviders/genericOpenAi/index.js index 7c027c434..fe2902300 100644 --- a/server/utils/AiProviders/genericOpenAi/index.js +++ b/server/utils/AiProviders/genericOpenAi/index.js @@ -55,6 +55,13 @@ class GenericOpenAiLLM { return "streamGetChatCompletion" in this; } + static promptWindowLimit(_modelName) { + const limit = process.env.GENERIC_OPEN_AI_MODEL_TOKEN_LIMIT || 4096; + if (!limit || isNaN(Number(limit))) + throw new Error("No token context limit was set."); + return Number(limit); + } + // Ensure the user set a value for the token limit // and if undefined - assume 4096 window. promptWindowLimit() { diff --git a/server/utils/AiProviders/groq/index.js b/server/utils/AiProviders/groq/index.js index d76bddcc4..c176f1dca 100644 --- a/server/utils/AiProviders/groq/index.js +++ b/server/utils/AiProviders/groq/index.js @@ -2,6 +2,7 @@ const { NativeEmbedder } = require("../../EmbeddingEngines/native"); const { handleDefaultStreamResponseV2, } = require("../../helpers/chat/responses"); +const { MODEL_MAP } = require("../modelMap"); class GroqLLM { constructor(embedder = null, modelPreference = null) { @@ -40,21 +41,12 @@ class GroqLLM { return "streamGetChatCompletion" in this; } + static promptWindowLimit(modelName) { + return MODEL_MAP.groq[modelName] ?? 8192; + } + promptWindowLimit() { - switch (this.model) { - case "gemma2-9b-it": - case "gemma-7b-it": - case "llama3-70b-8192": - case "llama3-8b-8192": - return 8192; - case "llama-3.1-70b-versatile": - case "llama-3.1-8b-instant": - return 8000; - case "mixtral-8x7b-32768": - return 32768; - default: - return 8192; - } + return MODEL_MAP.groq[this.model] ?? 8192; } async isValidChatCompletionModel(modelName = "") { diff --git a/server/utils/AiProviders/huggingface/index.js b/server/utils/AiProviders/huggingface/index.js index ddb1f6c42..021a636b3 100644 --- a/server/utils/AiProviders/huggingface/index.js +++ b/server/utils/AiProviders/huggingface/index.js @@ -45,6 +45,13 @@ class HuggingFaceLLM { return "streamGetChatCompletion" in this; } + static promptWindowLimit(_modelName) { + const limit = process.env.HUGGING_FACE_LLM_TOKEN_LIMIT || 4096; + if (!limit || isNaN(Number(limit))) + throw new Error("No HuggingFace token context limit was set."); + return Number(limit); + } + promptWindowLimit() { const limit = process.env.HUGGING_FACE_LLM_TOKEN_LIMIT || 4096; if (!limit || isNaN(Number(limit))) diff --git a/server/utils/AiProviders/koboldCPP/index.js b/server/utils/AiProviders/koboldCPP/index.js index 5c67103d3..9a700793d 100644 --- a/server/utils/AiProviders/koboldCPP/index.js +++ b/server/utils/AiProviders/koboldCPP/index.js @@ -51,6 +51,13 @@ class KoboldCPPLLM { return "streamGetChatCompletion" in this; } + static promptWindowLimit(_modelName) { + const limit = process.env.KOBOLD_CPP_MODEL_TOKEN_LIMIT || 4096; + if (!limit || isNaN(Number(limit))) + throw new Error("No token context limit was set."); + return Number(limit); + } + // Ensure the user set a value for the token limit // and if undefined - assume 4096 window. promptWindowLimit() { diff --git a/server/utils/AiProviders/liteLLM/index.js b/server/utils/AiProviders/liteLLM/index.js index 897a484dd..d8907e7a9 100644 --- a/server/utils/AiProviders/liteLLM/index.js +++ b/server/utils/AiProviders/liteLLM/index.js @@ -50,6 +50,13 @@ class LiteLLM { return "streamGetChatCompletion" in this; } + static promptWindowLimit(_modelName) { + const limit = process.env.LITE_LLM_MODEL_TOKEN_LIMIT || 4096; + if (!limit || isNaN(Number(limit))) + throw new Error("No token context limit was set."); + return Number(limit); + } + // Ensure the user set a value for the token limit // and if undefined - assume 4096 window. promptWindowLimit() { diff --git a/server/utils/AiProviders/lmStudio/index.js b/server/utils/AiProviders/lmStudio/index.js index 6ff025884..6f0593b8c 100644 --- a/server/utils/AiProviders/lmStudio/index.js +++ b/server/utils/AiProviders/lmStudio/index.js @@ -48,6 +48,13 @@ class LMStudioLLM { return "streamGetChatCompletion" in this; } + static promptWindowLimit(_modelName) { + const limit = process.env.LMSTUDIO_MODEL_TOKEN_LIMIT || 4096; + if (!limit || isNaN(Number(limit))) + throw new Error("No LMStudio token context limit was set."); + return Number(limit); + } + // Ensure the user set a value for the token limit // and if undefined - assume 4096 window. promptWindowLimit() { diff --git a/server/utils/AiProviders/localAi/index.js b/server/utils/AiProviders/localAi/index.js index 2275e1e8d..2d5e8b1f4 100644 --- a/server/utils/AiProviders/localAi/index.js +++ b/server/utils/AiProviders/localAi/index.js @@ -40,6 +40,13 @@ class LocalAiLLM { return "streamGetChatCompletion" in this; } + static promptWindowLimit(_modelName) { + const limit = process.env.LOCAL_AI_MODEL_TOKEN_LIMIT || 4096; + if (!limit || isNaN(Number(limit))) + throw new Error("No LocalAi token context limit was set."); + return Number(limit); + } + // Ensure the user set a value for the token limit // and if undefined - assume 4096 window. promptWindowLimit() { diff --git a/server/utils/AiProviders/mistral/index.js b/server/utils/AiProviders/mistral/index.js index 92cc63f5a..7dfe74196 100644 --- a/server/utils/AiProviders/mistral/index.js +++ b/server/utils/AiProviders/mistral/index.js @@ -41,6 +41,10 @@ class MistralLLM { return "streamGetChatCompletion" in this; } + static promptWindowLimit() { + return 32000; + } + promptWindowLimit() { return 32000; } diff --git a/server/utils/AiProviders/modelMap.js b/server/utils/AiProviders/modelMap.js new file mode 100644 index 000000000..151bd7cd8 --- /dev/null +++ b/server/utils/AiProviders/modelMap.js @@ -0,0 +1,55 @@ +/** + * The model name and context window for all know model windows + * that are available through providers which has discrete model options. + */ +const MODEL_MAP = { + anthropic: { + "claude-instant-1.2": 100_000, + "claude-2.0": 100_000, + "claude-2.1": 200_000, + "claude-3-opus-20240229": 200_000, + "claude-3-sonnet-20240229": 200_000, + "claude-3-haiku-20240307": 200_000, + "claude-3-5-sonnet-20240620": 200_000, + }, + cohere: { + "command-r": 128_000, + "command-r-plus": 128_000, + command: 4_096, + "command-light": 4_096, + "command-nightly": 8_192, + "command-light-nightly": 8_192, + }, + gemini: { + "gemini-pro": 30_720, + "gemini-1.0-pro": 30_720, + "gemini-1.5-flash-latest": 1_048_576, + "gemini-1.5-pro-latest": 2_097_152, + "gemini-1.5-pro-exp-0801": 2_097_152, + }, + groq: { + "gemma2-9b-it": 8192, + "gemma-7b-it": 8192, + "llama3-70b-8192": 8192, + "llama3-8b-8192": 8192, + "llama-3.1-70b-versatile": 8000, + "llama-3.1-8b-instant": 8000, + "mixtral-8x7b-32768": 32768, + }, + openai: { + "gpt-3.5-turbo": 16_385, + "gpt-3.5-turbo-1106": 16_385, + "gpt-4o": 128_000, + "gpt-4o-2024-08-06": 128_000, + "gpt-4o-2024-05-13": 128_000, + "gpt-4o-mini": 128_000, + "gpt-4o-mini-2024-07-18": 128_000, + "gpt-4-turbo": 128_000, + "gpt-4-1106-preview": 128_000, + "gpt-4-turbo-preview": 128_000, + "gpt-4": 8_192, + "gpt-4-32k": 32_000, + }, +}; + +module.exports = { MODEL_MAP }; diff --git a/server/utils/AiProviders/native/index.js b/server/utils/AiProviders/native/index.js index 630cc9ea1..4d15cdac0 100644 --- a/server/utils/AiProviders/native/index.js +++ b/server/utils/AiProviders/native/index.js @@ -96,6 +96,13 @@ class NativeLLM { return "streamGetChatCompletion" in this; } + static promptWindowLimit(_modelName) { + const limit = process.env.NATIVE_LLM_MODEL_TOKEN_LIMIT || 4096; + if (!limit || isNaN(Number(limit))) + throw new Error("No NativeAI token context limit was set."); + return Number(limit); + } + // Ensure the user set a value for the token limit promptWindowLimit() { const limit = process.env.NATIVE_LLM_MODEL_TOKEN_LIMIT || 4096; diff --git a/server/utils/AiProviders/ollama/index.js b/server/utils/AiProviders/ollama/index.js index 02e780777..eb18ee6f3 100644 --- a/server/utils/AiProviders/ollama/index.js +++ b/server/utils/AiProviders/ollama/index.js @@ -82,6 +82,13 @@ class OllamaAILLM { return "streamGetChatCompletion" in this; } + static promptWindowLimit(_modelName) { + const limit = process.env.OLLAMA_MODEL_TOKEN_LIMIT || 4096; + if (!limit || isNaN(Number(limit))) + throw new Error("No Ollama token context limit was set."); + return Number(limit); + } + // Ensure the user set a value for the token limit // and if undefined - assume 4096 window. promptWindowLimit() { diff --git a/server/utils/AiProviders/openAi/index.js b/server/utils/AiProviders/openAi/index.js index 57ea28897..b0e52dc2b 100644 --- a/server/utils/AiProviders/openAi/index.js +++ b/server/utils/AiProviders/openAi/index.js @@ -2,6 +2,7 @@ const { NativeEmbedder } = require("../../EmbeddingEngines/native"); const { handleDefaultStreamResponseV2, } = require("../../helpers/chat/responses"); +const { MODEL_MAP } = require("../modelMap"); class OpenAiLLM { constructor(embedder = null, modelPreference = null) { @@ -38,27 +39,12 @@ class OpenAiLLM { return "streamGetChatCompletion" in this; } + static promptWindowLimit(modelName) { + return MODEL_MAP.openai[modelName] ?? 4_096; + } + promptWindowLimit() { - switch (this.model) { - case "gpt-3.5-turbo": - case "gpt-3.5-turbo-1106": - return 16_385; - case "gpt-4o": - case "gpt-4o-2024-08-06": - case "gpt-4o-2024-05-13": - case "gpt-4o-mini": - case "gpt-4o-mini-2024-07-18": - case "gpt-4-turbo": - case "gpt-4-1106-preview": - case "gpt-4-turbo-preview": - return 128_000; - case "gpt-4": - return 8_192; - case "gpt-4-32k": - return 32_000; - default: - return 4_096; // assume a fine-tune 3.5? - } + return MODEL_MAP.openai[this.model] ?? 4_096; } // Short circuit if name has 'gpt' since we now fetch models from OpenAI API diff --git a/server/utils/AiProviders/openRouter/index.js b/server/utils/AiProviders/openRouter/index.js index 00a176e1b..3ec813423 100644 --- a/server/utils/AiProviders/openRouter/index.js +++ b/server/utils/AiProviders/openRouter/index.js @@ -117,6 +117,17 @@ class OpenRouterLLM { return "streamGetChatCompletion" in this; } + static promptWindowLimit(modelName) { + const cacheModelPath = path.resolve(cacheFolder, "models.json"); + const availableModels = fs.existsSync(cacheModelPath) + ? safeJsonParse( + fs.readFileSync(cacheModelPath, { encoding: "utf-8" }), + {} + ) + : {}; + return availableModels[modelName]?.maxLength || 4096; + } + promptWindowLimit() { const availableModels = this.models(); return availableModels[this.model]?.maxLength || 4096; diff --git a/server/utils/AiProviders/perplexity/index.js b/server/utils/AiProviders/perplexity/index.js index 712605f0b..93639f9f1 100644 --- a/server/utils/AiProviders/perplexity/index.js +++ b/server/utils/AiProviders/perplexity/index.js @@ -52,6 +52,11 @@ class PerplexityLLM { return "streamGetChatCompletion" in this; } + static promptWindowLimit(modelName) { + const availableModels = perplexityModels(); + return availableModels[modelName]?.maxLength || 4096; + } + promptWindowLimit() { const availableModels = this.allModelInformation(); return availableModels[this.model]?.maxLength || 4096; diff --git a/server/utils/AiProviders/textGenWebUI/index.js b/server/utils/AiProviders/textGenWebUI/index.js index 9400a12f4..68d7a6ac8 100644 --- a/server/utils/AiProviders/textGenWebUI/index.js +++ b/server/utils/AiProviders/textGenWebUI/index.js @@ -48,6 +48,13 @@ class TextGenWebUILLM { return "streamGetChatCompletion" in this; } + static promptWindowLimit(_modelName) { + const limit = process.env.TEXT_GEN_WEB_UI_MODEL_TOKEN_LIMIT || 4096; + if (!limit || isNaN(Number(limit))) + throw new Error("No token context limit was set."); + return Number(limit); + } + // Ensure the user set a value for the token limit // and if undefined - assume 4096 window. promptWindowLimit() { diff --git a/server/utils/AiProviders/togetherAi/index.js b/server/utils/AiProviders/togetherAi/index.js index 5d25edf9e..8c9f8831c 100644 --- a/server/utils/AiProviders/togetherAi/index.js +++ b/server/utils/AiProviders/togetherAi/index.js @@ -48,6 +48,11 @@ class TogetherAiLLM { return "streamGetChatCompletion" in this; } + static promptWindowLimit(modelName) { + const availableModels = togetherAiModels(); + return availableModels[modelName]?.maxLength || 4096; + } + // Ensure the user set a value for the token limit // and if undefined - assume 4096 window. promptWindowLimit() { diff --git a/server/utils/agents/aibitat/plugins/summarize.js b/server/utils/agents/aibitat/plugins/summarize.js index bd491f960..d532a0715 100644 --- a/server/utils/agents/aibitat/plugins/summarize.js +++ b/server/utils/agents/aibitat/plugins/summarize.js @@ -136,9 +136,11 @@ const docSummarizer = { ); } + const { TokenManager } = require("../../../helpers/tiktoken"); if ( - document.content?.length < - Provider.contextLimit(this.super.provider) + new TokenManager(this.super.model).countFromString( + document.content + ) < Provider.contextLimit(this.super.provider, this.super.model) ) { return document.content; } diff --git a/server/utils/agents/aibitat/plugins/web-scraping.js b/server/utils/agents/aibitat/plugins/web-scraping.js index df26caf81..a7dc7a3c7 100644 --- a/server/utils/agents/aibitat/plugins/web-scraping.js +++ b/server/utils/agents/aibitat/plugins/web-scraping.js @@ -77,7 +77,11 @@ const webScraping = { throw new Error("There was no content to be collected or read."); } - if (content.length < Provider.contextLimit(this.super.provider)) { + const { TokenManager } = require("../../../helpers/tiktoken"); + if ( + new TokenManager(this.super.model).countFromString(content) < + Provider.contextLimit(this.super.provider, this.super.model) + ) { return content; } diff --git a/server/utils/agents/aibitat/providers/ai-provider.js b/server/utils/agents/aibitat/providers/ai-provider.js index 472f72be2..034c67ad0 100644 --- a/server/utils/agents/aibitat/providers/ai-provider.js +++ b/server/utils/agents/aibitat/providers/ai-provider.js @@ -15,6 +15,7 @@ const { ChatAnthropic } = require("@langchain/anthropic"); const { ChatBedrockConverse } = require("@langchain/aws"); const { ChatOllama } = require("@langchain/community/chat_models/ollama"); const { toValidNumber } = require("../../../http"); +const { getLLMProviderClass } = require("../../../helpers"); const DEFAULT_WORKSPACE_PROMPT = "You are a helpful ai assistant who can assist the user and use tools available to help answer the users prompts and questions."; @@ -173,15 +174,16 @@ class Provider { } } - static contextLimit(provider = "openai") { - switch (provider) { - case "openai": - return 8_000; - case "anthropic": - return 100_000; - default: - return 8_000; - } + /** + * Get the context limit for a provider/model combination using static method in AIProvider class. + * @param {string} provider + * @param {string} modelName + * @returns {number} + */ + static contextLimit(provider = "openai", modelName) { + const llm = getLLMProviderClass({ provider }); + if (!llm || !llm.hasOwnProperty("promptWindowLimit")) return 8_000; + return llm.promptWindowLimit(modelName); } // For some providers we may want to override the system prompt to be more verbose. diff --git a/server/utils/helpers/index.js b/server/utils/helpers/index.js index 765e7226f..6ec0b2a31 100644 --- a/server/utils/helpers/index.js +++ b/server/utils/helpers/index.js @@ -20,6 +20,11 @@ * @property {Function} compressMessages - Compresses chat messages to fit within the token limit. */ +/** + * @typedef {Object} BaseLLMProviderClass - Class method of provider - not instantiated + * @property {function(string): number} promptWindowLimit - Returns the token limit for the provided model. + */ + /** * @typedef {Object} BaseVectorDatabaseProvider * @property {string} name - The name of the Vector Database instance. @@ -204,6 +209,78 @@ function getEmbeddingEngineSelection() { } } +/** + * Returns the LLMProviderClass - this is a helper method to access static methods on a class + * @param {{provider: string | null} | null} params - Initialize params for LLMs provider + * @returns {BaseLLMProviderClass} + */ +function getLLMProviderClass({ provider = null } = {}) { + switch (provider) { + case "openai": + const { OpenAiLLM } = require("../AiProviders/openAi"); + return OpenAiLLM; + case "azure": + const { AzureOpenAiLLM } = require("../AiProviders/azureOpenAi"); + return AzureOpenAiLLM; + case "anthropic": + const { AnthropicLLM } = require("../AiProviders/anthropic"); + return AnthropicLLM; + case "gemini": + const { GeminiLLM } = require("../AiProviders/gemini"); + return GeminiLLM; + case "lmstudio": + const { LMStudioLLM } = require("../AiProviders/lmStudio"); + return LMStudioLLM; + case "localai": + const { LocalAiLLM } = require("../AiProviders/localAi"); + return LocalAiLLM; + case "ollama": + const { OllamaAILLM } = require("../AiProviders/ollama"); + return OllamaAILLM; + case "togetherai": + const { TogetherAiLLM } = require("../AiProviders/togetherAi"); + return TogetherAiLLM; + case "perplexity": + const { PerplexityLLM } = require("../AiProviders/perplexity"); + return PerplexityLLM; + case "openrouter": + const { OpenRouterLLM } = require("../AiProviders/openRouter"); + return OpenRouterLLM; + case "mistral": + const { MistralLLM } = require("../AiProviders/mistral"); + return MistralLLM; + case "native": + const { NativeLLM } = require("../AiProviders/native"); + return NativeLLM; + case "huggingface": + const { HuggingFaceLLM } = require("../AiProviders/huggingface"); + return HuggingFaceLLM; + case "groq": + const { GroqLLM } = require("../AiProviders/groq"); + return GroqLLM; + case "koboldcpp": + const { KoboldCPPLLM } = require("../AiProviders/koboldCPP"); + return KoboldCPPLLM; + case "textgenwebui": + const { TextGenWebUILLM } = require("../AiProviders/textGenWebUI"); + return TextGenWebUILLM; + case "cohere": + const { CohereLLM } = require("../AiProviders/cohere"); + return CohereLLM; + case "litellm": + const { LiteLLM } = require("../AiProviders/liteLLM"); + return LiteLLM; + case "generic-openai": + const { GenericOpenAiLLM } = require("../AiProviders/genericOpenAi"); + return GenericOpenAiLLM; + case "bedrock": + const { AWSBedrockLLM } = require("../AiProviders/bedrock"); + return AWSBedrockLLM; + default: + return null; + } +} + // Some models have lower restrictions on chars that can be encoded in a single pass // and by default we assume it can handle 1,000 chars, but some models use work with smaller // chars so here we can override that value when embedding information. @@ -228,6 +305,7 @@ module.exports = { getEmbeddingEngineSelection, maximumChunkLength, getVectorDbClass, + getLLMProviderClass, getLLMProvider, toChunks, }; -- GitLab