diff --git a/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/llama_index/retrievers/bedrock/base.py b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/llama_index/retrievers/bedrock/base.py index 5a7a0184854cbd4eb82c5fa43b07d56a2f929c84..ad97e3b84b0022baae8b8e09f11a4354ab39c0d8 100644 --- a/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/llama_index/retrievers/bedrock/base.py +++ b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/llama_index/retrievers/bedrock/base.py @@ -3,7 +3,7 @@ from typing import List, Optional, Dict, Any from llama_index.core.base.base_retriever import BaseRetriever from llama_index.core.callbacks.base import CallbackManager -from llama_index.core.schema import NodeWithScore, TextNode +from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode from llama_index.core.utilities.aws_utils import get_aws_service_client @@ -71,7 +71,9 @@ class AmazonKnowledgeBasesRetriever(BaseRetriever): self.retrieval_config = retrieval_config super().__init__(callback_manager) - def _retrieve(self, query: str) -> List[NodeWithScore]: + def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + query = query_bundle.query_str + response = self._client.retrieve( retrievalQuery={"text": query.strip()}, knowledgeBaseId=self.knowledge_base_id, diff --git a/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/pyproject.toml b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/pyproject.toml index b5720f6c54708ee280f6831af18e5df3219f659c..177b04030bfd81223ea7893ca787c21826717228 100644 --- a/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/pyproject.toml +++ b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/pyproject.toml @@ -27,7 +27,7 @@ license = "MIT" name = "llama-index-retrievers-bedrock" packages = [{include = "llama_index/"}] readme = "README.md" -version = "0.1.0" +version = "0.1.1" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" diff --git a/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/tests/test_retrievers_bedrock.py b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/tests/test_retrievers_bedrock.py index 80b67aabe8fd586d7f10993964f1e52b0df62487..f2dd150c8088a37d1d718ec570c315588bb90232 100644 --- a/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/tests/test_retrievers_bedrock.py +++ b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/tests/test_retrievers_bedrock.py @@ -41,7 +41,7 @@ def test_retrieve(mock_get_aws_service_client): # Call the method being tested query = "Test query" - result = retriever._retrieve(query) + result = retriever.retrieve(query) # Assert the expected output expected_result = [