diff --git a/frontend/src/pages/WorkspaceSettings/VectorDatabase/VectorSearchMode/index.jsx b/frontend/src/pages/WorkspaceSettings/VectorDatabase/VectorSearchMode/index.jsx new file mode 100644 index 0000000000000000000000000000000000000000..5e5816cda8d2d3c7dc2097d4aa2da6a4ffadaef2 --- /dev/null +++ b/frontend/src/pages/WorkspaceSettings/VectorDatabase/VectorSearchMode/index.jsx @@ -0,0 +1,51 @@ +import { useState } from "react"; + +// We dont support all vectorDBs yet for reranking due to complexities of how each provider +// returns information. We need to normalize the response data so Reranker can be used for each provider. +const supportedVectorDBs = ["lancedb"]; +const hint = { + default: { + title: "Default", + description: + "This is the fastest performance, but may not return the most relevant results leading to model hallucinations.", + }, + rerank: { + title: "Accuracy Optimized", + description: + "LLM responses may take longer to generate, but your responses will be more accurate and relevant.", + }, +}; + +export default function VectorSearchMode({ workspace, setHasChanges }) { + const [selection, setSelection] = useState( + workspace?.vectorSearchMode ?? "default" + ); + if (!workspace?.vectorDB || !supportedVectorDBs.includes(workspace?.vectorDB)) + return null; + + return ( + <div> + <div className="flex flex-col"> + <label htmlFor="name" className="block input-label"> + Search Preference + </label> + </div> + <select + name="vectorSearchMode" + value={selection} + className="border-none bg-theme-settings-input-bg text-white text-sm mt-2 rounded-lg focus:outline-primary-button active:outline-primary-button outline-none block w-full p-2.5" + onChange={(e) => { + setSelection(e.target.value); + setHasChanges(true); + }} + required={true} + > + <option value="default">Default</option> + <option value="rerank">Accuracy Optimized</option> + </select> + <p className="text-white text-opacity-60 text-xs font-medium py-1.5"> + {hint[selection]?.description} + </p> + </div> + ); +} diff --git a/frontend/src/pages/WorkspaceSettings/VectorDatabase/index.jsx b/frontend/src/pages/WorkspaceSettings/VectorDatabase/index.jsx index 97d63291cdad0e96b5e3443a4783ceb62f48f628..7d7d44e8f44a92e423bb494b5b954d8b0b4e387c 100644 --- a/frontend/src/pages/WorkspaceSettings/VectorDatabase/index.jsx +++ b/frontend/src/pages/WorkspaceSettings/VectorDatabase/index.jsx @@ -7,6 +7,7 @@ import MaxContextSnippets from "./MaxContextSnippets"; import DocumentSimilarityThreshold from "./DocumentSimilarityThreshold"; import ResetDatabase from "./ResetDatabase"; import VectorCount from "./VectorCount"; +import VectorSearchMode from "./VectorSearchMode"; export default function VectorDatabase({ workspace }) { const [hasChanges, setHasChanges] = useState(false); @@ -43,6 +44,7 @@ export default function VectorDatabase({ workspace }) { <VectorDBIdentifier workspace={workspace} /> <VectorCount reload={true} workspace={workspace} /> </div> + <VectorSearchMode workspace={workspace} setHasChanges={setHasChanges} /> <MaxContextSnippets workspace={workspace} setHasChanges={setHasChanges} /> <DocumentSimilarityThreshold workspace={workspace} diff --git a/frontend/src/pages/WorkspaceSettings/index.jsx b/frontend/src/pages/WorkspaceSettings/index.jsx index 9d441ae27e7dab9668e8b9a4e683b37af2b7a186..c4ea012577cbc8772cd46c9a7e79d76a1e8941a8 100644 --- a/frontend/src/pages/WorkspaceSettings/index.jsx +++ b/frontend/src/pages/WorkspaceSettings/index.jsx @@ -23,6 +23,7 @@ import Members from "./Members"; import WorkspaceAgentConfiguration from "./AgentConfig"; import useUser from "@/hooks/useUser"; import { useTranslation } from "react-i18next"; +import System from "@/models/system"; const TABS = { "general-appearance": GeneralAppearance, @@ -59,9 +60,11 @@ function ShowWorkspaceChat() { return; } + const _settings = await System.keys(); const suggestedMessages = await Workspace.getSuggestedMessages(slug); setWorkspace({ ..._workspace, + vectorDB: _settings?.VectorDB, suggestedMessages, }); setLoading(false); diff --git a/server/endpoints/api/workspace/index.js b/server/endpoints/api/workspace/index.js index d816fae51a8b48f1762408acf8e73bb4f70a688a..069d695914d5937402cd4b3584d6f23527d761dd 100644 --- a/server/endpoints/api/workspace/index.js +++ b/server/endpoints/api/workspace/index.js @@ -961,6 +961,7 @@ function apiWorkspaceEndpoints(app) { LLMConnector: getLLMProvider(), similarityThreshold: parseSimilarityThreshold(), topN: parseTopN(), + rerank: workspace?.vectorSearchMode === "rerank", }); response.status(200).json({ diff --git a/server/models/workspace.js b/server/models/workspace.js index 4471f67e7f5f86c41250a8deb674b62ad4659a8d..91d659306a107d8b4958c6b13c55598f877c2ff7 100644 --- a/server/models/workspace.js +++ b/server/models/workspace.js @@ -34,6 +34,7 @@ const Workspace = { "agentProvider", "agentModel", "queryRefusalResponse", + "vectorSearchMode", ], validations: { @@ -99,6 +100,15 @@ const Workspace = { if (!value || typeof value !== "string") return null; return String(value); }, + vectorSearchMode: (value) => { + if ( + !value || + typeof value !== "string" || + !["default", "rerank"].includes(value) + ) + return "default"; + return value; + }, }, /** diff --git a/server/prisma/migrations/20250102204948_init/migration.sql b/server/prisma/migrations/20250102204948_init/migration.sql new file mode 100644 index 0000000000000000000000000000000000000000..788409bfa1a3109fec5a3bbd7488c3c2593c2507 --- /dev/null +++ b/server/prisma/migrations/20250102204948_init/migration.sql @@ -0,0 +1,2 @@ +-- AlterTable +ALTER TABLE "workspaces" ADD COLUMN "vectorSearchMode" TEXT DEFAULT 'default'; diff --git a/server/prisma/schema.prisma b/server/prisma/schema.prisma index 143646e6579f69f4add5a0df9c03ecca33b1ab8f..37c82d4ddfcf565679df6fd1a5e7a5b6cecda25f 100644 --- a/server/prisma/schema.prisma +++ b/server/prisma/schema.prisma @@ -137,6 +137,7 @@ model workspaces { agentProvider String? agentModel String? queryRefusalResponse String? + vectorSearchMode String? @default("default") workspace_users workspace_users[] documents workspace_documents[] workspace_suggested_messages workspace_suggested_messages[] diff --git a/server/storage/models/.gitignore b/server/storage/models/.gitignore index e669b51b277a7fcfbef1dab2ba71832c4b72097a..7a8f66d8ff40e02838ae333d767810bd1f4764b6 100644 --- a/server/storage/models/.gitignore +++ b/server/storage/models/.gitignore @@ -3,4 +3,5 @@ downloaded/* !downloaded/.placeholder openrouter apipie -novita \ No newline at end of file +novita +mixedbread-ai* \ No newline at end of file diff --git a/server/utils/EmbeddingRerankers/native/index.js b/server/utils/EmbeddingRerankers/native/index.js new file mode 100644 index 0000000000000000000000000000000000000000..f3d468cd7b2b1e02b6339a835a80464ebbd36db5 --- /dev/null +++ b/server/utils/EmbeddingRerankers/native/index.js @@ -0,0 +1,153 @@ +const path = require("path"); +const fs = require("fs"); + +class NativeEmbeddingReranker { + static #model = null; + static #tokenizer = null; + static #transformers = null; + + constructor() { + // An alternative model to the mixedbread-ai/mxbai-rerank-xsmall-v1 model (speed on CPU is much slower for this model @ 18docs = 6s) + // Model Card: https://huggingface.co/Xenova/ms-marco-MiniLM-L-6-v2 (speed on CPU is much faster @ 18docs = 1.6s) + this.model = "Xenova/ms-marco-MiniLM-L-6-v2"; + this.cacheDir = path.resolve( + process.env.STORAGE_DIR + ? path.resolve(process.env.STORAGE_DIR, `models`) + : path.resolve(__dirname, `../../../storage/models`) + ); + this.modelPath = path.resolve(this.cacheDir, ...this.model.split("/")); + // Make directory when it does not exist in existing installations + if (!fs.existsSync(this.cacheDir)) fs.mkdirSync(this.cacheDir); + this.log("Initialized"); + } + + log(text, ...args) { + console.log(`\x1b[36m[NativeEmbeddingReranker]\x1b[0m ${text}`, ...args); + } + + /** + * This function will preload the reranker suite and tokenizer. + * This is useful for reducing the latency of the first rerank call and pre-downloading the models and such + * to avoid having to wait for the models to download on the first rerank call. + */ + async preload() { + try { + this.log(`Preloading reranker suite...`); + await this.initClient(); + this.log( + `Preloaded reranker suite. Reranking is available as a service now.` + ); + return; + } catch (e) { + console.error(e); + this.log( + `Failed to preload reranker suite. Reranking will be available on the first rerank call.` + ); + return; + } + } + + async initClient() { + if (NativeEmbeddingReranker.#transformers) { + this.log(`Reranker suite already initialized - reusing.`); + return; + } + + await import("@xenova/transformers").then( + async ({ AutoModelForSequenceClassification, AutoTokenizer }) => { + this.log(`Loading reranker suite...`); + NativeEmbeddingReranker.#transformers = { + AutoModelForSequenceClassification, + AutoTokenizer, + }; + await this.#getPreTrainedModel(); + await this.#getPreTrainedTokenizer(); + } + ); + return; + } + + async #getPreTrainedModel() { + if (NativeEmbeddingReranker.#model) { + this.log(`Loading model from singleton...`); + return NativeEmbeddingReranker.#model; + } + + const model = + await NativeEmbeddingReranker.#transformers.AutoModelForSequenceClassification.from_pretrained( + this.model, + { + progress_callback: (p) => + p.status === "progress" && + this.log(`Loading model ${this.model}... ${p?.progress}%`), + cache_dir: this.cacheDir, + } + ); + this.log(`Loaded model ${this.model}`); + NativeEmbeddingReranker.#model = model; + return model; + } + + async #getPreTrainedTokenizer() { + if (NativeEmbeddingReranker.#tokenizer) { + this.log(`Loading tokenizer from singleton...`); + return NativeEmbeddingReranker.#tokenizer; + } + + const tokenizer = + await NativeEmbeddingReranker.#transformers.AutoTokenizer.from_pretrained( + this.model, + { + progress_callback: (p) => + p.status === "progress" && + this.log(`Loading tokenizer ${this.model}... ${p?.progress}%`), + cache_dir: this.cacheDir, + } + ); + this.log(`Loaded tokenizer ${this.model}`); + NativeEmbeddingReranker.#tokenizer = tokenizer; + return tokenizer; + } + + /** + * Reranks a list of documents based on the query. + * @param {string} query - The query to rerank the documents against. + * @param {{text: string}[]} documents - The list of document text snippets to rerank. Should be output from a vector search. + * @param {Object} options - The options for the reranking. + * @param {number} options.topK - The number of top documents to return. + * @returns {Promise<any[]>} - The reranked list of documents. + */ + async rerank(query, documents, options = { topK: 4 }) { + await this.initClient(); + const model = NativeEmbeddingReranker.#model; + const tokenizer = NativeEmbeddingReranker.#tokenizer; + + const start = Date.now(); + this.log(`Reranking ${documents.length} documents...`); + const inputs = tokenizer(new Array(documents.length).fill(query), { + text_pair: documents.map((doc) => doc.text), + padding: true, + truncation: true, + }); + const { logits } = await model(inputs); + const reranked = logits + .sigmoid() + .tolist() + .map(([score], i) => ({ + rerank_corpus_id: i, + rerank_score: score, + ...documents[i], + })) + .sort((a, b) => b.rerank_score - a.rerank_score) + .slice(0, options.topK); + + this.log( + `Reranking ${documents.length} documents to top ${options.topK} took ${Date.now() - start}ms` + ); + return reranked; + } +} + +module.exports = { + NativeEmbeddingReranker, +}; diff --git a/server/utils/agents/aibitat/plugins/memory.js b/server/utils/agents/aibitat/plugins/memory.js index 4f43d0ec460beff523025752285758a38e23890c..df52843015fe22bbf3eb5d46839c7b07c80ad1d4 100644 --- a/server/utils/agents/aibitat/plugins/memory.js +++ b/server/utils/agents/aibitat/plugins/memory.js @@ -95,6 +95,7 @@ const memory = { input: query, LLMConnector, topN: workspace?.topN ?? 4, + rerank: workspace?.vectorSearchMode === "rerank", }); if (contextTexts.length === 0) { diff --git a/server/utils/chats/apiChatHandler.js b/server/utils/chats/apiChatHandler.js index 7ba45fed62d3498d1bed8c638ca8febb5690028d..11421ea128ed7c5d987fd1342eca11abc7aabf35 100644 --- a/server/utils/chats/apiChatHandler.js +++ b/server/utils/chats/apiChatHandler.js @@ -180,6 +180,7 @@ async function chatSync({ similarityThreshold: workspace?.similarityThreshold, topN: workspace?.topN, filterIdentifiers: pinnedDocIdentifiers, + rerank: workspace?.vectorSearchMode === "rerank", }) : { contextTexts: [], @@ -480,6 +481,7 @@ async function streamChat({ similarityThreshold: workspace?.similarityThreshold, topN: workspace?.topN, filterIdentifiers: pinnedDocIdentifiers, + rerank: workspace?.vectorSearchMode === "rerank", }) : { contextTexts: [], diff --git a/server/utils/chats/embed.js b/server/utils/chats/embed.js index 7196d161e2b50f2ac340e4f693d693ab4b5b2e98..550a460f8743aadbdda37b75707bea18a78175c6 100644 --- a/server/utils/chats/embed.js +++ b/server/utils/chats/embed.js @@ -93,6 +93,7 @@ async function streamChatWithForEmbed( similarityThreshold: embed.workspace?.similarityThreshold, topN: embed.workspace?.topN, filterIdentifiers: pinnedDocIdentifiers, + rerank: embed.workspace?.vectorSearchMode === "rerank", }) : { contextTexts: [], diff --git a/server/utils/chats/openaiCompatible.js b/server/utils/chats/openaiCompatible.js index a76347bf71e1b41d3cbd52553a1418208888f640..fcae9782767a9fcb5b00a0786f214e4810e62b09 100644 --- a/server/utils/chats/openaiCompatible.js +++ b/server/utils/chats/openaiCompatible.js @@ -89,6 +89,7 @@ async function chatSync({ similarityThreshold: workspace?.similarityThreshold, topN: workspace?.topN, filterIdentifiers: pinnedDocIdentifiers, + rerank: workspace?.vectorSearchMode === "rerank", }) : { contextTexts: [], @@ -304,6 +305,7 @@ async function streamChat({ similarityThreshold: workspace?.similarityThreshold, topN: workspace?.topN, filterIdentifiers: pinnedDocIdentifiers, + rerank: workspace?.vectorSearchMode === "rerank", }) : { contextTexts: [], diff --git a/server/utils/chats/stream.js b/server/utils/chats/stream.js index 35b0c191e6b1830e5515544ce060567b9454b4b4..bd81f130898d0d6abd62c33cd5104ae772ab6fa4 100644 --- a/server/utils/chats/stream.js +++ b/server/utils/chats/stream.js @@ -139,6 +139,7 @@ async function streamChatWithWorkspace( similarityThreshold: workspace?.similarityThreshold, topN: workspace?.topN, filterIdentifiers: pinnedDocIdentifiers, + rerank: workspace?.vectorSearchMode === "rerank", }) : { contextTexts: [], diff --git a/server/utils/helpers/index.js b/server/utils/helpers/index.js index fa47f9cf78e0a2c4b215ac984dfaabf99d95f2de..544bd36ff102adf3c6dc781b3f33646c0b4f3011 100644 --- a/server/utils/helpers/index.js +++ b/server/utils/helpers/index.js @@ -56,6 +56,7 @@ * @property {Function} totalVectors - Returns the total number of vectors in the database. * @property {Function} namespaceCount - Returns the count of vectors in a given namespace. * @property {Function} similarityResponse - Performs a similarity search on a given namespace. + * @property {Function} rerankedSimilarityResponse - Performs a similarity search on a given namespace with reranking (if supported by provider). * @property {Function} namespace - Retrieves the specified namespace collection. * @property {Function} hasNamespace - Checks if a namespace exists. * @property {Function} namespaceExists - Verifies if a namespace exists in the client. diff --git a/server/utils/vectorDbProviders/lance/index.js b/server/utils/vectorDbProviders/lance/index.js index 78a32b80c780eb2fe55c2884f60386b84cca0214..e3f285478b7aaf66359bb2c5c6f2678ab4c9ec96 100644 --- a/server/utils/vectorDbProviders/lance/index.js +++ b/server/utils/vectorDbProviders/lance/index.js @@ -5,6 +5,7 @@ const { SystemSettings } = require("../../../models/systemSettings"); const { storeVectorResult, cachedVectorInformation } = require("../../files"); const { v4: uuidv4 } = require("uuid"); const { sourceIdentifier } = require("../../chats"); +const { NativeEmbeddingReranker } = require("../../EmbeddingRerankers/native"); /** * LancedDB Client connection object @@ -57,6 +58,91 @@ const LanceDb = { const table = await client.openTable(_namespace); return (await table.countRows()) || 0; }, + /** + * Performs a SimilaritySearch + Reranking on a namespace. + * @param {Object} params - The parameters for the rerankedSimilarityResponse. + * @param {Object} params.client - The vectorDB client. + * @param {string} params.namespace - The namespace to search in. + * @param {string} params.query - The query to search for (plain text). + * @param {number[]} params.queryVector - The vector of the query. + * @param {number} params.similarityThreshold - The threshold for similarity. + * @param {number} params.topN - the number of results to return from this process. + * @param {string[]} params.filterIdentifiers - The identifiers of the documents to filter out. + * @returns + */ + rerankedSimilarityResponse: async function ({ + client, + namespace, + query, + queryVector, + topN = 4, + similarityThreshold = 0.25, + filterIdentifiers = [], + }) { + const reranker = new NativeEmbeddingReranker(); + const collection = await client.openTable(namespace); + const totalEmbeddings = await this.namespaceCount(namespace); + const result = { + contextTexts: [], + sourceDocuments: [], + scores: [], + }; + + /** + * For reranking, we want to work with a larger number of results than the topN. + * This is because the reranker can only rerank the results it it given and we dont auto-expand the results. + * We want to give the reranker a larger number of results to work with. + * + * However, we cannot make this boundless as reranking is expensive and time consuming. + * So we limit the number of results to a maximum of 50 and a minimum of 10. + * This is a good balance between the number of results to rerank and the cost of reranking + * and ensures workspaces with 10K embeddings will still rerank within a reasonable timeframe on base level hardware. + * + * Benchmarks: + * On Intel Mac: 2.6 GHz 6-Core Intel Core i7 - 20 docs reranked in ~5.2 sec + */ + const searchLimit = Math.max( + 10, + Math.min(50, Math.ceil(totalEmbeddings * 0.1)) + ); + const vectorSearchResults = await collection + .vectorSearch(queryVector) + .distanceType("cosine") + .limit(searchLimit) + .toArray(); + + await reranker + .rerank(query, vectorSearchResults, { topK: topN }) + .then((rerankResults) => { + rerankResults.forEach((item) => { + if (this.distanceToSimilarity(item._distance) < similarityThreshold) + return; + const { vector: _, ...rest } = item; + if (filterIdentifiers.includes(sourceIdentifier(rest))) { + console.log( + "LanceDB: A source was filtered from context as it's parent document is pinned." + ); + return; + } + const score = + item?.rerank_score || this.distanceToSimilarity(item._distance); + + result.contextTexts.push(rest.text); + result.sourceDocuments.push({ + ...rest, + score, + }); + result.scores.push(score); + }); + }) + .catch((e) => { + console.error(e); + console.error("LanceDB::rerankedSimilarityResponse", e.message); + }); + + return result; + }, + /** * Performs a SimilaritySearch on a give LanceDB namespace. * @param {Object} params @@ -300,6 +386,7 @@ const LanceDb = { similarityThreshold = 0.25, topN = 4, filterIdentifiers = [], + rerank = false, }) { if (!namespace || !input || !LLMConnector) throw new Error("Invalid request to performSimilaritySearch."); @@ -314,15 +401,26 @@ const LanceDb = { } const queryVector = await LLMConnector.embedTextInput(input); - const { contextTexts, sourceDocuments } = await this.similarityResponse({ - client, - namespace, - queryVector, - similarityThreshold, - topN, - filterIdentifiers, - }); + const result = rerank + ? await this.rerankedSimilarityResponse({ + client, + namespace, + query: input, + queryVector, + similarityThreshold, + topN, + filterIdentifiers, + }) + : await this.similarityResponse({ + client, + namespace, + queryVector, + similarityThreshold, + topN, + filterIdentifiers, + }); + const { contextTexts, sourceDocuments } = result; const sources = sourceDocuments.map((metadata, i) => { return { metadata: { ...metadata, text: contextTexts[i] } }; });