From 11b385633495456cba4c63c926018a37d200b18c Mon Sep 17 00:00:00 2001
From: Thuc Pham <51660321+thucpn@users.noreply.github.com>
Date: Thu, 5 Sep 2024 14:11:31 +0700
Subject: [PATCH] feat: implement filters for MongoDBAtlasVectorSearch (#1142)

---
 .changeset/red-vans-taste.md                  |  5 +
 examples/mongodb/2_load_and_index.ts          | 13 ++-
 examples/mongodb/3_query.ts                   | 14 ++-
 .../vectorStore/MongoDBAtlasVectorStore.ts    | 99 ++++++++++++++-----
 4 files changed, 105 insertions(+), 26 deletions(-)
 create mode 100644 .changeset/red-vans-taste.md

diff --git a/.changeset/red-vans-taste.md b/.changeset/red-vans-taste.md
new file mode 100644
index 000000000..f59bef834
--- /dev/null
+++ b/.changeset/red-vans-taste.md
@@ -0,0 +1,5 @@
+---
+"llamaindex": patch
+---
+
+implement filters for MongoDBAtlasVectorSearch
diff --git a/examples/mongodb/2_load_and_index.ts b/examples/mongodb/2_load_and_index.ts
index d20d0279d..340013254 100644
--- a/examples/mongodb/2_load_and_index.ts
+++ b/examples/mongodb/2_load_and_index.ts
@@ -28,12 +28,23 @@ async function loadAndIndex() {
     "full_text",
   ]);
 
+  const FILTER_METADATA_FIELD = "content_type";
+
+  documents.forEach((document, index) => {
+    const contentType = ["tweet", "post", "story"][index % 3]; // assign a random content type to each document
+    document.metadata = {
+      ...document.metadata,
+      [FILTER_METADATA_FIELD]: contentType,
+    };
+  });
+
   // create Atlas as a vector store
   const vectorStore = new MongoDBAtlasVectorSearch({
     mongodbClient: client,
     dbName: databaseName,
     collectionName: vectorCollectionName, // this is where your embeddings will be stored
     indexName: indexName, // this is the name of the index you will need to create
+    indexedMetadataFields: [FILTER_METADATA_FIELD], // this is the field that will be used for the query
   });
 
   // now create an index from all the Documents and store them in Atlas
@@ -46,5 +57,3 @@ async function loadAndIndex() {
 }
 
 loadAndIndex().catch(console.error);
-
-// you can't query your index yet because you need to create a vector search index in mongodb's UI now
diff --git a/examples/mongodb/3_query.ts b/examples/mongodb/3_query.ts
index 1064b0036..f6bf62d1d 100644
--- a/examples/mongodb/3_query.ts
+++ b/examples/mongodb/3_query.ts
@@ -14,12 +14,24 @@ async function query() {
     dbName: process.env.MONGODB_DATABASE!,
     collectionName: process.env.MONGODB_VECTORS!,
     indexName: process.env.MONGODB_VECTOR_INDEX!,
+    indexedMetadataFields: ["content_type"],
   });
 
   const index = await VectorStoreIndex.fromVectorStore(store);
 
   const retriever = index.asRetriever({ similarityTopK: 20 });
-  const queryEngine = index.asQueryEngine({ retriever });
+  const queryEngine = index.asQueryEngine({
+    retriever,
+    preFilters: {
+      filters: [
+        {
+          key: "content_type",
+          value: "story", // try "tweet" or "post" to see the difference
+          operator: "==",
+        },
+      ],
+    },
+  });
   const result = await queryEngine.query({
     query: "What does author receive when he was 11 years old?", // Isaac Asimov's "Foundation" for Christmas
   });
diff --git a/packages/llamaindex/src/storage/vectorStore/MongoDBAtlasVectorStore.ts b/packages/llamaindex/src/storage/vectorStore/MongoDBAtlasVectorStore.ts
index 1577bd996..059528960 100644
--- a/packages/llamaindex/src/storage/vectorStore/MongoDBAtlasVectorStore.ts
+++ b/packages/llamaindex/src/storage/vectorStore/MongoDBAtlasVectorStore.ts
@@ -5,7 +5,10 @@ import { getEnv } from "@llamaindex/env";
 import type { BulkWriteOptions, Collection } from "mongodb";
 import { MongoClient } from "mongodb";
 import {
+  FilterCondition,
   VectorStoreBase,
+  type FilterOperator,
+  type MetadataFilter,
   type MetadataFilters,
   type VectorStoreNoEmbedModel,
   type VectorStoreQuery,
@@ -13,15 +16,51 @@ import {
 } from "./types.js";
 import { metadataDictToNode, nodeToMetadata } from "./utils.js";
 
-// Utility function to convert metadata filters to MongoDB filter
-function toMongoDBFilter(
-  standardFilters: MetadataFilters,
-): Record<string, any> {
-  const filters: Record<string, any> = {};
-  for (const filter of standardFilters?.filters ?? []) {
-    filters[filter.key] = filter.value;
+// define your Atlas Search index. See detail https://www.mongodb.com/docs/atlas/atlas-search/field-types/knn-vector/
+const DEFAULT_EMBEDDING_DEFINITION = {
+  type: "knnVector",
+  dimensions: 1536,
+  similarity: "cosine",
+};
+
+function mapLcMqlFilterOperators(operator: string): string {
+  const operatorMap: { [key in FilterOperator]?: string } = {
+    "==": "$eq",
+    "<": "$lt",
+    "<=": "$lte",
+    ">": "$gt",
+    ">=": "$gte",
+    "!=": "$ne",
+    in: "$in",
+    nin: "$nin",
+  };
+  const mqlOperator = operatorMap[operator as FilterOperator];
+  if (!mqlOperator) throw new Error(`Unsupported operator: ${operator}`);
+  return mqlOperator;
+}
+
+function toMongoDBFilter(filters?: MetadataFilters): Record<string, any> {
+  if (!filters) return {};
+
+  const createFilterObject = (mf: MetadataFilter) => ({
+    [mf.key]: {
+      [mapLcMqlFilterOperators(mf.operator)]: mf.value,
+    },
+  });
+
+  if (filters.filters.length === 1) {
+    return createFilterObject(filters.filters[0]);
+  }
+
+  if (filters.condition === FilterCondition.AND) {
+    return { $and: filters.filters.map(createFilterObject) };
+  }
+
+  if (filters.condition === FilterCondition.OR) {
+    return { $or: filters.filters.map(createFilterObject) };
   }
-  return filters;
+
+  throw new Error("filters condition not recognized. Must be AND or OR");
 }
 
 /**
@@ -38,6 +77,8 @@ export class MongoDBAtlasVectorSearch
   dbName: string;
   collectionName: string;
   autoCreateIndex: boolean;
+  embeddingDefinition: Record<string, unknown>;
+  indexedMetadataFields: string[];
 
   /**
    * The used MongoClient. If not given, a new MongoClient is created based on the MONGODB_URI env variable.
@@ -98,26 +139,14 @@ export class MongoDBAtlasVectorSearch
   numCandidates: (query: VectorStoreQuery) => number;
   private collection?: Collection;
 
-  // define your Atlas Search index. See detail https://www.mongodb.com/docs/atlas/atlas-search/field-types/knn-vector/
-  readonly SEARCH_INDEX_DEFINITION = {
-    mappings: {
-      dynamic: true,
-      fields: {
-        embedding: {
-          type: "knnVector",
-          dimensions: 1536,
-          similarity: "cosine",
-        },
-      },
-    },
-  };
-
   constructor(
     init: Partial<MongoDBAtlasVectorSearch> & {
       dbName: string;
       collectionName: string;
       embedModel?: BaseEmbedding;
       autoCreateIndex?: boolean;
+      indexedMetadataFields?: string[];
+      embeddingDefinition?: Record<string, unknown>;
     },
   ) {
     super(init.embedModel);
@@ -136,6 +165,11 @@ export class MongoDBAtlasVectorSearch
     this.dbName = init.dbName ?? "default_db";
     this.collectionName = init.collectionName ?? "default_collection";
     this.autoCreateIndex = init.autoCreateIndex ?? true;
+    this.indexedMetadataFields = init.indexedMetadataFields ?? [];
+    this.embeddingDefinition = {
+      ...DEFAULT_EMBEDDING_DEFINITION,
+      ...(init.embeddingDefinition ?? {}),
+    };
     this.indexName = init.indexName ?? "default";
     this.embeddingKey = init.embeddingKey ?? "embedding";
     this.idKey = init.idKey ?? "id";
@@ -161,9 +195,21 @@ export class MongoDBAtlasVectorSearch
         (index) => index.name === this.indexName,
       );
       if (!indexExists) {
+        const additionalDefinition: Record<string, { type: string }> = {};
+        this.indexedMetadataFields.forEach((field) => {
+          additionalDefinition[field] = { type: "token" };
+        });
         await this.collection.createSearchIndex({
           name: this.indexName,
-          definition: this.SEARCH_INDEX_DEFINITION,
+          definition: {
+            mappings: {
+              dynamic: true,
+              fields: {
+                embedding: this.embeddingDefinition,
+                ...additionalDefinition,
+              },
+            },
+          },
         });
       }
     }
@@ -189,11 +235,18 @@ export class MongoDBAtlasVectorSearch
         this.flatMetadata,
       );
 
+      // Include the specified metadata fields in the top level of the document (to help filter)
+      const populatedMetadata: Record<string, unknown> = {};
+      for (const field of this.indexedMetadataFields) {
+        populatedMetadata[field] = metadata[field];
+      }
+
       return {
         [this.idKey]: node.id_,
         [this.embeddingKey]: node.getEmbedding(),
         [this.textKey]: node.getContent(MetadataMode.NONE) || "",
         [this.metadataKey]: metadata,
+        ...populatedMetadata,
       };
     });
 
-- 
GitLab