From 933024d40b33a338ee074ec1e685866b7e1b3ab3 Mon Sep 17 00:00:00 2001 From: Renu Rozera <166179060+rozerarenu@users.noreply.github.com> Date: Thu, 11 Apr 2024 14:10:29 -0700 Subject: [PATCH] Add Amazon Bedrock knowledge base integration as retriever (#12737) --- .../retrievers/bedrock_retriever.ipynb | 100 ++++++++++++ .../llama-index-retrievers-bedrock/.gitignore | 153 ++++++++++++++++++ .../llama-index-retrievers-bedrock/BUILD | 3 + .../llama-index-retrievers-bedrock/Makefile | 17 ++ .../llama-index-retrievers-bedrock/README.md | 19 +++ .../__init__.py | 0 .../llama_index/retrievers/bedrock/BUILD | 1 + .../retrievers/bedrock/__init__.py | 3 + .../llama_index/retrievers/bedrock/base.py | 97 +++++++++++ .../pyproject.toml | 53 ++++++ .../tests/BUILD | 1 + .../tests/__init__.py | 0 .../tests/test_retrievers_bedrock.py | 70 ++++++++ 13 files changed, 517 insertions(+) create mode 100644 docs/docs/examples/retrievers/bedrock_retriever.ipynb create mode 100644 llama-index-integrations/retrievers/llama-index-retrievers-bedrock/.gitignore create mode 100644 llama-index-integrations/retrievers/llama-index-retrievers-bedrock/BUILD create mode 100644 llama-index-integrations/retrievers/llama-index-retrievers-bedrock/Makefile create mode 100644 llama-index-integrations/retrievers/llama-index-retrievers-bedrock/README.md create mode 100644 llama-index-integrations/retrievers/llama-index-retrievers-bedrock/__init__.py create mode 100644 llama-index-integrations/retrievers/llama-index-retrievers-bedrock/llama_index/retrievers/bedrock/BUILD create mode 100644 llama-index-integrations/retrievers/llama-index-retrievers-bedrock/llama_index/retrievers/bedrock/__init__.py create mode 100644 llama-index-integrations/retrievers/llama-index-retrievers-bedrock/llama_index/retrievers/bedrock/base.py create mode 100644 llama-index-integrations/retrievers/llama-index-retrievers-bedrock/pyproject.toml create mode 100644 llama-index-integrations/retrievers/llama-index-retrievers-bedrock/tests/BUILD create mode 100644 llama-index-integrations/retrievers/llama-index-retrievers-bedrock/tests/__init__.py create mode 100644 llama-index-integrations/retrievers/llama-index-retrievers-bedrock/tests/test_retrievers_bedrock.py diff --git a/docs/docs/examples/retrievers/bedrock_retriever.ipynb b/docs/docs/examples/retrievers/bedrock_retriever.ipynb new file mode 100644 index 000000000..646cac49a --- /dev/null +++ b/docs/docs/examples/retrievers/bedrock_retriever.ipynb @@ -0,0 +1,100 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Bedrock (Knowledge Bases)\n", + "\n", + "> [Knowledge bases for Amazon Bedrock](https://aws.amazon.com/bedrock/knowledge-bases/) is an Amazon Web Services (AWS) offering which lets you quickly build RAG applications by using your private data to customize FM response.\n", + "\n", + "> Implementing `RAG` requires organizations to perform several cumbersome steps to convert data into embeddings (vectors), store the embeddings in a specialized vector database, and build custom integrations into the database to search and retrieve text relevant to the user’s query. This can be time-consuming and inefficient.\n", + "\n", + "> With `Knowledge Bases for Amazon Bedrock`, simply point to the location of your data in `Amazon S3`, and `Knowledge Bases for Amazon Bedrock` takes care of the entire ingestion workflow into your vector database. If you do not have an existing vector database, Amazon Bedrock creates an Amazon OpenSearch Serverless vector store for you.\n", + "\n", + "> Knowledge base can be configured through [AWS Console](https://aws.amazon.com/console/) or by using [AWS SDKs](https://aws.amazon.com/developer/tools/).\n", + "\n", + "> In this notebook, we introduce AmazonKnowledgeBasesRetriever - Amazon Bedrock integration in Llama Index via the Retrieve API to retrieve relevant results for a user query from knowledge bases." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Using the Knowledge Base Retriever" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install --upgrade --quiet boto3\n", + "%pip install llama-index" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.retrievers.bedrock import AmazonKnowledgeBasesRetriever\n", + "\n", + "retriever = AmazonKnowledgeBasesRetriever(\n", + " knowledge_base_id=\"<knowledge-base-id>\",\n", + " retrieval_config={\n", + " \"vectorSearchConfiguration\": {\n", + " \"numberOfResults\": 4,\n", + " \"overrideSearchType\": \"HYBRID\",\n", + " # You will need to set up metadata.json file for your data source for filters to work\n", + " # For more info, see: https://aws.amazon.com/blogs/machine-learning/knowledge-bases-for-amazon-bedrock-now-supports-metadata-filtering-to-improve-retrieval-accuracy/\n", + " \"filter\": {\"equals\": {\"key\": \"tag\", \"value\": \"space\"}},\n", + " }\n", + " },\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "retrieved_results = retriever._retrieve(\n", + " \"How big is Milky Way as compared to the entire universe?\"\n", + ")\n", + "\n", + "print(retrieved_results[0].get_content())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The Milky Way is a large spiral galaxy, but in the grand scheme of the universe, it's relatively small. The observable universe is estimated to be about 93 billion light-years in diameter. In comparison, the Milky Way galaxy has a diameter of about 100,000 light-years. So, the Milky Way is just a tiny speck within the vastness of the observable universe. Keep in mind, however, that the universe may extend beyond the observable universe, but our ability to observe it is limited by the speed of light and the age of the universe." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/.gitignore b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/.gitignore new file mode 100644 index 000000000..990c18de2 --- /dev/null +++ b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/.gitignore @@ -0,0 +1,153 @@ +llama_index/_static +.DS_Store +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +bin/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +etc/ +include/ +lib/ +lib64/ +parts/ +sdist/ +share/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +.ruff_cache + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints +notebooks/ + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ +pyvenv.cfg + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# Jetbrains +.idea +modules/ +*.swp + +# VsCode +.vscode + +# pipenv +Pipfile +Pipfile.lock + +# pyright +pyrightconfig.json diff --git a/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/BUILD b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/BUILD new file mode 100644 index 000000000..0896ca890 --- /dev/null +++ b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/BUILD @@ -0,0 +1,3 @@ +poetry_requirements( + name="poetry", +) diff --git a/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/Makefile b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/Makefile new file mode 100644 index 000000000..b9eab05aa --- /dev/null +++ b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/Makefile @@ -0,0 +1,17 @@ +GIT_ROOT ?= $(shell git rev-parse --show-toplevel) + +help: ## Show all Makefile targets. + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}' + +format: ## Run code autoformatters (black). + pre-commit install + git ls-files | xargs pre-commit run black --files + +lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy + pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files + +test: ## Run tests via pytest. + pytest tests + +watch-docs: ## Build and watch documentation. + sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/ diff --git a/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/README.md b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/README.md new file mode 100644 index 000000000..3c85b689a --- /dev/null +++ b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/README.md @@ -0,0 +1,19 @@ +# LlamaIndex Retrievers Integration: Bedrock + +## Knowledge Bases + +> [Knowledge bases for Amazon Bedrock](https://aws.amazon.com/bedrock/knowledge-bases/) is an Amazon Web Services (AWS) offering which lets you quickly build RAG applications by using your private data to customize FM response. + +> Implementing `RAG` requires organizations to perform several cumbersome steps to convert data into embeddings (vectors), store the embeddings in a specialized vector database, and build custom integrations into the database to search and retrieve text relevant to the user’s query. This can be time-consuming and inefficient. + +> With `Knowledge Bases for Amazon Bedrock`, simply point to the location of your data in `Amazon S3`, and `Knowledge Bases for Amazon Bedrock` takes care of the entire ingestion workflow into your vector database. If you do not have an existing vector database, Amazon Bedrock creates an Amazon OpenSearch Serverless vector store for you. + +> Knowledge base can be configured through [AWS Console](https://aws.amazon.com/console/) or by using [AWS SDKs](https://aws.amazon.com/developer/tools/). + +### Notebook + +Explore the retriever using Notebook present at: + +``` +docs/docs/examples/retrievers/bedrock_retriever.ipynb +``` diff --git a/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/__init__.py b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/llama_index/retrievers/bedrock/BUILD b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/llama_index/retrievers/bedrock/BUILD new file mode 100644 index 000000000..db46e8d6c --- /dev/null +++ b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/llama_index/retrievers/bedrock/BUILD @@ -0,0 +1 @@ +python_sources() diff --git a/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/llama_index/retrievers/bedrock/__init__.py b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/llama_index/retrievers/bedrock/__init__.py new file mode 100644 index 000000000..111877409 --- /dev/null +++ b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/llama_index/retrievers/bedrock/__init__.py @@ -0,0 +1,3 @@ +from llama_index.retrievers.bedrock.base import AmazonKnowledgeBasesRetriever + +__all__ = ["AmazonKnowledgeBasesRetriever"] diff --git a/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/llama_index/retrievers/bedrock/base.py b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/llama_index/retrievers/bedrock/base.py new file mode 100644 index 000000000..5a7a01848 --- /dev/null +++ b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/llama_index/retrievers/bedrock/base.py @@ -0,0 +1,97 @@ +"""Bedrock Retriever.""" +from typing import List, Optional, Dict, Any + +from llama_index.core.base.base_retriever import BaseRetriever +from llama_index.core.callbacks.base import CallbackManager +from llama_index.core.schema import NodeWithScore, TextNode +from llama_index.core.utilities.aws_utils import get_aws_service_client + + +class AmazonKnowledgeBasesRetriever(BaseRetriever): + """`Amazon Bedrock Knowledge Bases` retrieval. + + See https://aws.amazon.com/bedrock/knowledge-bases for more info. + + Args: + knowledge_base_id: Knowledge Base ID. + retrieval_config: Configuration for retrieval. + profile_name: The name of the profile in the ~/.aws/credentials + or ~/.aws/config files, which has either access keys or role information + specified. If not specified, the default credential profile or, if on an + EC2 instance, credentials from IMDS will be used. + region_name: The aws region e.g., `us-west-2`. + Fallback to AWS_DEFAULT_REGION env variable or region specified in + ~/.aws/config. + aws_access_key_id: The aws access key id. + aws_secret_access_key: The aws secret access key. + aws_session_token: AWS temporary session token. + + Example: + .. code-block:: python + + from llama_index.retrievers.bedrock import AmazonKnowledgeBasesRetriever + + retriever = AmazonKnowledgeBasesRetriever( + knowledge_base_id="<knowledge-base-id>", + retrieval_config={ + "vectorSearchConfiguration": { + "numberOfResults": 4, + "overrideSearchType": "SEMANTIC", + "filter": { + "equals": { + "key": "tag", + "value": "space" + } + } + } + }, + ) + """ + + def __init__( + self, + knowledge_base_id: str, + retrieval_config: Optional[Dict[str, Any]] = None, + profile_name: Optional[str] = None, + region_name: Optional[str] = None, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, + callback_manager: Optional[CallbackManager] = None, + ): + self._client = get_aws_service_client( + service_name="bedrock-agent-runtime", + profile_name=profile_name, + region_name=region_name, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + ) + self.knowledge_base_id = knowledge_base_id + self.retrieval_config = retrieval_config + super().__init__(callback_manager) + + def _retrieve(self, query: str) -> List[NodeWithScore]: + response = self._client.retrieve( + retrievalQuery={"text": query.strip()}, + knowledgeBaseId=self.knowledge_base_id, + retrievalConfiguration=self.retrieval_config, + ) + results = response["retrievalResults"] + node_with_score = [] + for result in results: + metadata = {} + if "location" in result: + metadata["location"] = result["location"] + if "metadata" in result: + metadata["sourceMetadata"] = result["metadata"] + node_with_score.append( + NodeWithScore( + node=TextNode( + text=result["content"]["text"], + metadata=metadata, + ), + score=result["score"] if "score" in result else 0, + ) + ) + return node_with_score diff --git a/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/pyproject.toml b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/pyproject.toml new file mode 100644 index 000000000..d365e8ee8 --- /dev/null +++ b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/pyproject.toml @@ -0,0 +1,53 @@ +[build-system] +build-backend = "poetry.core.masonry.api" +requires = ["poetry-core"] + +[tool.codespell] +check-filenames = true +check-hidden = true +skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb" + +[tool.llamahub] +contains_example = true +import_path = "llama_index.retrievers.bedrock" + +[tool.llamahub.class_authors] +BedrockRetriever = "AmazonBedrock" + +[tool.mypy] +disallow_untyped_defs = true +exclude = ["_static", "build", "examples", "notebooks", "venv"] +ignore_missing_imports = true +python_version = "3.8" + +[tool.poetry] +authors = ["Your Name <you@example.com>"] +description = "llama-index retrievers bedrock integration" +license = "MIT" +name = "llama-index-retrievers-bedrock" +packages = [{include = "llama_index/"}] +readme = "README.md" +version = "0.1.0" + +[tool.poetry.dependencies] +python = ">=3.8.1,<4.0" +llama-index-core = "^0.10.0" + +[tool.poetry.group.dev.dependencies] +black = {extras = ["jupyter"], version = "<=23.9.1,>=23.7.0"} +codespell = {extras = ["toml"], version = ">=v2.2.6"} +ipython = "8.10.0" +jupyter = "^1.0.0" +mypy = "0.991" +pre-commit = "3.2.0" +pylint = "2.15.10" +pytest = "7.2.1" +pytest-mock = "3.11.1" +ruff = "0.0.292" +tree-sitter-languages = "^1.8.0" +types-Deprecated = ">=0.1.0" +types-PyYAML = "^6.0.12.12" +types-protobuf = "^4.24.0.4" +types-redis = "4.5.5.0" +types-requests = "2.28.11.8" # TODO: unpin when mypy>0.991 +types-setuptools = "67.1.0.0" diff --git a/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/tests/BUILD b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/tests/BUILD new file mode 100644 index 000000000..dabf212d7 --- /dev/null +++ b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/tests/BUILD @@ -0,0 +1 @@ +python_tests() diff --git a/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/tests/__init__.py b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/tests/test_retrievers_bedrock.py b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/tests/test_retrievers_bedrock.py new file mode 100644 index 000000000..80b67aabe --- /dev/null +++ b/llama-index-integrations/retrievers/llama-index-retrievers-bedrock/tests/test_retrievers_bedrock.py @@ -0,0 +1,70 @@ +from unittest.mock import patch, MagicMock + +from llama_index.core.schema import NodeWithScore, TextNode + + +@patch("llama_index.core.utilities.aws_utils.get_aws_service_client") +def test_retrieve(mock_get_aws_service_client): + mock_client = MagicMock() + mock_client.retrieve.return_value = { + "retrievalResults": [ + { + "content": {"text": "This is a test result."}, + "location": "test_location", + "metadata": { + "x-amz-bedrock-kb-source-uri": "s3://bucket/fileName", + "key": "value", + }, + "score": 0.8, + }, + { + "content": {"text": "Another test result."}, + }, + ] + } + mock_get_aws_service_client.return_value = mock_client + knowledge_base_id = "test-knowledge-base-id" + retrieval_config = { + "vectorSearchConfiguration": { + "numberOfResults": 2, + "overrideSearchType": "SEMANTIC", + "filter": {"equals": {"key": "tag", "value": "space"}}, + } + } + from llama_index.retrievers.bedrock import AmazonKnowledgeBasesRetriever + + retriever = AmazonKnowledgeBasesRetriever( + knowledge_base_id=knowledge_base_id, + retrieval_config=retrieval_config, + ) + retriever._client = mock_client + + # Call the method being tested + query = "Test query" + result = retriever._retrieve(query) + + # Assert the expected output + expected_result = [ + NodeWithScore( + node=TextNode( + text="This is a test result.", + metadata={ + "location": "test_location", + "sourceMetadata": { + "x-amz-bedrock-kb-source-uri": "s3://bucket/fileName", + "key": "value", + }, + }, + ), + score=0.8, + ), + NodeWithScore( + node=TextNode(text="Another test result.", metadata={}), score=0.0 + ), + ] + assert result[0].node.text == expected_result[0].node.text + assert result[0].node.metadata == expected_result[0].node.metadata + assert result[0].score == expected_result[0].score + assert result[1].node.text == expected_result[1].node.text + assert result[1].node.metadata == expected_result[1].node.metadata + assert result[1].score == expected_result[1].score -- GitLab