From c4f75feb088abc734751e1b04b141d36203c7114 Mon Sep 17 00:00:00 2001
From: Timothy Carambat <rambat1010@gmail.com>
Date: Thu, 16 Jan 2025 13:49:06 -0800
Subject: [PATCH] Support historical message image inputs/attachments for n+1
 queries (#2919)

* Support historical message image inputs/attachments for n+1 queries

* patch gemini

* OpenRouter vision support cleanup

* xai vision history support

* Mistral logging

---------

Co-authored-by: shatfield4 <seanhatfield5@gmail.com>
---
 server/utils/AiProviders/anthropic/index.js   |  3 +-
 server/utils/AiProviders/apipie/index.js      |  3 +-
 server/utils/AiProviders/azureOpenAi/index.js |  3 +-
 server/utils/AiProviders/bedrock/index.js     |  5 +-
 server/utils/AiProviders/gemini/index.js      | 15 ++++-
 .../utils/AiProviders/genericOpenAi/index.js  |  3 +-
 server/utils/AiProviders/groq/index.js        |  2 +
 server/utils/AiProviders/koboldCPP/index.js   |  3 +-
 server/utils/AiProviders/liteLLM/index.js     |  3 +-
 server/utils/AiProviders/lmStudio/index.js    |  3 +-
 server/utils/AiProviders/localAi/index.js     |  3 +-
 server/utils/AiProviders/mistral/index.js     |  8 ++-
 server/utils/AiProviders/novita/index.js      |  3 +-
 server/utils/AiProviders/nvidiaNim/index.js   |  3 +-
 server/utils/AiProviders/ollama/index.js      |  3 +-
 server/utils/AiProviders/openAi/index.js      |  3 +-
 server/utils/AiProviders/openRouter/index.js  |  5 +-
 .../utils/AiProviders/textGenWebUI/index.js   |  3 +-
 server/utils/AiProviders/xai/index.js         | 17 ++---
 server/utils/chats/embed.js                   |  2 +-
 server/utils/helpers/chat/responses.js        | 63 ++++++++++++++++++-
 21 files changed, 125 insertions(+), 31 deletions(-)

diff --git a/server/utils/AiProviders/anthropic/index.js b/server/utils/AiProviders/anthropic/index.js
index f9c4c91c7..0cd958959 100644
--- a/server/utils/AiProviders/anthropic/index.js
+++ b/server/utils/AiProviders/anthropic/index.js
@@ -2,6 +2,7 @@ const { v4 } = require("uuid");
 const {
   writeResponseChunk,
   clientAbortedHandler,
+  formatChatHistory,
 } = require("../../helpers/chat/responses");
 const { NativeEmbedder } = require("../../EmbeddingEngines/native");
 const { MODEL_MAP } = require("../modelMap");
@@ -99,7 +100,7 @@ class AnthropicLLM {
 
     return [
       prompt,
-      ...chatHistory,
+      ...formatChatHistory(chatHistory, this.#generateContent),
       {
         role: "user",
         content: this.#generateContent({ userPrompt, attachments }),
diff --git a/server/utils/AiProviders/apipie/index.js b/server/utils/AiProviders/apipie/index.js
index 1f6dd68a0..bd794d38e 100644
--- a/server/utils/AiProviders/apipie/index.js
+++ b/server/utils/AiProviders/apipie/index.js
@@ -3,6 +3,7 @@ const { v4: uuidv4 } = require("uuid");
 const {
   writeResponseChunk,
   clientAbortedHandler,
+  formatChatHistory,
 } = require("../../helpers/chat/responses");
 const fs = require("fs");
 const path = require("path");
@@ -177,7 +178,7 @@ class ApiPieLLM {
     };
     return [
       prompt,
-      ...chatHistory,
+      ...formatChatHistory(chatHistory, this.#generateContent),
       {
         role: "user",
         content: this.#generateContent({ userPrompt, attachments }),
diff --git a/server/utils/AiProviders/azureOpenAi/index.js b/server/utils/AiProviders/azureOpenAi/index.js
index 078f55eef..cbf2c2ef3 100644
--- a/server/utils/AiProviders/azureOpenAi/index.js
+++ b/server/utils/AiProviders/azureOpenAi/index.js
@@ -5,6 +5,7 @@ const {
 const {
   writeResponseChunk,
   clientAbortedHandler,
+  formatChatHistory,
 } = require("../../helpers/chat/responses");
 
 class AzureOpenAiLLM {
@@ -103,7 +104,7 @@ class AzureOpenAiLLM {
     };
     return [
       prompt,
-      ...chatHistory,
+      ...formatChatHistory(chatHistory, this.#generateContent),
       {
         role: "user",
         content: this.#generateContent({ userPrompt, attachments }),
diff --git a/server/utils/AiProviders/bedrock/index.js b/server/utils/AiProviders/bedrock/index.js
index 171d7b459..d5f66eaea 100644
--- a/server/utils/AiProviders/bedrock/index.js
+++ b/server/utils/AiProviders/bedrock/index.js
@@ -2,6 +2,7 @@ const { StringOutputParser } = require("@langchain/core/output_parsers");
 const {
   writeResponseChunk,
   clientAbortedHandler,
+  formatChatHistory,
 } = require("../../helpers/chat/responses");
 const { NativeEmbedder } = require("../../EmbeddingEngines/native");
 const {
@@ -199,7 +200,7 @@ class AWSBedrockLLM {
     // AWS Mistral models do not support system prompts
     if (this.model.startsWith("mistral"))
       return [
-        ...chatHistory,
+        ...formatChatHistory(chatHistory, this.#generateContent, "spread"),
         {
           role: "user",
           ...this.#generateContent({ userPrompt, attachments }),
@@ -212,7 +213,7 @@ class AWSBedrockLLM {
     };
     return [
       prompt,
-      ...chatHistory,
+      ...formatChatHistory(chatHistory, this.#generateContent, "spread"),
       {
         role: "user",
         ...this.#generateContent({ userPrompt, attachments }),
diff --git a/server/utils/AiProviders/gemini/index.js b/server/utils/AiProviders/gemini/index.js
index 9961c70d7..fd7929f4b 100644
--- a/server/utils/AiProviders/gemini/index.js
+++ b/server/utils/AiProviders/gemini/index.js
@@ -7,6 +7,7 @@ const {
 const {
   writeResponseChunk,
   clientAbortedHandler,
+  formatChatHistory,
 } = require("../../helpers/chat/responses");
 const { MODEL_MAP } = require("../modelMap");
 const { defaultGeminiModels, v1BetaModels } = require("./defaultModels");
@@ -254,6 +255,7 @@ class GeminiLLM {
     const models = await this.fetchModels(process.env.GEMINI_API_KEY);
     return models.some((model) => model.id === modelName);
   }
+
   /**
    * Generates appropriate content array for a message + attachments.
    * @param {{userPrompt:string, attachments: import("../../helpers").Attachment[]}}
@@ -290,7 +292,7 @@ class GeminiLLM {
     return [
       prompt,
       { role: "assistant", content: "Okay." },
-      ...chatHistory,
+      ...formatChatHistory(chatHistory, this.#generateContent),
       {
         role: "USER_PROMPT",
         content: this.#generateContent({ userPrompt, attachments }),
@@ -306,8 +308,17 @@ class GeminiLLM {
       .map((message) => {
         if (message.role === "system")
           return { role: "user", parts: [{ text: message.content }] };
-        if (message.role === "user")
+
+        if (message.role === "user") {
+          // If the content is an array - then we have already formatted the context so return it directly.
+          if (Array.isArray(message.content))
+            return { role: "user", parts: message.content };
+
+          // Otherwise, this was a regular user message with no attachments
+          // so we need to format it for Gemini
           return { role: "user", parts: [{ text: message.content }] };
+        }
+
         if (message.role === "assistant")
           return { role: "model", parts: [{ text: message.content }] };
         return null;
diff --git a/server/utils/AiProviders/genericOpenAi/index.js b/server/utils/AiProviders/genericOpenAi/index.js
index 57c8f6a14..eb020298c 100644
--- a/server/utils/AiProviders/genericOpenAi/index.js
+++ b/server/utils/AiProviders/genericOpenAi/index.js
@@ -4,6 +4,7 @@ const {
 } = require("../../helpers/chat/LLMPerformanceMonitor");
 const {
   handleDefaultStreamResponseV2,
+  formatChatHistory,
 } = require("../../helpers/chat/responses");
 const { toValidNumber } = require("../../http");
 
@@ -133,7 +134,7 @@ class GenericOpenAiLLM {
     };
     return [
       prompt,
-      ...chatHistory,
+      ...formatChatHistory(chatHistory, this.#generateContent),
       {
         role: "user",
         content: this.#generateContent({ userPrompt, attachments }),
diff --git a/server/utils/AiProviders/groq/index.js b/server/utils/AiProviders/groq/index.js
index 5793002f6..9e7e77fa1 100644
--- a/server/utils/AiProviders/groq/index.js
+++ b/server/utils/AiProviders/groq/index.js
@@ -89,6 +89,8 @@ class GroqLLM {
    * Since we can only explicitly support the current models, this is a temporary solution.
    * If the attachments are empty or the model is not a vision model, we will return the default prompt structure which will work for all models.
    * If the attachments are present and the model is a vision model - we only return the user prompt with attachments - see comment at end of function for more.
+   *
+   * Historical attachments are also omitted from prompt chat history for the reasons above. (TDC: Dec 30, 2024)
    */
   #conditionalPromptStruct({
     systemPrompt = "",
diff --git a/server/utils/AiProviders/koboldCPP/index.js b/server/utils/AiProviders/koboldCPP/index.js
index 0e5206cab..5ee58b5bf 100644
--- a/server/utils/AiProviders/koboldCPP/index.js
+++ b/server/utils/AiProviders/koboldCPP/index.js
@@ -2,6 +2,7 @@ const { NativeEmbedder } = require("../../EmbeddingEngines/native");
 const {
   clientAbortedHandler,
   writeResponseChunk,
+  formatChatHistory,
 } = require("../../helpers/chat/responses");
 const {
   LLMPerformanceMonitor,
@@ -116,7 +117,7 @@ class KoboldCPPLLM {
     };
     return [
       prompt,
-      ...chatHistory,
+      ...formatChatHistory(chatHistory, this.#generateContent),
       {
         role: "user",
         content: this.#generateContent({ userPrompt, attachments }),
diff --git a/server/utils/AiProviders/liteLLM/index.js b/server/utils/AiProviders/liteLLM/index.js
index 63f4115bc..2017d7774 100644
--- a/server/utils/AiProviders/liteLLM/index.js
+++ b/server/utils/AiProviders/liteLLM/index.js
@@ -4,6 +4,7 @@ const {
 } = require("../../helpers/chat/LLMPerformanceMonitor");
 const {
   handleDefaultStreamResponseV2,
+  formatChatHistory,
 } = require("../../helpers/chat/responses");
 
 class LiteLLM {
@@ -115,7 +116,7 @@ class LiteLLM {
     };
     return [
       prompt,
-      ...chatHistory,
+      ...formatChatHistory(chatHistory, this.#generateContent),
       {
         role: "user",
         content: this.#generateContent({ userPrompt, attachments }),
diff --git a/server/utils/AiProviders/lmStudio/index.js b/server/utils/AiProviders/lmStudio/index.js
index 082576c98..bde9ed486 100644
--- a/server/utils/AiProviders/lmStudio/index.js
+++ b/server/utils/AiProviders/lmStudio/index.js
@@ -1,6 +1,7 @@
 const { NativeEmbedder } = require("../../EmbeddingEngines/native");
 const {
   handleDefaultStreamResponseV2,
+  formatChatHistory,
 } = require("../../helpers/chat/responses");
 const {
   LLMPerformanceMonitor,
@@ -117,7 +118,7 @@ class LMStudioLLM {
     };
     return [
       prompt,
-      ...chatHistory,
+      ...formatChatHistory(chatHistory, this.#generateContent),
       {
         role: "user",
         content: this.#generateContent({ userPrompt, attachments }),
diff --git a/server/utils/AiProviders/localAi/index.js b/server/utils/AiProviders/localAi/index.js
index 53da280f2..f62fe70dd 100644
--- a/server/utils/AiProviders/localAi/index.js
+++ b/server/utils/AiProviders/localAi/index.js
@@ -4,6 +4,7 @@ const {
 } = require("../../helpers/chat/LLMPerformanceMonitor");
 const {
   handleDefaultStreamResponseV2,
+  formatChatHistory,
 } = require("../../helpers/chat/responses");
 
 class LocalAiLLM {
@@ -103,7 +104,7 @@ class LocalAiLLM {
     };
     return [
       prompt,
-      ...chatHistory,
+      ...formatChatHistory(chatHistory, this.#generateContent),
       {
         role: "user",
         content: this.#generateContent({ userPrompt, attachments }),
diff --git a/server/utils/AiProviders/mistral/index.js b/server/utils/AiProviders/mistral/index.js
index 219f6f52f..6c637857b 100644
--- a/server/utils/AiProviders/mistral/index.js
+++ b/server/utils/AiProviders/mistral/index.js
@@ -4,6 +4,7 @@ const {
 } = require("../../helpers/chat/LLMPerformanceMonitor");
 const {
   handleDefaultStreamResponseV2,
+  formatChatHistory,
 } = require("../../helpers/chat/responses");
 
 class MistralLLM {
@@ -26,6 +27,11 @@ class MistralLLM {
 
     this.embedder = embedder ?? new NativeEmbedder();
     this.defaultTemp = 0.0;
+    this.log("Initialized with model:", this.model);
+  }
+
+  log(text, ...args) {
+    console.log(`\x1b[36m[${this.constructor.name}]\x1b[0m ${text}`, ...args);
   }
 
   #appendContext(contextTexts = []) {
@@ -92,7 +98,7 @@ class MistralLLM {
     };
     return [
       prompt,
-      ...chatHistory,
+      ...formatChatHistory(chatHistory, this.#generateContent),
       {
         role: "user",
         content: this.#generateContent({ userPrompt, attachments }),
diff --git a/server/utils/AiProviders/novita/index.js b/server/utils/AiProviders/novita/index.js
index c41f5a666..8365d2882 100644
--- a/server/utils/AiProviders/novita/index.js
+++ b/server/utils/AiProviders/novita/index.js
@@ -3,6 +3,7 @@ const { v4: uuidv4 } = require("uuid");
 const {
   writeResponseChunk,
   clientAbortedHandler,
+  formatChatHistory,
 } = require("../../helpers/chat/responses");
 const fs = require("fs");
 const path = require("path");
@@ -177,7 +178,7 @@ class NovitaLLM {
     };
     return [
       prompt,
-      ...chatHistory,
+      ...formatChatHistory(chatHistory, this.#generateContent),
       {
         role: "user",
         content: this.#generateContent({ userPrompt, attachments }),
diff --git a/server/utils/AiProviders/nvidiaNim/index.js b/server/utils/AiProviders/nvidiaNim/index.js
index 3cf7f835f..4de408e98 100644
--- a/server/utils/AiProviders/nvidiaNim/index.js
+++ b/server/utils/AiProviders/nvidiaNim/index.js
@@ -4,6 +4,7 @@ const {
 } = require("../../helpers/chat/LLMPerformanceMonitor");
 const {
   handleDefaultStreamResponseV2,
+  formatChatHistory,
 } = require("../../helpers/chat/responses");
 
 class NvidiaNimLLM {
@@ -142,7 +143,7 @@ class NvidiaNimLLM {
     };
     return [
       prompt,
-      ...chatHistory,
+      ...formatChatHistory(chatHistory, this.#generateContent),
       {
         role: "user",
         content: this.#generateContent({ userPrompt, attachments }),
diff --git a/server/utils/AiProviders/ollama/index.js b/server/utils/AiProviders/ollama/index.js
index faffcb234..5c53dd5f4 100644
--- a/server/utils/AiProviders/ollama/index.js
+++ b/server/utils/AiProviders/ollama/index.js
@@ -1,6 +1,7 @@
 const {
   writeResponseChunk,
   clientAbortedHandler,
+  formatChatHistory,
 } = require("../../helpers/chat/responses");
 const { NativeEmbedder } = require("../../EmbeddingEngines/native");
 const {
@@ -120,7 +121,7 @@ class OllamaAILLM {
     };
     return [
       prompt,
-      ...chatHistory,
+      ...formatChatHistory(chatHistory, this.#generateContent, "spread"),
       {
         role: "user",
         ...this.#generateContent({ userPrompt, attachments }),
diff --git a/server/utils/AiProviders/openAi/index.js b/server/utils/AiProviders/openAi/index.js
index 4209b99ed..71a6a0edf 100644
--- a/server/utils/AiProviders/openAi/index.js
+++ b/server/utils/AiProviders/openAi/index.js
@@ -1,6 +1,7 @@
 const { NativeEmbedder } = require("../../EmbeddingEngines/native");
 const {
   handleDefaultStreamResponseV2,
+  formatChatHistory,
 } = require("../../helpers/chat/responses");
 const { MODEL_MAP } = require("../modelMap");
 const {
@@ -121,7 +122,7 @@ class OpenAiLLM {
     };
     return [
       prompt,
-      ...chatHistory,
+      ...formatChatHistory(chatHistory, this.#generateContent),
       {
         role: "user",
         content: this.#generateContent({ userPrompt, attachments }),
diff --git a/server/utils/AiProviders/openRouter/index.js b/server/utils/AiProviders/openRouter/index.js
index 3abab7634..08f040150 100644
--- a/server/utils/AiProviders/openRouter/index.js
+++ b/server/utils/AiProviders/openRouter/index.js
@@ -3,6 +3,7 @@ const { v4: uuidv4 } = require("uuid");
 const {
   writeResponseChunk,
   clientAbortedHandler,
+  formatChatHistory,
 } = require("../../helpers/chat/responses");
 const fs = require("fs");
 const path = require("path");
@@ -47,6 +48,7 @@ class OpenRouterLLM {
       fs.mkdirSync(cacheFolder, { recursive: true });
     this.cacheModelPath = path.resolve(cacheFolder, "models.json");
     this.cacheAtPath = path.resolve(cacheFolder, ".cached_at");
+    this.log("Initialized with model:", this.model);
   }
 
   log(text, ...args) {
@@ -162,7 +164,6 @@ class OpenRouterLLM {
         },
       });
     }
-    console.log(content.flat());
     return content.flat();
   }
 
@@ -179,7 +180,7 @@ class OpenRouterLLM {
     };
     return [
       prompt,
-      ...chatHistory,
+      ...formatChatHistory(chatHistory, this.#generateContent),
       {
         role: "user",
         content: this.#generateContent({ userPrompt, attachments }),
diff --git a/server/utils/AiProviders/textGenWebUI/index.js b/server/utils/AiProviders/textGenWebUI/index.js
index f1c3590bf..f3647c06d 100644
--- a/server/utils/AiProviders/textGenWebUI/index.js
+++ b/server/utils/AiProviders/textGenWebUI/index.js
@@ -1,6 +1,7 @@
 const { NativeEmbedder } = require("../../EmbeddingEngines/native");
 const {
   handleDefaultStreamResponseV2,
+  formatChatHistory,
 } = require("../../helpers/chat/responses");
 const {
   LLMPerformanceMonitor,
@@ -113,7 +114,7 @@ class TextGenWebUILLM {
     };
     return [
       prompt,
-      ...chatHistory,
+      ...formatChatHistory(chatHistory, this.#generateContent),
       {
         role: "user",
         content: this.#generateContent({ userPrompt, attachments }),
diff --git a/server/utils/AiProviders/xai/index.js b/server/utils/AiProviders/xai/index.js
index b18aae98c..2319e7220 100644
--- a/server/utils/AiProviders/xai/index.js
+++ b/server/utils/AiProviders/xai/index.js
@@ -4,6 +4,7 @@ const {
 } = require("../../helpers/chat/LLMPerformanceMonitor");
 const {
   handleDefaultStreamResponseV2,
+  formatChatHistory,
 } = require("../../helpers/chat/responses");
 const { MODEL_MAP } = require("../modelMap");
 
@@ -27,6 +28,11 @@ class XAiLLM {
 
     this.embedder = embedder ?? new NativeEmbedder();
     this.defaultTemp = 0.7;
+    this.log("Initialized with model:", this.model);
+  }
+
+  log(text, ...args) {
+    console.log(`\x1b[36m[${this.constructor.name}]\x1b[0m ${text}`, ...args);
   }
 
   #appendContext(contextTexts = []) {
@@ -53,13 +59,8 @@ class XAiLLM {
     return MODEL_MAP.xai[this.model] ?? 131_072;
   }
 
-  isValidChatCompletionModel(modelName = "") {
-    switch (modelName) {
-      case "grok-beta":
-        return true;
-      default:
-        return false;
-    }
+  isValidChatCompletionModel(_modelName = "") {
+    return true;
   }
 
   /**
@@ -103,7 +104,7 @@ class XAiLLM {
     };
     return [
       prompt,
-      ...chatHistory,
+      ...formatChatHistory(chatHistory, this.#generateContent),
       {
         role: "user",
         content: this.#generateContent({ userPrompt, attachments }),
diff --git a/server/utils/chats/embed.js b/server/utils/chats/embed.js
index 550a460f8..70904a541 100644
--- a/server/utils/chats/embed.js
+++ b/server/utils/chats/embed.js
@@ -210,7 +210,7 @@ async function streamChatWithForEmbed(
  * @param {string} sessionId the session id of the user from embed widget
  * @param {Object} embed the embed config object
  * @param {Number} messageLimit the number of messages to return
- * @returns {Promise<{rawHistory: import("@prisma/client").embed_chats[], chatHistory: {role: string, content: string}[]}>
+ * @returns {Promise<{rawHistory: import("@prisma/client").embed_chats[], chatHistory: {role: string, content: string, attachments?: Object[]}[]}>
  */
 async function recentEmbedChatHistory(sessionId, embed, messageLimit = 20) {
   const rawHistory = (
diff --git a/server/utils/helpers/chat/responses.js b/server/utils/helpers/chat/responses.js
index 9be1d224c..16a1e9af4 100644
--- a/server/utils/helpers/chat/responses.js
+++ b/server/utils/helpers/chat/responses.js
@@ -164,6 +164,11 @@ function convertToChatHistory(history = []) {
   return formattedHistory.flat();
 }
 
+/**
+ * Converts a chat history to a prompt history.
+ * @param {Object[]} history - The chat history to convert
+ * @returns {{role: string, content: string, attachments?: import("..").Attachment}[]}
+ */
 function convertToPromptHistory(history = []) {
   const formattedHistory = [];
   for (const record of history) {
@@ -185,8 +190,18 @@ function convertToPromptHistory(history = []) {
     }
 
     formattedHistory.push([
-      { role: "user", content: prompt },
-      { role: "assistant", content: data.text },
+      {
+        role: "user",
+        content: prompt,
+        // if there are attachments, add them as a property to the user message so we can reuse them in chat history later if supported by the llm.
+        ...(data?.attachments?.length > 0
+          ? { attachments: data?.attachments }
+          : {}),
+      },
+      {
+        role: "assistant",
+        content: data.text,
+      },
     ]);
   }
   return formattedHistory.flat();
@@ -197,10 +212,54 @@ function writeResponseChunk(response, data) {
   return;
 }
 
+/**
+ * Formats the chat history to re-use attachments in the chat history
+ * that might have existed in the conversation earlier.
+ * @param {{role:string, content:string, attachments?: Object[]}[]} chatHistory
+ * @param {function} formatterFunction - The function to format the chat history from the llm provider
+ * @param {('asProperty'|'spread')} mode - "asProperty" or "spread". Determines how the content is formatted in the message object.
+ * @returns {object[]}
+ */
+function formatChatHistory(
+  chatHistory = [],
+  formatterFunction,
+  mode = "asProperty"
+) {
+  return chatHistory.map((historicalMessage) => {
+    if (
+      historicalMessage?.role !== "user" || // Only user messages can have attachments
+      !historicalMessage?.attachments || // If there are no attachments, we can skip this
+      !historicalMessage.attachments.length // If there is an array but it is empty, we can skip this
+    )
+      return historicalMessage;
+
+    // Some providers, like Ollama, expect the content to be embedded in the message object.
+    if (mode === "spread") {
+      return {
+        role: historicalMessage.role,
+        ...formatterFunction({
+          userPrompt: historicalMessage.content,
+          attachments: historicalMessage.attachments,
+        }),
+      };
+    }
+
+    // Most providers expect the content to be a property of the message object formatted like OpenAI models.
+    return {
+      role: historicalMessage.role,
+      content: formatterFunction({
+        userPrompt: historicalMessage.content,
+        attachments: historicalMessage.attachments,
+      }),
+    };
+  });
+}
+
 module.exports = {
   handleDefaultStreamResponseV2,
   convertToChatHistory,
   convertToPromptHistory,
   writeResponseChunk,
   clientAbortedHandler,
+  formatChatHistory,
 };
-- 
GitLab