diff --git a/frontend/src/components/Modals/MangeWorkspace/Settings/index.jsx b/frontend/src/components/Modals/MangeWorkspace/Settings/index.jsx index da0e7b9f02a815659209a356a27b0d5c5e5bbe57..b288dc6c14b0456b8063c5d4b20de1f89328b281 100644 --- a/frontend/src/components/Modals/MangeWorkspace/Settings/index.jsx +++ b/frontend/src/components/Modals/MangeWorkspace/Settings/index.jsx @@ -21,6 +21,9 @@ function castToType(key, value) { similarityThreshold: { cast: (value) => parseFloat(value), }, + topN: { + cast: (value) => Number(value), + }, }; if (!definitions.hasOwnProperty(key)) return value; @@ -236,6 +239,38 @@ export default function WorkspaceSettings({ active, workspace, settings }) { autoComplete="off" onChange={() => setHasChanges(true)} /> + + <div className="mt-4"> + <div className="flex flex-col"> + <label + htmlFor="name" + className="block text-sm font-medium text-white" + > + Max Context Snippets + </label> + <p className="text-white text-opacity-60 text-xs font-medium py-1.5"> + This setting controls the maximum amount of context + snippets the will be sent to the LLM for per chat or + query. + <br /> + <i>Recommended: 4</i> + </p> + </div> + <input + name="topN" + type="number" + min={1} + max={12} + step={1} + onWheel={(e) => e.target.blur()} + defaultValue={workspace?.topN ?? 4} + className="bg-zinc-900 text-white text-sm rounded-lg focus:ring-blue-500 focus:border-blue-500 block w-full p-2.5" + placeholder="4" + required={true} + autoComplete="off" + onChange={() => setHasChanges(true)} + /> + </div> <div className="mt-4"> <div className="flex flex-col"> <label diff --git a/server/models/workspace.js b/server/models/workspace.js index 6de8053e9544f84810a1961b16ea17c61bf01283..9169d193dd8faa8478b4e1f92a31aa7b0311f369 100644 --- a/server/models/workspace.js +++ b/server/models/workspace.js @@ -15,6 +15,7 @@ const Workspace = { "openAiPrompt", "similarityThreshold", "chatModel", + "topN", ], new: async function (name = null, creatorId = null) { diff --git a/server/prisma/migrations/20240118201333_init/migration.sql b/server/prisma/migrations/20240118201333_init/migration.sql new file mode 100644 index 0000000000000000000000000000000000000000..aaf47f7af69ebdce6b199eecf53c82ea547631b1 --- /dev/null +++ b/server/prisma/migrations/20240118201333_init/migration.sql @@ -0,0 +1,2 @@ +-- AlterTable +ALTER TABLE "workspaces" ADD COLUMN "topN" INTEGER DEFAULT 4 CHECK ("topN" > 0); diff --git a/server/prisma/schema.prisma b/server/prisma/schema.prisma index 2f632a46ab8272cc308441053746c92d8df501b2..15ddbe23bd776d9ecf328a813bcc55eb9a15ed86 100644 --- a/server/prisma/schema.prisma +++ b/server/prisma/schema.prisma @@ -94,6 +94,7 @@ model workspaces { openAiPrompt String? similarityThreshold Float? @default(0.25) chatModel String? + topN Int? @default(4) workspace_users workspace_users[] documents workspace_documents[] } diff --git a/server/utils/chats/index.js b/server/utils/chats/index.js index 764c7795a64664f072c99e14179cf6acc844bd84..102189cbe844fb105dc319cff241863c73b8cdbf 100644 --- a/server/utils/chats/index.js +++ b/server/utils/chats/index.js @@ -129,6 +129,7 @@ async function chatWithWorkspace( input: message, LLMConnector, similarityThreshold: workspace?.similarityThreshold, + topN: workspace?.topN, }); // Failed similarity search. diff --git a/server/utils/chats/stream.js b/server/utils/chats/stream.js index cff565ed6e2bf7ad647adc3cb4a3d4d2a7c6ca1c..29386531e07720f62d57de222a500ad679d58cef 100644 --- a/server/utils/chats/stream.js +++ b/server/utils/chats/stream.js @@ -92,6 +92,7 @@ async function streamChatWithWorkspace( input: message, LLMConnector, similarityThreshold: workspace?.similarityThreshold, + topN: workspace?.topN, }); // Failed similarity search. diff --git a/server/utils/vectorDbProviders/chroma/index.js b/server/utils/vectorDbProviders/chroma/index.js index 28af39e66636907214b3ab0769cbcc72fb6a23c0..23f173ddc55057ce3b0fccd47766c84ae805251b 100644 --- a/server/utils/vectorDbProviders/chroma/index.js +++ b/server/utils/vectorDbProviders/chroma/index.js @@ -67,7 +67,8 @@ const Chroma = { client, namespace, queryVector, - similarityThreshold = 0.25 + similarityThreshold = 0.25, + topN = 4 ) { const collection = await client.getCollection({ name: namespace }); const result = { @@ -78,7 +79,7 @@ const Chroma = { const response = await collection.query({ queryEmbeddings: queryVector, - nResults: 4, + nResults: topN, }); response.ids[0].forEach((_, i) => { if ( @@ -271,6 +272,7 @@ const Chroma = { input = "", LLMConnector = null, similarityThreshold = 0.25, + topN = 4, }) { if (!namespace || !input || !LLMConnector) throw new Error("Invalid request to performSimilaritySearch."); @@ -289,7 +291,8 @@ const Chroma = { client, namespace, queryVector, - similarityThreshold + similarityThreshold, + topN ); const sources = sourceDocuments.map((metadata, i) => { diff --git a/server/utils/vectorDbProviders/lance/index.js b/server/utils/vectorDbProviders/lance/index.js index 8f243cf9b2243e60f0fec23fe4eea59cb05b6e59..67705c00e056c2ead3f2e94e9ab9227819cc77fd 100644 --- a/server/utils/vectorDbProviders/lance/index.js +++ b/server/utils/vectorDbProviders/lance/index.js @@ -62,7 +62,8 @@ const LanceDb = { client, namespace, queryVector, - similarityThreshold = 0.25 + similarityThreshold = 0.25, + topN = 4 ) { const collection = await client.openTable(namespace); const result = { @@ -74,7 +75,7 @@ const LanceDb = { const response = await collection .search(queryVector) .metricType("cosine") - .limit(5) + .limit(topN) .execute(); response.forEach((item) => { @@ -240,6 +241,7 @@ const LanceDb = { input = "", LLMConnector = null, similarityThreshold = 0.25, + topN = 4, }) { if (!namespace || !input || !LLMConnector) throw new Error("Invalid request to performSimilaritySearch."); @@ -258,7 +260,8 @@ const LanceDb = { client, namespace, queryVector, - similarityThreshold + similarityThreshold, + topN ); const sources = sourceDocuments.map((metadata, i) => { diff --git a/server/utils/vectorDbProviders/milvus/index.js b/server/utils/vectorDbProviders/milvus/index.js index 79a1324139d2bfa3d4e5ea966d34b3b611a9d51c..d1eec9f7116b646729846929381fe1026f5aefe1 100644 --- a/server/utils/vectorDbProviders/milvus/index.js +++ b/server/utils/vectorDbProviders/milvus/index.js @@ -265,6 +265,7 @@ const Milvus = { input = "", LLMConnector = null, similarityThreshold = 0.25, + topN = 4, }) { if (!namespace || !input || !LLMConnector) throw new Error("Invalid request to performSimilaritySearch."); @@ -283,7 +284,8 @@ const Milvus = { client, namespace, queryVector, - similarityThreshold + similarityThreshold, + topN ); const sources = sourceDocuments.map((metadata, i) => { @@ -299,7 +301,8 @@ const Milvus = { client, namespace, queryVector, - similarityThreshold = 0.25 + similarityThreshold = 0.25, + topN = 4 ) { const result = { contextTexts: [], @@ -309,6 +312,7 @@ const Milvus = { const response = await client.search({ collection_name: namespace, vectors: queryVector, + limit: topN, }); response.results.forEach((match) => { if (match.score < similarityThreshold) return; diff --git a/server/utils/vectorDbProviders/pinecone/index.js b/server/utils/vectorDbProviders/pinecone/index.js index 594a9aaf3dc3ff06ea11e2cbb4e4ce7de99864d8..260c0a257280cd235c21de9891ee9b54fe6ca079 100644 --- a/server/utils/vectorDbProviders/pinecone/index.js +++ b/server/utils/vectorDbProviders/pinecone/index.js @@ -44,7 +44,8 @@ const Pinecone = { index, namespace, queryVector, - similarityThreshold = 0.25 + similarityThreshold = 0.25, + topN = 4 ) { const result = { contextTexts: [], @@ -55,7 +56,7 @@ const Pinecone = { queryRequest: { namespace, vector: queryVector, - topK: 4, + topK: topN, includeMetadata: true, }, }); @@ -237,6 +238,7 @@ const Pinecone = { input = "", LLMConnector = null, similarityThreshold = 0.25, + topN = 4, }) { if (!namespace || !input || !LLMConnector) throw new Error("Invalid request to performSimilaritySearch."); @@ -252,7 +254,8 @@ const Pinecone = { pineconeIndex, namespace, queryVector, - similarityThreshold + similarityThreshold, + topN ); const sources = sourceDocuments.map((metadata, i) => { diff --git a/server/utils/vectorDbProviders/qdrant/index.js b/server/utils/vectorDbProviders/qdrant/index.js index 70c069e843308311e0fb8d93aa412478e64a9a0b..e7e00fe64d3e3e8207d91fcc2df5b8840b376161 100644 --- a/server/utils/vectorDbProviders/qdrant/index.js +++ b/server/utils/vectorDbProviders/qdrant/index.js @@ -53,7 +53,8 @@ const QDrant = { _client, namespace, queryVector, - similarityThreshold = 0.25 + similarityThreshold = 0.25, + topN = 4 ) { const { client } = await this.connect(); const result = { @@ -64,7 +65,7 @@ const QDrant = { const responses = await client.search(namespace, { vector: queryVector, - limit: 4, + limit: topN, with_payload: true, }); @@ -301,6 +302,7 @@ const QDrant = { input = "", LLMConnector = null, similarityThreshold = 0.25, + topN = 4, }) { if (!namespace || !input || !LLMConnector) throw new Error("Invalid request to performSimilaritySearch."); @@ -319,7 +321,8 @@ const QDrant = { client, namespace, queryVector, - similarityThreshold + similarityThreshold, + topN ); const sources = sourceDocuments.map((metadata, i) => { diff --git a/server/utils/vectorDbProviders/weaviate/index.js b/server/utils/vectorDbProviders/weaviate/index.js index ac89315af7f2f8b760cf5f1c3996c2f677bee3c2..13668303f9a520b613ec3dcbf23c596c01816fb8 100644 --- a/server/utils/vectorDbProviders/weaviate/index.js +++ b/server/utils/vectorDbProviders/weaviate/index.js @@ -80,7 +80,8 @@ const Weaviate = { client, namespace, queryVector, - similarityThreshold = 0.25 + similarityThreshold = 0.25, + topN = 4 ) { const result = { contextTexts: [], @@ -95,7 +96,7 @@ const Weaviate = { .withClassName(camelCase(namespace)) .withFields(`${fields} _additional { id certainty }`) .withNearVector({ vector: queryVector }) - .withLimit(4) + .withLimit(topN) .do(); const responses = queryResponse?.data?.Get?.[camelCase(namespace)]; @@ -347,6 +348,7 @@ const Weaviate = { input = "", LLMConnector = null, similarityThreshold = 0.25, + topN = 4, }) { if (!namespace || !input || !LLMConnector) throw new Error("Invalid request to performSimilaritySearch."); @@ -365,7 +367,8 @@ const Weaviate = { client, namespace, queryVector, - similarityThreshold + similarityThreshold, + topN ); const sources = sourceDocuments.map((metadata, i) => { diff --git a/server/utils/vectorDbProviders/zilliz/index.js b/server/utils/vectorDbProviders/zilliz/index.js index 31afab35a30acba66436355a88787f00bf1d5f6b..159ea809475449f84f104e22759c07fc4854865a 100644 --- a/server/utils/vectorDbProviders/zilliz/index.js +++ b/server/utils/vectorDbProviders/zilliz/index.js @@ -266,6 +266,7 @@ const Zilliz = { input = "", LLMConnector = null, similarityThreshold = 0.25, + topN = 4, }) { if (!namespace || !input || !LLMConnector) throw new Error("Invalid request to performSimilaritySearch."); @@ -284,7 +285,8 @@ const Zilliz = { client, namespace, queryVector, - similarityThreshold + similarityThreshold, + topN ); const sources = sourceDocuments.map((metadata, i) => { @@ -300,7 +302,8 @@ const Zilliz = { client, namespace, queryVector, - similarityThreshold = 0.25 + similarityThreshold = 0.25, + topN = 4 ) { const result = { contextTexts: [], @@ -310,6 +313,7 @@ const Zilliz = { const response = await client.search({ collection_name: namespace, vectors: queryVector, + limit: topN, }); response.results.forEach((match) => { if (match.score < similarityThreshold) return;