From b974eea3413c1478bafb783d03acd950bb4a6adb Mon Sep 17 00:00:00 2001 From: Thuc Pham <51660321+thucpn@users.noreply.github.com> Date: Wed, 17 Jul 2024 16:21:21 +0700 Subject: [PATCH] feat: add MetadataFilter for SimpleVectorStore and Milvus (#1030) Co-authored-by: Marcus Schiesser <mail@marcusschiesser.de> --- .changeset/famous-poets-hammer.md | 7 + .../query_engines/metadata_filtering.md | 4 +- examples/chromadb/preFilters.ts | 2 +- examples/metadata-filter/milvus.ts | 40 +++ examples/metadata-filter/simple.ts | 143 ++++++++ examples/qdrantdb/preFilters.ts | 2 +- .../storage/vectorStore/MilvusVectorStore.ts | 64 +++- .../src/storage/vectorStore/PGVectorStore.ts | 5 +- .../vectorStore/PineconeVectorStore.ts | 10 +- .../storage/vectorStore/SimpleVectorStore.ts | 134 ++++++- .../src/storage/vectorStore/types.ts | 37 +- .../src/storage/vectorStore/utils.ts | 23 ++ .../tests/mocks/TestableMilvusVectorStore.ts | 24 ++ .../vectorStores/MilvusVectorStore.test.ts | 333 ++++++++++++++++++ .../vectorStores/SimpleVectorStore.test.ts | 299 ++++++++++++++++ 15 files changed, 1091 insertions(+), 36 deletions(-) create mode 100644 .changeset/famous-poets-hammer.md create mode 100644 examples/metadata-filter/milvus.ts create mode 100644 examples/metadata-filter/simple.ts create mode 100644 packages/llamaindex/tests/mocks/TestableMilvusVectorStore.ts create mode 100644 packages/llamaindex/tests/vectorStores/MilvusVectorStore.test.ts create mode 100644 packages/llamaindex/tests/vectorStores/SimpleVectorStore.test.ts diff --git a/.changeset/famous-poets-hammer.md b/.changeset/famous-poets-hammer.md new file mode 100644 index 000000000..9207213e8 --- /dev/null +++ b/.changeset/famous-poets-hammer.md @@ -0,0 +1,7 @@ +--- +"llamaindex": patch +"@llamaindex/llamaindex-test": patch +"@llamaindex/core": patch +--- + +Add support for Metadata filters diff --git a/apps/docs/docs/modules/query_engines/metadata_filtering.md b/apps/docs/docs/modules/query_engines/metadata_filtering.md index 04b9d7c6c..80d7c87b0 100644 --- a/apps/docs/docs/modules/query_engines/metadata_filtering.md +++ b/apps/docs/docs/modules/query_engines/metadata_filtering.md @@ -75,7 +75,7 @@ const queryEngine = index.asQueryEngine({ { key: "dogId", value: "2", - filterType: "ExactMatch", + operator: "==", }, ], }, @@ -135,7 +135,7 @@ async function main() { { key: "dogId", value: "2", - filterType: "ExactMatch", + operator: "==", }, ], }, diff --git a/examples/chromadb/preFilters.ts b/examples/chromadb/preFilters.ts index 5265e201d..3c1e0abb6 100644 --- a/examples/chromadb/preFilters.ts +++ b/examples/chromadb/preFilters.ts @@ -40,7 +40,7 @@ async function main() { { key: "dogId", value: "2", - filterType: "ExactMatch", + operator: "==", }, ], }, diff --git a/examples/metadata-filter/milvus.ts b/examples/metadata-filter/milvus.ts new file mode 100644 index 000000000..9415bca57 --- /dev/null +++ b/examples/metadata-filter/milvus.ts @@ -0,0 +1,40 @@ +import { MilvusVectorStore, VectorStoreIndex } from "llamaindex"; + +const collectionName = "movie_reviews"; + +async function main() { + try { + const milvus = new MilvusVectorStore({ collection: collectionName }); + const index = await VectorStoreIndex.fromVectorStore(milvus); + const retriever = index.asRetriever({ similarityTopK: 20 }); + + console.log("\n=====\nQuerying the index with filters"); + const queryEngineWithFilters = index.asQueryEngine({ + retriever, + preFilters: { + filters: [ + { + key: "document_id", + value: "./data/movie_reviews.csv_37", + operator: "==", + }, + { + key: "document_id", + value: "./data/movie_reviews.csv_37", + operator: "!=", + }, + ], + condition: "or", + }, + }); + const resultAfterFilter = await queryEngineWithFilters.query({ + query: "Get all movie titles.", + }); + console.log(`Query from ${resultAfterFilter.sourceNodes?.length} nodes`); + console.log(resultAfterFilter.response); + } catch (e) { + console.error(e); + } +} + +void main(); diff --git a/examples/metadata-filter/simple.ts b/examples/metadata-filter/simple.ts new file mode 100644 index 000000000..245e48c3d --- /dev/null +++ b/examples/metadata-filter/simple.ts @@ -0,0 +1,143 @@ +import { + Document, + Settings, + SimpleDocumentStore, + VectorStoreIndex, + storageContextFromDefaults, +} from "llamaindex"; + +Settings.callbackManager.on("retrieve-end", (event) => { + const { nodes } = event.detail; + console.log("Number of retrieved nodes:", nodes.length); +}); + +async function getDataSource() { + const docs = [ + new Document({ + text: "The dog is brown", + metadata: { + dogId: "1", + private: true, + }, + }), + new Document({ + text: "The dog is yellow", + metadata: { + dogId: "2", + private: false, + }, + }), + new Document({ + text: "The dog is red", + metadata: { + dogId: "3", + private: false, + }, + }), + ]; + const storageContext = await storageContextFromDefaults({ + persistDir: "./cache", + }); + const numberOfDocs = Object.keys( + (storageContext.docStore as SimpleDocumentStore).toDict(), + ).length; + if (numberOfDocs === 0) { + // Generate the data source if it's empty + return await VectorStoreIndex.fromDocuments(docs, { + storageContext, + }); + } + return await VectorStoreIndex.init({ + storageContext, + }); +} + +async function main() { + const index = await getDataSource(); + console.log( + "=============\nQuerying index with no filters. The output should be any color.", + ); + const queryEngineNoFilters = index.asQueryEngine({ + similarityTopK: 3, + }); + const noFilterResponse = await queryEngineNoFilters.query({ + query: "What is the color of the dog?", + }); + console.log("No filter response:", noFilterResponse.toString()); + + console.log( + "\n=============\nQuerying index with dogId 2 and private false. The output always should be red.", + ); + const queryEngineEQ = index.asQueryEngine({ + preFilters: { + filters: [ + { + key: "private", + value: "false", + operator: "==", + }, + { + key: "dogId", + value: "3", + operator: "==", + }, + ], + }, + similarityTopK: 3, + }); + const responseEQ = await queryEngineEQ.query({ + query: "What is the color of the dog?", + }); + console.log("Filter with dogId 2 response:", responseEQ.toString()); + + console.log( + "\n=============\nQuerying index with dogId IN (1, 3). The output should be brown and red.", + ); + const queryEngineIN = index.asQueryEngine({ + preFilters: { + filters: [ + { + key: "dogId", + value: ["1", "3"], + operator: "in", + }, + ], + }, + similarityTopK: 3, + }); + const responseIN = await queryEngineIN.query({ + query: "What is the color of the dog?", + }); + console.log("Filter with dogId IN (1, 3) response:", responseIN.toString()); + + console.log( + "\n=============\nQuerying index with dogId IN (1, 3). The output should be any.", + ); + const queryEngineOR = index.asQueryEngine({ + preFilters: { + filters: [ + { + key: "private", + value: "false", + operator: "==", + }, + { + key: "dogId", + value: ["1", "3"], + operator: "in", + }, + ], + condition: "or", + }, + similarityTopK: 3, + }); + const responseOR = await queryEngineOR.query({ + query: "What is the color of the dog?", + }); + console.log( + "Filter with dogId with OR operator response:", + responseOR.toString(), + ); +} + +void main(); diff --git a/examples/qdrantdb/preFilters.ts b/examples/qdrantdb/preFilters.ts index d60751d9f..2047ea517 100644 --- a/examples/qdrantdb/preFilters.ts +++ b/examples/qdrantdb/preFilters.ts @@ -64,7 +64,7 @@ async function main() { { key: "dogId", value: "2", - filterType: "ExactMatch", + operator: "==", }, ], }, diff --git a/packages/llamaindex/src/storage/vectorStore/MilvusVectorStore.ts b/packages/llamaindex/src/storage/vectorStore/MilvusVectorStore.ts index 42b7df568..15c3fe6e9 100644 --- a/packages/llamaindex/src/storage/vectorStore/MilvusVectorStore.ts +++ b/packages/llamaindex/src/storage/vectorStore/MilvusVectorStore.ts @@ -11,11 +11,66 @@ import { import { VectorStoreBase, type IEmbedModel, + type MetadataFilters, type VectorStoreNoEmbedModel, type VectorStoreQuery, type VectorStoreQueryResult, } from "./types.js"; -import { metadataDictToNode, nodeToMetadata } from "./utils.js"; +import { + metadataDictToNode, + nodeToMetadata, + parseArrayValue, + parseNumberValue, + parsePrimitiveValue, +} from "./utils.js"; + +function parseScalarFilters(scalarFilters: MetadataFilters): string { + const condition = scalarFilters.condition ?? "and"; + const filters: string[] = []; + + for (const filter of scalarFilters.filters) { + switch (filter.operator) { + case "==": + case "!=": { + filters.push( + `metadata["${filter.key}"] ${filter.operator} "${parsePrimitiveValue(filter.value)}"`, + ); + break; + } + case "in": { + const filterValue = parseArrayValue(filter.value) + .map((v) => `"${v}"`) + .join(", "); + filters.push( + `metadata["${filter.key}"] ${filter.operator} [${filterValue}]`, + ); + break; + } + case "nin": { + // Milvus does not support `nin` operator, so we need to manually check every value + // Expected: not metadata["key"] != "value1" and not metadata["key"] != "value2" + const filterStr = parseArrayValue(filter.value) + .map((v) => `metadata["${filter.key}"] != "${v}"`) + .join(" && "); + filters.push(filterStr); + break; + } + case "<": + case "<=": + case ">": + case ">=": { + filters.push( + `metadata["${filter.key}"] ${filter.operator} ${parseNumberValue(filter.value)}`, + ); + break; + } + default: + throw new Error(`Operator ${filter.operator} is not supported.`); + } + } + + return filters.join(` ${condition} `); +} export class MilvusVectorStore extends VectorStoreBase @@ -183,6 +238,12 @@ export class MilvusVectorStore }); } + public toMilvusFilter(filters?: MetadataFilters): string | undefined { + if (!filters) return undefined; + // TODO: Milvus also support standard filters, we can add it later + return parseScalarFilters(filters); + } + public async query( query: VectorStoreQuery, _options?: any, @@ -193,6 +254,7 @@ export class MilvusVectorStore collection_name: this.collectionName, limit: query.similarityTopK, vector: query.queryEmbedding, + filter: this.toMilvusFilter(query.filters), }); const nodes: BaseNode<Metadata>[] = []; diff --git a/packages/llamaindex/src/storage/vectorStore/PGVectorStore.ts b/packages/llamaindex/src/storage/vectorStore/PGVectorStore.ts index e54c64aa2..468f5966d 100644 --- a/packages/llamaindex/src/storage/vectorStore/PGVectorStore.ts +++ b/packages/llamaindex/src/storage/vectorStore/PGVectorStore.ts @@ -272,7 +272,10 @@ export class PGVectorStore query.filters?.filters.forEach((filter, index) => { const paramIndex = params.length + 1; whereClauses.push(`metadata->>'${filter.key}' = $${paramIndex}`); - params.push(filter.value); + // TODO: support filter with other operators + if (!Array.isArray(filter.value)) { + params.push(filter.value); + } }); const where = diff --git a/packages/llamaindex/src/storage/vectorStore/PineconeVectorStore.ts b/packages/llamaindex/src/storage/vectorStore/PineconeVectorStore.ts index 81d1efa6c..50a2a6241 100644 --- a/packages/llamaindex/src/storage/vectorStore/PineconeVectorStore.ts +++ b/packages/llamaindex/src/storage/vectorStore/PineconeVectorStore.ts @@ -1,7 +1,7 @@ import { VectorStoreBase, - type ExactMatchFilter, type IEmbedModel, + type MetadataFilter, type MetadataFilters, type VectorStoreNoEmbedModel, type VectorStoreQuery, @@ -199,8 +199,12 @@ export class PineconeVectorStore } toPineconeFilter(stdFilters?: MetadataFilters) { - return stdFilters?.filters?.reduce((carry: any, item: ExactMatchFilter) => { - carry[item.key] = item.value; + return stdFilters?.filters?.reduce((carry: any, item: MetadataFilter) => { + // Use MetadataFilter with EQ operator to replace ExactMatchFilter + // TODO: support filter with other operators + if (item.operator === "==") { + carry[item.key] = item.value; + } return carry; }, {}); } diff --git a/packages/llamaindex/src/storage/vectorStore/SimpleVectorStore.ts b/packages/llamaindex/src/storage/vectorStore/SimpleVectorStore.ts index 4f927075d..6d0847f3e 100644 --- a/packages/llamaindex/src/storage/vectorStore/SimpleVectorStore.ts +++ b/packages/llamaindex/src/storage/vectorStore/SimpleVectorStore.ts @@ -8,13 +8,22 @@ import { import { exists } from "../FileSystem.js"; import { DEFAULT_PERSIST_DIR } from "../constants.js"; import { + FilterOperator, VectorStoreBase, VectorStoreQueryMode, type IEmbedModel, + type MetadataFilter, + type MetadataFilters, type VectorStoreNoEmbedModel, type VectorStoreQuery, type VectorStoreQueryResult, } from "./types.js"; +import { + nodeToMetadata, + parseArrayValue, + parseNumberValue, + parsePrimitiveValue, +} from "./utils.js"; const LEARNER_MODES = new Set<VectorStoreQueryMode>([ VectorStoreQueryMode.SVM, @@ -24,9 +33,85 @@ const LEARNER_MODES = new Set<VectorStoreQueryMode>([ const MMR_MODE = VectorStoreQueryMode.MMR; +type MetadataValue = Record<string, any>; + +// Mapping of filter operators to metadata filter functions +const OPERATOR_TO_FILTER: { + [key in FilterOperator]: ( + { key, value }: MetadataFilter, + metadata: MetadataValue, + ) => boolean; +} = { + [FilterOperator.EQ]: ({ key, value }, metadata) => { + return parsePrimitiveValue(metadata[key]) === parsePrimitiveValue(value); + }, + [FilterOperator.NE]: ({ key, value }, metadata) => { + return parsePrimitiveValue(metadata[key]) !== parsePrimitiveValue(value); + }, + [FilterOperator.IN]: ({ key, value }, metadata) => { + return parseArrayValue(value).includes(parsePrimitiveValue(metadata[key])); + }, + [FilterOperator.NIN]: ({ key, value }, metadata) => { + return !parseArrayValue(value).includes(parsePrimitiveValue(metadata[key])); + }, + [FilterOperator.ANY]: ({ key, value }, metadata) => { + return parseArrayValue(value).some((v) => + parseArrayValue(metadata[key]).includes(v), + ); + }, + [FilterOperator.ALL]: ({ key, value }, metadata) => { + return parseArrayValue(value).every((v) => + parseArrayValue(metadata[key]).includes(v), + ); + }, + [FilterOperator.TEXT_MATCH]: ({ key, value }, metadata) => { + return parsePrimitiveValue(metadata[key]).includes( + parsePrimitiveValue(value), + ); + }, + [FilterOperator.CONTAINS]: ({ key, value }, metadata) => { + return parseArrayValue(metadata[key]).includes(parsePrimitiveValue(value)); + }, + [FilterOperator.GT]: ({ key, value }, metadata) => { + return parseNumberValue(metadata[key]) > parseNumberValue(value); + }, + [FilterOperator.LT]: ({ key, value }, metadata) => { + return parseNumberValue(metadata[key]) < parseNumberValue(value); + }, + [FilterOperator.GTE]: ({ key, value }, metadata) => { + return parseNumberValue(metadata[key]) >= parseNumberValue(value); + }, + [FilterOperator.LTE]: ({ key, value }, metadata) => { + return parseNumberValue(metadata[key]) <= parseNumberValue(value); + }, +}; + +// Build a filter function based on the metadata and the preFilters +const buildFilterFn = ( + metadata: MetadataValue | undefined, + preFilters: MetadataFilters | undefined, +) => { + if (!preFilters) return true; + if (!metadata) return false; + + const { filters, condition } = preFilters; + const queryCondition = condition || "and"; // default to and + + const itemFilterFn = (filter: MetadataFilter) => { + const metadataLookupFn = OPERATOR_TO_FILTER[filter.operator]; + if (!metadataLookupFn) + throw new Error(`Unsupported operator: ${filter.operator}`); + return metadataLookupFn(filter, metadata); + }; + + if (queryCondition === "and") return filters.every(itemFilterFn); + return filters.some(itemFilterFn); +}; + class SimpleVectorStoreData { embeddingDict: Record<string, number[]> = {}; textIdToRefDocId: Record<string, string> = {}; + metadataDict: Record<string, MetadataValue> = {}; } export class SimpleVectorStore @@ -67,6 +152,11 @@ export class SimpleVectorStore } this.data.textIdToRefDocId[node.id_] = node.sourceNode?.nodeId; + + // Add metadata to the metadataDict + const metadata = nodeToMetadata(node, true, undefined, false); + delete metadata["_node_content"]; + this.data.metadataDict[node.id_] = metadata; } if (this.persistPath) { @@ -83,6 +173,7 @@ export class SimpleVectorStore for (const textId of textIdsToDelete) { delete this.data.embeddingDict[textId]; delete this.data.textIdToRefDocId[textId]; + if (this.data.metadataDict) delete this.data.metadataDict[textId]; } if (this.persistPath) { await this.persist(this.persistPath); @@ -90,27 +181,33 @@ export class SimpleVectorStore return Promise.resolve(); } - async query(query: VectorStoreQuery): Promise<VectorStoreQueryResult> { - if (!(query.filters == null)) { - throw new Error( - "Metadata filters not implemented for SimpleVectorStore yet.", - ); - } - + private async filterNodes(query: VectorStoreQuery): Promise<{ + nodeIds: string[]; + embeddings: number[][]; + }> { const items = Object.entries(this.data.embeddingDict); + const queryFilterFn = (nodeId: string) => { + const metadata = this.data.metadataDict[nodeId]; + return buildFilterFn(metadata, query.filters); + }; - let nodeIds: string[], embeddings: number[][]; - if (query.docIds) { + const nodeFilterFn = (nodeId: string) => { + if (!query.docIds) return true; const availableIds = new Set(query.docIds); - const queriedItems = items.filter((item) => availableIds.has(item[0])); - nodeIds = queriedItems.map((item) => item[0]); - embeddings = queriedItems.map((item) => item[1]); - } else { - // No docIds specified, so use all available items - nodeIds = items.map((item) => item[0]); - embeddings = items.map((item) => item[1]); - } + return availableIds.has(nodeId); + }; + + const queriedItems = items.filter( + (item) => nodeFilterFn(item[0]) && queryFilterFn(item[0]), + ); + const nodeIds = queriedItems.map((item) => item[0]); + const embeddings = queriedItems.map((item) => item[1]); + return { nodeIds, embeddings }; + } + + async query(query: VectorStoreQuery): Promise<VectorStoreQueryResult> { + const { nodeIds, embeddings } = await this.filterNodes(query); const queryEmbedding = query.queryEmbedding!; let topSimilarities: number[], topIds: string[]; @@ -191,6 +288,7 @@ export class SimpleVectorStore const data = new SimpleVectorStoreData(); data.embeddingDict = dataDict.embeddingDict ?? {}; data.textIdToRefDocId = dataDict.textIdToRefDocId ?? {}; + data.metadataDict = dataDict.metadataDict ?? {}; const store = new SimpleVectorStore({ data, embedModel }); store.persistPath = persistPath; return store; @@ -203,6 +301,7 @@ export class SimpleVectorStore const data = new SimpleVectorStoreData(); data.embeddingDict = saveDict.embeddingDict; data.textIdToRefDocId = saveDict.textIdToRefDocId; + data.metadataDict = saveDict.metadataDict; return new SimpleVectorStore({ data, embedModel }); } @@ -210,6 +309,7 @@ export class SimpleVectorStore return { embeddingDict: this.data.embeddingDict, textIdToRefDocId: this.data.textIdToRefDocId, + metadataDict: this.data.metadataDict, }; } } diff --git a/packages/llamaindex/src/storage/vectorStore/types.ts b/packages/llamaindex/src/storage/vectorStore/types.ts index 1862631f9..c13378687 100644 --- a/packages/llamaindex/src/storage/vectorStore/types.ts +++ b/packages/llamaindex/src/storage/vectorStore/types.ts @@ -20,20 +20,37 @@ export enum VectorStoreQueryMode { MMR = "mmr", } -export interface ExactMatchFilter { - filterType: "ExactMatch"; - key: string; - value: string | number; +export enum FilterOperator { + EQ = "==", // default operator (string, number) + IN = "in", // In array (string or number) + GT = ">", // greater than (number) + LT = "<", // less than (number) + NE = "!=", // not equal to (string, number) + GTE = ">=", // greater than or equal to (number) + LTE = "<=", // less than or equal to (number) + NIN = "nin", // Not in array (string or number) + ANY = "any", // Contains any (array of strings) + ALL = "all", // Contains all (array of strings) + TEXT_MATCH = "text_match", // full text match (allows you to search for a specific substring, token or phrase within the text field) + CONTAINS = "contains", // metadata array contains value (string or number) } -export interface MetadataFilters { - filters: ExactMatchFilter[]; +export enum FilterCondition { + AND = "and", + OR = "or", } -export interface VectorStoreQuerySpec { - query: string; - filters: ExactMatchFilter[]; - topK?: number; +export type MetadataFilterValue = string | number | string[] | number[]; + +export interface MetadataFilter { + key: string; + value: MetadataFilterValue; + operator: `${FilterOperator}`; // ==, any, all,... +} + +export interface MetadataFilters { + filters: Array<MetadataFilter>; + condition?: `${FilterCondition}`; // and, or } export interface MetadataInfo { diff --git a/packages/llamaindex/src/storage/vectorStore/utils.ts b/packages/llamaindex/src/storage/vectorStore/utils.ts index 1a00dee27..8bdc39474 100644 --- a/packages/llamaindex/src/storage/vectorStore/utils.ts +++ b/packages/llamaindex/src/storage/vectorStore/utils.ts @@ -1,5 +1,6 @@ import type { BaseNode, Metadata } from "@llamaindex/core/schema"; import { ObjectType, jsonToNode } from "@llamaindex/core/schema"; +import type { MetadataFilterValue } from "./types.js"; const DEFAULT_TEXT_KEY = "text"; @@ -77,3 +78,25 @@ export function metadataDictToNode( return jsonToNode(nodeObj, ObjectType.TEXT); } } + +export const parseNumberValue = (value: MetadataFilterValue): number => { + if (typeof value !== "number") throw new Error("Value must be a number"); + return value; +}; + +export const parsePrimitiveValue = (value: MetadataFilterValue): string => { + if (typeof value !== "number" && typeof value !== "string") { + throw new Error("Value must be a string or number"); + } + return value.toString(); +}; + +export const parseArrayValue = (value: MetadataFilterValue): string[] => { + const isPrimitiveArray = + Array.isArray(value) && + value.every((v) => typeof v === "string" || typeof v === "number"); + if (!isPrimitiveArray) { + throw new Error("Value must be an array of strings or numbers"); + } + return value.map(String); +}; diff --git a/packages/llamaindex/tests/mocks/TestableMilvusVectorStore.ts b/packages/llamaindex/tests/mocks/TestableMilvusVectorStore.ts new file mode 100644 index 000000000..f5665c7af --- /dev/null +++ b/packages/llamaindex/tests/mocks/TestableMilvusVectorStore.ts @@ -0,0 +1,24 @@ +import type { BaseNode } from "@llamaindex/core/schema"; +import type { MilvusClient } from "@zilliz/milvus2-sdk-node"; +import { MilvusVectorStore } from "llamaindex"; +import { type Mocked } from "vitest"; + +export class TestableMilvusVectorStore extends MilvusVectorStore { + public nodes: BaseNode[] = []; + + private fakeTimeout = (ms: number) => { + return new Promise((resolve) => setTimeout(resolve, ms)); + }; + + public async add(nodes: BaseNode[]): Promise<string[]> { + this.nodes.push(...nodes); + await this.fakeTimeout(100); + return nodes.map((node) => node.id_); + } + + constructor() { + super({ + milvusClient: {} as Mocked<MilvusClient>, + }); + } +} diff --git a/packages/llamaindex/tests/vectorStores/MilvusVectorStore.test.ts b/packages/llamaindex/tests/vectorStores/MilvusVectorStore.test.ts new file mode 100644 index 000000000..7c2c5e50a --- /dev/null +++ b/packages/llamaindex/tests/vectorStores/MilvusVectorStore.test.ts @@ -0,0 +1,333 @@ +import type { BaseNode } from "@llamaindex/core/schema"; +import { TextNode } from "@llamaindex/core/schema"; +import { + MilvusVectorStore, + VectorStoreQueryMode, + type MetadataFilters, +} from "llamaindex"; +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { TestableMilvusVectorStore } from "../mocks/TestableMilvusVectorStore.js"; + +type FilterTestCase = { + title: string; + filters?: MetadataFilters; + expected: number; + expectedFilterStr: string | undefined; + mockResultIds: string[]; +}; + +describe("MilvusVectorStore", () => { + let store: MilvusVectorStore; + let nodes: BaseNode[]; + + beforeEach(() => { + store = new TestableMilvusVectorStore(); + nodes = [ + new TextNode({ + id_: "1", + embedding: [0.1, 0.2], + text: "The dog is brown", + metadata: { + name: "Anakin", + dogId: "1", + private: "true", + weight: 1.2, + type: ["husky", "puppy"], + }, + }), + new TextNode({ + id_: "2", + embedding: [0.1, 0.2], + text: "The dog is yellow", + metadata: { + name: "Luke", + dogId: "2", + private: "false", + weight: 2.3, + type: ["puppy"], + }, + }), + new TextNode({ + id_: "3", + embedding: [0.1, 0.2], + text: "The dog is red", + metadata: { + name: "Leia", + dogId: "3", + private: "false", + weight: 3.4, + type: ["husky"], + }, + }), + ]; + }); + + describe("[MilvusVectorStore] manage nodes", () => { + it("able to add nodes to store", async () => { + const ids = await store.add(nodes); + expect(ids).length(3); + }); + }); + + describe("[MilvusVectorStore] filter nodes with supported operators", () => { + const testcases: FilterTestCase[] = [ + { + title: "No filter", + expected: 3, + mockResultIds: ["1", "2", "3"], + expectedFilterStr: undefined, + }, + { + title: "Filter EQ", + filters: { + filters: [ + { + key: "private", + value: "false", + operator: "==", + }, + ], + }, + expected: 2, + mockResultIds: ["2", "3"], + expectedFilterStr: 'metadata["private"] == "false"', + }, + { + title: "Filter NE", + filters: { + filters: [ + { + key: "private", + value: "false", + operator: "!=", + }, + ], + }, + expected: 1, + mockResultIds: ["1"], + expectedFilterStr: 'metadata["private"] != "false"', + }, + { + title: "Filter GT", + filters: { + filters: [ + { + key: "weight", + value: 2.3, + operator: ">", + }, + ], + }, + expected: 1, + mockResultIds: ["3"], + expectedFilterStr: 'metadata["weight"] > 2.3', + }, + { + title: "Filter GTE", + filters: { + filters: [ + { + key: "weight", + value: 2.3, + operator: ">=", + }, + ], + }, + expected: 2, + mockResultIds: ["2", "3"], + expectedFilterStr: 'metadata["weight"] >= 2.3', + }, + { + title: "Filter LT", + filters: { + filters: [ + { + key: "weight", + value: 2.3, + operator: "<", + }, + ], + }, + expected: 1, + mockResultIds: ["1"], + expectedFilterStr: 'metadata["weight"] < 2.3', + }, + { + title: "Filter LTE", + filters: { + filters: [ + { + key: "weight", + value: 2.3, + operator: "<=", + }, + ], + }, + expected: 2, + mockResultIds: ["1", "2"], + expectedFilterStr: 'metadata["weight"] <= 2.3', + }, + { + title: "Filter IN", + filters: { + filters: [ + { + key: "dogId", + value: ["1", "3"], + operator: "in", + }, + ], + }, + expected: 2, + mockResultIds: ["1", "3"], + expectedFilterStr: 'metadata["dogId"] in ["1", "3"]', + }, + { + title: "Filter NIN", + filters: { + filters: [ + { + key: "name", + value: ["Anakin", "Leia"], + operator: "nin", + }, + ], + }, + expected: 1, + mockResultIds: ["2"], + expectedFilterStr: + 'metadata["name"] != "Anakin" && metadata["name"] != "Leia"', + }, + { + title: "Filter OR", + filters: { + filters: [ + { + key: "private", + value: "false", + operator: "==", + }, + { + key: "dogId", + value: ["1", "3"], + operator: "in", + }, + ], + condition: "or", + }, + expected: 3, + mockResultIds: ["1", "2", "3"], + expectedFilterStr: + 'metadata["private"] == "false" or metadata["dogId"] in ["1", "3"]', + }, + { + title: "Filter AND", + filters: { + filters: [ + { + key: "private", + value: "false", + operator: "==", + }, + { + key: "dogId", + value: "10", + operator: "==", + }, + ], + condition: "and", + }, + expected: 0, + mockResultIds: [], + expectedFilterStr: + 'metadata["private"] == "false" and metadata["dogId"] == "10"', + }, + ]; + + testcases.forEach((tc) => { + it(`[${tc.title}] should return ${tc.expected} nodes`, async () => { + expect(store.toMilvusFilter(tc.filters)).toBe(tc.expectedFilterStr); + + vi.spyOn(store, "query").mockResolvedValue({ + ids: tc.mockResultIds, + similarities: [0.1, 0.2, 0.3], + }); + + await store.add(nodes); + const result = await store.query({ + queryEmbedding: [0.1, 0.2], + similarityTopK: 3, + mode: VectorStoreQueryMode.DEFAULT, + filters: tc.filters, + }); + expect(result.ids).length(tc.expected); + }); + }); + }); + + describe("[MilvusVectorStore] filter nodes with unsupported operators", () => { + const testcases: Array< + Omit<FilterTestCase, "expectedFilterStr" | "mockResultIds"> + > = [ + { + title: "Filter ANY", + filters: { + filters: [ + { + key: "type", + value: ["husky", "puppy"], + operator: "any", + }, + ], + }, + expected: 3, + }, + { + title: "Filter ALL", + filters: { + filters: [ + { + key: "type", + value: ["husky", "puppy"], + operator: "all", + }, + ], + }, + expected: 1, + }, + { + title: "Filter CONTAINS", + filters: { + filters: [ + { + key: "type", + value: "puppy", + operator: "contains", + }, + ], + }, + expected: 2, + }, + { + title: "Filter TEXT_MATCH", + filters: { + filters: [ + { + key: "name", + value: "Luk", + operator: "text_match", + }, + ], + }, + expected: 1, + }, + ]; + + testcases.forEach((tc) => { + it(`[Unsupported Operator] [${tc.title}] should throw error`, async () => { + const errorMsg = `Operator ${tc.filters?.filters[0].operator} is not supported.`; + expect(() => store.toMilvusFilter(tc.filters)).toThrow(errorMsg); + }); + }); + }); +}); diff --git a/packages/llamaindex/tests/vectorStores/SimpleVectorStore.test.ts b/packages/llamaindex/tests/vectorStores/SimpleVectorStore.test.ts new file mode 100644 index 000000000..6e4fece1a --- /dev/null +++ b/packages/llamaindex/tests/vectorStores/SimpleVectorStore.test.ts @@ -0,0 +1,299 @@ +import { + BaseEmbedding, + BaseNode, + SimpleVectorStore, + TextNode, + VectorStoreQueryMode, + type Metadata, + type MetadataFilters, +} from "llamaindex"; +import { beforeEach, describe, expect, it } from "vitest"; + +type FilterTestCase = { + title: string; + filters?: MetadataFilters; + expected: number; +}; + +describe("SimpleVectorStore", () => { + let nodes: BaseNode[]; + let store: SimpleVectorStore; + + beforeEach(() => { + nodes = [ + new TextNode({ + id_: "1", + embedding: [0.1, 0.2], + text: "The dog is brown", + metadata: { + name: "Anakin", + dogId: "1", + private: "true", + weight: 1.2, + type: ["husky", "puppy"], + }, + }), + new TextNode({ + id_: "2", + embedding: [0.1, 0.2], + text: "The dog is yellow", + metadata: { + name: "Luke", + dogId: "2", + private: "false", + weight: 2.3, + type: ["puppy"], + }, + }), + new TextNode({ + id_: "3", + embedding: [0.1, 0.2], + text: "The dog is red", + metadata: { + name: "Leia", + dogId: "3", + private: "false", + weight: 3.4, + type: ["husky"], + }, + }), + ]; + store = new SimpleVectorStore({ + embedModel: {} as BaseEmbedding, // Mocking the embedModel + data: { + embeddingDict: {}, + textIdToRefDocId: {}, + metadataDict: nodes.reduce( + (acc, node) => { + acc[node.id_] = node.metadata; + return acc; + }, + {} as Record<string, Metadata>, + ), + }, + }); + }); + + describe("[SimpleVectorStore] manage nodes", () => { + it("able to add nodes to store", async () => { + const ids = await store.add(nodes); + expect(ids).length(3); + }); + }); + + describe("[SimpleVectorStore] query nodes", () => { + const testcases: FilterTestCase[] = [ + { + title: "No filter", + expected: 3, + }, + { + title: "Filter EQ", + filters: { + filters: [ + { + key: "private", + value: "false", + operator: "==", + }, + ], + }, + expected: 2, + }, + { + title: "Filter NE", + filters: { + filters: [ + { + key: "private", + value: "false", + operator: "!=", + }, + ], + }, + expected: 1, + }, + { + title: "Filter GT", + filters: { + filters: [ + { + key: "weight", + value: 2.3, + operator: ">", + }, + ], + }, + expected: 1, + }, + { + title: "Filter GTE", + filters: { + filters: [ + { + key: "weight", + value: 2.3, + operator: ">=", + }, + ], + }, + expected: 2, + }, + { + title: "Filter LT", + filters: { + filters: [ + { + key: "weight", + value: 2.3, + operator: "<", + }, + ], + }, + expected: 1, + }, + { + title: "Filter LTE", + filters: { + filters: [ + { + key: "weight", + value: 2.3, + operator: "<=", + }, + ], + }, + expected: 2, + }, + { + title: "Filter IN", + filters: { + filters: [ + { + key: "dogId", + value: ["1", "3"], + operator: "in", + }, + ], + }, + expected: 2, + }, + { + title: "Filter NIN", + filters: { + filters: [ + { + key: "name", + value: ["Anakin", "Leia"], + operator: "nin", + }, + ], + }, + expected: 1, + }, + { + title: "Filter ANY", + filters: { + filters: [ + { + key: "type", + value: ["husky", "puppy"], + operator: "any", + }, + ], + }, + expected: 3, + }, + { + title: "Filter ALL", + filters: { + filters: [ + { + key: "type", + value: ["husky", "puppy"], + operator: "all", + }, + ], + }, + expected: 1, + }, + { + title: "Filter CONTAINS", + filters: { + filters: [ + { + key: "type", + value: "puppy", + operator: "contains", + }, + ], + }, + expected: 2, + }, + { + title: "Filter TEXT_MATCH", + filters: { + filters: [ + { + key: "name", + value: "Luk", + operator: "text_match", + }, + ], + }, + expected: 1, + }, + { + title: "Filter OR", + filters: { + filters: [ + { + key: "private", + value: "false", + operator: "==", + }, + { + key: "dogId", + value: ["1", "3"], + operator: "in", + }, + ], + condition: "or", + }, + expected: 3, + }, + { + title: "Filter AND", + filters: { + filters: [ + { + key: "private", + value: "false", + operator: "==", + }, + { + key: "dogId", + value: "10", + operator: "==", + }, + ], + condition: "and", + }, + expected: 0, + }, + ]; + + testcases.forEach((tc) => { + it(`[${tc.title}] should return ${tc.expected} nodes`, async () => { + await store.add(nodes); + const result = await store.query({ + queryEmbedding: [0.1, 0.2], + similarityTopK: 3, + mode: VectorStoreQueryMode.DEFAULT, + filters: tc.filters, + }); + expect(result.ids).length(tc.expected); + }); + }); + }); +}); -- GitLab