From 11b385633495456cba4c63c926018a37d200b18c Mon Sep 17 00:00:00 2001 From: Thuc Pham <51660321+thucpn@users.noreply.github.com> Date: Thu, 5 Sep 2024 14:11:31 +0700 Subject: [PATCH] feat: implement filters for MongoDBAtlasVectorSearch (#1142) --- .changeset/red-vans-taste.md | 5 + examples/mongodb/2_load_and_index.ts | 13 ++- examples/mongodb/3_query.ts | 14 ++- .../vectorStore/MongoDBAtlasVectorStore.ts | 99 ++++++++++++++----- 4 files changed, 105 insertions(+), 26 deletions(-) create mode 100644 .changeset/red-vans-taste.md diff --git a/.changeset/red-vans-taste.md b/.changeset/red-vans-taste.md new file mode 100644 index 000000000..f59bef834 --- /dev/null +++ b/.changeset/red-vans-taste.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +implement filters for MongoDBAtlasVectorSearch diff --git a/examples/mongodb/2_load_and_index.ts b/examples/mongodb/2_load_and_index.ts index d20d0279d..340013254 100644 --- a/examples/mongodb/2_load_and_index.ts +++ b/examples/mongodb/2_load_and_index.ts @@ -28,12 +28,23 @@ async function loadAndIndex() { "full_text", ]); + const FILTER_METADATA_FIELD = "content_type"; + + documents.forEach((document, index) => { + const contentType = ["tweet", "post", "story"][index % 3]; // assign a random content type to each document + document.metadata = { + ...document.metadata, + [FILTER_METADATA_FIELD]: contentType, + }; + }); + // create Atlas as a vector store const vectorStore = new MongoDBAtlasVectorSearch({ mongodbClient: client, dbName: databaseName, collectionName: vectorCollectionName, // this is where your embeddings will be stored indexName: indexName, // this is the name of the index you will need to create + indexedMetadataFields: [FILTER_METADATA_FIELD], // this is the field that will be used for the query }); // now create an index from all the Documents and store them in Atlas @@ -46,5 +57,3 @@ async function loadAndIndex() { } loadAndIndex().catch(console.error); - -// you can't query your index yet because you need to create a vector search index in mongodb's UI now diff --git a/examples/mongodb/3_query.ts b/examples/mongodb/3_query.ts index 1064b0036..f6bf62d1d 100644 --- a/examples/mongodb/3_query.ts +++ b/examples/mongodb/3_query.ts @@ -14,12 +14,24 @@ async function query() { dbName: process.env.MONGODB_DATABASE!, collectionName: process.env.MONGODB_VECTORS!, indexName: process.env.MONGODB_VECTOR_INDEX!, + indexedMetadataFields: ["content_type"], }); const index = await VectorStoreIndex.fromVectorStore(store); const retriever = index.asRetriever({ similarityTopK: 20 }); - const queryEngine = index.asQueryEngine({ retriever }); + const queryEngine = index.asQueryEngine({ + retriever, + preFilters: { + filters: [ + { + key: "content_type", + value: "story", // try "tweet" or "post" to see the difference + operator: "==", + }, + ], + }, + }); const result = await queryEngine.query({ query: "What does author receive when he was 11 years old?", // Isaac Asimov's "Foundation" for Christmas }); diff --git a/packages/llamaindex/src/storage/vectorStore/MongoDBAtlasVectorStore.ts b/packages/llamaindex/src/storage/vectorStore/MongoDBAtlasVectorStore.ts index 1577bd996..059528960 100644 --- a/packages/llamaindex/src/storage/vectorStore/MongoDBAtlasVectorStore.ts +++ b/packages/llamaindex/src/storage/vectorStore/MongoDBAtlasVectorStore.ts @@ -5,7 +5,10 @@ import { getEnv } from "@llamaindex/env"; import type { BulkWriteOptions, Collection } from "mongodb"; import { MongoClient } from "mongodb"; import { + FilterCondition, VectorStoreBase, + type FilterOperator, + type MetadataFilter, type MetadataFilters, type VectorStoreNoEmbedModel, type VectorStoreQuery, @@ -13,15 +16,51 @@ import { } from "./types.js"; import { metadataDictToNode, nodeToMetadata } from "./utils.js"; -// Utility function to convert metadata filters to MongoDB filter -function toMongoDBFilter( - standardFilters: MetadataFilters, -): Record<string, any> { - const filters: Record<string, any> = {}; - for (const filter of standardFilters?.filters ?? []) { - filters[filter.key] = filter.value; +// define your Atlas Search index. See detail https://www.mongodb.com/docs/atlas/atlas-search/field-types/knn-vector/ +const DEFAULT_EMBEDDING_DEFINITION = { + type: "knnVector", + dimensions: 1536, + similarity: "cosine", +}; + +function mapLcMqlFilterOperators(operator: string): string { + const operatorMap: { [key in FilterOperator]?: string } = { + "==": "$eq", + "<": "$lt", + "<=": "$lte", + ">": "$gt", + ">=": "$gte", + "!=": "$ne", + in: "$in", + nin: "$nin", + }; + const mqlOperator = operatorMap[operator as FilterOperator]; + if (!mqlOperator) throw new Error(`Unsupported operator: ${operator}`); + return mqlOperator; +} + +function toMongoDBFilter(filters?: MetadataFilters): Record<string, any> { + if (!filters) return {}; + + const createFilterObject = (mf: MetadataFilter) => ({ + [mf.key]: { + [mapLcMqlFilterOperators(mf.operator)]: mf.value, + }, + }); + + if (filters.filters.length === 1) { + return createFilterObject(filters.filters[0]); + } + + if (filters.condition === FilterCondition.AND) { + return { $and: filters.filters.map(createFilterObject) }; + } + + if (filters.condition === FilterCondition.OR) { + return { $or: filters.filters.map(createFilterObject) }; } - return filters; + + throw new Error("filters condition not recognized. Must be AND or OR"); } /** @@ -38,6 +77,8 @@ export class MongoDBAtlasVectorSearch dbName: string; collectionName: string; autoCreateIndex: boolean; + embeddingDefinition: Record<string, unknown>; + indexedMetadataFields: string[]; /** * The used MongoClient. If not given, a new MongoClient is created based on the MONGODB_URI env variable. @@ -98,26 +139,14 @@ export class MongoDBAtlasVectorSearch numCandidates: (query: VectorStoreQuery) => number; private collection?: Collection; - // define your Atlas Search index. See detail https://www.mongodb.com/docs/atlas/atlas-search/field-types/knn-vector/ - readonly SEARCH_INDEX_DEFINITION = { - mappings: { - dynamic: true, - fields: { - embedding: { - type: "knnVector", - dimensions: 1536, - similarity: "cosine", - }, - }, - }, - }; - constructor( init: Partial<MongoDBAtlasVectorSearch> & { dbName: string; collectionName: string; embedModel?: BaseEmbedding; autoCreateIndex?: boolean; + indexedMetadataFields?: string[]; + embeddingDefinition?: Record<string, unknown>; }, ) { super(init.embedModel); @@ -136,6 +165,11 @@ export class MongoDBAtlasVectorSearch this.dbName = init.dbName ?? "default_db"; this.collectionName = init.collectionName ?? "default_collection"; this.autoCreateIndex = init.autoCreateIndex ?? true; + this.indexedMetadataFields = init.indexedMetadataFields ?? []; + this.embeddingDefinition = { + ...DEFAULT_EMBEDDING_DEFINITION, + ...(init.embeddingDefinition ?? {}), + }; this.indexName = init.indexName ?? "default"; this.embeddingKey = init.embeddingKey ?? "embedding"; this.idKey = init.idKey ?? "id"; @@ -161,9 +195,21 @@ export class MongoDBAtlasVectorSearch (index) => index.name === this.indexName, ); if (!indexExists) { + const additionalDefinition: Record<string, { type: string }> = {}; + this.indexedMetadataFields.forEach((field) => { + additionalDefinition[field] = { type: "token" }; + }); await this.collection.createSearchIndex({ name: this.indexName, - definition: this.SEARCH_INDEX_DEFINITION, + definition: { + mappings: { + dynamic: true, + fields: { + embedding: this.embeddingDefinition, + ...additionalDefinition, + }, + }, + }, }); } } @@ -189,11 +235,18 @@ export class MongoDBAtlasVectorSearch this.flatMetadata, ); + // Include the specified metadata fields in the top level of the document (to help filter) + const populatedMetadata: Record<string, unknown> = {}; + for (const field of this.indexedMetadataFields) { + populatedMetadata[field] = metadata[field]; + } + return { [this.idKey]: node.id_, [this.embeddingKey]: node.getEmbedding(), [this.textKey]: node.getContent(MetadataMode.NONE) || "", [this.metadataKey]: metadata, + ...populatedMetadata, }; }); -- GitLab