Skip to content
Snippets Groups Projects
Unverified Commit f805e344 authored by Haotian Zhang's avatar Haotian Zhang Committed by GitHub
Browse files

Init Colbert reranker (#11057)

* Init colbert reranker

* cr

* cr

* cr

* fix

* cr

* cr

* cr

* cr

* cr

* cr
parent 826588cc
No related branches found
No related tags found
No related merge requests found
Showing
with 5712 additions and 0 deletions
%% Cell type:markdown id: tags:
<a href="https://colab.research.google.com/github/run-llama/llama_index/blob/main/docs/examples/node_postprocessor/ColbertRerank.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
%% Cell type:markdown id: tags:
# Colbert Rerank
%% Cell type:markdown id: tags:
If you're opening this Notebook on colab, you will probably need to install LlamaIndex 🦙.
[Colbert](https://github.com/stanford-futuredata/ColBERT): ColBERT is a fast and accurate retrieval model, enabling scalable BERT-based search over large text collections in tens of milliseconds.
This example shows how we use Colbert-V2 model as a reranker.
%% Cell type:code id: tags:
``` python
!pip install llama-index
!pip install llama-index-core
!pip install --quiet transformers torch
!pip install llama-index-embeddings-openai
!pip install llama-index-llms-openai
!pip install llama-index-postprocessor-colbert-rerank
```
%% Cell type:code id: tags:
``` python
from llama_index.core import (
VectorStoreIndex,
SimpleDirectoryReader,
)
```
%% Cell type:markdown id: tags:
Download Data
%% Cell type:code id: tags:
``` python
!mkdir -p 'data/paul_graham/'
!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/paul_graham/paul_graham_essay.txt' -O 'data/paul_graham/paul_graham_essay.txt'
```
%% Cell type:code id: tags:
``` python
import os
os.environ["OPENAI_API_KEY"] = "sk-"
```
%% Cell type:code id: tags:
``` python
# load documents
documents = SimpleDirectoryReader("./data/paul_graham/").load_data()
# build index
index = VectorStoreIndex.from_documents(documents=documents)
```
%% Cell type:markdown id: tags:
#### Retrieve top 10 most relevant nodes, then filter with Colbert Rerank
%% Cell type:code id: tags:
``` python
colbert_reranker = ColbertRerank(
top_n=5,
model="colbert-ir/colbertv2.0",
tokenizer="colbert-ir/colbertv2.0",
keep_retrieval_score=True,
)
query_engine = index.as_query_engine(
similarity_top_k=5,
node_postprocessors=[colbert_reranker],
)
response = query_engine.query(
"What did Sam Altman do in this essay?",
)
```
%% Cell type:code id: tags:
``` python
for node in response.source_nodes:
print(node.id_)
print(node.node.get_content()[:120])
print("reranking score: ", node.score)
print("retrieval score: ", node.node.metadata["retrieval_score"])
print("**********")
```
%% Output
bd5a8323-41bb-4cde-8b2b-2ac69b1e519a
When I was dealing with some urgent problem during YC, there was about a 60% chance it had to do with HN, and a 40% chan
reranking score: 0.6470144987106323
retrieval score: 0.8309415059604792
**********
24c6c722-bfd0-42e0-9e44-663253b79aa2
Now that I could write essays again, I wrote a bunch about topics I'd had stacked up. I kept writing essays through 2020
reranking score: 0.6377773284912109
retrieval score: 0.8053894057548092
**********
e572465c-d48c-48ce-9664-99ddf09cdae6
Much to my surprise, the time I spent working on this stuff was not wasted after all. After we started Y Combinator, I w
reranking score: 0.6206888556480408
retrieval score: 0.8091076626532405
**********
576168dd-98ce-43ee-91d4-fef0fb4368d2
[15] We got 225 applications for the Summer Founders Program, and we were surprised to find that a lot of them were from
reranking score: 0.6143158674240112
retrieval score: 0.8069205604148549
**********
d0f00ad3-b162-49d7-a01f-c513c6c068ad
Up till that point YC had been controlled by the original LLC we four had started. But we wanted YC to last for a long t
reranking score: 0.5917402505874634
retrieval score: 0.8230686425302381
**********
llama_index/_static
.DS_Store
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
bin/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
etc/
include/
lib/
lib64/
parts/
sdist/
share/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
.ruff_cache
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
notebooks/
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
pyvenv.cfg
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# Jetbrains
.idea
modules/
*.swp
# VsCode
.vscode
# pipenv
Pipfile
Pipfile.lock
# pyright
pyrightconfig.json
poetry_requirements(
name="poetry",
)
GIT_ROOT ?= $(shell git rev-parse --show-toplevel)
help: ## Show all Makefile targets.
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}'
format: ## Run code autoformatters (black).
pre-commit install
git ls-files | xargs pre-commit run black --files
lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy
pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files
test: ## Run tests via pytest.
pytest tests
watch-docs: ## Build and watch documentation.
sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/
# LlamaIndex Postprocessor Integration: Colbert Rerank
[Colbert](https://github.com/stanford-futuredata/ColBERT): ColBERT is a fast and accurate retrieval model, enabling scalable BERT-based search over large text collections in tens of milliseconds.
Please `pip install llama-index-postprocessor-colbert-rerank` to install Colbert Rerank package.
from llama_index.postprocessor.colbert_rerank.base import ColbertRerank
__all__ = ["ColbertRerank"]
from typing import Any, List, Optional
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.callbacks import CBEventType, EventPayload
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle
from llama_index.core.utils import infer_torch_device
import torch
from transformers import AutoTokenizer, AutoModel
DEFAULT_COLBERT_MAX_LENGTH = 512
class ColbertRerank(BaseNodePostprocessor):
model: str = Field(description="Colbert model name.")
top_n: int = Field(description="Number of nodes to return sorted by score.")
device: str = Field(
default="cpu",
description="Device to use for sentence transformer.",
)
keep_retrieval_score: bool = Field(
default=False,
description="Whether to keep the retrieval score in metadata.",
)
_model: Any = PrivateAttr()
_tokenizer: Any = PrivateAttr()
def __init__(
self,
top_n: int = 5,
model: str = "colbert-ir/colbertv2.0",
tokenizer: str = "colbert-ir/colbertv2.0",
device: Optional[str] = None,
keep_retrieval_score: Optional[bool] = False,
):
device = infer_torch_device() if device is None else device
self._tokenizer = AutoTokenizer.from_pretrained(tokenizer)
self._model = AutoModel.from_pretrained(model)
super().__init__(
top_n=top_n,
model=model,
tokenizer=tokenizer,
device=device,
keep_retrieval_score=keep_retrieval_score,
)
@classmethod
def class_name(cls) -> str:
return "ColbertRerank"
def _calculate_sim(self, query: str, documents_text_list: List[str]) -> List[float]:
# Query: [batch_size, query_length, embedding_size] -> [batch_size, query_length, 1, embedding_size]
# Document: [batch_size, doc_length, embedding_size] -> [batch_size, 1, doc_length, embedding_size]
query_encoding = self._tokenizer(query, return_tensors="pt")
query_embedding = (
self._model(**query_encoding).last_hidden_state.mean(dim=1).unsqueeze(0)
)
rerank_score_list = []
for document_text in documents_text_list:
document_encoding = self._tokenizer(
document_text, return_tensors="pt", truncation=True, max_length=512
)
document_embedding = self._model(**document_encoding).last_hidden_state
sim_matrix = torch.nn.functional.cosine_similarity(
query_embedding.unsqueeze(2), document_embedding.unsqueeze(1), dim=-1
)
# Take the maximum similarity for each query token (across all document tokens)
# sim_matrix shape: [batch_size, query_length, doc_length]
max_sim_scores, _ = torch.max(sim_matrix, dim=2)
rerank_score_list.append(torch.mean(max_sim_scores, dim=1))
return rerank_score_list
def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
if query_bundle is None:
raise ValueError("Missing query bundle in extra info.")
if len(nodes) == 0:
return []
nodes_text_list = [
str(node.node.get_content(metadata_mode=MetadataMode.EMBED))
for node in nodes
]
with self.callback_manager.event(
CBEventType.RERANKING,
payload={
EventPayload.NODES: nodes,
EventPayload.MODEL_NAME: self.model,
EventPayload.QUERY_STR: query_bundle.query_str,
EventPayload.TOP_K: self.top_n,
},
) as event:
scores = self._calculate_sim(query_bundle.query_str, nodes_text_list)
assert len(scores) == len(nodes)
for node, score in zip(nodes, scores):
if self.keep_retrieval_score:
# keep the retrieval score in metadata
node.node.metadata["retrieval_score"] = node.score
node.score = float(score)
reranked_nodes = sorted(nodes, key=lambda x: -x.score if x.score else 0)[
: self.top_n
]
event.on_end(payload={EventPayload.NODES: reranked_nodes})
return reranked_nodes
[build-system]
build-backend = "poetry.core.masonry.api"
requires = ["poetry-core"]
[tool.codespell]
check-filenames = true
check-hidden = true
# Feel free to un-skip examples, and experimental, you will just need to
# work through many typos (--write-changes and --interactive will help)
skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb"
[tool.llamahub]
contains_example = true
import_path = "llama_index.postprocessor.colbert_rerank"
[tool.llamahub.class_authors]
ColbertRerank = "hatiangzhang"
[tool.mypy]
disallow_untyped_defs = true
# Remove venv skip when integrated with pre-commit
exclude = ["_static", "build", "examples", "notebooks", "venv"]
ignore_missing_imports = true
python_version = "3.8"
[tool.poetry]
authors = ["Your Name <you@example.com>"]
description = "llama-index postprocessor colbert-rerank integration"
license = "MIT"
name = "llama-index-postprocessor-colbert-rerank"
packages = [{include = "llama_index/"}]
readme = "README.md"
version = "0.1.0"
[tool.poetry.dependencies]
python = ">=3.8.1,<3.12"
llama-index-core = "^0.10.0"
torch = "^2.2.0"
transformers = "^4.37.2"
[tool.poetry.group.dev.dependencies]
black = {extras = ["jupyter"], version = "<=23.9.1,>=23.7.0"}
codespell = {extras = ["toml"], version = ">=v2.2.6"}
ipython = "8.10.0"
jupyter = "^1.0.0"
mypy = "0.991"
pre-commit = "3.2.0"
pylint = "2.15.10"
pytest = "7.2.1"
pytest-mock = "3.11.1"
ruff = "0.0.292"
tree-sitter-languages = "^1.8.0"
types-Deprecated = ">=0.1.0"
types-PyYAML = "^6.0.12.12"
types-protobuf = "^4.24.0.4"
types-redis = "4.5.5.0"
types-requests = "2.28.11.8" # TODO: unpin when mypy>0.991
types-setuptools = "67.1.0.0"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment