Skip to content
Snippets Groups Projects
Unverified Commit c79f36da authored by Ravi Theja's avatar Ravi Theja Committed by GitHub
Browse files

Update MilvusVectorStore to Pydantic (#11432)

parent bf2c8a4b
No related branches found
No related tags found
No related merge requests found
......@@ -7,10 +7,11 @@ import logging
from typing import Any, Dict, List, Optional, Union
import pymilvus # noqa
from llama_index.core.bridge.pydantic import PrivateAttr
from llama_index.core.schema import BaseNode, TextNode
from llama_index.core.vector_stores.types import (
BasePydanticVectorStore,
MetadataFilters,
VectorStore,
VectorStoreQuery,
VectorStoreQueryMode,
VectorStoreQueryResult,
......@@ -39,7 +40,7 @@ def _to_milvus_filter(standard_filters: MetadataFilters) -> List[str]:
return filters
class MilvusVectorStore(VectorStore):
class MilvusVectorStore(BasePydanticVectorStore):
"""The Milvus Vector Store.
In this vector store we store the text, its embedding and
......@@ -87,11 +88,27 @@ class MilvusVectorStore(VectorStore):
stores_text: bool = True
stores_node: bool = True
uri: str = "http://localhost:19530"
token: str = ""
collection_name: str = "llamacollection"
dim: Optional[int]
embedding_field: str = DEFAULT_EMBEDDING_KEY
doc_id_field: str = DEFAULT_DOC_ID_KEY
similarity_metric: str = "IP"
consistency_level: str = "Strong"
overwrite: bool = False
text_key: Optional[str]
index_config: Optional[dict]
search_config: Optional[dict]
_milvusclient: MilvusClient = PrivateAttr()
_collection: Any = PrivateAttr()
def __init__(
self,
uri: str = "http://localhost:19530",
token: str = "",
collection_name: str = "llamalection",
collection_name: str = "llamacollection",
dim: Optional[int] = None,
embedding_field: str = DEFAULT_EMBEDDING_KEY,
doc_id_field: str = DEFAULT_DOC_ID_KEY,
......@@ -104,56 +121,48 @@ class MilvusVectorStore(VectorStore):
**kwargs: Any,
) -> None:
"""Init params."""
self.collection_name = collection_name
self.dim = dim
self.embedding_field = embedding_field
self.doc_id_field = doc_id_field
self.consistency_level = consistency_level
self.overwrite = overwrite
self.text_key = text_key
self.index_config: Dict[str, Any] = index_config.copy() if index_config else {}
# Note: The search configuration is set at construction to avoid having
# to change the API for usage of the vector store (i.e. to pass the
# search config along with the rest of the query).
self.search_config: Dict[str, Any] = (
search_config.copy() if search_config else {}
super().__init__(
collection_name=collection_name,
dim=dim,
embedding_field=embedding_field,
doc_id_field=doc_id_field,
consistency_level=consistency_level,
overwrite=overwrite,
text_key=text_key,
index_config=index_config if index_config else {},
search_config=search_config if search_config else {},
)
# Select the similarity metric
if similarity_metric.lower() in ("ip"):
self.similarity_metric = "IP"
elif similarity_metric.lower() in ("l2", "euclidean"):
self.similarity_metric = "L2"
similarity_metrics_map = {"ip": "IP", "l2": "L2", "euclidean": "L2"}
similarity_metric = similarity_metrics_map.get(similarity_metric.lower(), "L2")
# Connect to Milvus instance
self.milvusclient = MilvusClient(
self._milvusclient = MilvusClient(
uri=uri,
token=token,
**kwargs, # pass additional arguments such as server_pem_path
)
# Delete previous collection if overwriting
if self.overwrite and self.collection_name in self.client.list_collections():
self.milvusclient.drop_collection(self.collection_name)
if overwrite and collection_name in self.client.list_collections():
self._milvusclient.drop_collection(collection_name)
# Create the collection if it does not exist
if self.collection_name not in self.client.list_collections():
if self.dim is None:
if collection_name not in self.client.list_collections():
if dim is None:
raise ValueError("Dim argument required for collection creation.")
self.milvusclient.create_collection(
collection_name=self.collection_name,
dimension=self.dim,
self._milvusclient.create_collection(
collection_name=collection_name,
dimension=dim,
primary_field_name=MILVUS_ID_FIELD,
vector_field_name=self.embedding_field,
vector_field_name=embedding_field,
id_type="string",
metric_type=self.similarity_metric,
metric_type=similarity_metric,
max_length=65_535,
consistency_level=self.consistency_level,
consistency_level=consistency_level,
)
self.collection = Collection(
self.collection_name, using=self.milvusclient._using
)
self._collection = Collection(collection_name, using=self._milvusclient._using)
self._create_index_if_required()
logger.debug(f"Successfully created a new collection: {self.collection_name}")
......@@ -161,7 +170,7 @@ class MilvusVectorStore(VectorStore):
@property
def client(self) -> Any:
"""Get client."""
return self.milvusclient
return self._milvusclient
def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
"""Add the embeddings and their nodes into Milvus.
......@@ -189,8 +198,8 @@ class MilvusVectorStore(VectorStore):
insert_list.append(entry)
# Insert the data into milvus
self.collection.insert(insert_list)
self.collection.flush()
self._collection.insert(insert_list)
self._collection.flush()
self._create_index_if_required()
logger.debug(
f"Successfully inserted embeddings into: {self.collection_name} "
......@@ -217,13 +226,13 @@ class MilvusVectorStore(VectorStore):
# Begin by querying for the primary keys to delete
doc_ids = ['"' + entry + '"' for entry in doc_ids]
entries = self.milvusclient.query(
entries = self._milvusclient.query(
collection_name=self.collection_name,
filter=f"{self.doc_id_field} in [{','.join(doc_ids)}]",
)
if len(entries) > 0:
ids = [entry["id"] for entry in entries]
self.milvusclient.delete(collection_name=self.collection_name, pks=ids)
self._milvusclient.delete(collection_name=self.collection_name, pks=ids)
logger.debug(f"Successfully deleted embedding with doc_id: {doc_ids}")
def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
......@@ -267,7 +276,7 @@ class MilvusVectorStore(VectorStore):
string_expr = " and ".join(expr)
# Perform the search
res = self.milvusclient.search(
res = self._milvusclient.search(
collection_name=self.collection_name,
data=[query.query_embedding],
filter=string_expr,
......@@ -317,9 +326,9 @@ class MilvusVectorStore(VectorStore):
# provided to ensure that the index is created in the constructor even
# if self.overwrite is false. In the `add` method, the index is
# recreated only if self.overwrite is true.
if (self.collection.has_index() and self.overwrite) or force:
self.collection.release()
self.collection.drop_index()
if (self._collection.has_index() and self.overwrite) or force:
self._collection.release()
self._collection.drop_index()
base_params: Dict[str, Any] = self.index_config.copy()
index_type: str = base_params.pop("index_type", "FLAT")
index_params: Dict[str, Union[str, Dict[str, Any]]] = {
......@@ -327,7 +336,7 @@ class MilvusVectorStore(VectorStore):
"metric_type": self.similarity_metric,
"index_type": index_type,
}
self.collection.create_index(
self._collection.create_index(
self.embedding_field, index_params=index_params
)
self.collection.load()
self._collection.load()
from llama_index.core.vector_stores.types import VectorStore
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.milvus import MilvusVectorStore
def test_class():
names_of_base_classes = [b.__name__ for b in MilvusVectorStore.__mro__]
assert VectorStore.__name__ in names_of_base_classes
assert BasePydanticVectorStore.__name__ in names_of_base_classes
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment