Skip to content
Snippets Groups Projects
Unverified Commit 448584c8 authored by Richmond Alake's avatar Richmond Alake Committed by GitHub
Browse files

Updated MongoDB Vector Store to Pydantic Vector Store BaseClass (#10698)

parent 3c18d4ba
No related branches found
No related tags found
No related merge requests found
...@@ -9,10 +9,11 @@ import os ...@@ -9,10 +9,11 @@ import os
from importlib.metadata import version from importlib.metadata import version
from typing import Any, Dict, List, Optional, cast from typing import Any, Dict, List, Optional, cast
from llama_index.core.bridge.pydantic import PrivateAttr
from llama_index.core.schema import BaseNode, MetadataMode, TextNode from llama_index.core.schema import BaseNode, MetadataMode, TextNode
from llama_index.core.vector_stores.types import ( from llama_index.core.vector_stores.types import (
MetadataFilters, MetadataFilters,
VectorStore, BasePydanticVectorStore,
VectorStoreQuery, VectorStoreQuery,
VectorStoreQueryResult, VectorStoreQueryResult,
) )
...@@ -35,7 +36,7 @@ def _to_mongodb_filter(standard_filters: MetadataFilters) -> Dict: ...@@ -35,7 +36,7 @@ def _to_mongodb_filter(standard_filters: MetadataFilters) -> Dict:
return filters return filters
class MongoDBAtlasVectorSearch(VectorStore): class MongoDBAtlasVectorSearch(BasePydanticVectorStore):
"""MongoDB Atlas Vector Store. """MongoDB Atlas Vector Store.
To use, you should have both: To use, you should have both:
...@@ -48,6 +49,15 @@ class MongoDBAtlasVectorSearch(VectorStore): ...@@ -48,6 +49,15 @@ class MongoDBAtlasVectorSearch(VectorStore):
stores_text: bool = True stores_text: bool = True
flat_metadata: bool = True flat_metadata: bool = True
_mongodb_client: Any = PrivateAttr()
_collection: Any = PrivateAttr()
_index_name: str = PrivateAttr()
_embedding_key: str = PrivateAttr()
_id_key: str = PrivateAttr()
_text_key: str = PrivateAttr()
_metadata_key: str = PrivateAttr()
_insert_kwargs: Dict = PrivateAttr()
def __init__( def __init__(
self, self,
mongodb_client: Optional[Any] = None, mongodb_client: Optional[Any] = None,
...@@ -97,6 +107,8 @@ class MongoDBAtlasVectorSearch(VectorStore): ...@@ -97,6 +107,8 @@ class MongoDBAtlasVectorSearch(VectorStore):
self._metadata_key = metadata_key self._metadata_key = metadata_key
self._insert_kwargs = insert_kwargs or {} self._insert_kwargs = insert_kwargs or {}
super().__init__()
def add( def add(
self, self,
nodes: List[BaseNode], nodes: List[BaseNode],
......
from llama_index.core.vector_stores.types import VectorStore from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch
def test_class(): def test_class():
names_of_base_classes = [b.__name__ for b in MongoDBAtlasVectorSearch.__mro__] names_of_base_classes = [b.__name__ for b in MongoDBAtlasVectorSearch.__mro__]
assert VectorStore.__name__ in names_of_base_classes assert BasePydanticVectorStore.__name__ in names_of_base_classes
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment