From 8795368bc55135cc6dd399d9fc136cd7ea49bdd3 Mon Sep 17 00:00:00 2001 From: "Chandrashekar V.T" <57014454+chandrashekarvt@users.noreply.github.com> Date: Tue, 9 Apr 2024 23:16:50 +0530 Subject: [PATCH] Added support to retrieve metadata fields from milvus (#12626) --- .../indices/vector_store/retrievers/retriever.py | 11 ++++++----- .../llama_index/vector_stores/milvus/base.py | 12 ++++++++---- .../llama-index-vector-stores-milvus/pyproject.toml | 2 +- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/llama-index-core/llama_index/core/indices/vector_store/retrievers/retriever.py b/llama-index-core/llama_index/core/indices/vector_store/retrievers/retriever.py index 526dcc9168..3232b9c06a 100644 --- a/llama-index-core/llama_index/core/indices/vector_store/retrievers/retriever.py +++ b/llama-index-core/llama_index/core/indices/vector_store/retrievers/retriever.py @@ -102,15 +102,16 @@ class VectorIndexRetriever(BaseRetriever): @dispatcher.span async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + embedding = query_bundle.embedding if self._vector_store.is_embedding_query: if query_bundle.embedding is None and len(query_bundle.embedding_strs) > 0: embed_model = self._embed_model - query_bundle.embedding = ( - await embed_model.aget_agg_embedding_from_queries( - query_bundle.embedding_strs - ) + embedding = await embed_model.aget_agg_embedding_from_queries( + query_bundle.embedding_strs ) - return await self._aget_nodes_with_embeddings(query_bundle) + return await self._aget_nodes_with_embeddings( + QueryBundle(query_str=query_bundle.query_str, embedding=embedding) + ) def _build_vector_store_query( self, query_bundle_with_embeddings: QueryBundle diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-milvus/llama_index/vector_stores/milvus/base.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-milvus/llama_index/vector_stores/milvus/base.py index 433b110e2a..645230b34f 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-milvus/llama_index/vector_stores/milvus/base.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-milvus/llama_index/vector_stores/milvus/base.py @@ -8,7 +8,7 @@ import logging from typing import Any, Dict, List, Optional, Union import pymilvus # noqa -from llama_index.core.bridge.pydantic import PrivateAttr +from llama_index.core.bridge.pydantic import Field, PrivateAttr from llama_index.core.schema import BaseNode, TextNode from llama_index.core.vector_stores.types import ( BasePydanticVectorStore, @@ -115,6 +115,7 @@ class MilvusVectorStore(BasePydanticVectorStore): consistency_level: str = "Strong" overwrite: bool = False text_key: Optional[str] + output_fields: List[str] = Field(default_factory=list) index_config: Optional[dict] search_config: Optional[dict] @@ -133,6 +134,7 @@ class MilvusVectorStore(BasePydanticVectorStore): consistency_level: str = "Strong", overwrite: bool = False, text_key: Optional[str] = None, + output_fields: Optional[List[str]] = None, index_config: Optional[dict] = None, search_config: Optional[dict] = None, **kwargs: Any, @@ -146,6 +148,7 @@ class MilvusVectorStore(BasePydanticVectorStore): consistency_level=consistency_level, overwrite=overwrite, text_key=text_key, + output_fields=output_fields, index_config=index_config if index_config else {}, search_config=search_config if search_config else {}, ) @@ -336,9 +339,10 @@ class MilvusVectorStore(BasePydanticVectorStore): "The passed in text_key value does not exist " "in the retrieved entity." ) - node = TextNode( - text=text, - ) + + metadata = {key: hit["entity"].get(key) for key in self.output_fields} + node = TextNode(text=text, metadata=metadata) + nodes.append(node) similarities.append(hit["distance"]) ids.append(hit["id"]) diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-milvus/pyproject.toml b/llama-index-integrations/vector_stores/llama-index-vector-stores-milvus/pyproject.toml index e45be5c7a4..427c364340 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-milvus/pyproject.toml +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-milvus/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-vector-stores-milvus" readme = "README.md" -version = "0.1.7" +version = "0.1.8" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" -- GitLab