From ac6ab9f6bb5826c95f34945c1d5d15f7b47b0d54 Mon Sep 17 00:00:00 2001
From: Simon Suo <simonsdsuo@gmail.com>
Date: Wed, 17 Apr 2024 23:28:37 -0700
Subject: [PATCH] Fix Bedrock KB retriever (#12910)

wip
---
 .../llama_index/retrievers/bedrock/base.py                  | 6 ++++--
 .../llama-index-retrievers-bedrock/pyproject.toml           | 2 +-
 .../tests/test_retrievers_bedrock.py                        | 2 +-
 3 files changed, 6 insertions(+), 4 deletions(-)

diff --git a/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/llama_index/retrievers/bedrock/base.py b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/llama_index/retrievers/bedrock/base.py
index 5a7a01848..ad97e3b84 100644
--- a/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/llama_index/retrievers/bedrock/base.py
+++ b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/llama_index/retrievers/bedrock/base.py
@@ -3,7 +3,7 @@ from typing import List, Optional, Dict, Any
 
 from llama_index.core.base.base_retriever import BaseRetriever
 from llama_index.core.callbacks.base import CallbackManager
-from llama_index.core.schema import NodeWithScore, TextNode
+from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode
 from llama_index.core.utilities.aws_utils import get_aws_service_client
 
 
@@ -71,7 +71,9 @@ class AmazonKnowledgeBasesRetriever(BaseRetriever):
         self.retrieval_config = retrieval_config
         super().__init__(callback_manager)
 
-    def _retrieve(self, query: str) -> List[NodeWithScore]:
+    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
+        query = query_bundle.query_str
+
         response = self._client.retrieve(
             retrievalQuery={"text": query.strip()},
             knowledgeBaseId=self.knowledge_base_id,
diff --git a/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/pyproject.toml b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/pyproject.toml
index b5720f6c5..177b04030 100644
--- a/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/pyproject.toml
+++ b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/pyproject.toml
@@ -27,7 +27,7 @@ license = "MIT"
 name = "llama-index-retrievers-bedrock"
 packages = [{include = "llama_index/"}]
 readme = "README.md"
-version = "0.1.0"
+version = "0.1.1"
 
 [tool.poetry.dependencies]
 python = ">=3.8.1,<4.0"
diff --git a/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/tests/test_retrievers_bedrock.py b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/tests/test_retrievers_bedrock.py
index 80b67aabe..f2dd150c8 100644
--- a/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/tests/test_retrievers_bedrock.py
+++ b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/tests/test_retrievers_bedrock.py
@@ -41,7 +41,7 @@ def test_retrieve(mock_get_aws_service_client):
 
     # Call the method being tested
     query = "Test query"
-    result = retriever._retrieve(query)
+    result = retriever.retrieve(query)
 
     # Assert the expected output
     expected_result = [
-- 
GitLab