diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-milvus/llama_index/vector_stores/milvus/base.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-milvus/llama_index/vector_stores/milvus/base.py index 570ed49fc092797dfa1a6276ffcbd2ceff9c8e46..6358c077dc5be9b8379427f940bf4f7e1a2c6246 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-milvus/llama_index/vector_stores/milvus/base.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-milvus/llama_index/vector_stores/milvus/base.py @@ -7,10 +7,11 @@ import logging from typing import Any, Dict, List, Optional, Union import pymilvus # noqa +from llama_index.core.bridge.pydantic import PrivateAttr from llama_index.core.schema import BaseNode, TextNode from llama_index.core.vector_stores.types import ( + BasePydanticVectorStore, MetadataFilters, - VectorStore, VectorStoreQuery, VectorStoreQueryMode, VectorStoreQueryResult, @@ -39,7 +40,7 @@ def _to_milvus_filter(standard_filters: MetadataFilters) -> List[str]: return filters -class MilvusVectorStore(VectorStore): +class MilvusVectorStore(BasePydanticVectorStore): """The Milvus Vector Store. In this vector store we store the text, its embedding and @@ -87,11 +88,27 @@ class MilvusVectorStore(VectorStore): stores_text: bool = True stores_node: bool = True + uri: str = "http://localhost:19530" + token: str = "" + collection_name: str = "llamacollection" + dim: Optional[int] + embedding_field: str = DEFAULT_EMBEDDING_KEY + doc_id_field: str = DEFAULT_DOC_ID_KEY + similarity_metric: str = "IP" + consistency_level: str = "Strong" + overwrite: bool = False + text_key: Optional[str] + index_config: Optional[dict] + search_config: Optional[dict] + + _milvusclient: MilvusClient = PrivateAttr() + _collection: Any = PrivateAttr() + def __init__( self, uri: str = "http://localhost:19530", token: str = "", - collection_name: str = "llamalection", + collection_name: str = "llamacollection", dim: Optional[int] = None, embedding_field: str = DEFAULT_EMBEDDING_KEY, doc_id_field: str = DEFAULT_DOC_ID_KEY, @@ -104,56 +121,48 @@ class MilvusVectorStore(VectorStore): **kwargs: Any, ) -> None: """Init params.""" - self.collection_name = collection_name - self.dim = dim - self.embedding_field = embedding_field - self.doc_id_field = doc_id_field - self.consistency_level = consistency_level - self.overwrite = overwrite - self.text_key = text_key - self.index_config: Dict[str, Any] = index_config.copy() if index_config else {} - # Note: The search configuration is set at construction to avoid having - # to change the API for usage of the vector store (i.e. to pass the - # search config along with the rest of the query). - self.search_config: Dict[str, Any] = ( - search_config.copy() if search_config else {} + super().__init__( + collection_name=collection_name, + dim=dim, + embedding_field=embedding_field, + doc_id_field=doc_id_field, + consistency_level=consistency_level, + overwrite=overwrite, + text_key=text_key, + index_config=index_config if index_config else {}, + search_config=search_config if search_config else {}, ) # Select the similarity metric - if similarity_metric.lower() in ("ip"): - self.similarity_metric = "IP" - elif similarity_metric.lower() in ("l2", "euclidean"): - self.similarity_metric = "L2" + similarity_metrics_map = {"ip": "IP", "l2": "L2", "euclidean": "L2"} + similarity_metric = similarity_metrics_map.get(similarity_metric.lower(), "L2") # Connect to Milvus instance - self.milvusclient = MilvusClient( + self._milvusclient = MilvusClient( uri=uri, token=token, **kwargs, # pass additional arguments such as server_pem_path ) - # Delete previous collection if overwriting - if self.overwrite and self.collection_name in self.client.list_collections(): - self.milvusclient.drop_collection(self.collection_name) + if overwrite and collection_name in self.client.list_collections(): + self._milvusclient.drop_collection(collection_name) # Create the collection if it does not exist - if self.collection_name not in self.client.list_collections(): - if self.dim is None: + if collection_name not in self.client.list_collections(): + if dim is None: raise ValueError("Dim argument required for collection creation.") - self.milvusclient.create_collection( - collection_name=self.collection_name, - dimension=self.dim, + self._milvusclient.create_collection( + collection_name=collection_name, + dimension=dim, primary_field_name=MILVUS_ID_FIELD, - vector_field_name=self.embedding_field, + vector_field_name=embedding_field, id_type="string", - metric_type=self.similarity_metric, + metric_type=similarity_metric, max_length=65_535, - consistency_level=self.consistency_level, + consistency_level=consistency_level, ) - self.collection = Collection( - self.collection_name, using=self.milvusclient._using - ) + self._collection = Collection(collection_name, using=self._milvusclient._using) self._create_index_if_required() logger.debug(f"Successfully created a new collection: {self.collection_name}") @@ -161,7 +170,7 @@ class MilvusVectorStore(VectorStore): @property def client(self) -> Any: """Get client.""" - return self.milvusclient + return self._milvusclient def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]: """Add the embeddings and their nodes into Milvus. @@ -189,8 +198,8 @@ class MilvusVectorStore(VectorStore): insert_list.append(entry) # Insert the data into milvus - self.collection.insert(insert_list) - self.collection.flush() + self._collection.insert(insert_list) + self._collection.flush() self._create_index_if_required() logger.debug( f"Successfully inserted embeddings into: {self.collection_name} " @@ -217,13 +226,13 @@ class MilvusVectorStore(VectorStore): # Begin by querying for the primary keys to delete doc_ids = ['"' + entry + '"' for entry in doc_ids] - entries = self.milvusclient.query( + entries = self._milvusclient.query( collection_name=self.collection_name, filter=f"{self.doc_id_field} in [{','.join(doc_ids)}]", ) if len(entries) > 0: ids = [entry["id"] for entry in entries] - self.milvusclient.delete(collection_name=self.collection_name, pks=ids) + self._milvusclient.delete(collection_name=self.collection_name, pks=ids) logger.debug(f"Successfully deleted embedding with doc_id: {doc_ids}") def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult: @@ -267,7 +276,7 @@ class MilvusVectorStore(VectorStore): string_expr = " and ".join(expr) # Perform the search - res = self.milvusclient.search( + res = self._milvusclient.search( collection_name=self.collection_name, data=[query.query_embedding], filter=string_expr, @@ -317,9 +326,9 @@ class MilvusVectorStore(VectorStore): # provided to ensure that the index is created in the constructor even # if self.overwrite is false. In the `add` method, the index is # recreated only if self.overwrite is true. - if (self.collection.has_index() and self.overwrite) or force: - self.collection.release() - self.collection.drop_index() + if (self._collection.has_index() and self.overwrite) or force: + self._collection.release() + self._collection.drop_index() base_params: Dict[str, Any] = self.index_config.copy() index_type: str = base_params.pop("index_type", "FLAT") index_params: Dict[str, Union[str, Dict[str, Any]]] = { @@ -327,7 +336,7 @@ class MilvusVectorStore(VectorStore): "metric_type": self.similarity_metric, "index_type": index_type, } - self.collection.create_index( + self._collection.create_index( self.embedding_field, index_params=index_params ) - self.collection.load() + self._collection.load() diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-milvus/tests/test_vector_stores_milvus.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-milvus/tests/test_vector_stores_milvus.py index a5720405d550c63484641603768d8bbd6a661e52..325cf75d5281ee1f144af6fc8401452496fbed6a 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-milvus/tests/test_vector_stores_milvus.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-milvus/tests/test_vector_stores_milvus.py @@ -1,7 +1,7 @@ -from llama_index.core.vector_stores.types import VectorStore +from llama_index.core.vector_stores.types import BasePydanticVectorStore from llama_index.vector_stores.milvus import MilvusVectorStore def test_class(): names_of_base_classes = [b.__name__ for b in MilvusVectorStore.__mro__] - assert VectorStore.__name__ in names_of_base_classes + assert BasePydanticVectorStore.__name__ in names_of_base_classes