From 63e9846e979b79046bf4bcc490c875e6496600a0 Mon Sep 17 00:00:00 2001
From: Thuc Pham <51660321+thucpn@users.noreply.github.com>
Date: Fri, 4 Oct 2024 14:24:01 +0700
Subject: [PATCH] fix: preFilters doesnot work with asQueryEngine (#1298)

---
 .changeset/tame-squids-clean.md               |  5 ++
 examples/metadata-filter/preFilters.ts        | 51 +++++++++++++++++++
 examples/vectorIndexCustomize.ts              |  9 ++--
 examples/vectorIndexFromVectorStore.ts        |  4 +-
 .../llamaindex/src/cloud/LlamaCloudIndex.ts   |  1 -
 .../src/engines/query/RetrieverQueryEngine.ts |  3 --
 .../llamaindex/src/indices/keyword/index.ts   |  1 -
 .../llamaindex/src/indices/summary/index.ts   |  1 -
 .../src/indices/vectorStore/index.ts          |  5 +-
 9 files changed, 62 insertions(+), 18 deletions(-)
 create mode 100644 .changeset/tame-squids-clean.md
 create mode 100644 examples/metadata-filter/preFilters.ts

diff --git a/.changeset/tame-squids-clean.md b/.changeset/tame-squids-clean.md
new file mode 100644
index 000000000..b3263d6a7
--- /dev/null
+++ b/.changeset/tame-squids-clean.md
@@ -0,0 +1,5 @@
+---
+"llamaindex": patch
+---
+
+fix: preFilters does not work with asQueryEngine
diff --git a/examples/metadata-filter/preFilters.ts b/examples/metadata-filter/preFilters.ts
new file mode 100644
index 000000000..9ccf4bc8c
--- /dev/null
+++ b/examples/metadata-filter/preFilters.ts
@@ -0,0 +1,51 @@
+import {
+  Document,
+  MetadataFilters,
+  Settings,
+  SimpleDocumentStore,
+  VectorStoreIndex,
+  storageContextFromDefaults,
+} from "llamaindex";
+
+async function getDataSource() {
+  const docs = [
+    new Document({ text: "The dog is brown", metadata: { dogId: "1" } }),
+    new Document({ text: "The dog is yellow", metadata: { dogId: "2" } }),
+  ];
+  const storageContext = await storageContextFromDefaults({
+    persistDir: "./cache",
+  });
+  const numberOfDocs = Object.keys(
+    (storageContext.docStore as SimpleDocumentStore).toDict(),
+  ).length;
+  if (numberOfDocs === 0) {
+    return await VectorStoreIndex.fromDocuments(docs, { storageContext });
+  }
+  return await VectorStoreIndex.init({
+    storageContext,
+  });
+}
+
+Settings.callbackManager.on("retrieve-end", (event) => {
+  const { nodes, query } = event.detail;
+  console.log(`${query.query} - Number of retrieved nodes:`, nodes.length);
+});
+
+async function main() {
+  const index = await getDataSource();
+  const filters: MetadataFilters = {
+    filters: [{ key: "dogId", value: "2", operator: "==" }],
+  };
+
+  const retriever = index.asRetriever({ similarityTopK: 3, filters });
+  const queryEngine = index.asQueryEngine({
+    similarityTopK: 3,
+    preFilters: filters,
+  });
+
+  console.log("Retriever and query engine should only retrieve 1 node:");
+  await retriever.retrieve({ query: "Retriever: get dog" });
+  await queryEngine.query({ query: "QueryEngine: get dog" });
+}
+
+void main();
diff --git a/examples/vectorIndexCustomize.ts b/examples/vectorIndexCustomize.ts
index cf29d97ef..9f04c7048 100644
--- a/examples/vectorIndexCustomize.ts
+++ b/examples/vectorIndexCustomize.ts
@@ -25,12 +25,9 @@ async function main() {
     similarityCutoff: 0.7,
   });
   // TODO: cannot pass responseSynthesizer into retriever query engine
-  const queryEngine = new RetrieverQueryEngine(
-    retriever,
-    undefined,
-    undefined,
-    [nodePostprocessor],
-  );
+  const queryEngine = new RetrieverQueryEngine(retriever, undefined, [
+    nodePostprocessor,
+  ]);
 
   const response = await queryEngine.query({
     query: "What did the author do growing up?",
diff --git a/examples/vectorIndexFromVectorStore.ts b/examples/vectorIndexFromVectorStore.ts
index 042dab855..02da91cfa 100644
--- a/examples/vectorIndexFromVectorStore.ts
+++ b/examples/vectorIndexFromVectorStore.ts
@@ -165,9 +165,7 @@ async function main() {
     });
 
     const responseSynthesizer = getResponseSynthesizer("tree_summarize");
-    return new RetrieverQueryEngine(retriever, responseSynthesizer, {
-      filter,
-    });
+    return new RetrieverQueryEngine(retriever, responseSynthesizer);
   };
 
   // whatever is a key from your metadata
diff --git a/packages/llamaindex/src/cloud/LlamaCloudIndex.ts b/packages/llamaindex/src/cloud/LlamaCloudIndex.ts
index e3509b296..b71cf07dc 100644
--- a/packages/llamaindex/src/cloud/LlamaCloudIndex.ts
+++ b/packages/llamaindex/src/cloud/LlamaCloudIndex.ts
@@ -308,7 +308,6 @@ export class LlamaCloudIndex {
     return new RetrieverQueryEngine(
       retriever,
       params?.responseSynthesizer,
-      params?.preFilters,
       params?.nodePostprocessors,
     );
   }
diff --git a/packages/llamaindex/src/engines/query/RetrieverQueryEngine.ts b/packages/llamaindex/src/engines/query/RetrieverQueryEngine.ts
index ab1906e07..2a3d105f2 100644
--- a/packages/llamaindex/src/engines/query/RetrieverQueryEngine.ts
+++ b/packages/llamaindex/src/engines/query/RetrieverQueryEngine.ts
@@ -14,12 +14,10 @@ export class RetrieverQueryEngine extends BaseQueryEngine {
   retriever: BaseRetriever;
   responseSynthesizer: BaseSynthesizer;
   nodePostprocessors: BaseNodePostprocessor[];
-  preFilters?: unknown;
 
   constructor(
     retriever: BaseRetriever,
     responseSynthesizer?: BaseSynthesizer,
-    preFilters?: unknown,
     nodePostprocessors?: BaseNodePostprocessor[],
   ) {
     super(async (strOrQueryBundle, stream) => {
@@ -52,7 +50,6 @@ export class RetrieverQueryEngine extends BaseQueryEngine {
     this.retriever = retriever;
     this.responseSynthesizer =
       responseSynthesizer || getResponseSynthesizer("compact");
-    this.preFilters = preFilters;
     this.nodePostprocessors = nodePostprocessors || [];
   }
 
diff --git a/packages/llamaindex/src/indices/keyword/index.ts b/packages/llamaindex/src/indices/keyword/index.ts
index 911850616..369b52ed5 100644
--- a/packages/llamaindex/src/indices/keyword/index.ts
+++ b/packages/llamaindex/src/indices/keyword/index.ts
@@ -246,7 +246,6 @@ export class KeywordTableIndex extends BaseIndex<KeywordTable> {
     return new RetrieverQueryEngine(
       retriever ?? this.asRetriever(),
       responseSynthesizer,
-      options?.preFilters,
       options?.nodePostprocessors,
     );
   }
diff --git a/packages/llamaindex/src/indices/summary/index.ts b/packages/llamaindex/src/indices/summary/index.ts
index c449a1297..a36b31978 100644
--- a/packages/llamaindex/src/indices/summary/index.ts
+++ b/packages/llamaindex/src/indices/summary/index.ts
@@ -189,7 +189,6 @@ export class SummaryIndex extends BaseIndex<IndexList> {
     return new RetrieverQueryEngine(
       retriever,
       responseSynthesizer,
-      options?.preFilters,
       options?.nodePostprocessors,
     );
   }
diff --git a/packages/llamaindex/src/indices/vectorStore/index.ts b/packages/llamaindex/src/indices/vectorStore/index.ts
index d0bb647ef..e6a232b96 100644
--- a/packages/llamaindex/src/indices/vectorStore/index.ts
+++ b/packages/llamaindex/src/indices/vectorStore/index.ts
@@ -298,9 +298,8 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> {
       similarityTopK,
     } = options ?? {};
     return new RetrieverQueryEngine(
-      retriever ?? this.asRetriever({ similarityTopK }),
+      retriever ?? this.asRetriever({ similarityTopK, filters: preFilters }),
       responseSynthesizer,
-      preFilters,
       nodePostprocessors,
     );
   }
@@ -387,7 +386,7 @@ export type VectorIndexRetrieverOptions = {
   index: VectorStoreIndex;
   similarityTopK?: number | undefined;
   topK?: TopKMap | undefined;
-  filters?: MetadataFilters;
+  filters?: MetadataFilters | undefined;
 };
 
 export class VectorIndexRetriever extends BaseRetriever {
-- 
GitLab