Skip to content
Snippets Groups Projects
Unverified Commit ac6ab9f6 authored by Simon Suo's avatar Simon Suo Committed by GitHub
Browse files

Fix Bedrock KB retriever (#12910)

wip
parent 6a72f3a8
No related branches found
No related tags found
No related merge requests found
...@@ -3,7 +3,7 @@ from typing import List, Optional, Dict, Any ...@@ -3,7 +3,7 @@ from typing import List, Optional, Dict, Any
from llama_index.core.base.base_retriever import BaseRetriever from llama_index.core.base.base_retriever import BaseRetriever
from llama_index.core.callbacks.base import CallbackManager from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.schema import NodeWithScore, TextNode from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode
from llama_index.core.utilities.aws_utils import get_aws_service_client from llama_index.core.utilities.aws_utils import get_aws_service_client
...@@ -71,7 +71,9 @@ class AmazonKnowledgeBasesRetriever(BaseRetriever): ...@@ -71,7 +71,9 @@ class AmazonKnowledgeBasesRetriever(BaseRetriever):
self.retrieval_config = retrieval_config self.retrieval_config = retrieval_config
super().__init__(callback_manager) super().__init__(callback_manager)
def _retrieve(self, query: str) -> List[NodeWithScore]: def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
query = query_bundle.query_str
response = self._client.retrieve( response = self._client.retrieve(
retrievalQuery={"text": query.strip()}, retrievalQuery={"text": query.strip()},
knowledgeBaseId=self.knowledge_base_id, knowledgeBaseId=self.knowledge_base_id,
......
...@@ -27,7 +27,7 @@ license = "MIT" ...@@ -27,7 +27,7 @@ license = "MIT"
name = "llama-index-retrievers-bedrock" name = "llama-index-retrievers-bedrock"
packages = [{include = "llama_index/"}] packages = [{include = "llama_index/"}]
readme = "README.md" readme = "README.md"
version = "0.1.0" version = "0.1.1"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = ">=3.8.1,<4.0" python = ">=3.8.1,<4.0"
......
...@@ -41,7 +41,7 @@ def test_retrieve(mock_get_aws_service_client): ...@@ -41,7 +41,7 @@ def test_retrieve(mock_get_aws_service_client):
# Call the method being tested # Call the method being tested
query = "Test query" query = "Test query"
result = retriever._retrieve(query) result = retriever.retrieve(query)
# Assert the expected output # Assert the expected output
expected_result = [ expected_result = [
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment