Skip to content
Snippets Groups Projects
Unverified Commit d18ecc7b authored by DJC's avatar DJC Committed by GitHub
Browse files

SearChain package reproduced using Llama_index library ​ (#11649)

parent a430abae
No related branches found
No related tags found
No related merge requests found
Showing
with 641 additions and 0 deletions
llama_index/_static
.DS_Store
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
bin/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
etc/
include/
lib/
lib64/
parts/
sdist/
share/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
.ruff_cache
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
notebooks/
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
pyvenv.cfg
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# Jetbrains
.idea
modules/
*.swp
# VsCode
.vscode
# pipenv
Pipfile
Pipfile.lock
# pyright
pyrightconfig.json
poetry_requirements(
name="poetry",
)
GIT_ROOT ?= $(shell git rev-parse --show-toplevel)
help: ## Show all Makefile targets.
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}'
format: ## Run code autoformatters (black).
pre-commit install
git ls-files | xargs pre-commit run black --files
lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy
pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files
test: ## Run tests via pytest.
pytest tests
watch-docs: ## Build and watch documentation.
sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/
# LlamaIndex Packs Integration: Searchain
This LlamaPack implements a framework called SearChain, which implements the interaction between LLM and IR in the form of the global reasoning chain called Chain-of-Query (CoQ).
This follows the idea in the paper [Search-in-the-Chain: Towards Accurate, Credible and Traceable Large Language Models for Knowledge-intensive Tasks](https://arxiv.org/abs/2304.14732).
Making content generated by large language models (LLMs) such as ChatGPT accurate, trustworthy, and traceable is critical, especially for knowledge-intensive tasks. Introducing information retrieval (IR) to provide LLM with external knowledge is likely to solve this problem, however, where and how to introduce IR is a big challenge. The SearChain framework generates a global reasoning chain called a Chain of Query (CoQ) for LLM, where each node contains an IR-oriented query and the answer to the query. Second, IR verifies the answer of each node of CoQ, it corrects the answer that is not consistent with the retrieved information when IR gives high confidence, which improves the credibility. Third, LLM can mark its missing knowledge in CoQ and IR can provide this knowledge to LLM. These three operations improve the accuracy of LLM for complex knowledge-intensive tasks in terms of reasoning ability and knowledge. This Pack implements the above🤗!
You can see its use case in the examples folder.
This implementation is adapted from the author's implementation. You can find the official code repository [here](https://github.com/xsc1234/Search-in-the-Chain).
## Code Usage
First, you need to install SearChainpack using the following code,
```python
from llama_index.core.llama_pack import download_llama_pack
download_llama_pack("SearChainPack", "./searchain_pack")
```
Next you can load and initialize a searchain object,
```python
from searchain_pack.base import SearChainPack
searchain = SearChainPack(
data_path="data",
dprtokenizer_path="dpr_reader_multi",
dprmodel_path="dpr_reader_multi",
crossencoder_name_or_path="Quora_cross_encoder",
)
```
Relevant data can be found [here](https://www.kaggle.com/datasets/anastasiajia/searchain/data). You can run searchain using the following method,
```python
start_idx = 0
while not start_idx == -1:
start_idx = execute(
"/hotpotqa/hotpot_dev_fullwiki_v1_line.json", start_idx=start_idx
)
```
%% Cell type:markdown id:cef7fbcf-dcc4-4986-998f-5bd6c058b348 tags:
# An Example of Searchain Application
%% Cell type:markdown id:ae02aac8-ddc3-481f-9b41-a3c52f2ad9b5 tags:
This LlamaPack implements short form the [SearChain paper by Xu et al..](https://arxiv.org/abs/2304.14732)
This implementation is adapted from the author's implementation. You can find the official code repository [here](https://github.com/xsc1234/Search-in-the-Chain).
%% Cell type:markdown id:d500e8af-c685-4f11-b176-dd534c7824e5 tags:
## Load Pack
%% Cell type:code id:a111a49c-e9c2-4a19-96dc-136fc820bbfb tags:
``` python
from llama_index.core.llama_pack import download_llama_pack
download_llama_pack("SearChainPack", "./searchain_pack")
from searchain_pack.base import SearChainPack
```
%% Cell type:markdown id:37739fc7-6df9-44e3-ac46-22471565af36 tags:
## Setup
%% Cell type:code id:d113b360-739e-4f29-b0db-273ba2d65e2a tags:
``` python
searchain = SearChainPack(
data_path="data",
dprtokenizer_path="./model/dpr_reader_multi",
dprmodel_path="./model/dpr_reader_multi",
crossencoder_name_or_path="./model/Quora_cross_encoder",
)
```
%% Cell type:markdown id:7a531d37-6832-40ab-b579-8dc007e1a1e2 tags:
## Excute
%% Cell type:code id:4dd90f1c-720a-4ef0-9a35-d1b54be8cd53 tags:
``` python
start_idx = 0
while not start_idx == -1:
start_idx = execute(
"/hotpotqa/hotpot_dev_fullwiki_v1_line.json", start_idx=start_idx
)
```
python_sources()
from llama_index.packs.searchain.base import SearChainPack
__all__ = ["SearChainPack"]
import regex
import string
import json
from llama_index.core.llama_pack.base import BaseLlamaPack
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
from llama_index.llms.openai import OpenAI
from llama_index.core.llms import ChatMessage
from transformers import DPRReader, DPRReaderTokenizer
import torch
from typing import Any
from sentence_transformers import CrossEncoder
import time
def _normalize_answer(s):
def remove_articles(text):
return regex.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def _match_or_not(prediction, ground_truth):
norm_predict = _normalize_answer(prediction)
norm_answer = _normalize_answer(ground_truth)
return norm_answer in norm_predict
def _have_seen_or_not(model_cross_encoder, query_item, query_seen_list, query_type):
if "Unsolved" in query_type:
return False
for query_seen in query_seen_list:
with torch.no_grad():
if model_cross_encoder.predict([(query_seen, query_item)]) > 0.5:
return True
return False
class SearChainPack(BaseLlamaPack):
"""Simple short form SearChain pack."""
def __init__(
self,
data_path: str,
dprtokenizer_path: str = "/dpr_reader_multi",
dprmodel_path: str = "/dpr_reader_multi",
crossencoder_name_or_path: str = "/Quora_cross_encoder",
device: str = "cuda",
**kwargs: Any,
) -> None:
"""Init params."""
self.device = device
self.crossencoder = CrossEncoder(crossencoder_name_or_path, device=self.device)
self.documents = SimpleDirectoryReader(data_path).load_data()
self.index = VectorStoreIndex.from_documents(self.documents)
self.query_engine = self.index.as_query_engine()
self.llm = OpenAI()
self.dprtokenizer = DPRReaderTokenizer.from_pretrained(dprtokenizer_path)
self.dprmodel = DPRReader.from_pretrained(dprmodel_path)
self.dprmodel.eval()
self.dprmodel.to(self.device)
def _get_answer(self, query, texts, title):
print("texts:" + texts)
encoded_inputs = self.dprtokenizer(
questions=[query],
titles=[title],
texts=[texts],
return_tensors="pt",
max_length=510,
)
outputs = self.dprmodel(**encoded_inputs.to(self.device))
start_logits = outputs.start_logits
end_logits = outputs.end_logits
relevance_logits = outputs.relevance_logits
answer_start_index = outputs.start_logits.argmax()
answer_end_index = outputs.end_logits.argmax()
predict_answer_tokens = encoded_inputs.input_ids[
0, answer_start_index : answer_end_index + 1
]
answer = self.dprtokenizer.decode(predict_answer_tokens)
return answer, relevance_logits
def _ir(self, query, query_seen_list):
flag_contibue_label = False
query_list = query.split("\n")
message = ""
for idx in range(len(query_list)):
query_item = query_list[idx]
if "Query" in query_item and "]:" in query_item:
temp = query_item.split("]")
if len(temp) < 2:
continue
query_type = temp[0]
query_item = temp[1]
if ":" in query_item:
query_item = query_item[1:]
if not _have_seen_or_not(
self.crossencoder, query_item, query_seen_list, query_type
):
now_reference = {}
query_seen_list.append(query_item)
response = str(self.query_engine.query(query_item))
answer, relevance_score = self._get_answer(
query=query_item, texts="", title=response
)
now_reference["query"] = query_item
now_reference["answer"] = answer
now_reference["reference"] = response
now_reference["ref_score"] = relevance_score
now_reference["idx"] = response
if "Unsolved" in query_type:
message = "[Unsolved Query]:{}<SEP>[Answer]:{}<SEP>[Reference]:{}<SEP>".format(
query_item, answer, response
)
flag_contibue_label = True
break
elif relevance_score > 1.5:
answer_start_idx = idx + 1
predict_answer = ""
while answer_start_idx < len(query_list):
if "Answer" in query_list[answer_start_idx]:
predict_answer = query_list[answer_start_idx]
break
answer_start_idx += 1
match_label = _match_or_not(
prediction=predict_answer, ground_truth=answer
)
if match_label:
continue
else:
message = "[Query]:{}<SEP>[Answer]:{}<SEP>[Reference]:{}<SEP>".format(
query_item, answer, response
)
flag_contibue_label = True
break
return message, flag_contibue_label, query_seen_list
def _extract(self, message_keys_list):
text = message_keys_list
idx = len(text)
while idx > 0:
idx = idx - 1
item = text[idx]
if item.role == "assistant" and "Final Content" in item.content:
list_item = item.content.split("\n")
for sp in list_item:
if "Final Content" in sp:
return item.content
return "Sorry, I still cannot solve this question!"
def execute(self, data_path, start_idx):
data = open(data_path)
for k, example in enumerate(data):
if k < start_idx:
continue
example = json.loads(example)
q = example["question"]
round_count = 0
message_keys_list = [
ChatMessage(
role="user",
content="""Construct a global reasoning chain for this complex [Question] : " {} " You should generate a query to the search engine based on what you already know at each step of the reasoning chain, starting with [Query]. If you know the answer for [Query], generate it starting with [Answer]. You can try to generate the final answer for the [Question] by referring to the [Query]-[Answer] pairs, starting with [Final Content]. If you don't know the answer, generate a query to search engine based on what you already know and do not know, starting with [Unsolved Query].
For example:
[Question]: "Where do greyhound buses that are in the birthplace of Spirit If...'s performer leave from? "
[Query 1]: Who is the performer of Spirit If... ?
If you don't know the answer:
[Unsolved Query]: Who is the performer of Spirit If... ?
If you know the answer:
[Answer 1]: The performer of Spirit If... is Kevin Drew.
[Query 2]: Where was Kevin Drew born?
If you don't know the answer:
[Unsolved Query]: Where was Kevin Drew born?
If you know the answer:
[Answer 2]: Toronto.
[Query 3]: Where do greyhound buses in Toronto leave from?
If you don't know the answer:
[Unsolved Query]: Where do greyhound buses in Toronto leave from?
If you know the answer:
[Answer 3]: Toronto Coach Terminal.
[Final Content]: The performer of Spirit If... is Kevin Drew [1]. Kevin Drew was born in Toronto [2]. Greyhound buses in Toronto leave from Toronto Coach Terminal [3]. So the final answer is Toronto Coach Terminal.
[Question]:"Which magazine was started first Arthur’s Magazine or First for Women?"
[Query 1]: When was Arthur’s Magazine started?
[Answer 1]: 1844.
[Query 2]: When was First for Women started?
[Answer 2]: 1989
[Final Content]: Arthur’s Magazine started in 1844 [1]. First for Women started in 1989 [2]. So Arthur’s Magazine was started first. So the answer is Arthur’s Magazi
[Question]: {}""".format(
q, q
),
)
]
feedback_answer = "continue"
predict_answer = ""
query_seen_list = []
while round_count < 5 and feedback_answer != "end":
time.sleep(0.5)
rsp = self.llm.chat(message_keys_list)
round_count += 1
input_str = str(rsp.message.content)
message_keys_list.append(
ChatMessage(role="assistant", content=input_str)
)
predict_answer += input_str
message, flag_contibue_label, query_seen_list = self._ir(
input_str, query_seen_list
)
if flag_contibue_label:
feedback = message
else:
feedback = "end"
if feedback == "end":
break
# [Query]:xxxx<SEP>[Answer]:xxxx<SEP>[Reference]:xxxx<SEP>
feedback_list = feedback.split("<SEP>")
if "Unsolved Query" not in feedback:
new_prompt = """Reference: {} According to this Reference, the answer for "{}" should be "{}", you can change your answer based on the Reference and continue constructing the reasoning chain to give the final answer for [Question]:{}""".format(
feedback_list[0], feedback_list[1], q, feedback_list[2]
)
else:
new_prompt = """Reference: {} According to this Reference, the answer for "{}" should be "{}", you can give your answer based on the Reference and continue constructing the reasoning chain to give the final answer for [Question]:{} """.format(
feedback_list[0], feedback_list[1], q, feedback_list[2]
)
message_keys_list.append(ChatMessage(role="user", content=new_prompt))
result = self._extract(message_keys_list)
print(result)
return -1
[build-system]
build-backend = "poetry.core.masonry.api"
requires = ["poetry-core"]
[tool.codespell]
check-filenames = true
check-hidden = true
# Feel free to un-skip examples, and experimental, you will just need to
# work through many typos (--write-changes and --interactive will help)
skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb"
[tool.llamahub]
contains_example = false
import_path = "llama_index.packs.searchain"
[tool.llamahub.class_authors]
SearChainPack = "DJC-GO-SOLO"
[tool.mypy]
disallow_untyped_defs = true
# Remove venv skip when integrated with pre-commit
exclude = ["_static", "build", "examples", "notebooks", "venv"]
ignore_missing_imports = true
python_version = "3.8"
[tool.poetry]
authors = ["Your Name <you@example.com>"]
description = "llama-index packs searchain integration"
license = "MIT"
name = "llama-index-packs-searchain"
packages = [{include = "llama_index/"}]
readme = "README.md"
version = "0.1.0"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
llama-index-core = "^0.10.0"
torch = "^2.1.2"
transformers = "^4.38.1"
sentence_transformers = "^2.5.1"
[tool.poetry.group.dev.dependencies]
black = {extras = ["jupyter"], version = "<=23.9.1,>=23.7.0"}
codespell = {extras = ["toml"], version = ">=v2.2.6"}
ipython = "8.10.0"
jupyter = "^1.0.0"
mypy = "0.991"
pre-commit = "3.2.0"
pylint = "2.15.10"
pytest = "7.2.1"
pytest-mock = "3.11.1"
ruff = "0.0.292"
tree-sitter-languages = "^1.8.0"
types-Deprecated = ">=0.1.0"
types-PyYAML = "^6.0.12.12"
types-protobuf = "^4.24.0.4"
types-redis = "4.5.5.0"
types-requests = "2.28.11.8" # TODO: unpin when mypy>0.991
types-setuptools = "67.1.0.0"
python_tests()
from llama_index.core.llama_pack import BaseLlamaPack
from llama_index.packs.searchain import SearChainPack
def test_class():
names_of_base_classes = [b.__name__ for b in SearChainPack.__mro__]
assert BaseLlamaPack.__name__ in names_of_base_classes
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment