diff --git a/server/utils/agents/aibitat/index.js b/server/utils/agents/aibitat/index.js index fa490edb29744a9f597beaf47a9b8ff71c4c4205..0fe6eb5105adee83794a6344538bac864bb57851 100644 --- a/server/utils/agents/aibitat/index.js +++ b/server/utils/agents/aibitat/index.js @@ -41,6 +41,7 @@ class AIbitat { ...rest, }; this.provider = this.defaultProvider.provider; + this.model = this.defaultProvider.model; } /** diff --git a/server/utils/agents/aibitat/plugins/summarize.js b/server/utils/agents/aibitat/plugins/summarize.js index de1657c9f88dda6985e30c361bd9bd92dddcdcb8..bd491f9608c948987bf3def188140e436c2a8611 100644 --- a/server/utils/agents/aibitat/plugins/summarize.js +++ b/server/utils/agents/aibitat/plugins/summarize.js @@ -154,11 +154,12 @@ const docSummarizer = { this.controller.abort(); }); - return await summarizeContent( - this.super.provider, - this.controller.signal, - document.content - ); + return await summarizeContent({ + provider: this.super.provider, + model: this.super.model, + controllerSignal: this.controller.signal, + content: document.content, + }); } catch (error) { this.super.handlerProps.log( `document-summarizer.summarizeDoc raised an error. ${error.message}` diff --git a/server/utils/agents/aibitat/plugins/web-scraping.js b/server/utils/agents/aibitat/plugins/web-scraping.js index f5c8d41f5180aba6680b4e4ffd3380597cd49a08..2ca69ec96079bd3d655e3fbf232ba086d5d889be 100644 --- a/server/utils/agents/aibitat/plugins/web-scraping.js +++ b/server/utils/agents/aibitat/plugins/web-scraping.js @@ -90,11 +90,13 @@ const webScraping = { ); this.controller.abort(); }); - return summarizeContent( - this.super.provider, - this.controller.signal, - content - ); + + return summarizeContent({ + provider: this.super.provider, + model: this.super.model, + controllerSignal: this.controller.signal, + content, + }); }, }); }, diff --git a/server/utils/agents/aibitat/providers/ai-provider.js b/server/utils/agents/aibitat/providers/ai-provider.js index 91a81ebfa71f63659219cf7aada9ba5e3abd70a8..b3a8b1791beedbffc43fbc6dfb6259cb6ba03c81 100644 --- a/server/utils/agents/aibitat/providers/ai-provider.js +++ b/server/utils/agents/aibitat/providers/ai-provider.js @@ -2,8 +2,19 @@ * A service that provides an AI client to create a completion. */ +/** + * @typedef {Object} LangChainModelConfig + * @property {(string|null)} baseURL - Override the default base URL process.env for this provider + * @property {(string|null)} apiKey - Override the default process.env for this provider + * @property {(number|null)} temperature - Override the default temperature + * @property {(string|null)} model - Overrides model used for provider. + */ + const { ChatOpenAI } = require("@langchain/openai"); const { ChatAnthropic } = require("@langchain/anthropic"); +const { ChatOllama } = require("@langchain/community/chat_models/ollama"); +const { toValidNumber } = require("../../../http"); + const DEFAULT_WORKSPACE_PROMPT = "You are a helpful ai assistant who can assist the user and use tools available to help answer the users prompts and questions."; @@ -27,8 +38,15 @@ class Provider { return this._client; } + /** + * + * @param {string} provider - the string key of the provider LLM being loaded. + * @param {LangChainModelConfig} config - Config to be used to override default connection object. + * @returns + */ static LangChainChatModel(provider = "openai", config = {}) { switch (provider) { + // Cloud models case "openai": return new ChatOpenAI({ apiKey: process.env.OPEN_AI_KEY, @@ -39,11 +57,108 @@ class Provider { apiKey: process.env.ANTHROPIC_API_KEY, ...config, }); - default: + case "groq": return new ChatOpenAI({ - apiKey: process.env.OPEN_AI_KEY, + configuration: { + baseURL: "https://api.groq.com/openai/v1", + }, + apiKey: process.env.GROQ_API_KEY, + ...config, + }); + case "mistral": + return new ChatOpenAI({ + configuration: { + baseURL: "https://api.mistral.ai/v1", + }, + apiKey: process.env.MISTRAL_API_KEY ?? null, + ...config, + }); + case "openrouter": + return new ChatOpenAI({ + configuration: { + baseURL: "https://openrouter.ai/api/v1", + defaultHeaders: { + "HTTP-Referer": "https://useanything.com", + "X-Title": "AnythingLLM", + }, + }, + apiKey: process.env.OPENROUTER_API_KEY ?? null, + ...config, + }); + case "perplexity": + return new ChatOpenAI({ + configuration: { + baseURL: "https://api.perplexity.ai", + }, + apiKey: process.env.PERPLEXITY_API_KEY ?? null, + ...config, + }); + case "togetherai": + return new ChatOpenAI({ + configuration: { + baseURL: "https://api.together.xyz/v1", + }, + apiKey: process.env.TOGETHER_AI_API_KEY ?? null, + ...config, + }); + case "generic-openai": + return new ChatOpenAI({ + configuration: { + baseURL: process.env.GENERIC_OPEN_AI_BASE_PATH, + }, + apiKey: process.env.GENERIC_OPEN_AI_API_KEY, + maxTokens: toValidNumber( + process.env.GENERIC_OPEN_AI_MAX_TOKENS, + 1024 + ), + ...config, + }); + + // OSS Model Runners + // case "anythingllm_ollama": + // return new ChatOllama({ + // baseUrl: process.env.PLACEHOLDER, + // ...config, + // }); + case "ollama": + return new ChatOllama({ + baseUrl: process.env.OLLAMA_BASE_PATH, ...config, }); + case "lmstudio": + return new ChatOpenAI({ + configuration: { + baseURL: process.env.LMSTUDIO_BASE_PATH?.replace(/\/+$/, ""), + }, + apiKey: "not-used", // Needs to be specified or else will assume OpenAI + ...config, + }); + case "koboldcpp": + return new ChatOpenAI({ + configuration: { + baseURL: process.env.KOBOLD_CPP_BASE_PATH, + }, + apiKey: "not-used", + ...config, + }); + case "localai": + return new ChatOpenAI({ + configuration: { + baseURL: process.env.LOCAL_AI_BASE_PATH, + }, + apiKey: process.env.LOCAL_AI_API_KEY ?? "not-used", + ...config, + }); + case "textgenwebui": + return new ChatOpenAI({ + configuration: { + baseURL: process.env.TEXT_GEN_WEB_UI_BASE_PATH, + }, + apiKey: process.env.TEXT_GEN_WEB_UI_API_KEY ?? "not-used", + ...config, + }); + default: + throw new Error(`Unsupported provider ${provider} for this task.`); } } diff --git a/server/utils/agents/aibitat/providers/groq.js b/server/utils/agents/aibitat/providers/groq.js index 01f69f7c110b925083de342e7a6135a82395fe15..9ca99065d88641349dd7ff7ba91fe2076c276855 100644 --- a/server/utils/agents/aibitat/providers/groq.js +++ b/server/utils/agents/aibitat/providers/groq.js @@ -1,28 +1,52 @@ const OpenAI = require("openai"); const Provider = require("./ai-provider.js"); -const { RetryError } = require("../error.js"); +const InheritMultiple = require("./helpers/classes.js"); +const UnTooled = require("./helpers/untooled.js"); /** - * The agent provider for the Groq provider. - * Using OpenAI tool calling with groq really sucks right now - * its just fast and bad. We should probably migrate this to Untooled to improve - * coherence. + * The agent provider for the GroqAI provider. + * We wrap Groq in UnTooled because its tool-calling built in is quite bad and wasteful. */ -class GroqProvider extends Provider { +class GroqProvider extends InheritMultiple([Provider, UnTooled]) { model; constructor(config = {}) { const { model = "llama3-8b-8192" } = config; + super(); const client = new OpenAI({ baseURL: "https://api.groq.com/openai/v1", apiKey: process.env.GROQ_API_KEY, maxRetries: 3, }); - super(client); + + this._client = client; this.model = model; this.verbose = true; } + get client() { + return this._client; + } + + async #handleFunctionCallChat({ messages = [] }) { + return await this.client.chat.completions + .create({ + model: this.model, + temperature: 0, + messages, + }) + .then((result) => { + if (!result.hasOwnProperty("choices")) + throw new Error("GroqAI chat: No results!"); + if (result.choices.length === 0) + throw new Error("GroqAI chat: No results length!"); + return result.choices[0].message.content; + }) + .catch((_) => { + return null; + }); + } + /** * Create a completion based on the received messages. * @@ -32,68 +56,49 @@ class GroqProvider extends Provider { */ async complete(messages, functions = null) { try { - const response = await this.client.chat.completions.create({ - model: this.model, - // stream: true, - messages, - ...(Array.isArray(functions) && functions?.length > 0 - ? { functions } - : {}), - }); + let completion; + if (functions.length > 0) { + const { toolCall, text } = await this.functionCall( + messages, + functions, + this.#handleFunctionCallChat.bind(this) + ); - // Right now, we only support one completion, - // so we just take the first one in the list - const completion = response.choices[0].message; - const cost = this.getCost(response.usage); - // treat function calls - if (completion.function_call) { - let functionArgs = {}; - try { - functionArgs = JSON.parse(completion.function_call.arguments); - } catch (error) { - // call the complete function again in case it gets a json error - return this.complete( - [ - ...messages, - { - role: "function", - name: completion.function_call.name, - function_call: completion.function_call, - content: error?.message, - }, - ], - functions - ); + if (toolCall !== null) { + this.providerLog(`Valid tool call found - running ${toolCall.name}.`); + this.deduplicator.trackRun(toolCall.name, toolCall.arguments); + return { + result: null, + functionCall: { + name: toolCall.name, + arguments: toolCall.arguments, + }, + cost: 0, + }; } + completion = { content: text }; + } - // console.log(completion, { functionArgs }) - return { - result: null, - functionCall: { - name: completion.function_call.name, - arguments: functionArgs, - }, - cost, - }; + if (!completion?.content) { + this.providerLog( + "Will assume chat completion without tool call inputs." + ); + const response = await this.client.chat.completions.create({ + model: this.model, + messages: this.cleanMsgs(messages), + }); + completion = response.choices[0].message; } + // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent + // from calling the exact same function over and over in a loop within a single chat exchange + // _but_ we should enable it to call previously used tools in a new chat interaction. + this.deduplicator.reset("runs"); return { result: completion.content, - cost, + cost: 0, }; } catch (error) { - // If invalid Auth error we need to abort because no amount of waiting - // will make auth better. - if (error instanceof OpenAI.AuthenticationError) throw error; - - if ( - error instanceof OpenAI.RateLimitError || - error instanceof OpenAI.InternalServerError || - error instanceof OpenAI.APIError // Also will catch AuthenticationError!!! - ) { - throw new RetryError(error.message); - } - throw error; } } @@ -103,7 +108,7 @@ class GroqProvider extends Provider { * * @param _usage The completion to get the cost for. * @returns The cost of the completion. - * Stubbed since Groq has no cost basis. + * Stubbed since LMStudio has no cost basis. */ getCost(_usage) { return 0; diff --git a/server/utils/agents/aibitat/utils/summarize.js b/server/utils/agents/aibitat/utils/summarize.js index 7f1852c023514f6ab2275f7c287aa5d197317ade..fbee20533e9da26671129e396c861c40bec933ae 100644 --- a/server/utils/agents/aibitat/utils/summarize.js +++ b/server/utils/agents/aibitat/utils/summarize.js @@ -3,26 +3,27 @@ const { PromptTemplate } = require("@langchain/core/prompts"); const { RecursiveCharacterTextSplitter } = require("@langchain/textsplitters"); const Provider = require("../providers/ai-provider"); /** - * Summarize content using OpenAI's GPT-3.5 model. - * - * @param self The context of the caller function - * @param content The content to summarize. - * @returns The summarized content. + * @typedef {Object} LCSummarizationConfig + * @property {string} provider The LLM to use for summarization (inherited) + * @property {string} model The LLM Model to use for summarization (inherited) + * @property {AbortController['signal']} controllerSignal Abort controller to stop recursive summarization + * @property {string} content The text content of the text to summarize */ -const SUMMARY_MODEL = { - anthropic: "claude-3-opus-20240229", // 200,000 tokens - openai: "gpt-4o", // 128,000 tokens -}; - -async function summarizeContent( +/** + * Summarize content using LLM LC-Chain call + * @param {LCSummarizationConfig} The LLM to use for summarization (inherited) + * @returns {Promise<string>} The summarized content. + */ +async function summarizeContent({ provider = "openai", + model = null, controllerSignal, - content -) { + content, +}) { const llm = Provider.LangChainChatModel(provider, { temperature: 0, - modelName: SUMMARY_MODEL[provider], + model: model, }); const textSplitter = new RecursiveCharacterTextSplitter({