diff --git a/llama-index-core/llama_index/core/indices/vector_store/retrievers/retriever.py b/llama-index-core/llama_index/core/indices/vector_store/retrievers/retriever.py index 526dcc91684be67efbf4e5874cb532b487f7ac2e..3232b9c06a973ddf8127aa755ec04355f071f301 100644 --- a/llama-index-core/llama_index/core/indices/vector_store/retrievers/retriever.py +++ b/llama-index-core/llama_index/core/indices/vector_store/retrievers/retriever.py @@ -102,15 +102,16 @@ class VectorIndexRetriever(BaseRetriever): @dispatcher.span async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + embedding = query_bundle.embedding if self._vector_store.is_embedding_query: if query_bundle.embedding is None and len(query_bundle.embedding_strs) > 0: embed_model = self._embed_model - query_bundle.embedding = ( - await embed_model.aget_agg_embedding_from_queries( - query_bundle.embedding_strs - ) + embedding = await embed_model.aget_agg_embedding_from_queries( + query_bundle.embedding_strs ) - return await self._aget_nodes_with_embeddings(query_bundle) + return await self._aget_nodes_with_embeddings( + QueryBundle(query_str=query_bundle.query_str, embedding=embedding) + ) def _build_vector_store_query( self, query_bundle_with_embeddings: QueryBundle diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-milvus/llama_index/vector_stores/milvus/base.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-milvus/llama_index/vector_stores/milvus/base.py index 433b110e2ad596334d71575eda059ff9ab40c013..645230b34ff9707bcb545e6c7c34fc9f086c2e48 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-milvus/llama_index/vector_stores/milvus/base.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-milvus/llama_index/vector_stores/milvus/base.py @@ -8,7 +8,7 @@ import logging from typing import Any, Dict, List, Optional, Union import pymilvus # noqa -from llama_index.core.bridge.pydantic import PrivateAttr +from llama_index.core.bridge.pydantic import Field, PrivateAttr from llama_index.core.schema import BaseNode, TextNode from llama_index.core.vector_stores.types import ( BasePydanticVectorStore, @@ -115,6 +115,7 @@ class MilvusVectorStore(BasePydanticVectorStore): consistency_level: str = "Strong" overwrite: bool = False text_key: Optional[str] + output_fields: List[str] = Field(default_factory=list) index_config: Optional[dict] search_config: Optional[dict] @@ -133,6 +134,7 @@ class MilvusVectorStore(BasePydanticVectorStore): consistency_level: str = "Strong", overwrite: bool = False, text_key: Optional[str] = None, + output_fields: Optional[List[str]] = None, index_config: Optional[dict] = None, search_config: Optional[dict] = None, **kwargs: Any, @@ -146,6 +148,7 @@ class MilvusVectorStore(BasePydanticVectorStore): consistency_level=consistency_level, overwrite=overwrite, text_key=text_key, + output_fields=output_fields, index_config=index_config if index_config else {}, search_config=search_config if search_config else {}, ) @@ -336,9 +339,10 @@ class MilvusVectorStore(BasePydanticVectorStore): "The passed in text_key value does not exist " "in the retrieved entity." ) - node = TextNode( - text=text, - ) + + metadata = {key: hit["entity"].get(key) for key in self.output_fields} + node = TextNode(text=text, metadata=metadata) + nodes.append(node) similarities.append(hit["distance"]) ids.append(hit["id"]) diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-milvus/pyproject.toml b/llama-index-integrations/vector_stores/llama-index-vector-stores-milvus/pyproject.toml index e45be5c7a4ffe6da3e98899b71acc3bd58ff87cc..427c364340234c88b6dac4f604cce1f65471e969 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-milvus/pyproject.toml +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-milvus/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-vector-stores-milvus" readme = "README.md" -version = "0.1.7" +version = "0.1.8" [tool.poetry.dependencies] python = ">=3.8.1,<4.0"