Skip to content
Snippets Groups Projects
Unverified Commit d740d9d8 authored by Ravi Theja's avatar Ravi Theja Committed by GitHub
Browse files

Add RAFT llamapack (#12084)

* Add RAFT llamapack

* Update readme and usage

* resolve errors

* resolve errors

* Updated node parser based on raft repository

* Add examples
parent be63bae5
No related branches found
No related tags found
No related merge requests found
Showing
with 655 additions and 0 deletions
llama_index/_static
.DS_Store
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
bin/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
etc/
include/
lib/
lib64/
parts/
sdist/
share/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
.ruff_cache
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
notebooks/
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
pyvenv.cfg
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# Jetbrains
.idea
modules/
*.swp
# VsCode
.vscode
# pipenv
Pipfile
Pipfile.lock
# pyright
pyrightconfig.json
poetry_requirements(
name="poetry",
)
GIT_ROOT ?= $(shell git rev-parse --show-toplevel)
help: ## Show all Makefile targets.
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}'
format: ## Run code autoformatters (black).
pre-commit install
git ls-files | xargs pre-commit run black --files
lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy
pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files
test: ## Run tests via pytest.
pytest tests
watch-docs: ## Build and watch documentation.
sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/
# RAFT: Adapting Language Model to Domain Specific RAG Llama Pack
This LlamaPack implements RAFT: Adapting Language Model to Domain Specific RAG [paper](https://arxiv.org/abs/2403.10131)
Retrieval Augmented FineTuning (RAFT) is a training recipe introduced in this paper that aims to improve the performance of large language models (LLMs) in open-book, in-domain question-answering tasks. Given a question and a set of retrieved documents, RAFT trains the LLM to identify and cite verbatim the most relevant sequences from the documents that help answer the question, while ignoring irrelevant or distracting information. By explicitly training the model to distinguish between relevant and irrelevant information and to provide evidence from the relevant documents, RAFT encourages the LLM to develop better reasoning and explanation abilities, ultimately improving its ability to answer questions accurately and rationally in scenarios where additional context or knowledge is available.
### Installation
```bash
pip install llama-index
```
## CLI Usage
You can download llamapacks directly using `llamaindex-cli`, which comes installed with the `llama-index` python package:
```bash
llamaindex-cli download-llamapack RAFTDatasetPack --download-dir ./raft_dataset_pack
```
You can then inspect the files at `./raft_dataset_pack` and use them as a template for your own project.
## Code Usage
You can download the pack to a the `./raft_dataset_pack` directory:
```python
from llama_index.core.llama_pack import download_llama_pack
# download and install dependencies
RAFTDatasetPack = download_llama_pack("RAFTDatasetPack", "./raft_dataset_pack")
# You can use any llama-hub loader to get documents!
raft_dataset = RAFTDatasetPack(file_path)
```
From here, you can use the pack, or inspect and modify the pack in `./raft_dataset_pack`.
The `run()` function contains around logic behind RAFT: Adapting Language Model to Domain Specific RAG [paper](https://arxiv.org/abs/2403.10131)
```python
dataset = raft_dataset.run()
```
This will return the dataset which can be further used for finetuned purpose. Please refer to [original blog](https://techcommunity.microsoft.com/t5/ai-ai-platform-blog/raft-a-new-way-to-teach-llms-to-be-better-at-rag/ba-p/4084674) on using the dataset for fine-tuning.
%% Cell type:markdown id: tags:
# RAFT Dataset LlamaPack
This LlamaPack implements RAFT: Adapting Language Model to Domain Specific RAG [paper](https://arxiv.org/abs/2403.10131)
Retrieval Augmented FineTuning (RAFT) is a training recipe introduced in this paper that aims to improve the performance of large language models (LLMs) in open-book, in-domain question-answering tasks. Given a question and a set of retrieved documents, RAFT trains the LLM to identify and cite verbatim the most relevant sequences from the documents that help answer the question, while ignoring irrelevant or distracting information. By explicitly training the model to distinguish between relevant and irrelevant information and to provide evidence from the relevant documents, RAFT encourages the LLM to develop better reasoning and explanation abilities, ultimately improving its ability to answer questions accurately and rationally in scenarios where additional context or knowledge is available.
%% Cell type:markdown id: tags:
#### Installation
%% Cell type:code id: tags:
``` python
!pip install llama-index
```
%% Cell type:code id: tags:
``` python
import os
os.environ["OPENAI_API_KEY"] = "sk-"
```
%% Cell type:markdown id: tags:
#### Download Data
%% Cell type:code id: tags:
``` python
!wget --user-agent "Mozilla" "https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/paul_graham/paul_graham_essay.txt" -O './paul_graham_essay.txt'
```
%% Cell type:code id: tags:
``` python
from llama_index.packs.raft_dataset import RAFTDatasetPack
```
%% Cell type:code id: tags:
``` python
raft_dataset = RAFTDatasetPack("./paul_graham_essay.txt")
```
%% Cell type:code id: tags:
``` python
# Beware of the costs invloved. This will use GPT-4 for dataset creation.
# It will also take long time based on the file size.
dataset = raft_dataset.run()
```
%% Cell type:markdown id: tags:
The above dataset is HuggingFace Dataset format. You can then save it into `.arrow` or `.jsonl` format and use it for further finetuning.
%% Cell type:code id: tags:
``` python
output_path = "<output path>"
# Save as .arrow format
dataset.save_to_disk(output_path)
# Save as .jsonl format
dataset.to_json(output_path + ".jsonl")
```
%% Cell type:markdown id: tags:
#### You can refer to the original implementation [here](https://github.com/ShishirPatil/gorilla/tree/main/raft)
python_sources()
from llama_index.packs.raft_dataset.base import RAFTDatasetPack
__all__ = ["RAFTDatasetPack"]
"""RAFT Dataset LlamaPack class."""
# Inspired from https://github.com/ShishirPatil/gorilla/tree/main/raft
from typing import Any, List
import random
from datasets import Dataset
from llama_index.core.llama_pack.base import BaseLlamaPack
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SemanticSplitterNodeParser
from llama_index.core.llms import ChatMessage
from llama_index.llms.openai import OpenAI
from llama_index.embeddings.openai import OpenAIEmbedding
DEFAULT_CHUNK_SIZE = 512
DEFAULT_BREAKPOINT_PERCENTILE_THRESHOLD = 95
class RAFTDatasetPack(BaseLlamaPack):
"""RAFT Dataset Generator pack."""
def __init__(
self,
file_path: str,
llm: Any = None,
embed_model: Any = None,
num_questions_per_chunk: int = 5,
num_distract_docs: int = 3,
chunk_size: int = DEFAULT_CHUNK_SIZE,
default_breakpoint_percentile_threshold=DEFAULT_BREAKPOINT_PERCENTILE_THRESHOLD,
**kwargs: Any,
):
self.file_path = file_path
self.num_questions_per_chunk = num_questions_per_chunk
self.num_distract_docs = num_distract_docs
self.chunk_size = chunk_size
self.default_breakpoint_percentile_threshold = (
default_breakpoint_percentile_threshold
)
self.ds = None
self.llm = OpenAI(temperature=0, n=1, model="gpt-4") if llm is None else llm
self.embed_model = OpenAIEmbedding() if embed_model is None else embed_model
def strip_str(self, s) -> str:
"""
Helper function for helping format strings returned by GPT-4.
"""
if s.startswith("assistant:"): # Check if the string starts with 'assistant '
s = s.replace("assistant:", "", 1) # Replace the first occurrence
start_index, end_index = 0, len(s) - 1
beg_found = False
for i in range(len(s)):
if s[i].isalpha():
if not beg_found:
start_index = i
beg_found = True
else:
end_index = i
end_index += 2
return s[start_index : min(end_index, len(s))]
def encode_question_gen(self, question, chunk) -> List[str]:
"""
Encode multiple prompt instructions into a single string for the general case.
"""
prompt = f"""
Question: {question}\nContext: {chunk}\n
Answer this question using the information given in the context above. Here is things to pay attention to:
- First provide step-by-step reasoning on how to answer the question.
- In the reasoning, if you need to copy paste some sentences from the context, include them in ##begin_quote## and ##end_quote##. This would mean that things outside of ##begin_quote## and ##end_quote## are not directly copy paste from the context.
- End your response with final answer in the form <ANSWER>: $answer, the answer should be succinct.
"""
return [
ChatMessage(
role="system",
content="You are a helpful question answerer who can provide an answer given a question and relevant context.",
),
ChatMessage(role="user", content=prompt),
]
def generate_label(self, question, context) -> str:
"""
Generates the label / answer to `question` using `context` and GPT-4.
"""
question_messages = self.encode_question_gen(question, context)
response = self.llm.chat(question_messages)
return str(response)
def generate_instructions_gen(self, chunk, x=5) -> List[str]:
"""
Generates `x` questions / use cases for `chunk`. Used when the input document is of general types
`pdf`, `json`, or `txt`.
"""
messages = [
ChatMessage(
role="system",
content="You are a synthetic question-answer pair generator. Given a chunk of context about some topic(s), generate %s example questions a user could ask and would be answered using information from the chunk. For example, if the given context was a Wikipedia paragraph about the United States, an example question could be 'How many states are in the United States?'"
% (x),
),
ChatMessage(
role="system",
content="The questions should be able to be answered in a few words or less.",
),
ChatMessage(role="user", content=str(chunk)),
]
queries = str(self.llm.chat(messages)).split("\n")
queries = [self.strip_str(q) for q in queries]
return [q for q in queries if any(c.isalpha() for c in q)]
def get_chunks(self, file_path: str, chunk_size: int) -> List[str]:
"""
Takes in a `file_path`, retrieves the document, breaks it down into chunks of size
`chunk_size`, and returns the chunks.
"""
chunks = []
documents = SimpleDirectoryReader(input_files=[file_path]).load_data()
# TODO: Should be changed to SemanticSplitterNodeParser
splitter = SemanticSplitterNodeParser(
buffer_size=1,
breakpoint_percentile_threshold=self.default_breakpoint_percentile_threshold,
embed_model=self.embed_model,
)
nodes = splitter.get_nodes_from_documents(documents)
return [node.get_content() for node in nodes]
def add_chunk_to_dataset(
self,
chunks: List,
chunk: str,
x: int = 5,
num_distract: int = 3,
p: float = 1.0,
):
"""
Given a chunk, create {Q, A, D} triplets and add them to the dataset.
"""
i = chunks.index(chunk)
qs = self.generate_instructions_gen(chunk, x)
for q in qs:
datapt = {
"id": None,
"type": None,
"question": None,
"context": None,
"oracle_context": None,
"cot_answer": None,
}
datapt["id"] = f"seed_task_{0 if not self.ds else self.ds.num_rows}"
datapt["type"] = "general"
datapt["question"] = q
# add 4 distractor docs
docs = [chunk]
indices = list(range(len(chunks)))
indices.remove(i)
for j in random.sample(indices, num_distract):
docs.append(chunks[j])
# decides whether to add oracle document
oracle = random.uniform(0, 1) < p
if not oracle:
docs[0] = chunks[random.sample(indices, 1)[0]]
random.shuffle(docs)
d = {"title": [], "sentences": []}
d["title"].append(["placeholder_title"] * (num_distract + 1))
d["sentences"].append(docs)
datapt["context"] = d
datapt["oracle_context"] = chunk
# add answer to q
datapt["cot_answer"] = self.generate_label(q, chunk)
# construct model instruction
context = ""
for doc in docs:
context += "<DOCUMENT>" + str(doc) + "</DOCUMENT>\n"
context += q
datapt["instruction"] = context
# add to dataset
if not self.ds:
# init ds
datapt["id"] = [datapt["id"]]
datapt["type"] = [datapt["type"]]
datapt["question"] = [datapt["question"]]
datapt["context"] = [datapt["context"]]
datapt["oracle_context"] = [datapt["oracle_context"]]
datapt["cot_answer"] = [datapt["cot_answer"]]
datapt["instruction"] = [datapt["instruction"]]
self.ds = Dataset.from_dict(datapt)
else:
self.ds = self.ds.add_item(datapt)
def run(self, *args: Any, **kwargs: Any) -> Any:
"""Run the pipeline."""
chunks = self.get_chunks(self.file_path, self.chunk_size)
self.num_distract_docs = (
min(self.num_distract_docs, len(chunks)) - 1
) # should be less than number of chunks/ nodes created
for chunk in chunks:
self.add_chunk_to_dataset(
chunks, chunk, self.num_questions_per_chunk, self.num_distract_docs
)
return self.ds
[build-system]
build-backend = "poetry.core.masonry.api"
requires = ["poetry-core"]
[tool.codespell]
check-filenames = true
check-hidden = true
skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb"
[tool.llamahub]
contains_example = true
import_path = "llama_index.packs.raft_dataset"
[tool.llamahub.class_authors]
RAFTDatasetPack = "llama-index"
[tool.mypy]
disallow_untyped_defs = true
exclude = ["_static", "build", "examples", "notebooks", "venv"]
ignore_missing_imports = true
python_version = "3.8"
[tool.poetry]
authors = ["Ravi Theja <ravi03071991@gmail.com>"]
description = "llama-index packs RAFT Dataset paper implementation"
exclude = ["**/BUILD"]
keywords = ["finetuning", "raft", "raft_dataset"]
license = "MIT"
maintainers = ["ravi-theja"]
name = "llama-index-packs-raft-dataset"
readme = "README.md"
version = "0.1.1"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
llama-index-core = "^0.10.0"
datasets = "^2.18.0"
[tool.poetry.group.dev.dependencies]
ipython = "8.10.0"
jupyter = "^1.0.0"
mypy = "0.991"
pre-commit = "3.2.0"
pylint = "2.15.10"
pytest = "7.2.1"
pytest-mock = "3.11.1"
ruff = "0.0.292"
tree-sitter-languages = "^1.8.0"
types-Deprecated = ">=0.1.0"
types-PyYAML = "^6.0.12.12"
types-protobuf = "^4.24.0.4"
types-redis = "4.5.5.0"
types-requests = "2.28.11.8"
types-setuptools = "67.1.0.0"
[tool.poetry.group.dev.dependencies.black]
extras = ["jupyter"]
version = "<=23.9.1,>=23.7.0"
[tool.poetry.group.dev.dependencies.codespell]
extras = ["toml"]
version = ">=v2.2.6"
[[tool.poetry.packages]]
include = "llama_index/"
python_tests()
from llama_index.core.llama_pack import BaseLlamaPack
from llama_index.packs.raft_dataset import RAFTDatasetPack
def test_class():
names_of_base_classes = [b.__name__ for b in RAFTDatasetPack.__mro__]
assert BaseLlamaPack.__name__ in names_of_base_classes
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment