Skip to content
Snippets Groups Projects
Unverified Commit ebbf1a7e authored by Ravi Theja's avatar Ravi Theja Committed by GitHub
Browse files

Add Corrective RAG LlamaPack (#10715)

* Add Corrective RAG LlamaPack

* Update with review suggestions

* Remove requirements.txt file

* Update BUILD

* Update examples

* Testing done
parent 448584c8
No related branches found
No related tags found
No related merge requests found
Showing
with 5187 additions and 0 deletions
llama_index/_static
.DS_Store
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
bin/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
etc/
include/
lib/
lib64/
parts/
sdist/
share/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
.ruff_cache
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
notebooks/
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
pyvenv.cfg
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# Jetbrains
.idea
modules/
*.swp
# VsCode
.vscode
# pipenv
Pipfile
Pipfile.lock
# pyright
pyrightconfig.json
python_sources()
poetry_requirements(
name="poetry",
)
GIT_ROOT ?= $(shell git rev-parse --show-toplevel)
help: ## Show all Makefile targets.
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}'
format: ## Run code autoformatters (black).
pre-commit install
git ls-files | xargs pre-commit run black --files
lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy
pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files
test: ## Run tests via pytest.
pytest tests
watch-docs: ## Build and watch documentation.
sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/
# Corrective Retrieval Augmented Generation Llama Pack
This LlamaPack implements the Corrective Retrieval Augmented Generation (CRAG) [paper](https://arxiv.org/pdf/2401.15884.pdf)
Corrective Retrieval Augmented Generation (CRAG) is a method designed to enhance the robustness of language model generation by evaluating and augmenting the relevance of retrieved documents through a an evaluator and large-scale web searches, ensuring more accurate and reliable information is used in generation.
This LlamaPack uses [Tavily AI](https://app.tavily.com/home) API for web-searches. So, we recommend you to get the api-key before proceeding further.
### Installation
```bash
pip install llama-index llama-index-tools-tavily-research
```
## CLI Usage
You can download llamapacks directly using `llamaindex-cli`, which comes installed with the `llama-index` python package:
```bash
llamaindex-cli download-llamapack CorrectiveRAGPack --download-dir ./corrective_rag_pack
```
You can then inspect the files at `./corrective_rag_pack` and use them as a template for your own project.
## Code Usage
You can download the pack to a the `./corrective_rag_pack` directory:
```python
from llama_index.core.llama_pack import download_llama_pack
# download and install dependencies
CorrectiveRAGPack = download_llama_pack(
"CorrectiveRAGPack", "./corrective_rag_pack"
)
# You can use any llama-hub loader to get documents!
corrective_rag = CorrectiveRAGPack(documents, tavily_ai_api_key)
```
From here, you can use the pack, or inspect and modify the pack in `./corrective_rag_pack`.
The `run()` function contains around logic behind Corrective Retrieval Augmented Generation - [CRAG](https://arxiv.org/pdf/2401.15884.pdf) paper.
```python
response = corrective_rag.run("<query>", similarity_top_k=2)
```
python_sources()
%% Cell type:markdown id: tags:
# Corrective RAG Pack
This notebook walks through using the `CorrectiveRAGPack` based on the paper [Corrective Retrieval Augmented Generation](https://arxiv.org/abs/2401.15884).
A brief understanding of the paper:
Corrective Retrieval Augmented Generation (CRAG) is a method designed to enhance the robustness of language model generation by evaluating and augmenting the relevance of retrieved documents through a an evaluator and large-scale web searches, ensuring more accurate and reliable information is used in generation.
We use `GPT-4` as a relevancy evaluator and `Tavily AI` for web searches. So, we recommend getting `OPENAI_API_KEY` and `tavily_ai_api_key` before proceeding further.
%% Cell type:markdown id: tags:
## Setup
%% Cell type:code id: tags:
``` python
%pip install llama-index-llms-openai llama-index-tools-tavily-research llama-index-embeddings-openai
%pip install llama-index-packs-corrective-rag
%pip install llama-index-readers-file
```
%% Cell type:code id: tags:
``` python
import os
os.environ["OPENAI_API_KEY"] = "YOUR OPENAI API KEY"
tavily_ai_api_key = "<tavily_ai_api_key>"
import nest_asyncio
nest_asyncio.apply()
```
%% Cell type:markdown id: tags:
## Download `Llama2` paper.
%% Cell type:code id: tags:
``` python
!mkdir -p 'data/'
!curl 'https://arxiv.org/pdf/2307.09288.pdf' -o 'data/llama2.pdf'
```
%% Output
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
100 13.0M 100 13.0M 0 0 7397k 0 0:00:01 0:00:01 --:--:-- 7415k
%% Cell type:code id: tags:
``` python
from llama_index.core import SimpleDirectoryReader
documents = SimpleDirectoryReader("data").load_data()
```
%% Cell type:markdown id: tags:
## Run the CorrectiveRAGPack
%% Cell type:code id: tags:
``` python
from llama_index.packs.corrective_rag import CorrectiveRAGPack
corrective_rag_pack = CorrectiveRAGPack(documents, tavily_ai_apikey=tavily_ai_api_key)
```
%% Cell type:markdown id: tags:
## Test Queries
%% Cell type:code id: tags:
``` python
from IPython.display import Markdown, display
response = corrective_rag_pack.run("How was Llama2 pretrained?")
display(Markdown(str(response)))
```
%% Output
Llama2 was pretrained using an optimized auto-regressive transformer. Several changes were made to improve performance, including more robust data cleaning, updated data mixes, training on 40% more total tokens, doubling the context length, and using grouped-query attention (GQA) to improve inference scalability for larger models.
%% Cell type:code id: tags:
``` python
from IPython.display import Markdown, display
response = corrective_rag_pack.run(
"What is the functionality of latest ChatGPT memory."
)
display(Markdown(str(response)))
```
%% Output
The latest ChatGPT memory update allows the chatbot to carry what it learns between chats, enabling it to provide more relevant responses. Users can ask ChatGPT to remember specific information or let it pick up details itself. This memory feature allows ChatGPT to remember and forget things based on user input, enhancing user interaction with personalized continuity and emotional intelligence. Users can toggle this function on or off within the ChatGPT settings menu, and when disabled, ChatGPT will neither generate nor access memories.
# Required Environment Variables: OPENAI_API_KEY
# Required TavilyAI API KEY for web searches - https://tavily.com/
from llama_index.core import SimpleDirectoryReader
from llama_index.core.llama_pack import download_llama_pack
# download and install dependencies
CorrectiveRAGPack = download_llama_pack("CorrectiveRAGPack", "./corrective_rag_pack")
# load documents
documents = SimpleDirectoryReader("./data").load_data()
# uses the LLM to extract propositions from every document/node!
corrective_rag = CorrectiveRAGPack(documents, tavily_ai_apikey="<tavily_ai_apikey>")
# run the pack
response = corrective_rag.run("<Query>")
print(response)
python_sources()
from llama_index.packs.corrective_rag.base import CorrectiveRAGPack
__all__ = ["CorrectiveRAGPack"]
"""Corrective RAG LlamaPack class."""
from typing import Any, Dict, List
from llama_index.core import VectorStoreIndex, SummaryIndex
from llama_index.core.llama_pack.base import BaseLlamaPack
from llama_index.llms.openai import OpenAI
from llama_index.core.schema import Document, NodeWithScore
from llama_index.core.query_pipeline.query import QueryPipeline
from llama_index.tools.tavily_research.base import TavilyToolSpec
from llama_index.core.prompts import PromptTemplate
DEFAULT_RELEVANCY_PROMPT_TEMPLATE = PromptTemplate(
template="""As a grader, your task is to evaluate the relevance of a document retrieved in response to a user's question.
Retrieved Document:
-------------------
{context_str}
User Question:
--------------
{query_str}
Evaluation Criteria:
- Consider whether the document contains keywords or topics related to the user's question.
- The evaluation should not be overly stringent; the primary objective is to identify and filter out clearly irrelevant retrievals.
Decision:
- Assign a binary score to indicate the document's relevance.
- Use 'yes' if the document is relevant to the question, or 'no' if it is not.
Please provide your binary score ('yes' or 'no') below to indicate the document's relevance to the user question."""
)
DEFAULT_TRANSFORM_QUERY_TEMPLATE = PromptTemplate(
template="""Your task is to refine a query to ensure it is highly effective for retrieving relevant search results. \n
Analyze the given input to grasp the core semantic intent or meaning. \n
Original Query:
\n ------- \n
{query_str}
\n ------- \n
Your goal is to rephrase or enhance this query to improve its search performance. Ensure the revised query is concise and directly aligned with the intended search objective. \n
Respond with the optimized query only:"""
)
class CorrectiveRAGPack(BaseLlamaPack):
def __init__(self, documents: List[Document], tavily_ai_apikey: str) -> None:
"""Init params."""
llm = OpenAI(model="gpt-4")
self.relevancy_pipeline = QueryPipeline(
chain=[DEFAULT_RELEVANCY_PROMPT_TEMPLATE, llm]
)
self.transform_query_pipeline = QueryPipeline(
chain=[DEFAULT_TRANSFORM_QUERY_TEMPLATE, llm]
)
self.llm = llm
self.index = VectorStoreIndex.from_documents(documents)
self.tavily_tool = TavilyToolSpec(api_key=tavily_ai_apikey)
def get_modules(self) -> Dict[str, Any]:
"""Get modules."""
return {"llm": self.llm, "index": self.index}
def retrieve_nodes(self, query_str: str, **kwargs: Any) -> List[NodeWithScore]:
"""Retrieve the relevant nodes for the query."""
retriever = self.index.as_retriever(**kwargs)
return retriever.retrieve(query_str)
def evaluate_relevancy(
self, retrieved_nodes: List[Document], query_str: str
) -> List[str]:
"""Evaluate relevancy of retrieved documents with the query."""
relevancy_results = []
for node in retrieved_nodes:
relevancy = self.relevancy_pipeline.run(
context_str=node.text, query_str=query_str
)
relevancy_results.append(relevancy.message.content.lower().strip())
return relevancy_results
def extract_relevant_texts(
self, retrieved_nodes: List[NodeWithScore], relevancy_results: List[str]
) -> str:
"""Extract relevant texts from retrieved documents."""
relevant_texts = [
retrieved_nodes[i].text
for i, result in enumerate(relevancy_results)
if result == "yes"
]
return "\n".join(relevant_texts)
def search_with_transformed_query(self, query_str: str) -> str:
"""Search the transformed query with Tavily API."""
search_results = self.tavily_tool.search(query_str, max_results=5)
return "\n".join([result.text for result in search_results])
def get_result(self, relevant_text: str, search_text: str, query_str: str) -> Any:
"""Get result with relevant text."""
documents = [Document(text=relevant_text + "\n" + search_text)]
index = SummaryIndex.from_documents(documents)
query_engine = index.as_query_engine()
return query_engine.query(query_str)
def run(self, query_str: str, **kwargs: Any) -> Any:
"""Run the pipeline."""
# Retrieve nodes based on the input query string.
retrieved_nodes = self.retrieve_nodes(query_str, **kwargs)
# Evaluate the relevancy of each retrieved document in relation to the query string.
relevancy_results = self.evaluate_relevancy(retrieved_nodes, query_str)
# Extract texts from documents that are deemed relevant based on the evaluation.
relevant_text = self.extract_relevant_texts(retrieved_nodes, relevancy_results)
# Initialize search_text variable to handle cases where it might not get defined.
search_text = ""
# If any document is found irrelevant, transform the query string for better search results.
if "no" in relevancy_results:
transformed_query_str = self.transform_query_pipeline.run(
query_str=query_str
).message.content
# Conduct a search with the transformed query string and collect the results.
search_text = self.search_with_transformed_query(transformed_query_str)
# Compile the final result. If there's additional search text from the transformed query,
# it's included; otherwise, only the relevant text from the initial retrieval is returned.
if search_text:
return self.get_result(relevant_text, search_text, query_str)
else:
return self.get_result(relevant_text, "", query_str)
This diff is collapsed.
[build-system]
build-backend = "poetry.core.masonry.api"
requires = ["poetry-core"]
[tool.codespell]
check-filenames = true
check-hidden = true
# Feel free to un-skip examples, and experimental, you will just need to
# work through many typos (--write-changes and --interactive will help)
skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb"
[tool.llamahub]
classes = ["CorrectiveRAGPack"]
contains_example = false
import_path = "llama_index.packs.corrective_rag"
[tool.mypy]
disallow_untyped_defs = true
# Remove venv skip when integrated with pre-commit
exclude = ["_static", "build", "examples", "notebooks", "venv"]
ignore_missing_imports = true
python_version = "3.8"
[tool.poetry]
authors = ["Ravi Theja <ravi03071991@gmail.com>"]
description = "llama-index packs corrective_rag paper implementation"
keywords = ["corrective", "corrective_rag", "crag", "rag", "retrieve"]
license = "MIT"
maintainers = ["ravi-theja"]
name = "llama-index-packs-corrective-rag"
packages = [{include = "llama_index/"}]
readme = "README.md"
version = "0.1.0"
[tool.poetry.dependencies]
python = ">=3.8.1,<3.12"
llama-index-core = "^0.10.0"
tavily-python = "^0.3.1"
[tool.poetry.group.dev.dependencies]
black = {extras = ["jupyter"], version = "<=23.9.1,>=23.7.0"}
codespell = {extras = ["toml"], version = ">=v2.2.6"}
ipython = "8.10.0"
jupyter = "^1.0.0"
mypy = "0.991"
pre-commit = "3.2.0"
pylint = "2.15.10"
pytest = "7.2.1"
pytest-mock = "3.11.1"
ruff = "0.0.292"
tree-sitter-languages = "^1.8.0"
types-Deprecated = ">=0.1.0"
types-PyYAML = "^6.0.12.12"
types-protobuf = "^4.24.0.4"
types-redis = "4.5.5.0"
types-requests = "2.28.11.8" # TODO: unpin when mypy>0.991
types-setuptools = "67.1.0.0"
python_tests()
from llama_index.core.llama_pack import BaseLlamaPack
from llama_index.packs.corrective_rag import CorrectiveRAGPack
def test_class():
names_of_base_classes = [b.__name__ for b in CorrectiveRAGPack.__mro__]
assert BaseLlamaPack.__name__ in names_of_base_classes
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment