From 9655880cf0154205f6a2876d93095652ae2698eb Mon Sep 17 00:00:00 2001
From: Timothy Carambat <rambat1010@gmail.com>
Date: Wed, 17 Apr 2024 18:04:39 -0700
Subject: [PATCH] Update all vector dbs to filter duplicate source documents
 that may be pinned (#1122)

* Update all vector dbs to filter duplicate parents

* cleanup
---
 server/utils/chats/embed.js                   |  5 ++++-
 server/utils/chats/index.js                   | 14 +++++++++++++
 server/utils/chats/stream.js                  |  4 ++++
 server/utils/vectorDbProviders/astra/index.js | 14 +++++++++++--
 .../utils/vectorDbProviders/chroma/index.js   | 17 +++++++++++++--
 server/utils/vectorDbProviders/lance/index.js | 15 +++++++++++--
 .../utils/vectorDbProviders/milvus/index.js   | 15 +++++++++++--
 .../utils/vectorDbProviders/pinecone/index.js | 15 +++++++++++--
 .../utils/vectorDbProviders/qdrant/index.js   | 18 +++++++++++++---
 .../utils/vectorDbProviders/weaviate/index.js | 21 ++++++++++++++-----
 .../utils/vectorDbProviders/zilliz/index.js   | 14 +++++++++++--
 11 files changed, 131 insertions(+), 21 deletions(-)

diff --git a/server/utils/chats/embed.js b/server/utils/chats/embed.js
index 497b2c8e4..d894693e0 100644
--- a/server/utils/chats/embed.js
+++ b/server/utils/chats/embed.js
@@ -1,6 +1,6 @@
 const { v4: uuidv4 } = require("uuid");
 const { getVectorDbClass, getLLMProvider } = require("../helpers");
-const { chatPrompt } = require("./index");
+const { chatPrompt, sourceIdentifier } = require("./index");
 const { EmbedChats } = require("../../models/embedChats");
 const {
   convertToPromptHistory,
@@ -69,6 +69,7 @@ async function streamChatWithForEmbed(
   let completeText;
   let contextTexts = [];
   let sources = [];
+  let pinnedDocIdentifiers = [];
   const { rawHistory, chatHistory } = await recentEmbedChatHistory(
     sessionId,
     embed,
@@ -86,6 +87,7 @@ async function streamChatWithForEmbed(
     .then((pinnedDocs) => {
       pinnedDocs.forEach((doc) => {
         const { pageContent, ...metadata } = doc;
+        pinnedDocIdentifiers.push(sourceIdentifier(doc));
         contextTexts.push(doc.pageContent);
         sources.push({
           text:
@@ -104,6 +106,7 @@ async function streamChatWithForEmbed(
           LLMConnector,
           similarityThreshold: embed.workspace?.similarityThreshold,
           topN: embed.workspace?.topN,
+          filterIdentifiers: pinnedDocIdentifiers,
         })
       : {
           contextTexts: [],
diff --git a/server/utils/chats/index.js b/server/utils/chats/index.js
index 7e40b9a8b..760891b5f 100644
--- a/server/utils/chats/index.js
+++ b/server/utils/chats/index.js
@@ -79,6 +79,7 @@ async function chatWithWorkspace(
   // 2. Chatting in "query" mode and has at least 1 embedding
   let contextTexts = [];
   let sources = [];
+  let pinnedDocIdentifiers = [];
   const { rawHistory, chatHistory } = await recentChatHistory({
     user,
     workspace,
@@ -97,6 +98,7 @@ async function chatWithWorkspace(
     .then((pinnedDocs) => {
       pinnedDocs.forEach((doc) => {
         const { pageContent, ...metadata } = doc;
+        pinnedDocIdentifiers.push(sourceIdentifier(doc));
         contextTexts.push(doc.pageContent);
         sources.push({
           text:
@@ -115,6 +117,7 @@ async function chatWithWorkspace(
           LLMConnector,
           similarityThreshold: workspace?.similarityThreshold,
           topN: workspace?.topN,
+          filterIdentifiers: pinnedDocIdentifiers,
         })
       : {
           contextTexts: [],
@@ -227,7 +230,18 @@ function chatPrompt(workspace) {
   );
 }
 
+// We use this util function to deduplicate sources from similarity searching
+// if the document is already pinned.
+// Eg: You pin a csv, if we RAG + full-text that you will get the same data
+// points both in the full-text and possibly from RAG - result in bad results
+// even if the LLM was not even going to hallucinate.
+function sourceIdentifier(sourceDocument) {
+  if (!sourceDocument?.title || !sourceDocument?.published) return uuidv4();
+  return `title:${sourceDocument.title}-timestamp:${sourceDocument.published}`;
+}
+
 module.exports = {
+  sourceIdentifier,
   recentChatHistory,
   chatWithWorkspace,
   chatPrompt,
diff --git a/server/utils/chats/stream.js b/server/utils/chats/stream.js
index 57025da30..01244e752 100644
--- a/server/utils/chats/stream.js
+++ b/server/utils/chats/stream.js
@@ -9,6 +9,7 @@ const {
   VALID_COMMANDS,
   chatPrompt,
   recentChatHistory,
+  sourceIdentifier,
 } = require("./index");
 
 const VALID_CHAT_MODE = ["chat", "query"];
@@ -92,6 +93,7 @@ async function streamChatWithWorkspace(
   let completeText;
   let contextTexts = [];
   let sources = [];
+  let pinnedDocIdentifiers = [];
   const { rawHistory, chatHistory } = await recentChatHistory({
     user,
     workspace,
@@ -110,6 +112,7 @@ async function streamChatWithWorkspace(
     .then((pinnedDocs) => {
       pinnedDocs.forEach((doc) => {
         const { pageContent, ...metadata } = doc;
+        pinnedDocIdentifiers.push(sourceIdentifier(doc));
         contextTexts.push(doc.pageContent);
         sources.push({
           text:
@@ -128,6 +131,7 @@ async function streamChatWithWorkspace(
           LLMConnector,
           similarityThreshold: workspace?.similarityThreshold,
           topN: workspace?.topN,
+          filterIdentifiers: pinnedDocIdentifiers,
         })
       : {
           contextTexts: [],
diff --git a/server/utils/vectorDbProviders/astra/index.js b/server/utils/vectorDbProviders/astra/index.js
index 4420d19c1..5f0b086f3 100644
--- a/server/utils/vectorDbProviders/astra/index.js
+++ b/server/utils/vectorDbProviders/astra/index.js
@@ -8,6 +8,7 @@ const {
   getLLMProvider,
   getEmbeddingEngineSelection,
 } = require("../../helpers");
+const { sourceIdentifier } = require("../../chats");
 
 const AstraDB = {
   name: "AstraDB",
@@ -252,6 +253,7 @@ const AstraDB = {
     LLMConnector = null,
     similarityThreshold = 0.25,
     topN = 4,
+    filterIdentifiers = [],
   }) {
     if (!namespace || !input || !LLMConnector)
       throw new Error("Invalid request to performSimilaritySearch.");
@@ -272,7 +274,8 @@ const AstraDB = {
       namespace,
       queryVector,
       similarityThreshold,
-      topN
+      topN,
+      filterIdentifiers
     );
 
     const sources = sourceDocuments.map((metadata, i) => {
@@ -289,7 +292,8 @@ const AstraDB = {
     namespace,
     queryVector,
     similarityThreshold = 0.25,
-    topN = 4
+    topN = 4,
+    filterIdentifiers = []
   ) {
     const result = {
       contextTexts: [],
@@ -311,6 +315,12 @@ const AstraDB = {
 
     responses.forEach((response) => {
       if (response.$similarity < similarityThreshold) return;
+      if (filterIdentifiers.includes(sourceIdentifier(response.metadata))) {
+        console.log(
+          "AstraDB: A source was filtered from context as it's parent document is pinned."
+        );
+        return;
+      }
       result.contextTexts.push(response.metadata.text);
       result.sourceDocuments.push(response);
       result.scores.push(response.$similarity);
diff --git a/server/utils/vectorDbProviders/chroma/index.js b/server/utils/vectorDbProviders/chroma/index.js
index 5b5800318..d87b3aad0 100644
--- a/server/utils/vectorDbProviders/chroma/index.js
+++ b/server/utils/vectorDbProviders/chroma/index.js
@@ -9,6 +9,7 @@ const {
   getEmbeddingEngineSelection,
 } = require("../../helpers");
 const { parseAuthHeader } = require("../../http");
+const { sourceIdentifier } = require("../../chats");
 
 const Chroma = {
   name: "Chroma",
@@ -70,7 +71,8 @@ const Chroma = {
     namespace,
     queryVector,
     similarityThreshold = 0.25,
-    topN = 4
+    topN = 4,
+    filterIdentifiers = []
   ) {
     const collection = await client.getCollection({ name: namespace });
     const result = {
@@ -89,6 +91,15 @@ const Chroma = {
         similarityThreshold
       )
         return;
+
+      if (
+        filterIdentifiers.includes(sourceIdentifier(response.metadatas[0][i]))
+      ) {
+        console.log(
+          "Chroma: A source was filtered from context as it's parent document is pinned."
+        );
+        return;
+      }
       result.contextTexts.push(response.documents[0][i]);
       result.sourceDocuments.push(response.metadatas[0][i]);
       result.scores.push(this.distanceToSimilarity(response.distances[0][i]));
@@ -282,6 +293,7 @@ const Chroma = {
     LLMConnector = null,
     similarityThreshold = 0.25,
     topN = 4,
+    filterIdentifiers = [],
   }) {
     if (!namespace || !input || !LLMConnector)
       throw new Error("Invalid request to performSimilaritySearch.");
@@ -301,7 +313,8 @@ const Chroma = {
       namespace,
       queryVector,
       similarityThreshold,
-      topN
+      topN,
+      filterIdentifiers
     );
 
     const sources = sourceDocuments.map((metadata, i) => {
diff --git a/server/utils/vectorDbProviders/lance/index.js b/server/utils/vectorDbProviders/lance/index.js
index 4769c4834..d195086ad 100644
--- a/server/utils/vectorDbProviders/lance/index.js
+++ b/server/utils/vectorDbProviders/lance/index.js
@@ -9,6 +9,7 @@ const { TextSplitter } = require("../../TextSplitter");
 const { SystemSettings } = require("../../../models/systemSettings");
 const { storeVectorResult, cachedVectorInformation } = require("../../files");
 const { v4: uuidv4 } = require("uuid");
+const { sourceIdentifier } = require("../../chats");
 
 const LanceDb = {
   uri: `${
@@ -64,7 +65,8 @@ const LanceDb = {
     namespace,
     queryVector,
     similarityThreshold = 0.25,
-    topN = 4
+    topN = 4,
+    filterIdentifiers = []
   ) {
     const collection = await client.openTable(namespace);
     const result = {
@@ -82,6 +84,13 @@ const LanceDb = {
     response.forEach((item) => {
       if (this.distanceToSimilarity(item.score) < similarityThreshold) return;
       const { vector: _, ...rest } = item;
+      if (filterIdentifiers.includes(sourceIdentifier(rest))) {
+        console.log(
+          "LanceDB: A source was filtered from context as it's parent document is pinned."
+        );
+        return;
+      }
+
       result.contextTexts.push(rest.text);
       result.sourceDocuments.push(rest);
       result.scores.push(this.distanceToSimilarity(item.score));
@@ -250,6 +259,7 @@ const LanceDb = {
     LLMConnector = null,
     similarityThreshold = 0.25,
     topN = 4,
+    filterIdentifiers = [],
   }) {
     if (!namespace || !input || !LLMConnector)
       throw new Error("Invalid request to performSimilaritySearch.");
@@ -269,7 +279,8 @@ const LanceDb = {
       namespace,
       queryVector,
       similarityThreshold,
-      topN
+      topN,
+      filterIdentifiers
     );
 
     const sources = sourceDocuments.map((metadata, i) => {
diff --git a/server/utils/vectorDbProviders/milvus/index.js b/server/utils/vectorDbProviders/milvus/index.js
index 3402960aa..c4c91c22e 100644
--- a/server/utils/vectorDbProviders/milvus/index.js
+++ b/server/utils/vectorDbProviders/milvus/index.js
@@ -13,6 +13,7 @@ const {
   getLLMProvider,
   getEmbeddingEngineSelection,
 } = require("../../helpers");
+const { sourceIdentifier } = require("../../chats");
 
 const Milvus = {
   name: "Milvus",
@@ -288,6 +289,7 @@ const Milvus = {
     LLMConnector = null,
     similarityThreshold = 0.25,
     topN = 4,
+    filterIdentifiers = [],
   }) {
     if (!namespace || !input || !LLMConnector)
       throw new Error("Invalid request to performSimilaritySearch.");
@@ -307,7 +309,8 @@ const Milvus = {
       namespace,
       queryVector,
       similarityThreshold,
-      topN
+      topN,
+      filterIdentifiers
     );
 
     const sources = sourceDocuments.map((metadata, i) => {
@@ -324,7 +327,8 @@ const Milvus = {
     namespace,
     queryVector,
     similarityThreshold = 0.25,
-    topN = 4
+    topN = 4,
+    filterIdentifiers = []
   ) {
     const result = {
       contextTexts: [],
@@ -338,6 +342,13 @@ const Milvus = {
     });
     response.results.forEach((match) => {
       if (match.score < similarityThreshold) return;
+      if (filterIdentifiers.includes(sourceIdentifier(match.metadata))) {
+        console.log(
+          "Milvus: A source was filtered from context as it's parent document is pinned."
+        );
+        return;
+      }
+
       result.contextTexts.push(match.metadata.text);
       result.sourceDocuments.push(match);
       result.scores.push(match.score);
diff --git a/server/utils/vectorDbProviders/pinecone/index.js b/server/utils/vectorDbProviders/pinecone/index.js
index 3cf0e19a2..cf71c893f 100644
--- a/server/utils/vectorDbProviders/pinecone/index.js
+++ b/server/utils/vectorDbProviders/pinecone/index.js
@@ -8,6 +8,7 @@ const {
   getLLMProvider,
   getEmbeddingEngineSelection,
 } = require("../../helpers");
+const { sourceIdentifier } = require("../../chats");
 
 const PineconeDB = {
   name: "Pinecone",
@@ -44,7 +45,8 @@ const PineconeDB = {
     namespace,
     queryVector,
     similarityThreshold = 0.25,
-    topN = 4
+    topN = 4,
+    filterIdentifiers = []
   ) {
     const result = {
       contextTexts: [],
@@ -61,6 +63,13 @@ const PineconeDB = {
 
     response.matches.forEach((match) => {
       if (match.score < similarityThreshold) return;
+      if (filterIdentifiers.includes(sourceIdentifier(match.metadata))) {
+        console.log(
+          "Pinecone: A source was filtered from context as it's parent document is pinned."
+        );
+        return;
+      }
+
       result.contextTexts.push(match.metadata.text);
       result.sourceDocuments.push(match);
       result.scores.push(match.score);
@@ -233,6 +242,7 @@ const PineconeDB = {
     LLMConnector = null,
     similarityThreshold = 0.25,
     topN = 4,
+    filterIdentifiers = [],
   }) {
     if (!namespace || !input || !LLMConnector)
       throw new Error("Invalid request to performSimilaritySearch.");
@@ -249,7 +259,8 @@ const PineconeDB = {
       namespace,
       queryVector,
       similarityThreshold,
-      topN
+      topN,
+      filterIdentifiers
     );
 
     const sources = sourceDocuments.map((metadata, i) => {
diff --git a/server/utils/vectorDbProviders/qdrant/index.js b/server/utils/vectorDbProviders/qdrant/index.js
index 2a37d3d66..2497c3f30 100644
--- a/server/utils/vectorDbProviders/qdrant/index.js
+++ b/server/utils/vectorDbProviders/qdrant/index.js
@@ -8,6 +8,7 @@ const {
   getLLMProvider,
   getEmbeddingEngineSelection,
 } = require("../../helpers");
+const { sourceIdentifier } = require("../../chats");
 
 const QDrant = {
   name: "QDrant",
@@ -55,7 +56,8 @@ const QDrant = {
     namespace,
     queryVector,
     similarityThreshold = 0.25,
-    topN = 4
+    topN = 4,
+    filterIdentifiers = []
   ) {
     const { client } = await this.connect();
     const result = {
@@ -72,6 +74,13 @@ const QDrant = {
 
     responses.forEach((response) => {
       if (response.score < similarityThreshold) return;
+      if (filterIdentifiers.includes(sourceIdentifier(response?.payload))) {
+        console.log(
+          "QDrant: A source was filtered from context as it's parent document is pinned."
+        );
+        return;
+      }
+
       result.contextTexts.push(response?.payload?.text || "");
       result.sourceDocuments.push({
         ...(response?.payload || {}),
@@ -146,7 +155,8 @@ const QDrant = {
         const { client } = await this.connect();
         const { chunks } = cacheResult;
         const documentVectors = [];
-        vectorDimension = chunks[0][0].vector.length || null;
+        vectorDimension =
+          chunks[0][0]?.vector?.length ?? chunks[0][0]?.values?.length ?? null;
 
         const collection = await this.getOrCreateCollection(
           client,
@@ -311,6 +321,7 @@ const QDrant = {
     LLMConnector = null,
     similarityThreshold = 0.25,
     topN = 4,
+    filterIdentifiers = [],
   }) {
     if (!namespace || !input || !LLMConnector)
       throw new Error("Invalid request to performSimilaritySearch.");
@@ -330,7 +341,8 @@ const QDrant = {
       namespace,
       queryVector,
       similarityThreshold,
-      topN
+      topN,
+      filterIdentifiers
     );
 
     const sources = sourceDocuments.map((metadata, i) => {
diff --git a/server/utils/vectorDbProviders/weaviate/index.js b/server/utils/vectorDbProviders/weaviate/index.js
index 490d26e26..9e7849678 100644
--- a/server/utils/vectorDbProviders/weaviate/index.js
+++ b/server/utils/vectorDbProviders/weaviate/index.js
@@ -9,6 +9,7 @@ const {
   getEmbeddingEngineSelection,
 } = require("../../helpers");
 const { camelCase } = require("../../helpers/camelcase");
+const { sourceIdentifier } = require("../../chats");
 
 const Weaviate = {
   name: "Weaviate",
@@ -82,7 +83,8 @@ const Weaviate = {
     namespace,
     queryVector,
     similarityThreshold = 0.25,
-    topN = 4
+    topN = 4,
+    filterIdentifiers = []
   ) {
     const result = {
       contextTexts: [],
@@ -91,7 +93,8 @@ const Weaviate = {
     };
 
     const weaviateClass = await this.namespace(client, namespace);
-    const fields = weaviateClass.properties.map((prop) => prop.name).join(" ");
+    const fields =
+      weaviateClass.properties?.map((prop) => prop.name)?.join(" ") ?? "";
     const queryResponse = await client.graphql
       .get()
       .withClassName(camelCase(namespace))
@@ -109,6 +112,12 @@ const Weaviate = {
         ...rest
       } = response;
       if (certainty < similarityThreshold) return;
+      if (filterIdentifiers.includes(sourceIdentifier(rest))) {
+        console.log(
+          "Weaviate: A source was filtered from context as it's parent document is pinned."
+        );
+        return;
+      }
       result.contextTexts.push(rest.text);
       result.sourceDocuments.push({ ...rest, id });
       result.scores.push(certainty);
@@ -214,7 +223,7 @@ const Weaviate = {
           chunk.forEach((chunk) => {
             const id = uuidv4();
             const flattenedMetadata = this.flattenObjectForWeaviate(
-              chunk.properties
+              chunk.properties ?? chunk.metadata
             );
             documentVectors.push({ docId, vectorId: id });
             const vectorRecord = {
@@ -357,6 +366,7 @@ const Weaviate = {
     LLMConnector = null,
     similarityThreshold = 0.25,
     topN = 4,
+    filterIdentifiers = [],
   }) {
     if (!namespace || !input || !LLMConnector)
       throw new Error("Invalid request to performSimilaritySearch.");
@@ -376,7 +386,8 @@ const Weaviate = {
       namespace,
       queryVector,
       similarityThreshold,
-      topN
+      topN,
+      filterIdentifiers
     );
 
     const sources = sourceDocuments.map((metadata, i) => {
@@ -437,7 +448,7 @@ const Weaviate = {
     const flattenedObject = {};
 
     for (const key in obj) {
-      if (!Object.hasOwn(obj, key)) {
+      if (!Object.hasOwn(obj, key) || key === "id") {
         continue;
       }
       const value = obj[key];
diff --git a/server/utils/vectorDbProviders/zilliz/index.js b/server/utils/vectorDbProviders/zilliz/index.js
index f80f4dc22..0efe69967 100644
--- a/server/utils/vectorDbProviders/zilliz/index.js
+++ b/server/utils/vectorDbProviders/zilliz/index.js
@@ -13,6 +13,7 @@ const {
   getLLMProvider,
   getEmbeddingEngineSelection,
 } = require("../../helpers");
+const { sourceIdentifier } = require("../../chats");
 
 // Zilliz is basically a copy of Milvus DB class with a different constructor
 // to connect to the cloud
@@ -289,6 +290,7 @@ const Zilliz = {
     LLMConnector = null,
     similarityThreshold = 0.25,
     topN = 4,
+    filterIdentifiers = [],
   }) {
     if (!namespace || !input || !LLMConnector)
       throw new Error("Invalid request to performSimilaritySearch.");
@@ -308,7 +310,8 @@ const Zilliz = {
       namespace,
       queryVector,
       similarityThreshold,
-      topN
+      topN,
+      filterIdentifiers
     );
 
     const sources = sourceDocuments.map((metadata, i) => {
@@ -325,7 +328,8 @@ const Zilliz = {
     namespace,
     queryVector,
     similarityThreshold = 0.25,
-    topN = 4
+    topN = 4,
+    filterIdentifiers = []
   ) {
     const result = {
       contextTexts: [],
@@ -339,6 +343,12 @@ const Zilliz = {
     });
     response.results.forEach((match) => {
       if (match.score < similarityThreshold) return;
+      if (filterIdentifiers.includes(sourceIdentifier(match.metadata))) {
+        console.log(
+          "Zilliz: A source was filtered from context as it's parent document is pinned."
+        );
+        return;
+      }
       result.contextTexts.push(match.metadata.text);
       result.sourceDocuments.push(match);
       result.scores.push(match.score);
-- 
GitLab