Skip to content
Snippets Groups Projects
Unverified Commit 11b38563 authored by Thuc Pham's avatar Thuc Pham Committed by GitHub
Browse files

feat: implement filters for MongoDBAtlasVectorSearch (#1142)

parent e8f229cd
No related branches found
No related tags found
No related merge requests found
---
"llamaindex": patch
---
implement filters for MongoDBAtlasVectorSearch
......@@ -28,12 +28,23 @@ async function loadAndIndex() {
"full_text",
]);
const FILTER_METADATA_FIELD = "content_type";
documents.forEach((document, index) => {
const contentType = ["tweet", "post", "story"][index % 3]; // assign a random content type to each document
document.metadata = {
...document.metadata,
[FILTER_METADATA_FIELD]: contentType,
};
});
// create Atlas as a vector store
const vectorStore = new MongoDBAtlasVectorSearch({
mongodbClient: client,
dbName: databaseName,
collectionName: vectorCollectionName, // this is where your embeddings will be stored
indexName: indexName, // this is the name of the index you will need to create
indexedMetadataFields: [FILTER_METADATA_FIELD], // this is the field that will be used for the query
});
// now create an index from all the Documents and store them in Atlas
......@@ -46,5 +57,3 @@ async function loadAndIndex() {
}
loadAndIndex().catch(console.error);
// you can't query your index yet because you need to create a vector search index in mongodb's UI now
......@@ -14,12 +14,24 @@ async function query() {
dbName: process.env.MONGODB_DATABASE!,
collectionName: process.env.MONGODB_VECTORS!,
indexName: process.env.MONGODB_VECTOR_INDEX!,
indexedMetadataFields: ["content_type"],
});
const index = await VectorStoreIndex.fromVectorStore(store);
const retriever = index.asRetriever({ similarityTopK: 20 });
const queryEngine = index.asQueryEngine({ retriever });
const queryEngine = index.asQueryEngine({
retriever,
preFilters: {
filters: [
{
key: "content_type",
value: "story", // try "tweet" or "post" to see the difference
operator: "==",
},
],
},
});
const result = await queryEngine.query({
query: "What does author receive when he was 11 years old?", // Isaac Asimov's "Foundation" for Christmas
});
......
......@@ -5,7 +5,10 @@ import { getEnv } from "@llamaindex/env";
import type { BulkWriteOptions, Collection } from "mongodb";
import { MongoClient } from "mongodb";
import {
FilterCondition,
VectorStoreBase,
type FilterOperator,
type MetadataFilter,
type MetadataFilters,
type VectorStoreNoEmbedModel,
type VectorStoreQuery,
......@@ -13,15 +16,51 @@ import {
} from "./types.js";
import { metadataDictToNode, nodeToMetadata } from "./utils.js";
// Utility function to convert metadata filters to MongoDB filter
function toMongoDBFilter(
standardFilters: MetadataFilters,
): Record<string, any> {
const filters: Record<string, any> = {};
for (const filter of standardFilters?.filters ?? []) {
filters[filter.key] = filter.value;
// define your Atlas Search index. See detail https://www.mongodb.com/docs/atlas/atlas-search/field-types/knn-vector/
const DEFAULT_EMBEDDING_DEFINITION = {
type: "knnVector",
dimensions: 1536,
similarity: "cosine",
};
function mapLcMqlFilterOperators(operator: string): string {
const operatorMap: { [key in FilterOperator]?: string } = {
"==": "$eq",
"<": "$lt",
"<=": "$lte",
">": "$gt",
">=": "$gte",
"!=": "$ne",
in: "$in",
nin: "$nin",
};
const mqlOperator = operatorMap[operator as FilterOperator];
if (!mqlOperator) throw new Error(`Unsupported operator: ${operator}`);
return mqlOperator;
}
function toMongoDBFilter(filters?: MetadataFilters): Record<string, any> {
if (!filters) return {};
const createFilterObject = (mf: MetadataFilter) => ({
[mf.key]: {
[mapLcMqlFilterOperators(mf.operator)]: mf.value,
},
});
if (filters.filters.length === 1) {
return createFilterObject(filters.filters[0]);
}
if (filters.condition === FilterCondition.AND) {
return { $and: filters.filters.map(createFilterObject) };
}
if (filters.condition === FilterCondition.OR) {
return { $or: filters.filters.map(createFilterObject) };
}
return filters;
throw new Error("filters condition not recognized. Must be AND or OR");
}
/**
......@@ -38,6 +77,8 @@ export class MongoDBAtlasVectorSearch
dbName: string;
collectionName: string;
autoCreateIndex: boolean;
embeddingDefinition: Record<string, unknown>;
indexedMetadataFields: string[];
/**
* The used MongoClient. If not given, a new MongoClient is created based on the MONGODB_URI env variable.
......@@ -98,26 +139,14 @@ export class MongoDBAtlasVectorSearch
numCandidates: (query: VectorStoreQuery) => number;
private collection?: Collection;
// define your Atlas Search index. See detail https://www.mongodb.com/docs/atlas/atlas-search/field-types/knn-vector/
readonly SEARCH_INDEX_DEFINITION = {
mappings: {
dynamic: true,
fields: {
embedding: {
type: "knnVector",
dimensions: 1536,
similarity: "cosine",
},
},
},
};
constructor(
init: Partial<MongoDBAtlasVectorSearch> & {
dbName: string;
collectionName: string;
embedModel?: BaseEmbedding;
autoCreateIndex?: boolean;
indexedMetadataFields?: string[];
embeddingDefinition?: Record<string, unknown>;
},
) {
super(init.embedModel);
......@@ -136,6 +165,11 @@ export class MongoDBAtlasVectorSearch
this.dbName = init.dbName ?? "default_db";
this.collectionName = init.collectionName ?? "default_collection";
this.autoCreateIndex = init.autoCreateIndex ?? true;
this.indexedMetadataFields = init.indexedMetadataFields ?? [];
this.embeddingDefinition = {
...DEFAULT_EMBEDDING_DEFINITION,
...(init.embeddingDefinition ?? {}),
};
this.indexName = init.indexName ?? "default";
this.embeddingKey = init.embeddingKey ?? "embedding";
this.idKey = init.idKey ?? "id";
......@@ -161,9 +195,21 @@ export class MongoDBAtlasVectorSearch
(index) => index.name === this.indexName,
);
if (!indexExists) {
const additionalDefinition: Record<string, { type: string }> = {};
this.indexedMetadataFields.forEach((field) => {
additionalDefinition[field] = { type: "token" };
});
await this.collection.createSearchIndex({
name: this.indexName,
definition: this.SEARCH_INDEX_DEFINITION,
definition: {
mappings: {
dynamic: true,
fields: {
embedding: this.embeddingDefinition,
...additionalDefinition,
},
},
},
});
}
}
......@@ -189,11 +235,18 @@ export class MongoDBAtlasVectorSearch
this.flatMetadata,
);
// Include the specified metadata fields in the top level of the document (to help filter)
const populatedMetadata: Record<string, unknown> = {};
for (const field of this.indexedMetadataFields) {
populatedMetadata[field] = metadata[field];
}
return {
[this.idKey]: node.id_,
[this.embeddingKey]: node.getEmbedding(),
[this.textKey]: node.getContent(MetadataMode.NONE) || "",
[this.metadataKey]: metadata,
...populatedMetadata,
};
});
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment