From 6f75306c1731b06bf6654b43b9f01f9cd4f04b8a Mon Sep 17 00:00:00 2001
From: Thuc Pham <51660321+thucpn@users.noreply.github.com>
Date: Mon, 14 Oct 2024 15:31:00 +0700
Subject: [PATCH] feat: support metadata filters for Astra (#1330)

Co-authored-by: Marcus Schiesser <mail@marcusschiesser.de>
---
 .changeset/clever-hats-change.md              |  5 ++
 examples/astradb/example.ts                   |  7 ++-
 .../src/vector-store/AstraDBVectorStore.ts    | 55 ++++++++++++++++---
 3 files changed, 58 insertions(+), 9 deletions(-)
 create mode 100644 .changeset/clever-hats-change.md

diff --git a/.changeset/clever-hats-change.md b/.changeset/clever-hats-change.md
new file mode 100644
index 000000000..aa9714915
--- /dev/null
+++ b/.changeset/clever-hats-change.md
@@ -0,0 +1,5 @@
+---
+"llamaindex": patch
+---
+
+feat: support metadata filters for AstraDB
diff --git a/examples/astradb/example.ts b/examples/astradb/example.ts
index 982dbf985..e576dcacc 100644
--- a/examples/astradb/example.ts
+++ b/examples/astradb/example.ts
@@ -1,6 +1,7 @@
 import {
   AstraDBVectorStore,
   Document,
+  MetadataFilters,
   storageContextFromDefaults,
   VectorStoreIndex,
 } from "llamaindex";
@@ -42,8 +43,10 @@ async function main() {
     const index = await VectorStoreIndex.fromDocuments(docs, {
       storageContext: ctx,
     });
-
-    const queryEngine = index.asQueryEngine();
+    const preFilters: MetadataFilters = {
+      filters: [{ key: "id", operator: "in", value: [123, 789] }],
+    }; // try changing the filters to see the different results
+    const queryEngine = index.asQueryEngine({ preFilters });
     const response = await queryEngine.query({
       query: "Describe AstraDB.",
     });
diff --git a/packages/llamaindex/src/vector-store/AstraDBVectorStore.ts b/packages/llamaindex/src/vector-store/AstraDBVectorStore.ts
index 6eda34a3f..6fd3fa124 100644
--- a/packages/llamaindex/src/vector-store/AstraDBVectorStore.ts
+++ b/packages/llamaindex/src/vector-store/AstraDBVectorStore.ts
@@ -2,19 +2,29 @@ import {
   Collection,
   DataAPIClient,
   Db,
+  type Filter,
   type FindOptions,
+  type SomeDoc,
 } from "@datastax/astra-db-ts";
 import type { BaseNode } from "@llamaindex/core/schema";
 import { MetadataMode } from "@llamaindex/core/schema";
 import { getEnv } from "@llamaindex/env";
 import {
+  FilterCondition,
+  FilterOperator,
   VectorStoreBase,
   type IEmbedModel,
+  type MetadataFilter,
+  type MetadataFilters,
   type VectorStoreNoEmbedModel,
   type VectorStoreQuery,
   type VectorStoreQueryResult,
 } from "./types.js";
-import { metadataDictToNode, nodeToMetadata } from "./utils.js";
+import {
+  metadataDictToNode,
+  nodeToMetadata,
+  parseArrayValue,
+} from "./utils.js";
 
 export class AstraDBVectorStore
   extends VectorStoreBase
@@ -183,12 +193,8 @@ export class AstraDBVectorStore
     }
     const collection = this.collection;
 
-    const filters: Record<string, any> = {};
-    query.filters?.filters?.forEach((f) => {
-      filters[f.key] = f.value;
-    });
-
-    const cursor = await collection.find(filters, <FindOptions>{
+    const astraFilter = this.toAstraFilter(query.filters);
+    const cursor = await collection.find(astraFilter, <FindOptions>{
       ...options,
       sort: query.queryEmbedding
         ? { $vector: query.queryEmbedding }
@@ -230,4 +236,39 @@ export class AstraDBVectorStore
       nodes,
     };
   }
+
+  private toAstraFilter(filters?: MetadataFilters): Filter<SomeDoc> {
+    if (!filters || filters.filters?.length === 0) return {};
+    const condition = filters.condition ?? FilterCondition.AND;
+    const listFilter = filters.filters.map((f) => this.buildFilterItem(f));
+    if (condition === FilterCondition.OR) return { $or: listFilter };
+    if (condition === FilterCondition.AND) return { $and: listFilter };
+    throw new Error(`Not supported filter condition: ${condition}`);
+  }
+
+  private buildFilterItem(filter: MetadataFilter): Filter<SomeDoc> {
+    const { key, operator, value } = filter;
+    switch (operator) {
+      case FilterOperator.EQ:
+        return { [key]: value };
+      case FilterOperator.NE:
+        return { [key]: { $ne: value } };
+      case FilterOperator.GT:
+        return { [key]: { $gt: value } };
+      case FilterOperator.LT:
+        return { [key]: { $lt: value } };
+      case FilterOperator.GTE:
+        return { [key]: { $gte: value } };
+      case FilterOperator.LTE:
+        return { [key]: { $lte: value } };
+      case FilterOperator.IN:
+        return { [key]: { $in: parseArrayValue(value) } };
+      case FilterOperator.NIN:
+        return { [key]: { $nin: parseArrayValue(value) } };
+      case FilterOperator.IS_EMPTY:
+        return { [key]: { $size: 0 } };
+      default:
+        throw new Error(`Not supported filter operator: ${operator}`);
+    }
+  }
 }
-- 
GitLab