diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-mongodb/llama_index/vector_stores/mongodb/base.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-mongodb/llama_index/vector_stores/mongodb/base.py index abd3776fecbaa9a6d82fe063c2d007ca71e21043..c883a3e39ae68f5cef9e45f5a89964907340f2fb 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-mongodb/llama_index/vector_stores/mongodb/base.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-mongodb/llama_index/vector_stores/mongodb/base.py @@ -9,10 +9,11 @@ import os from importlib.metadata import version from typing import Any, Dict, List, Optional, cast +from llama_index.core.bridge.pydantic import PrivateAttr from llama_index.core.schema import BaseNode, MetadataMode, TextNode from llama_index.core.vector_stores.types import ( MetadataFilters, - VectorStore, + BasePydanticVectorStore, VectorStoreQuery, VectorStoreQueryResult, ) @@ -35,7 +36,7 @@ def _to_mongodb_filter(standard_filters: MetadataFilters) -> Dict: return filters -class MongoDBAtlasVectorSearch(VectorStore): +class MongoDBAtlasVectorSearch(BasePydanticVectorStore): """MongoDB Atlas Vector Store. To use, you should have both: @@ -48,6 +49,15 @@ class MongoDBAtlasVectorSearch(VectorStore): stores_text: bool = True flat_metadata: bool = True + _mongodb_client: Any = PrivateAttr() + _collection: Any = PrivateAttr() + _index_name: str = PrivateAttr() + _embedding_key: str = PrivateAttr() + _id_key: str = PrivateAttr() + _text_key: str = PrivateAttr() + _metadata_key: str = PrivateAttr() + _insert_kwargs: Dict = PrivateAttr() + def __init__( self, mongodb_client: Optional[Any] = None, @@ -97,6 +107,8 @@ class MongoDBAtlasVectorSearch(VectorStore): self._metadata_key = metadata_key self._insert_kwargs = insert_kwargs or {} + super().__init__() + def add( self, nodes: List[BaseNode], diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-mongodb/tests/test_vector_stores_mongodb.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-mongodb/tests/test_vector_stores_mongodb.py index 88017acc967fffcf8bcc4c69d4170edbba48a188..4066dae44415f68fccc01164ea96e60b43fe8db4 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-mongodb/tests/test_vector_stores_mongodb.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-mongodb/tests/test_vector_stores_mongodb.py @@ -1,7 +1,7 @@ -from llama_index.core.vector_stores.types import VectorStore +from llama_index.core.vector_stores.types import BasePydanticVectorStore from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch def test_class(): names_of_base_classes = [b.__name__ for b in MongoDBAtlasVectorSearch.__mro__] - assert VectorStore.__name__ in names_of_base_classes + assert BasePydanticVectorStore.__name__ in names_of_base_classes