From 8795368bc55135cc6dd399d9fc136cd7ea49bdd3 Mon Sep 17 00:00:00 2001
From: "Chandrashekar V.T" <57014454+chandrashekarvt@users.noreply.github.com>
Date: Tue, 9 Apr 2024 23:16:50 +0530
Subject: [PATCH] Added support to retrieve metadata fields from milvus
 (#12626)

---
 .../indices/vector_store/retrievers/retriever.py     | 11 ++++++-----
 .../llama_index/vector_stores/milvus/base.py         | 12 ++++++++----
 .../llama-index-vector-stores-milvus/pyproject.toml  |  2 +-
 3 files changed, 15 insertions(+), 10 deletions(-)

diff --git a/llama-index-core/llama_index/core/indices/vector_store/retrievers/retriever.py b/llama-index-core/llama_index/core/indices/vector_store/retrievers/retriever.py
index 526dcc9168..3232b9c06a 100644
--- a/llama-index-core/llama_index/core/indices/vector_store/retrievers/retriever.py
+++ b/llama-index-core/llama_index/core/indices/vector_store/retrievers/retriever.py
@@ -102,15 +102,16 @@ class VectorIndexRetriever(BaseRetriever):
 
     @dispatcher.span
     async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
+        embedding = query_bundle.embedding
         if self._vector_store.is_embedding_query:
             if query_bundle.embedding is None and len(query_bundle.embedding_strs) > 0:
                 embed_model = self._embed_model
-                query_bundle.embedding = (
-                    await embed_model.aget_agg_embedding_from_queries(
-                        query_bundle.embedding_strs
-                    )
+                embedding = await embed_model.aget_agg_embedding_from_queries(
+                    query_bundle.embedding_strs
                 )
-        return await self._aget_nodes_with_embeddings(query_bundle)
+        return await self._aget_nodes_with_embeddings(
+            QueryBundle(query_str=query_bundle.query_str, embedding=embedding)
+        )
 
     def _build_vector_store_query(
         self, query_bundle_with_embeddings: QueryBundle
diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-milvus/llama_index/vector_stores/milvus/base.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-milvus/llama_index/vector_stores/milvus/base.py
index 433b110e2a..645230b34f 100644
--- a/llama-index-integrations/vector_stores/llama-index-vector-stores-milvus/llama_index/vector_stores/milvus/base.py
+++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-milvus/llama_index/vector_stores/milvus/base.py
@@ -8,7 +8,7 @@ import logging
 from typing import Any, Dict, List, Optional, Union
 
 import pymilvus  # noqa
-from llama_index.core.bridge.pydantic import PrivateAttr
+from llama_index.core.bridge.pydantic import Field, PrivateAttr
 from llama_index.core.schema import BaseNode, TextNode
 from llama_index.core.vector_stores.types import (
     BasePydanticVectorStore,
@@ -115,6 +115,7 @@ class MilvusVectorStore(BasePydanticVectorStore):
     consistency_level: str = "Strong"
     overwrite: bool = False
     text_key: Optional[str]
+    output_fields: List[str] = Field(default_factory=list)
     index_config: Optional[dict]
     search_config: Optional[dict]
 
@@ -133,6 +134,7 @@ class MilvusVectorStore(BasePydanticVectorStore):
         consistency_level: str = "Strong",
         overwrite: bool = False,
         text_key: Optional[str] = None,
+        output_fields: Optional[List[str]] = None,
         index_config: Optional[dict] = None,
         search_config: Optional[dict] = None,
         **kwargs: Any,
@@ -146,6 +148,7 @@ class MilvusVectorStore(BasePydanticVectorStore):
             consistency_level=consistency_level,
             overwrite=overwrite,
             text_key=text_key,
+            output_fields=output_fields,
             index_config=index_config if index_config else {},
             search_config=search_config if search_config else {},
         )
@@ -336,9 +339,10 @@ class MilvusVectorStore(BasePydanticVectorStore):
                         "The passed in text_key value does not exist "
                         "in the retrieved entity."
                     )
-                node = TextNode(
-                    text=text,
-                )
+
+                metadata = {key: hit["entity"].get(key) for key in self.output_fields}
+                node = TextNode(text=text, metadata=metadata)
+
             nodes.append(node)
             similarities.append(hit["distance"])
             ids.append(hit["id"])
diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-milvus/pyproject.toml b/llama-index-integrations/vector_stores/llama-index-vector-stores-milvus/pyproject.toml
index e45be5c7a4..427c364340 100644
--- a/llama-index-integrations/vector_stores/llama-index-vector-stores-milvus/pyproject.toml
+++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-milvus/pyproject.toml
@@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
 license = "MIT"
 name = "llama-index-vector-stores-milvus"
 readme = "README.md"
-version = "0.1.7"
+version = "0.1.8"
 
 [tool.poetry.dependencies]
 python = ">=3.8.1,<4.0"
-- 
GitLab