diff --git a/examples/chromadb/preFilters.ts b/examples/chromadb/preFilters.ts new file mode 100644 index 0000000000000000000000000000000000000000..df5da1c869c8131da0adbecbf52d0681fe764167 --- /dev/null +++ b/examples/chromadb/preFilters.ts @@ -0,0 +1,57 @@ +import { + ChromaVectorStore, + Document, + VectorStoreIndex, + storageContextFromDefaults, +} from "llamaindex"; + +const collectionName = "dog_colors"; + +async function main() { + try { + const docs = [ + new Document({ + text: "The dog is brown", + metadata: { + dogId: "1", + }, + }), + new Document({ + text: "The dog is red", + metadata: { + dogId: "2", + }, + }), + ]; + + console.log("Creating ChromaDB vector store"); + const chromaVS = new ChromaVectorStore({ collectionName }); + const ctx = await storageContextFromDefaults({ vectorStore: chromaVS }); + + console.log("Embedding documents and adding to index"); + const index = await VectorStoreIndex.fromDocuments(docs, { + storageContext: ctx, + }); + + console.log("Querying index"); + const queryEngine = index.asQueryEngine({ + preFilters: { + filters: [ + { + key: "dogId", + value: "2", + filterType: "ExactMatch", + }, + ], + }, + }); + const response = await queryEngine.query({ + query: "What is the color of the dog?", + }); + console.log(response.toString()); + } catch (e) { + console.error(e); + } +} + +main(); diff --git a/packages/core/src/indices/vectorStore/VectorIndexRetriever.ts b/packages/core/src/indices/vectorStore/VectorIndexRetriever.ts index 28674acf2ff2cfa4e904e4e214859470de798897..718437178e087cd721d7af3d293bf858ce271ba7 100644 --- a/packages/core/src/indices/vectorStore/VectorIndexRetriever.ts +++ b/packages/core/src/indices/vectorStore/VectorIndexRetriever.ts @@ -6,6 +6,7 @@ import { Event } from "../../callbacks/CallbackManager"; import { DEFAULT_SIMILARITY_TOP_K } from "../../constants"; import { BaseEmbedding } from "../../embeddings"; import { + MetadataFilters, VectorStoreQuery, VectorStoreQueryMode, VectorStoreQueryResult, @@ -40,7 +41,7 @@ export class VectorIndexRetriever implements BaseRetriever { async retrieve( query: string, parentEvent?: Event, - preFilters?: unknown, + preFilters?: MetadataFilters, ): Promise<NodeWithScore[]> { let nodesWithScores = await this.textRetrieve(query, preFilters); nodesWithScores = nodesWithScores.concat( @@ -52,18 +53,23 @@ export class VectorIndexRetriever implements BaseRetriever { protected async textRetrieve( query: string, - preFilters?: unknown, + preFilters?: MetadataFilters, ): Promise<NodeWithScore[]> { + const options = {}; const q = await this.buildVectorStoreQuery( this.index.embedModel, query, this.similarityTopK, + preFilters, ); - const result = await this.index.vectorStore.query(q, preFilters); + const result = await this.index.vectorStore.query(q, options); return this.buildNodeListFromQueryResult(result); } - private async textToImageRetrieve(query: string, preFilters?: unknown) { + private async textToImageRetrieve( + query: string, + preFilters?: MetadataFilters, + ) { if (!this.index.imageEmbedModel || !this.index.imageVectorStore) { // no-op if image embedding and vector store are not set return []; @@ -72,6 +78,7 @@ export class VectorIndexRetriever implements BaseRetriever { this.index.imageEmbedModel, query, this.imageSimilarityTopK, + preFilters, ); const result = await this.index.imageVectorStore.query(q, preFilters); return this.buildNodeListFromQueryResult(result); @@ -98,6 +105,7 @@ export class VectorIndexRetriever implements BaseRetriever { embedModel: BaseEmbedding, query: string, similarityTopK: number, + preFilters?: MetadataFilters, ): Promise<VectorStoreQuery> { const queryEmbedding = await embedModel.getQueryEmbedding(query); @@ -105,6 +113,7 @@ export class VectorIndexRetriever implements BaseRetriever { queryEmbedding: queryEmbedding, mode: VectorStoreQueryMode.DEFAULT, similarityTopK: similarityTopK, + filters: preFilters ?? undefined, }; } diff --git a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts index ea52780bb91e18b2330635ce396672ddb2b031f1..64c9a90e2cadef3df6d0879cc96aa35e09a5f7fa 100644 --- a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts +++ b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts @@ -20,6 +20,7 @@ import { import { BaseNodePostprocessor } from "../../postprocessors"; import { BaseIndexStore, + MetadataFilters, StorageContext, VectorStore, storageContextFromDefaults, @@ -263,7 +264,7 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> { asQueryEngine(options?: { retriever?: BaseRetriever; responseSynthesizer?: BaseSynthesizer; - preFilters?: unknown; + preFilters?: MetadataFilters; nodePostprocessors?: BaseNodePostprocessor[]; }): BaseQueryEngine { const { retriever, responseSynthesizer } = options ?? {}; diff --git a/packages/core/src/storage/vectorStore/ChromaVectorStore.ts b/packages/core/src/storage/vectorStore/ChromaVectorStore.ts index 8ad38e720338aaf0b5d8091a69de7ef6e476e2c4..cc3321babf698bacd5418b9fb9f153464d1ea6fe 100644 --- a/packages/core/src/storage/vectorStore/ChromaVectorStore.ts +++ b/packages/core/src/storage/vectorStore/ChromaVectorStore.ts @@ -107,7 +107,7 @@ export class ChromaVectorStore implements VectorStore { } const chromaWhere: { [x: string]: string | number | boolean } = {}; - if (query.filters) { + if (query.filters?.filters) { query.filters.filters.map((filter) => { const filterKey = filter.key; const filterValue = filter.value; diff --git a/packages/core/src/storage/vectorStore/MongoDBAtlasVectorStore.ts b/packages/core/src/storage/vectorStore/MongoDBAtlasVectorStore.ts index 6f2ec478ba3da60eee3dfbd3c8042af77c9217f8..434e0afe6502c1049ddcb80d8a01c032f8a25222 100644 --- a/packages/core/src/storage/vectorStore/MongoDBAtlasVectorStore.ts +++ b/packages/core/src/storage/vectorStore/MongoDBAtlasVectorStore.ts @@ -13,7 +13,7 @@ function toMongoDBFilter( standardFilters: MetadataFilters, ): Record<string, any> { const filters: Record<string, any> = {}; - for (const filter of standardFilters.filters) { + for (const filter of standardFilters?.filters ?? []) { filters[filter.key] = filter.value; } return filters;