From 63e9846e979b79046bf4bcc490c875e6496600a0 Mon Sep 17 00:00:00 2001 From: Thuc Pham <51660321+thucpn@users.noreply.github.com> Date: Fri, 4 Oct 2024 14:24:01 +0700 Subject: [PATCH] fix: preFilters doesnot work with asQueryEngine (#1298) --- .changeset/tame-squids-clean.md | 5 ++ examples/metadata-filter/preFilters.ts | 51 +++++++++++++++++++ examples/vectorIndexCustomize.ts | 9 ++-- examples/vectorIndexFromVectorStore.ts | 4 +- .../llamaindex/src/cloud/LlamaCloudIndex.ts | 1 - .../src/engines/query/RetrieverQueryEngine.ts | 3 -- .../llamaindex/src/indices/keyword/index.ts | 1 - .../llamaindex/src/indices/summary/index.ts | 1 - .../src/indices/vectorStore/index.ts | 5 +- 9 files changed, 62 insertions(+), 18 deletions(-) create mode 100644 .changeset/tame-squids-clean.md create mode 100644 examples/metadata-filter/preFilters.ts diff --git a/.changeset/tame-squids-clean.md b/.changeset/tame-squids-clean.md new file mode 100644 index 000000000..b3263d6a7 --- /dev/null +++ b/.changeset/tame-squids-clean.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +fix: preFilters does not work with asQueryEngine diff --git a/examples/metadata-filter/preFilters.ts b/examples/metadata-filter/preFilters.ts new file mode 100644 index 000000000..9ccf4bc8c --- /dev/null +++ b/examples/metadata-filter/preFilters.ts @@ -0,0 +1,51 @@ +import { + Document, + MetadataFilters, + Settings, + SimpleDocumentStore, + VectorStoreIndex, + storageContextFromDefaults, +} from "llamaindex"; + +async function getDataSource() { + const docs = [ + new Document({ text: "The dog is brown", metadata: { dogId: "1" } }), + new Document({ text: "The dog is yellow", metadata: { dogId: "2" } }), + ]; + const storageContext = await storageContextFromDefaults({ + persistDir: "./cache", + }); + const numberOfDocs = Object.keys( + (storageContext.docStore as SimpleDocumentStore).toDict(), + ).length; + if (numberOfDocs === 0) { + return await VectorStoreIndex.fromDocuments(docs, { storageContext }); + } + return await VectorStoreIndex.init({ + storageContext, + }); +} + +Settings.callbackManager.on("retrieve-end", (event) => { + const { nodes, query } = event.detail; + console.log(`${query.query} - Number of retrieved nodes:`, nodes.length); +}); + +async function main() { + const index = await getDataSource(); + const filters: MetadataFilters = { + filters: [{ key: "dogId", value: "2", operator: "==" }], + }; + + const retriever = index.asRetriever({ similarityTopK: 3, filters }); + const queryEngine = index.asQueryEngine({ + similarityTopK: 3, + preFilters: filters, + }); + + console.log("Retriever and query engine should only retrieve 1 node:"); + await retriever.retrieve({ query: "Retriever: get dog" }); + await queryEngine.query({ query: "QueryEngine: get dog" }); +} + +void main(); diff --git a/examples/vectorIndexCustomize.ts b/examples/vectorIndexCustomize.ts index cf29d97ef..9f04c7048 100644 --- a/examples/vectorIndexCustomize.ts +++ b/examples/vectorIndexCustomize.ts @@ -25,12 +25,9 @@ async function main() { similarityCutoff: 0.7, }); // TODO: cannot pass responseSynthesizer into retriever query engine - const queryEngine = new RetrieverQueryEngine( - retriever, - undefined, - undefined, - [nodePostprocessor], - ); + const queryEngine = new RetrieverQueryEngine(retriever, undefined, [ + nodePostprocessor, + ]); const response = await queryEngine.query({ query: "What did the author do growing up?", diff --git a/examples/vectorIndexFromVectorStore.ts b/examples/vectorIndexFromVectorStore.ts index 042dab855..02da91cfa 100644 --- a/examples/vectorIndexFromVectorStore.ts +++ b/examples/vectorIndexFromVectorStore.ts @@ -165,9 +165,7 @@ async function main() { }); const responseSynthesizer = getResponseSynthesizer("tree_summarize"); - return new RetrieverQueryEngine(retriever, responseSynthesizer, { - filter, - }); + return new RetrieverQueryEngine(retriever, responseSynthesizer); }; // whatever is a key from your metadata diff --git a/packages/llamaindex/src/cloud/LlamaCloudIndex.ts b/packages/llamaindex/src/cloud/LlamaCloudIndex.ts index e3509b296..b71cf07dc 100644 --- a/packages/llamaindex/src/cloud/LlamaCloudIndex.ts +++ b/packages/llamaindex/src/cloud/LlamaCloudIndex.ts @@ -308,7 +308,6 @@ export class LlamaCloudIndex { return new RetrieverQueryEngine( retriever, params?.responseSynthesizer, - params?.preFilters, params?.nodePostprocessors, ); } diff --git a/packages/llamaindex/src/engines/query/RetrieverQueryEngine.ts b/packages/llamaindex/src/engines/query/RetrieverQueryEngine.ts index ab1906e07..2a3d105f2 100644 --- a/packages/llamaindex/src/engines/query/RetrieverQueryEngine.ts +++ b/packages/llamaindex/src/engines/query/RetrieverQueryEngine.ts @@ -14,12 +14,10 @@ export class RetrieverQueryEngine extends BaseQueryEngine { retriever: BaseRetriever; responseSynthesizer: BaseSynthesizer; nodePostprocessors: BaseNodePostprocessor[]; - preFilters?: unknown; constructor( retriever: BaseRetriever, responseSynthesizer?: BaseSynthesizer, - preFilters?: unknown, nodePostprocessors?: BaseNodePostprocessor[], ) { super(async (strOrQueryBundle, stream) => { @@ -52,7 +50,6 @@ export class RetrieverQueryEngine extends BaseQueryEngine { this.retriever = retriever; this.responseSynthesizer = responseSynthesizer || getResponseSynthesizer("compact"); - this.preFilters = preFilters; this.nodePostprocessors = nodePostprocessors || []; } diff --git a/packages/llamaindex/src/indices/keyword/index.ts b/packages/llamaindex/src/indices/keyword/index.ts index 911850616..369b52ed5 100644 --- a/packages/llamaindex/src/indices/keyword/index.ts +++ b/packages/llamaindex/src/indices/keyword/index.ts @@ -246,7 +246,6 @@ export class KeywordTableIndex extends BaseIndex<KeywordTable> { return new RetrieverQueryEngine( retriever ?? this.asRetriever(), responseSynthesizer, - options?.preFilters, options?.nodePostprocessors, ); } diff --git a/packages/llamaindex/src/indices/summary/index.ts b/packages/llamaindex/src/indices/summary/index.ts index c449a1297..a36b31978 100644 --- a/packages/llamaindex/src/indices/summary/index.ts +++ b/packages/llamaindex/src/indices/summary/index.ts @@ -189,7 +189,6 @@ export class SummaryIndex extends BaseIndex<IndexList> { return new RetrieverQueryEngine( retriever, responseSynthesizer, - options?.preFilters, options?.nodePostprocessors, ); } diff --git a/packages/llamaindex/src/indices/vectorStore/index.ts b/packages/llamaindex/src/indices/vectorStore/index.ts index d0bb647ef..e6a232b96 100644 --- a/packages/llamaindex/src/indices/vectorStore/index.ts +++ b/packages/llamaindex/src/indices/vectorStore/index.ts @@ -298,9 +298,8 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { similarityTopK, } = options ?? {}; return new RetrieverQueryEngine( - retriever ?? this.asRetriever({ similarityTopK }), + retriever ?? this.asRetriever({ similarityTopK, filters: preFilters }), responseSynthesizer, - preFilters, nodePostprocessors, ); } @@ -387,7 +386,7 @@ export type VectorIndexRetrieverOptions = { index: VectorStoreIndex; similarityTopK?: number | undefined; topK?: TopKMap | undefined; - filters?: MetadataFilters; + filters?: MetadataFilters | undefined; }; export class VectorIndexRetriever extends BaseRetriever { -- GitLab