From 448584c8cd30bab744d7629c9d1a7ee72e5af5ad Mon Sep 17 00:00:00 2001
From: Richmond Alake <richmond.alake@gmail.com>
Date: Wed, 14 Feb 2024 17:38:32 +0000
Subject: [PATCH] Updated MongoDB Vector Store to Pydantic Vector Store
 BaseClass (#10698)

---
 .../llama_index/vector_stores/mongodb/base.py    | 16 ++++++++++++++--
 .../tests/test_vector_stores_mongodb.py          |  4 ++--
 2 files changed, 16 insertions(+), 4 deletions(-)

diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-mongodb/llama_index/vector_stores/mongodb/base.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-mongodb/llama_index/vector_stores/mongodb/base.py
index abd3776fe..c883a3e39 100644
--- a/llama-index-integrations/vector_stores/llama-index-vector-stores-mongodb/llama_index/vector_stores/mongodb/base.py
+++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-mongodb/llama_index/vector_stores/mongodb/base.py
@@ -9,10 +9,11 @@ import os
 from importlib.metadata import version
 from typing import Any, Dict, List, Optional, cast
 
+from llama_index.core.bridge.pydantic import PrivateAttr
 from llama_index.core.schema import BaseNode, MetadataMode, TextNode
 from llama_index.core.vector_stores.types import (
     MetadataFilters,
-    VectorStore,
+    BasePydanticVectorStore,
     VectorStoreQuery,
     VectorStoreQueryResult,
 )
@@ -35,7 +36,7 @@ def _to_mongodb_filter(standard_filters: MetadataFilters) -> Dict:
     return filters
 
 
-class MongoDBAtlasVectorSearch(VectorStore):
+class MongoDBAtlasVectorSearch(BasePydanticVectorStore):
     """MongoDB Atlas Vector Store.
 
     To use, you should have both:
@@ -48,6 +49,15 @@ class MongoDBAtlasVectorSearch(VectorStore):
     stores_text: bool = True
     flat_metadata: bool = True
 
+    _mongodb_client: Any = PrivateAttr()
+    _collection: Any = PrivateAttr()
+    _index_name: str = PrivateAttr()
+    _embedding_key: str = PrivateAttr()
+    _id_key: str = PrivateAttr()
+    _text_key: str = PrivateAttr()
+    _metadata_key: str = PrivateAttr()
+    _insert_kwargs: Dict = PrivateAttr()
+
     def __init__(
         self,
         mongodb_client: Optional[Any] = None,
@@ -97,6 +107,8 @@ class MongoDBAtlasVectorSearch(VectorStore):
         self._metadata_key = metadata_key
         self._insert_kwargs = insert_kwargs or {}
 
+        super().__init__()
+
     def add(
         self,
         nodes: List[BaseNode],
diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-mongodb/tests/test_vector_stores_mongodb.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-mongodb/tests/test_vector_stores_mongodb.py
index 88017acc9..4066dae44 100644
--- a/llama-index-integrations/vector_stores/llama-index-vector-stores-mongodb/tests/test_vector_stores_mongodb.py
+++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-mongodb/tests/test_vector_stores_mongodb.py
@@ -1,7 +1,7 @@
-from llama_index.core.vector_stores.types import VectorStore
+from llama_index.core.vector_stores.types import BasePydanticVectorStore
 from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch
 
 
 def test_class():
     names_of_base_classes = [b.__name__ for b in MongoDBAtlasVectorSearch.__mro__]
-    assert VectorStore.__name__ in names_of_base_classes
+    assert BasePydanticVectorStore.__name__ in names_of_base_classes
-- 
GitLab