diff --git a/server/utils/AiProviders/anthropic/index.js b/server/utils/AiProviders/anthropic/index.js index f9c4c91c7a8f1a90f28a7811d8790e2c128a3d1e..0cd9589598c6be01ca3484a8748b78c8958f1743 100644 --- a/server/utils/AiProviders/anthropic/index.js +++ b/server/utils/AiProviders/anthropic/index.js @@ -2,6 +2,7 @@ const { v4 } = require("uuid"); const { writeResponseChunk, clientAbortedHandler, + formatChatHistory, } = require("../../helpers/chat/responses"); const { NativeEmbedder } = require("../../EmbeddingEngines/native"); const { MODEL_MAP } = require("../modelMap"); @@ -99,7 +100,7 @@ class AnthropicLLM { return [ prompt, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent), { role: "user", content: this.#generateContent({ userPrompt, attachments }), diff --git a/server/utils/AiProviders/apipie/index.js b/server/utils/AiProviders/apipie/index.js index 1f6dd68a0fed68cc2762df8edb3b5748ac285607..bd794d38ea626ac7039fd414bdebd56522fee2bb 100644 --- a/server/utils/AiProviders/apipie/index.js +++ b/server/utils/AiProviders/apipie/index.js @@ -3,6 +3,7 @@ const { v4: uuidv4 } = require("uuid"); const { writeResponseChunk, clientAbortedHandler, + formatChatHistory, } = require("../../helpers/chat/responses"); const fs = require("fs"); const path = require("path"); @@ -177,7 +178,7 @@ class ApiPieLLM { }; return [ prompt, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent), { role: "user", content: this.#generateContent({ userPrompt, attachments }), diff --git a/server/utils/AiProviders/azureOpenAi/index.js b/server/utils/AiProviders/azureOpenAi/index.js index 078f55eef070674a99e6552259688f9b0565e633..cbf2c2ef3cf5cb0d022b6930d1fece234a54feba 100644 --- a/server/utils/AiProviders/azureOpenAi/index.js +++ b/server/utils/AiProviders/azureOpenAi/index.js @@ -5,6 +5,7 @@ const { const { writeResponseChunk, clientAbortedHandler, + formatChatHistory, } = require("../../helpers/chat/responses"); class AzureOpenAiLLM { @@ -103,7 +104,7 @@ class AzureOpenAiLLM { }; return [ prompt, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent), { role: "user", content: this.#generateContent({ userPrompt, attachments }), diff --git a/server/utils/AiProviders/bedrock/index.js b/server/utils/AiProviders/bedrock/index.js index 171d7b4595f991dd9f09f661ecefef2996f7ea24..d5f66eaeaf1d9cca73eea6df4e3f5373c19b22bc 100644 --- a/server/utils/AiProviders/bedrock/index.js +++ b/server/utils/AiProviders/bedrock/index.js @@ -2,6 +2,7 @@ const { StringOutputParser } = require("@langchain/core/output_parsers"); const { writeResponseChunk, clientAbortedHandler, + formatChatHistory, } = require("../../helpers/chat/responses"); const { NativeEmbedder } = require("../../EmbeddingEngines/native"); const { @@ -199,7 +200,7 @@ class AWSBedrockLLM { // AWS Mistral models do not support system prompts if (this.model.startsWith("mistral")) return [ - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent, "spread"), { role: "user", ...this.#generateContent({ userPrompt, attachments }), @@ -212,7 +213,7 @@ class AWSBedrockLLM { }; return [ prompt, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent, "spread"), { role: "user", ...this.#generateContent({ userPrompt, attachments }), diff --git a/server/utils/AiProviders/gemini/index.js b/server/utils/AiProviders/gemini/index.js index 9961c70d762be43afea57cd5c5103a5ba76a97f9..fd7929f4b32b006140d7c9f19904fe2e0f37f6c5 100644 --- a/server/utils/AiProviders/gemini/index.js +++ b/server/utils/AiProviders/gemini/index.js @@ -7,6 +7,7 @@ const { const { writeResponseChunk, clientAbortedHandler, + formatChatHistory, } = require("../../helpers/chat/responses"); const { MODEL_MAP } = require("../modelMap"); const { defaultGeminiModels, v1BetaModels } = require("./defaultModels"); @@ -254,6 +255,7 @@ class GeminiLLM { const models = await this.fetchModels(process.env.GEMINI_API_KEY); return models.some((model) => model.id === modelName); } + /** * Generates appropriate content array for a message + attachments. * @param {{userPrompt:string, attachments: import("../../helpers").Attachment[]}} @@ -290,7 +292,7 @@ class GeminiLLM { return [ prompt, { role: "assistant", content: "Okay." }, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent), { role: "USER_PROMPT", content: this.#generateContent({ userPrompt, attachments }), @@ -306,8 +308,17 @@ class GeminiLLM { .map((message) => { if (message.role === "system") return { role: "user", parts: [{ text: message.content }] }; - if (message.role === "user") + + if (message.role === "user") { + // If the content is an array - then we have already formatted the context so return it directly. + if (Array.isArray(message.content)) + return { role: "user", parts: message.content }; + + // Otherwise, this was a regular user message with no attachments + // so we need to format it for Gemini return { role: "user", parts: [{ text: message.content }] }; + } + if (message.role === "assistant") return { role: "model", parts: [{ text: message.content }] }; return null; diff --git a/server/utils/AiProviders/genericOpenAi/index.js b/server/utils/AiProviders/genericOpenAi/index.js index 57c8f6a1445d373b0b00fbb86f605cce1643db6d..eb020298c424850e1c020c29f6d5dd8c8d9779d0 100644 --- a/server/utils/AiProviders/genericOpenAi/index.js +++ b/server/utils/AiProviders/genericOpenAi/index.js @@ -4,6 +4,7 @@ const { } = require("../../helpers/chat/LLMPerformanceMonitor"); const { handleDefaultStreamResponseV2, + formatChatHistory, } = require("../../helpers/chat/responses"); const { toValidNumber } = require("../../http"); @@ -133,7 +134,7 @@ class GenericOpenAiLLM { }; return [ prompt, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent), { role: "user", content: this.#generateContent({ userPrompt, attachments }), diff --git a/server/utils/AiProviders/groq/index.js b/server/utils/AiProviders/groq/index.js index 5793002f6ca7e3455347b3aae8f05e90d97f269b..9e7e77fa16bd14eeeb49169861b73e1717ef8cc9 100644 --- a/server/utils/AiProviders/groq/index.js +++ b/server/utils/AiProviders/groq/index.js @@ -89,6 +89,8 @@ class GroqLLM { * Since we can only explicitly support the current models, this is a temporary solution. * If the attachments are empty or the model is not a vision model, we will return the default prompt structure which will work for all models. * If the attachments are present and the model is a vision model - we only return the user prompt with attachments - see comment at end of function for more. + * + * Historical attachments are also omitted from prompt chat history for the reasons above. (TDC: Dec 30, 2024) */ #conditionalPromptStruct({ systemPrompt = "", diff --git a/server/utils/AiProviders/koboldCPP/index.js b/server/utils/AiProviders/koboldCPP/index.js index 0e5206cabe8afe21c555687520fec3a9f13162c4..5ee58b5bfecdc24c2dfd0242fbf7da515aa2ddca 100644 --- a/server/utils/AiProviders/koboldCPP/index.js +++ b/server/utils/AiProviders/koboldCPP/index.js @@ -2,6 +2,7 @@ const { NativeEmbedder } = require("../../EmbeddingEngines/native"); const { clientAbortedHandler, writeResponseChunk, + formatChatHistory, } = require("../../helpers/chat/responses"); const { LLMPerformanceMonitor, @@ -116,7 +117,7 @@ class KoboldCPPLLM { }; return [ prompt, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent), { role: "user", content: this.#generateContent({ userPrompt, attachments }), diff --git a/server/utils/AiProviders/liteLLM/index.js b/server/utils/AiProviders/liteLLM/index.js index 63f4115bc94d312e371303c27f91afb195eb798d..2017d7774f8b0d10a48c3565da843443a02cc2ff 100644 --- a/server/utils/AiProviders/liteLLM/index.js +++ b/server/utils/AiProviders/liteLLM/index.js @@ -4,6 +4,7 @@ const { } = require("../../helpers/chat/LLMPerformanceMonitor"); const { handleDefaultStreamResponseV2, + formatChatHistory, } = require("../../helpers/chat/responses"); class LiteLLM { @@ -115,7 +116,7 @@ class LiteLLM { }; return [ prompt, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent), { role: "user", content: this.#generateContent({ userPrompt, attachments }), diff --git a/server/utils/AiProviders/lmStudio/index.js b/server/utils/AiProviders/lmStudio/index.js index 082576c98ad77af59d2213111220897b5f1104ad..bde9ed486b3a56fb2db404d84bb4ebfd9d52f314 100644 --- a/server/utils/AiProviders/lmStudio/index.js +++ b/server/utils/AiProviders/lmStudio/index.js @@ -1,6 +1,7 @@ const { NativeEmbedder } = require("../../EmbeddingEngines/native"); const { handleDefaultStreamResponseV2, + formatChatHistory, } = require("../../helpers/chat/responses"); const { LLMPerformanceMonitor, @@ -117,7 +118,7 @@ class LMStudioLLM { }; return [ prompt, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent), { role: "user", content: this.#generateContent({ userPrompt, attachments }), diff --git a/server/utils/AiProviders/localAi/index.js b/server/utils/AiProviders/localAi/index.js index 53da280f2d0a8766ea7d68f9a2581d5137511a2d..f62fe70dd9d0566b3173b402a9b11aa5283a162e 100644 --- a/server/utils/AiProviders/localAi/index.js +++ b/server/utils/AiProviders/localAi/index.js @@ -4,6 +4,7 @@ const { } = require("../../helpers/chat/LLMPerformanceMonitor"); const { handleDefaultStreamResponseV2, + formatChatHistory, } = require("../../helpers/chat/responses"); class LocalAiLLM { @@ -103,7 +104,7 @@ class LocalAiLLM { }; return [ prompt, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent), { role: "user", content: this.#generateContent({ userPrompt, attachments }), diff --git a/server/utils/AiProviders/mistral/index.js b/server/utils/AiProviders/mistral/index.js index 219f6f52f952b942aa667c727b34e21c2b84de47..6c637857b357d80fca82c710d745f5aa5ee213b5 100644 --- a/server/utils/AiProviders/mistral/index.js +++ b/server/utils/AiProviders/mistral/index.js @@ -4,6 +4,7 @@ const { } = require("../../helpers/chat/LLMPerformanceMonitor"); const { handleDefaultStreamResponseV2, + formatChatHistory, } = require("../../helpers/chat/responses"); class MistralLLM { @@ -26,6 +27,11 @@ class MistralLLM { this.embedder = embedder ?? new NativeEmbedder(); this.defaultTemp = 0.0; + this.log("Initialized with model:", this.model); + } + + log(text, ...args) { + console.log(`\x1b[36m[${this.constructor.name}]\x1b[0m ${text}`, ...args); } #appendContext(contextTexts = []) { @@ -92,7 +98,7 @@ class MistralLLM { }; return [ prompt, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent), { role: "user", content: this.#generateContent({ userPrompt, attachments }), diff --git a/server/utils/AiProviders/novita/index.js b/server/utils/AiProviders/novita/index.js index c41f5a6667603cb42504acb5def8c83335f91c84..8365d288216cd751ef432883d72f47f2582d58f4 100644 --- a/server/utils/AiProviders/novita/index.js +++ b/server/utils/AiProviders/novita/index.js @@ -3,6 +3,7 @@ const { v4: uuidv4 } = require("uuid"); const { writeResponseChunk, clientAbortedHandler, + formatChatHistory, } = require("../../helpers/chat/responses"); const fs = require("fs"); const path = require("path"); @@ -177,7 +178,7 @@ class NovitaLLM { }; return [ prompt, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent), { role: "user", content: this.#generateContent({ userPrompt, attachments }), diff --git a/server/utils/AiProviders/nvidiaNim/index.js b/server/utils/AiProviders/nvidiaNim/index.js index 3cf7f835f155e7b95ec5dfa38ebb27d45ae19ddc..4de408e98cb0cb1dd990c3a7b7fd1869ff2733fe 100644 --- a/server/utils/AiProviders/nvidiaNim/index.js +++ b/server/utils/AiProviders/nvidiaNim/index.js @@ -4,6 +4,7 @@ const { } = require("../../helpers/chat/LLMPerformanceMonitor"); const { handleDefaultStreamResponseV2, + formatChatHistory, } = require("../../helpers/chat/responses"); class NvidiaNimLLM { @@ -142,7 +143,7 @@ class NvidiaNimLLM { }; return [ prompt, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent), { role: "user", content: this.#generateContent({ userPrompt, attachments }), diff --git a/server/utils/AiProviders/ollama/index.js b/server/utils/AiProviders/ollama/index.js index faffcb234b3ef3d236ab9265ab75fc97f88e8c3b..5c53dd5f4b3f97cd694c2a571b7820df75477256 100644 --- a/server/utils/AiProviders/ollama/index.js +++ b/server/utils/AiProviders/ollama/index.js @@ -1,6 +1,7 @@ const { writeResponseChunk, clientAbortedHandler, + formatChatHistory, } = require("../../helpers/chat/responses"); const { NativeEmbedder } = require("../../EmbeddingEngines/native"); const { @@ -120,7 +121,7 @@ class OllamaAILLM { }; return [ prompt, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent, "spread"), { role: "user", ...this.#generateContent({ userPrompt, attachments }), diff --git a/server/utils/AiProviders/openAi/index.js b/server/utils/AiProviders/openAi/index.js index 4209b99ed24478269e6988a796b40b6cb7769f02..71a6a0edf759d3e1817ca4da1e37171d7c4f6df9 100644 --- a/server/utils/AiProviders/openAi/index.js +++ b/server/utils/AiProviders/openAi/index.js @@ -1,6 +1,7 @@ const { NativeEmbedder } = require("../../EmbeddingEngines/native"); const { handleDefaultStreamResponseV2, + formatChatHistory, } = require("../../helpers/chat/responses"); const { MODEL_MAP } = require("../modelMap"); const { @@ -121,7 +122,7 @@ class OpenAiLLM { }; return [ prompt, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent), { role: "user", content: this.#generateContent({ userPrompt, attachments }), diff --git a/server/utils/AiProviders/openRouter/index.js b/server/utils/AiProviders/openRouter/index.js index 3abab7634411f176e8aa723b85ade9cffa94af1f..08f040150f84956a5214e50dbacb635c281c980b 100644 --- a/server/utils/AiProviders/openRouter/index.js +++ b/server/utils/AiProviders/openRouter/index.js @@ -3,6 +3,7 @@ const { v4: uuidv4 } = require("uuid"); const { writeResponseChunk, clientAbortedHandler, + formatChatHistory, } = require("../../helpers/chat/responses"); const fs = require("fs"); const path = require("path"); @@ -47,6 +48,7 @@ class OpenRouterLLM { fs.mkdirSync(cacheFolder, { recursive: true }); this.cacheModelPath = path.resolve(cacheFolder, "models.json"); this.cacheAtPath = path.resolve(cacheFolder, ".cached_at"); + this.log("Initialized with model:", this.model); } log(text, ...args) { @@ -162,7 +164,6 @@ class OpenRouterLLM { }, }); } - console.log(content.flat()); return content.flat(); } @@ -179,7 +180,7 @@ class OpenRouterLLM { }; return [ prompt, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent), { role: "user", content: this.#generateContent({ userPrompt, attachments }), diff --git a/server/utils/AiProviders/textGenWebUI/index.js b/server/utils/AiProviders/textGenWebUI/index.js index f1c3590bff72e656fa36a26b99487fcf2b543fc2..f3647c06d455f230d32d5ffa79218f9557fc5ba6 100644 --- a/server/utils/AiProviders/textGenWebUI/index.js +++ b/server/utils/AiProviders/textGenWebUI/index.js @@ -1,6 +1,7 @@ const { NativeEmbedder } = require("../../EmbeddingEngines/native"); const { handleDefaultStreamResponseV2, + formatChatHistory, } = require("../../helpers/chat/responses"); const { LLMPerformanceMonitor, @@ -113,7 +114,7 @@ class TextGenWebUILLM { }; return [ prompt, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent), { role: "user", content: this.#generateContent({ userPrompt, attachments }), diff --git a/server/utils/AiProviders/xai/index.js b/server/utils/AiProviders/xai/index.js index b18aae98ce5a277da17cd295dc396d2412c18d25..2319e72206e558e5a5190562e710a6e5797be9a9 100644 --- a/server/utils/AiProviders/xai/index.js +++ b/server/utils/AiProviders/xai/index.js @@ -4,6 +4,7 @@ const { } = require("../../helpers/chat/LLMPerformanceMonitor"); const { handleDefaultStreamResponseV2, + formatChatHistory, } = require("../../helpers/chat/responses"); const { MODEL_MAP } = require("../modelMap"); @@ -27,6 +28,11 @@ class XAiLLM { this.embedder = embedder ?? new NativeEmbedder(); this.defaultTemp = 0.7; + this.log("Initialized with model:", this.model); + } + + log(text, ...args) { + console.log(`\x1b[36m[${this.constructor.name}]\x1b[0m ${text}`, ...args); } #appendContext(contextTexts = []) { @@ -53,13 +59,8 @@ class XAiLLM { return MODEL_MAP.xai[this.model] ?? 131_072; } - isValidChatCompletionModel(modelName = "") { - switch (modelName) { - case "grok-beta": - return true; - default: - return false; - } + isValidChatCompletionModel(_modelName = "") { + return true; } /** @@ -103,7 +104,7 @@ class XAiLLM { }; return [ prompt, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent), { role: "user", content: this.#generateContent({ userPrompt, attachments }), diff --git a/server/utils/chats/embed.js b/server/utils/chats/embed.js index 550a460f8743aadbdda37b75707bea18a78175c6..70904a541051beb82820e9960fd9a2ff5a7ec75f 100644 --- a/server/utils/chats/embed.js +++ b/server/utils/chats/embed.js @@ -210,7 +210,7 @@ async function streamChatWithForEmbed( * @param {string} sessionId the session id of the user from embed widget * @param {Object} embed the embed config object * @param {Number} messageLimit the number of messages to return - * @returns {Promise<{rawHistory: import("@prisma/client").embed_chats[], chatHistory: {role: string, content: string}[]}> + * @returns {Promise<{rawHistory: import("@prisma/client").embed_chats[], chatHistory: {role: string, content: string, attachments?: Object[]}[]}> */ async function recentEmbedChatHistory(sessionId, embed, messageLimit = 20) { const rawHistory = ( diff --git a/server/utils/helpers/chat/responses.js b/server/utils/helpers/chat/responses.js index 9be1d224cce4a717af22cd7d52272bdafb19de3c..16a1e9af4392477e73c073d4d4c0d981f09f0f03 100644 --- a/server/utils/helpers/chat/responses.js +++ b/server/utils/helpers/chat/responses.js @@ -164,6 +164,11 @@ function convertToChatHistory(history = []) { return formattedHistory.flat(); } +/** + * Converts a chat history to a prompt history. + * @param {Object[]} history - The chat history to convert + * @returns {{role: string, content: string, attachments?: import("..").Attachment}[]} + */ function convertToPromptHistory(history = []) { const formattedHistory = []; for (const record of history) { @@ -185,8 +190,18 @@ function convertToPromptHistory(history = []) { } formattedHistory.push([ - { role: "user", content: prompt }, - { role: "assistant", content: data.text }, + { + role: "user", + content: prompt, + // if there are attachments, add them as a property to the user message so we can reuse them in chat history later if supported by the llm. + ...(data?.attachments?.length > 0 + ? { attachments: data?.attachments } + : {}), + }, + { + role: "assistant", + content: data.text, + }, ]); } return formattedHistory.flat(); @@ -197,10 +212,54 @@ function writeResponseChunk(response, data) { return; } +/** + * Formats the chat history to re-use attachments in the chat history + * that might have existed in the conversation earlier. + * @param {{role:string, content:string, attachments?: Object[]}[]} chatHistory + * @param {function} formatterFunction - The function to format the chat history from the llm provider + * @param {('asProperty'|'spread')} mode - "asProperty" or "spread". Determines how the content is formatted in the message object. + * @returns {object[]} + */ +function formatChatHistory( + chatHistory = [], + formatterFunction, + mode = "asProperty" +) { + return chatHistory.map((historicalMessage) => { + if ( + historicalMessage?.role !== "user" || // Only user messages can have attachments + !historicalMessage?.attachments || // If there are no attachments, we can skip this + !historicalMessage.attachments.length // If there is an array but it is empty, we can skip this + ) + return historicalMessage; + + // Some providers, like Ollama, expect the content to be embedded in the message object. + if (mode === "spread") { + return { + role: historicalMessage.role, + ...formatterFunction({ + userPrompt: historicalMessage.content, + attachments: historicalMessage.attachments, + }), + }; + } + + // Most providers expect the content to be a property of the message object formatted like OpenAI models. + return { + role: historicalMessage.role, + content: formatterFunction({ + userPrompt: historicalMessage.content, + attachments: historicalMessage.attachments, + }), + }; + }); +} + module.exports = { handleDefaultStreamResponseV2, convertToChatHistory, convertToPromptHistory, writeResponseChunk, clientAbortedHandler, + formatChatHistory, };