diff --git a/llama-index-integrations/postprocessor/llama-index-postprocessor-dashscope-rerank/.gitignore b/llama-index-integrations/postprocessor/llama-index-postprocessor-dashscope-rerank/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..990c18de229088f55c6c514fd0f2d49981d1b0e7 --- /dev/null +++ b/llama-index-integrations/postprocessor/llama-index-postprocessor-dashscope-rerank/.gitignore @@ -0,0 +1,153 @@ +llama_index/_static +.DS_Store +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +bin/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +etc/ +include/ +lib/ +lib64/ +parts/ +sdist/ +share/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +.ruff_cache + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints +notebooks/ + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ +pyvenv.cfg + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# Jetbrains +.idea +modules/ +*.swp + +# VsCode +.vscode + +# pipenv +Pipfile +Pipfile.lock + +# pyright +pyrightconfig.json diff --git a/llama-index-integrations/postprocessor/llama-index-postprocessor-dashscope-rerank/BUILD b/llama-index-integrations/postprocessor/llama-index-postprocessor-dashscope-rerank/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..0896ca890d8bffd60a44fa824f8d57fecd73ee53 --- /dev/null +++ b/llama-index-integrations/postprocessor/llama-index-postprocessor-dashscope-rerank/BUILD @@ -0,0 +1,3 @@ +poetry_requirements( + name="poetry", +) diff --git a/llama-index-integrations/postprocessor/llama-index-postprocessor-dashscope-rerank/Makefile b/llama-index-integrations/postprocessor/llama-index-postprocessor-dashscope-rerank/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..b9eab05aa370629a4a3de75df3ff64cd53887b68 --- /dev/null +++ b/llama-index-integrations/postprocessor/llama-index-postprocessor-dashscope-rerank/Makefile @@ -0,0 +1,17 @@ +GIT_ROOT ?= $(shell git rev-parse --show-toplevel) + +help: ## Show all Makefile targets. + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}' + +format: ## Run code autoformatters (black). + pre-commit install + git ls-files | xargs pre-commit run black --files + +lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy + pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files + +test: ## Run tests via pytest. + pytest tests + +watch-docs: ## Build and watch documentation. + sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/ diff --git a/llama-index-integrations/postprocessor/llama-index-postprocessor-dashscope-rerank/README.md b/llama-index-integrations/postprocessor/llama-index-postprocessor-dashscope-rerank/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8d4ad74f03969ebd3a4100743ff405df9548c460 --- /dev/null +++ b/llama-index-integrations/postprocessor/llama-index-postprocessor-dashscope-rerank/README.md @@ -0,0 +1,54 @@ +# LlamaIndex Postprocessor Integration: DashScope-Rerank + +The `llama-index-postprocessor-dashscope-rerank` package contains LlamaIndex integrations for the `gte-rerank` series models provided by Alibaba Tongyi Laboratory. + +## Installation + +```shell +pip install --upgrade llama-index llama-index-core llama-index-postprocessor-dashscope-rerank +``` + +## Setup + +**Get started:** + +1. Obtain the **API-KEY** from the [Alibaba Cloud ModelStudio platform](https://help.aliyun.com/document_detail/2712195.html?spm=a2c4g.2587460.0.i6). +2. Set **API-KEY** + +```shell +export DASHSCOPE_API_KEY=YOUR_DASHSCOPE_API_KEY +``` + +**Example:** + +```python +from llama_index.core.data_structs import Node +from llama_index.core.schema import NodeWithScore +from llama_index.postprocessor.dashscope_rerank import DashScopeRerank + +nodes = [ + NodeWithScore(node=Node(text="text1"), score=0.7), + NodeWithScore(node=Node(text="text2"), score=0.8), +] + +dashscope_rerank = DashScopeRerank(top_n=5) +results = dashscope_rerank.postprocess_nodes(nodes, query_str="<user query>") +for res in results: + print("Text: ", res.node.get_content(), "Score: ", res.score) +``` + +**output** + +```text +Text: text1 Score: 0.25589250620997755 +Text: text2 Score: 0.18071043165292258 +``` + +### Parameters + +| Name | Type | Description | Default | +| :--------------: | :----: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :----------: | +| model | `str` | model name | `gte-rerank` | +| top_n | `int` | The number of top documents to be returned in the ranking; if not specified, all candidate documents will be returned. If the specified top_n value exceeds the number of input candidate documents, all documents will be returned. | `3` | +| return_documents | `bool` | Whether to return the original text for each document in the returned sorted result list, with the default value being False. | `False` | +| api_key | `str` | The DashScope api key. | `None` | diff --git a/llama-index-integrations/postprocessor/llama-index-postprocessor-dashscope-rerank/llama_index/postprocessor/dashscope_rerank/BUILD b/llama-index-integrations/postprocessor/llama-index-postprocessor-dashscope-rerank/llama_index/postprocessor/dashscope_rerank/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..db46e8d6c978c67e301dd6c47bee08c1b3fd141c --- /dev/null +++ b/llama-index-integrations/postprocessor/llama-index-postprocessor-dashscope-rerank/llama_index/postprocessor/dashscope_rerank/BUILD @@ -0,0 +1 @@ +python_sources() diff --git a/llama-index-integrations/postprocessor/llama-index-postprocessor-dashscope-rerank/llama_index/postprocessor/dashscope_rerank/__init__.py b/llama-index-integrations/postprocessor/llama-index-postprocessor-dashscope-rerank/llama_index/postprocessor/dashscope_rerank/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cf2ff60244fe542aa8dfd44af32aa289563c2be9 --- /dev/null +++ b/llama-index-integrations/postprocessor/llama-index-postprocessor-dashscope-rerank/llama_index/postprocessor/dashscope_rerank/__init__.py @@ -0,0 +1,4 @@ +from llama_index.postprocessor.dashscope_rerank.base import DashScopeRerank + + +__all__ = ["DashScopeRerank"] diff --git a/llama-index-integrations/postprocessor/llama-index-postprocessor-dashscope-rerank/llama_index/postprocessor/dashscope_rerank/base.py b/llama-index-integrations/postprocessor/llama-index-postprocessor-dashscope-rerank/llama_index/postprocessor/dashscope_rerank/base.py new file mode 100644 index 0000000000000000000000000000000000000000..54b93a8b82ba90e3c7f4e4068deb289c61b093d7 --- /dev/null +++ b/llama-index-integrations/postprocessor/llama-index-postprocessor-dashscope-rerank/llama_index/postprocessor/dashscope_rerank/base.py @@ -0,0 +1,100 @@ +import os +from typing import List, Optional + +from llama_index.core.bridge.pydantic import Field +from llama_index.core.callbacks import CBEventType, EventPayload +from llama_index.core.instrumentation import get_dispatcher +from llama_index.core.instrumentation.events.rerank import ( + ReRankStartEvent, + ReRankEndEvent, +) +from llama_index.core.postprocessor.types import BaseNodePostprocessor +from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle + +dispatcher = get_dispatcher() + +try: + import dashscope +except ImportError: + raise ImportError("DashScope requires `pip install dashscope`") + + +class DashScopeRerank(BaseNodePostprocessor): + model: str = Field(description="Dashscope rerank model name.") + top_n: int = Field(description="Top N nodes to return.") + + def __init__( + self, + top_n: int = 3, + model: str = "gte-rerank", + return_documents: bool = False, + api_key: Optional[str] = None, + ): + try: + api_key = api_key or os.environ["DASHSCOPE_API_KEY"] + except IndexError: + raise ValueError( + "Must pass in dashscope api key or " + "specify via DASHSCOPE_API_KEY environment variable " + ) + + super().__init__(top_n=top_n, model=model, return_documents=return_documents) + + @classmethod + def class_name(cls) -> str: + return "DashScopeRerank" + + def _postprocess_nodes( + self, + nodes: List[NodeWithScore], + query_bundle: Optional[QueryBundle] = None, + ) -> List[NodeWithScore]: + dispatch_event = dispatcher.get_dispatch_event() + + dispatch_event( + ReRankStartEvent( + model=self.model, + top_n=self.top_n, + query=query_bundle, + nodes=nodes, + ) + ) + + if query_bundle is None: + raise ValueError("Missing query bundle in extra info.") + if len(nodes) == 0: + return [] + + with self.callback_manager.event( + CBEventType.RERANKING, + payload={ + EventPayload.NODES: nodes, + EventPayload.MODEL_NAME: self.model, + EventPayload.QUERY_STR: query_bundle.query_str, + EventPayload.TOP_K: self.top_n, + }, + ) as event: + texts = [ + node.node.get_content(metadata_mode=MetadataMode.EMBED) + for node in nodes + ] + results = dashscope.TextReRank.call( + model=self.model, + top_n=self.top_n, + query=query_bundle.query_str, + documents=texts, + ) + new_nodes = [] + for result in results.output.results: + new_node_with_score = NodeWithScore( + node=nodes[result.index].node, score=result.relevance_score + ) + new_nodes.append(new_node_with_score) + event.on_end(payload={EventPayload.NODES: new_nodes}) + + dispatch_event( + ReRankEndEvent( + nodes=new_nodes, + ) + ) + return new_nodes diff --git a/llama-index-integrations/postprocessor/llama-index-postprocessor-dashscope-rerank/pyproject.toml b/llama-index-integrations/postprocessor/llama-index-postprocessor-dashscope-rerank/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..b4c9d4499392f38e76a483b1f8ca5224526aa218 --- /dev/null +++ b/llama-index-integrations/postprocessor/llama-index-postprocessor-dashscope-rerank/pyproject.toml @@ -0,0 +1,55 @@ +[build-system] +build-backend = "poetry.core.masonry.api" +requires = ["poetry-core"] + +[tool.codespell] +check-filenames = true +check-hidden = true +skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb" + +[tool.llamahub] +contains_example = false +import_path = "llama_index.postprocessor.dashscope_rerank" + +[tool.llamahub.class_authors] +DashScopeRerank = "llama-index" + +[tool.mypy] +disallow_untyped_defs = true +# Remove venv skip when integrated with pre-commit +exclude = ["_static", "build", "examples", "notebooks", "venv"] +ignore_missing_imports = true +python_version = "3.8" + +[tool.poetry] +authors = ["dingkun-ldk <dingkun.ldk@alibaba-inc.com>"] +description = "llama-index postprocessor dashscope-rerank integration" +exclude = ["**/BUILD"] +license = "MIT" +name = "llama-index-postprocessor-dashscope-rerank" +readme = "README.md" +version = "0.1.0" + +[tool.poetry.dependencies] +python = ">=3.8.1,<4.0" +llama-index-core = "^0.10.0" +dashscope = ">=1.17.1" + +[tool.poetry.group.dev.dependencies] +black = {extras = ["jupyter"], version = "<=23.9.1,>=23.7.0"} +codespell = {extras = ["toml"], version = ">=v2.2.6"} +ipython = "8.10.0" +jupyter = "^1.0.0" +mypy = "0.991" +pre-commit = "3.2.0" +pylint = "2.15.10" +pytest = "7.2.1" +pytest-mock = "3.11.1" +ruff = "0.0.292" +tree-sitter-languages = "^1.8.0" +types-Deprecated = ">=0.1.0" +types-PyYAML = "^6.0.12.12" +types-protobuf = "^4.24.0.4" +types-redis = "4.5.5.0" +types-requests = "2.28.11.8" +types-setuptools = "67.1.0.0" diff --git a/llama-index-integrations/postprocessor/llama-index-postprocessor-dashscope-rerank/tests/BUILD b/llama-index-integrations/postprocessor/llama-index-postprocessor-dashscope-rerank/tests/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..dabf212d7e7162849c24a733909ac4f645d75a31 --- /dev/null +++ b/llama-index-integrations/postprocessor/llama-index-postprocessor-dashscope-rerank/tests/BUILD @@ -0,0 +1 @@ +python_tests() diff --git a/llama-index-integrations/postprocessor/llama-index-postprocessor-dashscope-rerank/tests/__init__.py b/llama-index-integrations/postprocessor/llama-index-postprocessor-dashscope-rerank/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/llama-index-integrations/postprocessor/llama-index-postprocessor-dashscope-rerank/tests/test_postprocessor_dashscope_rerank.py b/llama-index-integrations/postprocessor/llama-index-postprocessor-dashscope-rerank/tests/test_postprocessor_dashscope_rerank.py new file mode 100644 index 0000000000000000000000000000000000000000..e69cbfc57862bbaa1165801370780c70c2ed81fb --- /dev/null +++ b/llama-index-integrations/postprocessor/llama-index-postprocessor-dashscope-rerank/tests/test_postprocessor_dashscope_rerank.py @@ -0,0 +1,7 @@ +from llama_index.core.postprocessor.types import BaseNodePostprocessor +from llama_index.postprocessor.dashscope_rerank import DashScopeRerank + + +def test_class(): + names_of_base_classes = [b.__name__ for b in DashScopeRerank.__mro__] + assert BaseNodePostprocessor.__name__ in names_of_base_classes