From 56fa17caf27949818d34c367605371ceac803c74 Mon Sep 17 00:00:00 2001
From: Sean Hatfield <seanhatfield5@gmail.com>
Date: Thu, 18 Jan 2024 12:34:20 -0800
Subject: [PATCH] create configurable topN per workspace (#616)

* create configurable topN per workspace

* Update TopN UI text
Fix fallbacks for all providers
Add SQLite CHECK to TOPN value

* merge with master
Update zilliz provider for variable TopN

---------

Co-authored-by: timothycarambat <rambat1010@gmail.com>
---
 .../Modals/MangeWorkspace/Settings/index.jsx  | 35 +++++++++++++++++++
 server/models/workspace.js                    |  1 +
 .../20240118201333_init/migration.sql         |  2 ++
 server/prisma/schema.prisma                   |  1 +
 server/utils/chats/index.js                   |  1 +
 server/utils/chats/stream.js                  |  1 +
 .../utils/vectorDbProviders/chroma/index.js   |  9 +++--
 server/utils/vectorDbProviders/lance/index.js |  9 +++--
 .../utils/vectorDbProviders/milvus/index.js   |  8 +++--
 .../utils/vectorDbProviders/pinecone/index.js |  9 +++--
 .../utils/vectorDbProviders/qdrant/index.js   |  9 +++--
 .../utils/vectorDbProviders/weaviate/index.js |  9 +++--
 .../utils/vectorDbProviders/zilliz/index.js   |  8 +++--
 13 files changed, 83 insertions(+), 19 deletions(-)
 create mode 100644 server/prisma/migrations/20240118201333_init/migration.sql

diff --git a/frontend/src/components/Modals/MangeWorkspace/Settings/index.jsx b/frontend/src/components/Modals/MangeWorkspace/Settings/index.jsx
index da0e7b9f0..b288dc6c1 100644
--- a/frontend/src/components/Modals/MangeWorkspace/Settings/index.jsx
+++ b/frontend/src/components/Modals/MangeWorkspace/Settings/index.jsx
@@ -21,6 +21,9 @@ function castToType(key, value) {
     similarityThreshold: {
       cast: (value) => parseFloat(value),
     },
+    topN: {
+      cast: (value) => Number(value),
+    },
   };
 
   if (!definitions.hasOwnProperty(key)) return value;
@@ -236,6 +239,38 @@ export default function WorkspaceSettings({ active, workspace, settings }) {
                   autoComplete="off"
                   onChange={() => setHasChanges(true)}
                 />
+
+                <div className="mt-4">
+                  <div className="flex flex-col">
+                    <label
+                      htmlFor="name"
+                      className="block text-sm font-medium text-white"
+                    >
+                      Max Context Snippets
+                    </label>
+                    <p className="text-white text-opacity-60 text-xs font-medium py-1.5">
+                      This setting controls the maximum amount of context
+                      snippets the will be sent to the LLM for per chat or
+                      query.
+                      <br />
+                      <i>Recommended: 4</i>
+                    </p>
+                  </div>
+                  <input
+                    name="topN"
+                    type="number"
+                    min={1}
+                    max={12}
+                    step={1}
+                    onWheel={(e) => e.target.blur()}
+                    defaultValue={workspace?.topN ?? 4}
+                    className="bg-zinc-900 text-white text-sm rounded-lg focus:ring-blue-500 focus:border-blue-500 block w-full p-2.5"
+                    placeholder="4"
+                    required={true}
+                    autoComplete="off"
+                    onChange={() => setHasChanges(true)}
+                  />
+                </div>
                 <div className="mt-4">
                   <div className="flex flex-col">
                     <label
diff --git a/server/models/workspace.js b/server/models/workspace.js
index 6de8053e9..9169d193d 100644
--- a/server/models/workspace.js
+++ b/server/models/workspace.js
@@ -15,6 +15,7 @@ const Workspace = {
     "openAiPrompt",
     "similarityThreshold",
     "chatModel",
+    "topN",
   ],
 
   new: async function (name = null, creatorId = null) {
diff --git a/server/prisma/migrations/20240118201333_init/migration.sql b/server/prisma/migrations/20240118201333_init/migration.sql
new file mode 100644
index 000000000..aaf47f7af
--- /dev/null
+++ b/server/prisma/migrations/20240118201333_init/migration.sql
@@ -0,0 +1,2 @@
+-- AlterTable
+ALTER TABLE "workspaces" ADD COLUMN "topN" INTEGER DEFAULT 4 CHECK ("topN" > 0);
diff --git a/server/prisma/schema.prisma b/server/prisma/schema.prisma
index 2f632a46a..15ddbe23b 100644
--- a/server/prisma/schema.prisma
+++ b/server/prisma/schema.prisma
@@ -94,6 +94,7 @@ model workspaces {
   openAiPrompt        String?
   similarityThreshold Float?                @default(0.25)
   chatModel           String?
+  topN                Int?                  @default(4)
   workspace_users     workspace_users[]
   documents           workspace_documents[]
 }
diff --git a/server/utils/chats/index.js b/server/utils/chats/index.js
index 764c7795a..102189cbe 100644
--- a/server/utils/chats/index.js
+++ b/server/utils/chats/index.js
@@ -129,6 +129,7 @@ async function chatWithWorkspace(
     input: message,
     LLMConnector,
     similarityThreshold: workspace?.similarityThreshold,
+    topN: workspace?.topN,
   });
 
   // Failed similarity search.
diff --git a/server/utils/chats/stream.js b/server/utils/chats/stream.js
index cff565ed6..29386531e 100644
--- a/server/utils/chats/stream.js
+++ b/server/utils/chats/stream.js
@@ -92,6 +92,7 @@ async function streamChatWithWorkspace(
     input: message,
     LLMConnector,
     similarityThreshold: workspace?.similarityThreshold,
+    topN: workspace?.topN,
   });
 
   // Failed similarity search.
diff --git a/server/utils/vectorDbProviders/chroma/index.js b/server/utils/vectorDbProviders/chroma/index.js
index 28af39e66..23f173ddc 100644
--- a/server/utils/vectorDbProviders/chroma/index.js
+++ b/server/utils/vectorDbProviders/chroma/index.js
@@ -67,7 +67,8 @@ const Chroma = {
     client,
     namespace,
     queryVector,
-    similarityThreshold = 0.25
+    similarityThreshold = 0.25,
+    topN = 4
   ) {
     const collection = await client.getCollection({ name: namespace });
     const result = {
@@ -78,7 +79,7 @@ const Chroma = {
 
     const response = await collection.query({
       queryEmbeddings: queryVector,
-      nResults: 4,
+      nResults: topN,
     });
     response.ids[0].forEach((_, i) => {
       if (
@@ -271,6 +272,7 @@ const Chroma = {
     input = "",
     LLMConnector = null,
     similarityThreshold = 0.25,
+    topN = 4,
   }) {
     if (!namespace || !input || !LLMConnector)
       throw new Error("Invalid request to performSimilaritySearch.");
@@ -289,7 +291,8 @@ const Chroma = {
       client,
       namespace,
       queryVector,
-      similarityThreshold
+      similarityThreshold,
+      topN
     );
 
     const sources = sourceDocuments.map((metadata, i) => {
diff --git a/server/utils/vectorDbProviders/lance/index.js b/server/utils/vectorDbProviders/lance/index.js
index 8f243cf9b..67705c00e 100644
--- a/server/utils/vectorDbProviders/lance/index.js
+++ b/server/utils/vectorDbProviders/lance/index.js
@@ -62,7 +62,8 @@ const LanceDb = {
     client,
     namespace,
     queryVector,
-    similarityThreshold = 0.25
+    similarityThreshold = 0.25,
+    topN = 4
   ) {
     const collection = await client.openTable(namespace);
     const result = {
@@ -74,7 +75,7 @@ const LanceDb = {
     const response = await collection
       .search(queryVector)
       .metricType("cosine")
-      .limit(5)
+      .limit(topN)
       .execute();
 
     response.forEach((item) => {
@@ -240,6 +241,7 @@ const LanceDb = {
     input = "",
     LLMConnector = null,
     similarityThreshold = 0.25,
+    topN = 4,
   }) {
     if (!namespace || !input || !LLMConnector)
       throw new Error("Invalid request to performSimilaritySearch.");
@@ -258,7 +260,8 @@ const LanceDb = {
       client,
       namespace,
       queryVector,
-      similarityThreshold
+      similarityThreshold,
+      topN
     );
 
     const sources = sourceDocuments.map((metadata, i) => {
diff --git a/server/utils/vectorDbProviders/milvus/index.js b/server/utils/vectorDbProviders/milvus/index.js
index 79a132413..d1eec9f71 100644
--- a/server/utils/vectorDbProviders/milvus/index.js
+++ b/server/utils/vectorDbProviders/milvus/index.js
@@ -265,6 +265,7 @@ const Milvus = {
     input = "",
     LLMConnector = null,
     similarityThreshold = 0.25,
+    topN = 4,
   }) {
     if (!namespace || !input || !LLMConnector)
       throw new Error("Invalid request to performSimilaritySearch.");
@@ -283,7 +284,8 @@ const Milvus = {
       client,
       namespace,
       queryVector,
-      similarityThreshold
+      similarityThreshold,
+      topN
     );
 
     const sources = sourceDocuments.map((metadata, i) => {
@@ -299,7 +301,8 @@ const Milvus = {
     client,
     namespace,
     queryVector,
-    similarityThreshold = 0.25
+    similarityThreshold = 0.25,
+    topN = 4
   ) {
     const result = {
       contextTexts: [],
@@ -309,6 +312,7 @@ const Milvus = {
     const response = await client.search({
       collection_name: namespace,
       vectors: queryVector,
+      limit: topN,
     });
     response.results.forEach((match) => {
       if (match.score < similarityThreshold) return;
diff --git a/server/utils/vectorDbProviders/pinecone/index.js b/server/utils/vectorDbProviders/pinecone/index.js
index 594a9aaf3..260c0a257 100644
--- a/server/utils/vectorDbProviders/pinecone/index.js
+++ b/server/utils/vectorDbProviders/pinecone/index.js
@@ -44,7 +44,8 @@ const Pinecone = {
     index,
     namespace,
     queryVector,
-    similarityThreshold = 0.25
+    similarityThreshold = 0.25,
+    topN = 4
   ) {
     const result = {
       contextTexts: [],
@@ -55,7 +56,7 @@ const Pinecone = {
       queryRequest: {
         namespace,
         vector: queryVector,
-        topK: 4,
+        topK: topN,
         includeMetadata: true,
       },
     });
@@ -237,6 +238,7 @@ const Pinecone = {
     input = "",
     LLMConnector = null,
     similarityThreshold = 0.25,
+    topN = 4,
   }) {
     if (!namespace || !input || !LLMConnector)
       throw new Error("Invalid request to performSimilaritySearch.");
@@ -252,7 +254,8 @@ const Pinecone = {
       pineconeIndex,
       namespace,
       queryVector,
-      similarityThreshold
+      similarityThreshold,
+      topN
     );
 
     const sources = sourceDocuments.map((metadata, i) => {
diff --git a/server/utils/vectorDbProviders/qdrant/index.js b/server/utils/vectorDbProviders/qdrant/index.js
index 70c069e84..e7e00fe64 100644
--- a/server/utils/vectorDbProviders/qdrant/index.js
+++ b/server/utils/vectorDbProviders/qdrant/index.js
@@ -53,7 +53,8 @@ const QDrant = {
     _client,
     namespace,
     queryVector,
-    similarityThreshold = 0.25
+    similarityThreshold = 0.25,
+    topN = 4
   ) {
     const { client } = await this.connect();
     const result = {
@@ -64,7 +65,7 @@ const QDrant = {
 
     const responses = await client.search(namespace, {
       vector: queryVector,
-      limit: 4,
+      limit: topN,
       with_payload: true,
     });
 
@@ -301,6 +302,7 @@ const QDrant = {
     input = "",
     LLMConnector = null,
     similarityThreshold = 0.25,
+    topN = 4,
   }) {
     if (!namespace || !input || !LLMConnector)
       throw new Error("Invalid request to performSimilaritySearch.");
@@ -319,7 +321,8 @@ const QDrant = {
       client,
       namespace,
       queryVector,
-      similarityThreshold
+      similarityThreshold,
+      topN
     );
 
     const sources = sourceDocuments.map((metadata, i) => {
diff --git a/server/utils/vectorDbProviders/weaviate/index.js b/server/utils/vectorDbProviders/weaviate/index.js
index ac89315af..13668303f 100644
--- a/server/utils/vectorDbProviders/weaviate/index.js
+++ b/server/utils/vectorDbProviders/weaviate/index.js
@@ -80,7 +80,8 @@ const Weaviate = {
     client,
     namespace,
     queryVector,
-    similarityThreshold = 0.25
+    similarityThreshold = 0.25,
+    topN = 4
   ) {
     const result = {
       contextTexts: [],
@@ -95,7 +96,7 @@ const Weaviate = {
       .withClassName(camelCase(namespace))
       .withFields(`${fields} _additional { id certainty }`)
       .withNearVector({ vector: queryVector })
-      .withLimit(4)
+      .withLimit(topN)
       .do();
 
     const responses = queryResponse?.data?.Get?.[camelCase(namespace)];
@@ -347,6 +348,7 @@ const Weaviate = {
     input = "",
     LLMConnector = null,
     similarityThreshold = 0.25,
+    topN = 4,
   }) {
     if (!namespace || !input || !LLMConnector)
       throw new Error("Invalid request to performSimilaritySearch.");
@@ -365,7 +367,8 @@ const Weaviate = {
       client,
       namespace,
       queryVector,
-      similarityThreshold
+      similarityThreshold,
+      topN
     );
 
     const sources = sourceDocuments.map((metadata, i) => {
diff --git a/server/utils/vectorDbProviders/zilliz/index.js b/server/utils/vectorDbProviders/zilliz/index.js
index 31afab35a..159ea8094 100644
--- a/server/utils/vectorDbProviders/zilliz/index.js
+++ b/server/utils/vectorDbProviders/zilliz/index.js
@@ -266,6 +266,7 @@ const Zilliz = {
     input = "",
     LLMConnector = null,
     similarityThreshold = 0.25,
+    topN = 4,
   }) {
     if (!namespace || !input || !LLMConnector)
       throw new Error("Invalid request to performSimilaritySearch.");
@@ -284,7 +285,8 @@ const Zilliz = {
       client,
       namespace,
       queryVector,
-      similarityThreshold
+      similarityThreshold,
+      topN
     );
 
     const sources = sourceDocuments.map((metadata, i) => {
@@ -300,7 +302,8 @@ const Zilliz = {
     client,
     namespace,
     queryVector,
-    similarityThreshold = 0.25
+    similarityThreshold = 0.25,
+    topN = 4
   ) {
     const result = {
       contextTexts: [],
@@ -310,6 +313,7 @@ const Zilliz = {
     const response = await client.search({
       collection_name: namespace,
       vectors: queryVector,
+      limit: topN,
     });
     response.results.forEach((match) => {
       if (match.score < similarityThreshold) return;
-- 
GitLab