Skip to content
Snippets Groups Projects
Unverified Commit 08e2ce85 authored by dingkun-ldk's avatar dingkun-ldk Committed by GitHub
Browse files

Add dashscope rerank for postprocessor (#13353)

parent 01df13f4
No related branches found
No related tags found
No related merge requests found
Showing
with 395 additions and 0 deletions
llama_index/_static
.DS_Store
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
bin/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
etc/
include/
lib/
lib64/
parts/
sdist/
share/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
.ruff_cache
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
notebooks/
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
pyvenv.cfg
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# Jetbrains
.idea
modules/
*.swp
# VsCode
.vscode
# pipenv
Pipfile
Pipfile.lock
# pyright
pyrightconfig.json
poetry_requirements(
name="poetry",
)
GIT_ROOT ?= $(shell git rev-parse --show-toplevel)
help: ## Show all Makefile targets.
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}'
format: ## Run code autoformatters (black).
pre-commit install
git ls-files | xargs pre-commit run black --files
lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy
pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files
test: ## Run tests via pytest.
pytest tests
watch-docs: ## Build and watch documentation.
sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/
# LlamaIndex Postprocessor Integration: DashScope-Rerank
The `llama-index-postprocessor-dashscope-rerank` package contains LlamaIndex integrations for the `gte-rerank` series models provided by Alibaba Tongyi Laboratory.
## Installation
```shell
pip install --upgrade llama-index llama-index-core llama-index-postprocessor-dashscope-rerank
```
## Setup
**Get started:**
1. Obtain the **API-KEY** from the [Alibaba Cloud ModelStudio platform](https://help.aliyun.com/document_detail/2712195.html?spm=a2c4g.2587460.0.i6).
2. Set **API-KEY**
```shell
export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY
```
**Example:**
```python
from llama_index.core.data_structs import Node
from llama_index.core.schema import NodeWithScore
from llama_index.postprocessor.dashscope_rerank import DashScopeRerank
nodes = [
NodeWithScore(node=Node(text="text1"), score=0.7),
NodeWithScore(node=Node(text="text2"), score=0.8),
]
dashscope_rerank = DashScopeRerank(top_n=5)
results = dashscope_rerank.postprocess_nodes(nodes, query_str="<user query>")
for res in results:
print("Text: ", res.node.get_content(), "Score: ", res.score)
```
**output**
```text
Text: text1 Score: 0.25589250620997755
Text: text2 Score: 0.18071043165292258
```
### Parameters
| Name | Type | Description | Default |
| :--------------: | :----: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :----------: |
| model | `str` | model name | `gte-rerank` |
| top_n | `int` | The number of top documents to be returned in the ranking; if not specified, all candidate documents will be returned. If the specified top_n value exceeds the number of input candidate documents, all documents will be returned. | `3` |
| return_documents | `bool` | Whether to return the original text for each document in the returned sorted result list, with the default value being False. | `False` |
| api_key | `str` | The DashScope api key. | `None` |
from llama_index.postprocessor.dashscope_rerank.base import DashScopeRerank
__all__ = ["DashScopeRerank"]
import os
from typing import List, Optional
from llama_index.core.bridge.pydantic import Field
from llama_index.core.callbacks import CBEventType, EventPayload
from llama_index.core.instrumentation import get_dispatcher
from llama_index.core.instrumentation.events.rerank import (
ReRankStartEvent,
ReRankEndEvent,
)
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle
dispatcher = get_dispatcher()
try:
import dashscope
except ImportError:
raise ImportError("DashScope requires `pip install dashscope`")
class DashScopeRerank(BaseNodePostprocessor):
model: str = Field(description="Dashscope rerank model name.")
top_n: int = Field(description="Top N nodes to return.")
def __init__(
self,
top_n: int = 3,
model: str = "gte-rerank",
return_documents: bool = False,
api_key: Optional[str] = None,
):
try:
api_key = api_key or os.environ["DASHSCOPE_API_KEY"]
except IndexError:
raise ValueError(
"Must pass in dashscope api key or "
"specify via DASHSCOPE_API_KEY environment variable "
)
super().__init__(top_n=top_n, model=model, return_documents=return_documents)
@classmethod
def class_name(cls) -> str:
return "DashScopeRerank"
def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
dispatch_event = dispatcher.get_dispatch_event()
dispatch_event(
ReRankStartEvent(
model=self.model,
top_n=self.top_n,
query=query_bundle,
nodes=nodes,
)
)
if query_bundle is None:
raise ValueError("Missing query bundle in extra info.")
if len(nodes) == 0:
return []
with self.callback_manager.event(
CBEventType.RERANKING,
payload={
EventPayload.NODES: nodes,
EventPayload.MODEL_NAME: self.model,
EventPayload.QUERY_STR: query_bundle.query_str,
EventPayload.TOP_K: self.top_n,
},
) as event:
texts = [
node.node.get_content(metadata_mode=MetadataMode.EMBED)
for node in nodes
]
results = dashscope.TextReRank.call(
model=self.model,
top_n=self.top_n,
query=query_bundle.query_str,
documents=texts,
)
new_nodes = []
for result in results.output.results:
new_node_with_score = NodeWithScore(
node=nodes[result.index].node, score=result.relevance_score
)
new_nodes.append(new_node_with_score)
event.on_end(payload={EventPayload.NODES: new_nodes})
dispatch_event(
ReRankEndEvent(
nodes=new_nodes,
)
)
return new_nodes
[build-system]
build-backend = "poetry.core.masonry.api"
requires = ["poetry-core"]
[tool.codespell]
check-filenames = true
check-hidden = true
skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb"
[tool.llamahub]
contains_example = false
import_path = "llama_index.postprocessor.dashscope_rerank"
[tool.llamahub.class_authors]
DashScopeRerank = "llama-index"
[tool.mypy]
disallow_untyped_defs = true
# Remove venv skip when integrated with pre-commit
exclude = ["_static", "build", "examples", "notebooks", "venv"]
ignore_missing_imports = true
python_version = "3.8"
[tool.poetry]
authors = ["dingkun-ldk <dingkun.ldk@alibaba-inc.com>"]
description = "llama-index postprocessor dashscope-rerank integration"
exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-postprocessor-dashscope-rerank"
readme = "README.md"
version = "0.1.0"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
llama-index-core = "^0.10.0"
dashscope = ">=1.17.1"
[tool.poetry.group.dev.dependencies]
black = {extras = ["jupyter"], version = "<=23.9.1,>=23.7.0"}
codespell = {extras = ["toml"], version = ">=v2.2.6"}
ipython = "8.10.0"
jupyter = "^1.0.0"
mypy = "0.991"
pre-commit = "3.2.0"
pylint = "2.15.10"
pytest = "7.2.1"
pytest-mock = "3.11.1"
ruff = "0.0.292"
tree-sitter-languages = "^1.8.0"
types-Deprecated = ">=0.1.0"
types-PyYAML = "^6.0.12.12"
types-protobuf = "^4.24.0.4"
types-redis = "4.5.5.0"
types-requests = "2.28.11.8"
types-setuptools = "67.1.0.0"
python_tests()
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.postprocessor.dashscope_rerank import DashScopeRerank
def test_class():
names_of_base_classes = [b.__name__ for b in DashScopeRerank.__mro__]
assert BaseNodePostprocessor.__name__ in names_of_base_classes
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment