From ac6ab9f6bb5826c95f34945c1d5d15f7b47b0d54 Mon Sep 17 00:00:00 2001 From: Simon Suo <simonsdsuo@gmail.com> Date: Wed, 17 Apr 2024 23:28:37 -0700 Subject: [PATCH] Fix Bedrock KB retriever (#12910) wip --- .../llama_index/retrievers/bedrock/base.py | 6 ++++-- .../llama-index-retrievers-bedrock/pyproject.toml | 2 +- .../tests/test_retrievers_bedrock.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/llama_index/retrievers/bedrock/base.py b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/llama_index/retrievers/bedrock/base.py index 5a7a01848..ad97e3b84 100644 --- a/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/llama_index/retrievers/bedrock/base.py +++ b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/llama_index/retrievers/bedrock/base.py @@ -3,7 +3,7 @@ from typing import List, Optional, Dict, Any from llama_index.core.base.base_retriever import BaseRetriever from llama_index.core.callbacks.base import CallbackManager -from llama_index.core.schema import NodeWithScore, TextNode +from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode from llama_index.core.utilities.aws_utils import get_aws_service_client @@ -71,7 +71,9 @@ class AmazonKnowledgeBasesRetriever(BaseRetriever): self.retrieval_config = retrieval_config super().__init__(callback_manager) - def _retrieve(self, query: str) -> List[NodeWithScore]: + def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + query = query_bundle.query_str + response = self._client.retrieve( retrievalQuery={"text": query.strip()}, knowledgeBaseId=self.knowledge_base_id, diff --git a/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/pyproject.toml b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/pyproject.toml index b5720f6c5..177b04030 100644 --- a/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/pyproject.toml +++ b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/pyproject.toml @@ -27,7 +27,7 @@ license = "MIT" name = "llama-index-retrievers-bedrock" packages = [{include = "llama_index/"}] readme = "README.md" -version = "0.1.0" +version = "0.1.1" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" diff --git a/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/tests/test_retrievers_bedrock.py b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/tests/test_retrievers_bedrock.py index 80b67aabe..f2dd150c8 100644 --- a/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/tests/test_retrievers_bedrock.py +++ b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/tests/test_retrievers_bedrock.py @@ -41,7 +41,7 @@ def test_retrieve(mock_get_aws_service_client): # Call the method being tested query = "Test query" - result = retriever._retrieve(query) + result = retriever.retrieve(query) # Assert the expected output expected_result = [ -- GitLab