From c4f75feb088abc734751e1b04b141d36203c7114 Mon Sep 17 00:00:00 2001 From: Timothy Carambat <rambat1010@gmail.com> Date: Thu, 16 Jan 2025 13:49:06 -0800 Subject: [PATCH] Support historical message image inputs/attachments for n+1 queries (#2919) * Support historical message image inputs/attachments for n+1 queries * patch gemini * OpenRouter vision support cleanup * xai vision history support * Mistral logging --------- Co-authored-by: shatfield4 <seanhatfield5@gmail.com> --- server/utils/AiProviders/anthropic/index.js | 3 +- server/utils/AiProviders/apipie/index.js | 3 +- server/utils/AiProviders/azureOpenAi/index.js | 3 +- server/utils/AiProviders/bedrock/index.js | 5 +- server/utils/AiProviders/gemini/index.js | 15 ++++- .../utils/AiProviders/genericOpenAi/index.js | 3 +- server/utils/AiProviders/groq/index.js | 2 + server/utils/AiProviders/koboldCPP/index.js | 3 +- server/utils/AiProviders/liteLLM/index.js | 3 +- server/utils/AiProviders/lmStudio/index.js | 3 +- server/utils/AiProviders/localAi/index.js | 3 +- server/utils/AiProviders/mistral/index.js | 8 ++- server/utils/AiProviders/novita/index.js | 3 +- server/utils/AiProviders/nvidiaNim/index.js | 3 +- server/utils/AiProviders/ollama/index.js | 3 +- server/utils/AiProviders/openAi/index.js | 3 +- server/utils/AiProviders/openRouter/index.js | 5 +- .../utils/AiProviders/textGenWebUI/index.js | 3 +- server/utils/AiProviders/xai/index.js | 17 ++--- server/utils/chats/embed.js | 2 +- server/utils/helpers/chat/responses.js | 63 ++++++++++++++++++- 21 files changed, 125 insertions(+), 31 deletions(-) diff --git a/server/utils/AiProviders/anthropic/index.js b/server/utils/AiProviders/anthropic/index.js index f9c4c91c7..0cd958959 100644 --- a/server/utils/AiProviders/anthropic/index.js +++ b/server/utils/AiProviders/anthropic/index.js @@ -2,6 +2,7 @@ const { v4 } = require("uuid"); const { writeResponseChunk, clientAbortedHandler, + formatChatHistory, } = require("../../helpers/chat/responses"); const { NativeEmbedder } = require("../../EmbeddingEngines/native"); const { MODEL_MAP } = require("../modelMap"); @@ -99,7 +100,7 @@ class AnthropicLLM { return [ prompt, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent), { role: "user", content: this.#generateContent({ userPrompt, attachments }), diff --git a/server/utils/AiProviders/apipie/index.js b/server/utils/AiProviders/apipie/index.js index 1f6dd68a0..bd794d38e 100644 --- a/server/utils/AiProviders/apipie/index.js +++ b/server/utils/AiProviders/apipie/index.js @@ -3,6 +3,7 @@ const { v4: uuidv4 } = require("uuid"); const { writeResponseChunk, clientAbortedHandler, + formatChatHistory, } = require("../../helpers/chat/responses"); const fs = require("fs"); const path = require("path"); @@ -177,7 +178,7 @@ class ApiPieLLM { }; return [ prompt, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent), { role: "user", content: this.#generateContent({ userPrompt, attachments }), diff --git a/server/utils/AiProviders/azureOpenAi/index.js b/server/utils/AiProviders/azureOpenAi/index.js index 078f55eef..cbf2c2ef3 100644 --- a/server/utils/AiProviders/azureOpenAi/index.js +++ b/server/utils/AiProviders/azureOpenAi/index.js @@ -5,6 +5,7 @@ const { const { writeResponseChunk, clientAbortedHandler, + formatChatHistory, } = require("../../helpers/chat/responses"); class AzureOpenAiLLM { @@ -103,7 +104,7 @@ class AzureOpenAiLLM { }; return [ prompt, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent), { role: "user", content: this.#generateContent({ userPrompt, attachments }), diff --git a/server/utils/AiProviders/bedrock/index.js b/server/utils/AiProviders/bedrock/index.js index 171d7b459..d5f66eaea 100644 --- a/server/utils/AiProviders/bedrock/index.js +++ b/server/utils/AiProviders/bedrock/index.js @@ -2,6 +2,7 @@ const { StringOutputParser } = require("@langchain/core/output_parsers"); const { writeResponseChunk, clientAbortedHandler, + formatChatHistory, } = require("../../helpers/chat/responses"); const { NativeEmbedder } = require("../../EmbeddingEngines/native"); const { @@ -199,7 +200,7 @@ class AWSBedrockLLM { // AWS Mistral models do not support system prompts if (this.model.startsWith("mistral")) return [ - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent, "spread"), { role: "user", ...this.#generateContent({ userPrompt, attachments }), @@ -212,7 +213,7 @@ class AWSBedrockLLM { }; return [ prompt, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent, "spread"), { role: "user", ...this.#generateContent({ userPrompt, attachments }), diff --git a/server/utils/AiProviders/gemini/index.js b/server/utils/AiProviders/gemini/index.js index 9961c70d7..fd7929f4b 100644 --- a/server/utils/AiProviders/gemini/index.js +++ b/server/utils/AiProviders/gemini/index.js @@ -7,6 +7,7 @@ const { const { writeResponseChunk, clientAbortedHandler, + formatChatHistory, } = require("../../helpers/chat/responses"); const { MODEL_MAP } = require("../modelMap"); const { defaultGeminiModels, v1BetaModels } = require("./defaultModels"); @@ -254,6 +255,7 @@ class GeminiLLM { const models = await this.fetchModels(process.env.GEMINI_API_KEY); return models.some((model) => model.id === modelName); } + /** * Generates appropriate content array for a message + attachments. * @param {{userPrompt:string, attachments: import("../../helpers").Attachment[]}} @@ -290,7 +292,7 @@ class GeminiLLM { return [ prompt, { role: "assistant", content: "Okay." }, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent), { role: "USER_PROMPT", content: this.#generateContent({ userPrompt, attachments }), @@ -306,8 +308,17 @@ class GeminiLLM { .map((message) => { if (message.role === "system") return { role: "user", parts: [{ text: message.content }] }; - if (message.role === "user") + + if (message.role === "user") { + // If the content is an array - then we have already formatted the context so return it directly. + if (Array.isArray(message.content)) + return { role: "user", parts: message.content }; + + // Otherwise, this was a regular user message with no attachments + // so we need to format it for Gemini return { role: "user", parts: [{ text: message.content }] }; + } + if (message.role === "assistant") return { role: "model", parts: [{ text: message.content }] }; return null; diff --git a/server/utils/AiProviders/genericOpenAi/index.js b/server/utils/AiProviders/genericOpenAi/index.js index 57c8f6a14..eb020298c 100644 --- a/server/utils/AiProviders/genericOpenAi/index.js +++ b/server/utils/AiProviders/genericOpenAi/index.js @@ -4,6 +4,7 @@ const { } = require("../../helpers/chat/LLMPerformanceMonitor"); const { handleDefaultStreamResponseV2, + formatChatHistory, } = require("../../helpers/chat/responses"); const { toValidNumber } = require("../../http"); @@ -133,7 +134,7 @@ class GenericOpenAiLLM { }; return [ prompt, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent), { role: "user", content: this.#generateContent({ userPrompt, attachments }), diff --git a/server/utils/AiProviders/groq/index.js b/server/utils/AiProviders/groq/index.js index 5793002f6..9e7e77fa1 100644 --- a/server/utils/AiProviders/groq/index.js +++ b/server/utils/AiProviders/groq/index.js @@ -89,6 +89,8 @@ class GroqLLM { * Since we can only explicitly support the current models, this is a temporary solution. * If the attachments are empty or the model is not a vision model, we will return the default prompt structure which will work for all models. * If the attachments are present and the model is a vision model - we only return the user prompt with attachments - see comment at end of function for more. + * + * Historical attachments are also omitted from prompt chat history for the reasons above. (TDC: Dec 30, 2024) */ #conditionalPromptStruct({ systemPrompt = "", diff --git a/server/utils/AiProviders/koboldCPP/index.js b/server/utils/AiProviders/koboldCPP/index.js index 0e5206cab..5ee58b5bf 100644 --- a/server/utils/AiProviders/koboldCPP/index.js +++ b/server/utils/AiProviders/koboldCPP/index.js @@ -2,6 +2,7 @@ const { NativeEmbedder } = require("../../EmbeddingEngines/native"); const { clientAbortedHandler, writeResponseChunk, + formatChatHistory, } = require("../../helpers/chat/responses"); const { LLMPerformanceMonitor, @@ -116,7 +117,7 @@ class KoboldCPPLLM { }; return [ prompt, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent), { role: "user", content: this.#generateContent({ userPrompt, attachments }), diff --git a/server/utils/AiProviders/liteLLM/index.js b/server/utils/AiProviders/liteLLM/index.js index 63f4115bc..2017d7774 100644 --- a/server/utils/AiProviders/liteLLM/index.js +++ b/server/utils/AiProviders/liteLLM/index.js @@ -4,6 +4,7 @@ const { } = require("../../helpers/chat/LLMPerformanceMonitor"); const { handleDefaultStreamResponseV2, + formatChatHistory, } = require("../../helpers/chat/responses"); class LiteLLM { @@ -115,7 +116,7 @@ class LiteLLM { }; return [ prompt, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent), { role: "user", content: this.#generateContent({ userPrompt, attachments }), diff --git a/server/utils/AiProviders/lmStudio/index.js b/server/utils/AiProviders/lmStudio/index.js index 082576c98..bde9ed486 100644 --- a/server/utils/AiProviders/lmStudio/index.js +++ b/server/utils/AiProviders/lmStudio/index.js @@ -1,6 +1,7 @@ const { NativeEmbedder } = require("../../EmbeddingEngines/native"); const { handleDefaultStreamResponseV2, + formatChatHistory, } = require("../../helpers/chat/responses"); const { LLMPerformanceMonitor, @@ -117,7 +118,7 @@ class LMStudioLLM { }; return [ prompt, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent), { role: "user", content: this.#generateContent({ userPrompt, attachments }), diff --git a/server/utils/AiProviders/localAi/index.js b/server/utils/AiProviders/localAi/index.js index 53da280f2..f62fe70dd 100644 --- a/server/utils/AiProviders/localAi/index.js +++ b/server/utils/AiProviders/localAi/index.js @@ -4,6 +4,7 @@ const { } = require("../../helpers/chat/LLMPerformanceMonitor"); const { handleDefaultStreamResponseV2, + formatChatHistory, } = require("../../helpers/chat/responses"); class LocalAiLLM { @@ -103,7 +104,7 @@ class LocalAiLLM { }; return [ prompt, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent), { role: "user", content: this.#generateContent({ userPrompt, attachments }), diff --git a/server/utils/AiProviders/mistral/index.js b/server/utils/AiProviders/mistral/index.js index 219f6f52f..6c637857b 100644 --- a/server/utils/AiProviders/mistral/index.js +++ b/server/utils/AiProviders/mistral/index.js @@ -4,6 +4,7 @@ const { } = require("../../helpers/chat/LLMPerformanceMonitor"); const { handleDefaultStreamResponseV2, + formatChatHistory, } = require("../../helpers/chat/responses"); class MistralLLM { @@ -26,6 +27,11 @@ class MistralLLM { this.embedder = embedder ?? new NativeEmbedder(); this.defaultTemp = 0.0; + this.log("Initialized with model:", this.model); + } + + log(text, ...args) { + console.log(`\x1b[36m[${this.constructor.name}]\x1b[0m ${text}`, ...args); } #appendContext(contextTexts = []) { @@ -92,7 +98,7 @@ class MistralLLM { }; return [ prompt, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent), { role: "user", content: this.#generateContent({ userPrompt, attachments }), diff --git a/server/utils/AiProviders/novita/index.js b/server/utils/AiProviders/novita/index.js index c41f5a666..8365d2882 100644 --- a/server/utils/AiProviders/novita/index.js +++ b/server/utils/AiProviders/novita/index.js @@ -3,6 +3,7 @@ const { v4: uuidv4 } = require("uuid"); const { writeResponseChunk, clientAbortedHandler, + formatChatHistory, } = require("../../helpers/chat/responses"); const fs = require("fs"); const path = require("path"); @@ -177,7 +178,7 @@ class NovitaLLM { }; return [ prompt, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent), { role: "user", content: this.#generateContent({ userPrompt, attachments }), diff --git a/server/utils/AiProviders/nvidiaNim/index.js b/server/utils/AiProviders/nvidiaNim/index.js index 3cf7f835f..4de408e98 100644 --- a/server/utils/AiProviders/nvidiaNim/index.js +++ b/server/utils/AiProviders/nvidiaNim/index.js @@ -4,6 +4,7 @@ const { } = require("../../helpers/chat/LLMPerformanceMonitor"); const { handleDefaultStreamResponseV2, + formatChatHistory, } = require("../../helpers/chat/responses"); class NvidiaNimLLM { @@ -142,7 +143,7 @@ class NvidiaNimLLM { }; return [ prompt, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent), { role: "user", content: this.#generateContent({ userPrompt, attachments }), diff --git a/server/utils/AiProviders/ollama/index.js b/server/utils/AiProviders/ollama/index.js index faffcb234..5c53dd5f4 100644 --- a/server/utils/AiProviders/ollama/index.js +++ b/server/utils/AiProviders/ollama/index.js @@ -1,6 +1,7 @@ const { writeResponseChunk, clientAbortedHandler, + formatChatHistory, } = require("../../helpers/chat/responses"); const { NativeEmbedder } = require("../../EmbeddingEngines/native"); const { @@ -120,7 +121,7 @@ class OllamaAILLM { }; return [ prompt, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent, "spread"), { role: "user", ...this.#generateContent({ userPrompt, attachments }), diff --git a/server/utils/AiProviders/openAi/index.js b/server/utils/AiProviders/openAi/index.js index 4209b99ed..71a6a0edf 100644 --- a/server/utils/AiProviders/openAi/index.js +++ b/server/utils/AiProviders/openAi/index.js @@ -1,6 +1,7 @@ const { NativeEmbedder } = require("../../EmbeddingEngines/native"); const { handleDefaultStreamResponseV2, + formatChatHistory, } = require("../../helpers/chat/responses"); const { MODEL_MAP } = require("../modelMap"); const { @@ -121,7 +122,7 @@ class OpenAiLLM { }; return [ prompt, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent), { role: "user", content: this.#generateContent({ userPrompt, attachments }), diff --git a/server/utils/AiProviders/openRouter/index.js b/server/utils/AiProviders/openRouter/index.js index 3abab7634..08f040150 100644 --- a/server/utils/AiProviders/openRouter/index.js +++ b/server/utils/AiProviders/openRouter/index.js @@ -3,6 +3,7 @@ const { v4: uuidv4 } = require("uuid"); const { writeResponseChunk, clientAbortedHandler, + formatChatHistory, } = require("../../helpers/chat/responses"); const fs = require("fs"); const path = require("path"); @@ -47,6 +48,7 @@ class OpenRouterLLM { fs.mkdirSync(cacheFolder, { recursive: true }); this.cacheModelPath = path.resolve(cacheFolder, "models.json"); this.cacheAtPath = path.resolve(cacheFolder, ".cached_at"); + this.log("Initialized with model:", this.model); } log(text, ...args) { @@ -162,7 +164,6 @@ class OpenRouterLLM { }, }); } - console.log(content.flat()); return content.flat(); } @@ -179,7 +180,7 @@ class OpenRouterLLM { }; return [ prompt, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent), { role: "user", content: this.#generateContent({ userPrompt, attachments }), diff --git a/server/utils/AiProviders/textGenWebUI/index.js b/server/utils/AiProviders/textGenWebUI/index.js index f1c3590bf..f3647c06d 100644 --- a/server/utils/AiProviders/textGenWebUI/index.js +++ b/server/utils/AiProviders/textGenWebUI/index.js @@ -1,6 +1,7 @@ const { NativeEmbedder } = require("../../EmbeddingEngines/native"); const { handleDefaultStreamResponseV2, + formatChatHistory, } = require("../../helpers/chat/responses"); const { LLMPerformanceMonitor, @@ -113,7 +114,7 @@ class TextGenWebUILLM { }; return [ prompt, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent), { role: "user", content: this.#generateContent({ userPrompt, attachments }), diff --git a/server/utils/AiProviders/xai/index.js b/server/utils/AiProviders/xai/index.js index b18aae98c..2319e7220 100644 --- a/server/utils/AiProviders/xai/index.js +++ b/server/utils/AiProviders/xai/index.js @@ -4,6 +4,7 @@ const { } = require("../../helpers/chat/LLMPerformanceMonitor"); const { handleDefaultStreamResponseV2, + formatChatHistory, } = require("../../helpers/chat/responses"); const { MODEL_MAP } = require("../modelMap"); @@ -27,6 +28,11 @@ class XAiLLM { this.embedder = embedder ?? new NativeEmbedder(); this.defaultTemp = 0.7; + this.log("Initialized with model:", this.model); + } + + log(text, ...args) { + console.log(`\x1b[36m[${this.constructor.name}]\x1b[0m ${text}`, ...args); } #appendContext(contextTexts = []) { @@ -53,13 +59,8 @@ class XAiLLM { return MODEL_MAP.xai[this.model] ?? 131_072; } - isValidChatCompletionModel(modelName = "") { - switch (modelName) { - case "grok-beta": - return true; - default: - return false; - } + isValidChatCompletionModel(_modelName = "") { + return true; } /** @@ -103,7 +104,7 @@ class XAiLLM { }; return [ prompt, - ...chatHistory, + ...formatChatHistory(chatHistory, this.#generateContent), { role: "user", content: this.#generateContent({ userPrompt, attachments }), diff --git a/server/utils/chats/embed.js b/server/utils/chats/embed.js index 550a460f8..70904a541 100644 --- a/server/utils/chats/embed.js +++ b/server/utils/chats/embed.js @@ -210,7 +210,7 @@ async function streamChatWithForEmbed( * @param {string} sessionId the session id of the user from embed widget * @param {Object} embed the embed config object * @param {Number} messageLimit the number of messages to return - * @returns {Promise<{rawHistory: import("@prisma/client").embed_chats[], chatHistory: {role: string, content: string}[]}> + * @returns {Promise<{rawHistory: import("@prisma/client").embed_chats[], chatHistory: {role: string, content: string, attachments?: Object[]}[]}> */ async function recentEmbedChatHistory(sessionId, embed, messageLimit = 20) { const rawHistory = ( diff --git a/server/utils/helpers/chat/responses.js b/server/utils/helpers/chat/responses.js index 9be1d224c..16a1e9af4 100644 --- a/server/utils/helpers/chat/responses.js +++ b/server/utils/helpers/chat/responses.js @@ -164,6 +164,11 @@ function convertToChatHistory(history = []) { return formattedHistory.flat(); } +/** + * Converts a chat history to a prompt history. + * @param {Object[]} history - The chat history to convert + * @returns {{role: string, content: string, attachments?: import("..").Attachment}[]} + */ function convertToPromptHistory(history = []) { const formattedHistory = []; for (const record of history) { @@ -185,8 +190,18 @@ function convertToPromptHistory(history = []) { } formattedHistory.push([ - { role: "user", content: prompt }, - { role: "assistant", content: data.text }, + { + role: "user", + content: prompt, + // if there are attachments, add them as a property to the user message so we can reuse them in chat history later if supported by the llm. + ...(data?.attachments?.length > 0 + ? { attachments: data?.attachments } + : {}), + }, + { + role: "assistant", + content: data.text, + }, ]); } return formattedHistory.flat(); @@ -197,10 +212,54 @@ function writeResponseChunk(response, data) { return; } +/** + * Formats the chat history to re-use attachments in the chat history + * that might have existed in the conversation earlier. + * @param {{role:string, content:string, attachments?: Object[]}[]} chatHistory + * @param {function} formatterFunction - The function to format the chat history from the llm provider + * @param {('asProperty'|'spread')} mode - "asProperty" or "spread". Determines how the content is formatted in the message object. + * @returns {object[]} + */ +function formatChatHistory( + chatHistory = [], + formatterFunction, + mode = "asProperty" +) { + return chatHistory.map((historicalMessage) => { + if ( + historicalMessage?.role !== "user" || // Only user messages can have attachments + !historicalMessage?.attachments || // If there are no attachments, we can skip this + !historicalMessage.attachments.length // If there is an array but it is empty, we can skip this + ) + return historicalMessage; + + // Some providers, like Ollama, expect the content to be embedded in the message object. + if (mode === "spread") { + return { + role: historicalMessage.role, + ...formatterFunction({ + userPrompt: historicalMessage.content, + attachments: historicalMessage.attachments, + }), + }; + } + + // Most providers expect the content to be a property of the message object formatted like OpenAI models. + return { + role: historicalMessage.role, + content: formatterFunction({ + userPrompt: historicalMessage.content, + attachments: historicalMessage.attachments, + }), + }; + }); +} + module.exports = { handleDefaultStreamResponseV2, convertToChatHistory, convertToPromptHistory, writeResponseChunk, clientAbortedHandler, + formatChatHistory, }; -- GitLab