From d18ecc7b63246939b569f82525bfb68eda419e5f Mon Sep 17 00:00:00 2001 From: DJC <85006613+DJC-GO-SOLO@users.noreply.github.com> Date: Sat, 9 Mar 2024 14:25:15 +0800 Subject: [PATCH] =?UTF-8?q?SearChain=20package=20reproduced=20using=20Llam?= =?UTF-8?q?a=5Findex=20library=20=E2=80=8B=20(#11649)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../llama-index-packs-searchain/.gitignore | 153 +++++++++++ .../llama-index-packs-searchain/BUILD | 3 + .../llama-index-packs-searchain/Makefile | 17 ++ .../llama-index-packs-searchain/README.md | 44 ++++ .../examples/searchain.ipynb | 108 ++++++++ .../llama_index/packs/searchain/BUILD | 1 + .../llama_index/packs/searchain/__init__.py | 4 + .../llama_index/packs/searchain/base.py | 244 ++++++++++++++++++ .../pyproject.toml | 59 +++++ .../llama-index-packs-searchain/tests/BUILD | 1 + .../tests/__init__.py | 0 .../tests/test_packs_searchain.py | 7 + 12 files changed, 641 insertions(+) create mode 100644 llama-index-packs/llama-index-packs-searchain/.gitignore create mode 100644 llama-index-packs/llama-index-packs-searchain/BUILD create mode 100644 llama-index-packs/llama-index-packs-searchain/Makefile create mode 100644 llama-index-packs/llama-index-packs-searchain/README.md create mode 100644 llama-index-packs/llama-index-packs-searchain/examples/searchain.ipynb create mode 100644 llama-index-packs/llama-index-packs-searchain/llama_index/packs/searchain/BUILD create mode 100644 llama-index-packs/llama-index-packs-searchain/llama_index/packs/searchain/__init__.py create mode 100644 llama-index-packs/llama-index-packs-searchain/llama_index/packs/searchain/base.py create mode 100644 llama-index-packs/llama-index-packs-searchain/pyproject.toml create mode 100644 llama-index-packs/llama-index-packs-searchain/tests/BUILD create mode 100644 llama-index-packs/llama-index-packs-searchain/tests/__init__.py create mode 100644 llama-index-packs/llama-index-packs-searchain/tests/test_packs_searchain.py diff --git a/llama-index-packs/llama-index-packs-searchain/.gitignore b/llama-index-packs/llama-index-packs-searchain/.gitignore new file mode 100644 index 0000000000..990c18de22 --- /dev/null +++ b/llama-index-packs/llama-index-packs-searchain/.gitignore @@ -0,0 +1,153 @@ +llama_index/_static +.DS_Store +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +bin/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +etc/ +include/ +lib/ +lib64/ +parts/ +sdist/ +share/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +.ruff_cache + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints +notebooks/ + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ +pyvenv.cfg + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# Jetbrains +.idea +modules/ +*.swp + +# VsCode +.vscode + +# pipenv +Pipfile +Pipfile.lock + +# pyright +pyrightconfig.json diff --git a/llama-index-packs/llama-index-packs-searchain/BUILD b/llama-index-packs/llama-index-packs-searchain/BUILD new file mode 100644 index 0000000000..0896ca890d --- /dev/null +++ b/llama-index-packs/llama-index-packs-searchain/BUILD @@ -0,0 +1,3 @@ +poetry_requirements( + name="poetry", +) diff --git a/llama-index-packs/llama-index-packs-searchain/Makefile b/llama-index-packs/llama-index-packs-searchain/Makefile new file mode 100644 index 0000000000..b9eab05aa3 --- /dev/null +++ b/llama-index-packs/llama-index-packs-searchain/Makefile @@ -0,0 +1,17 @@ +GIT_ROOT ?= $(shell git rev-parse --show-toplevel) + +help: ## Show all Makefile targets. + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}' + +format: ## Run code autoformatters (black). + pre-commit install + git ls-files | xargs pre-commit run black --files + +lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy + pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files + +test: ## Run tests via pytest. + pytest tests + +watch-docs: ## Build and watch documentation. + sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/ diff --git a/llama-index-packs/llama-index-packs-searchain/README.md b/llama-index-packs/llama-index-packs-searchain/README.md new file mode 100644 index 0000000000..daaaabe9dd --- /dev/null +++ b/llama-index-packs/llama-index-packs-searchain/README.md @@ -0,0 +1,44 @@ +# LlamaIndex Packs Integration: Searchain + +This LlamaPack implements a framework called SearChain, which implements the interaction between LLM and IR in the form of the global reasoning chain called Chain-of-Query (CoQ). + +This follows the idea in the paper [Search-in-the-Chain: Towards Accurate, Credible and Traceable Large Language Models for Knowledge-intensive Tasks](https://arxiv.org/abs/2304.14732). + +Making content generated by large language models (LLMs) such as ChatGPT accurate, trustworthy, and traceable is critical, especially for knowledge-intensive tasks. Introducing information retrieval (IR) to provide LLM with external knowledge is likely to solve this problem, however, where and how to introduce IR is a big challenge. The SearChain framework generates a global reasoning chain called a Chain of Query (CoQ) for LLM, where each node contains an IR-oriented query and the answer to the query. Second, IR verifies the answer of each node of CoQ, it corrects the answer that is not consistent with the retrieved information when IR gives high confidence, which improves the credibility. Third, LLM can mark its missing knowledge in CoQ and IR can provide this knowledge to LLM. These three operations improve the accuracy of LLM for complex knowledge-intensive tasks in terms of reasoning ability and knowledge. This Pack implements the above🤗! + +You can see its use case in the examples folder. + +This implementation is adapted from the author's implementation. You can find the official code repository [here](https://github.com/xsc1234/Search-in-the-Chain). + +## Code Usage + +First, you need to install SearChainpack using the following code, + +```python +from llama_index.core.llama_pack import download_llama_pack + +download_llama_pack("SearChainPack", "./searchain_pack") +``` + +Next you can load and initialize a searchain object, + +```python +from searchain_pack.base import SearChainPack + +searchain = SearChainPack( + data_path="data", + dprtokenizer_path="dpr_reader_multi", + dprmodel_path="dpr_reader_multi", + crossencoder_name_or_path="Quora_cross_encoder", +) +``` + +Relevant data can be found [here](https://www.kaggle.com/datasets/anastasiajia/searchain/data). You can run searchain using the following method, + +```python +start_idx = 0 +while not start_idx == -1: + start_idx = execute( + "/hotpotqa/hotpot_dev_fullwiki_v1_line.json", start_idx=start_idx + ) +``` diff --git a/llama-index-packs/llama-index-packs-searchain/examples/searchain.ipynb b/llama-index-packs/llama-index-packs-searchain/examples/searchain.ipynb new file mode 100644 index 0000000000..c5ceea0813 --- /dev/null +++ b/llama-index-packs/llama-index-packs-searchain/examples/searchain.ipynb @@ -0,0 +1,108 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "cef7fbcf-dcc4-4986-998f-5bd6c058b348", + "metadata": {}, + "source": [ + "# An Example of Searchain Application" + ] + }, + { + "cell_type": "markdown", + "id": "ae02aac8-ddc3-481f-9b41-a3c52f2ad9b5", + "metadata": {}, + "source": [ + "This LlamaPack implements short form the [SearChain paper by Xu et al..](https://arxiv.org/abs/2304.14732)\n", + "\n", + "This implementation is adapted from the author's implementation. You can find the official code repository [here](https://github.com/xsc1234/Search-in-the-Chain)." + ] + }, + { + "cell_type": "markdown", + "id": "d500e8af-c685-4f11-b176-dd534c7824e5", + "metadata": {}, + "source": [ + "## Load Pack" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a111a49c-e9c2-4a19-96dc-136fc820bbfb", + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.core.llama_pack import download_llama_pack\n", + "\n", + "download_llama_pack(\"SearChainPack\", \"./searchain_pack\")\n", + "from searchain_pack.base import SearChainPack" + ] + }, + { + "cell_type": "markdown", + "id": "37739fc7-6df9-44e3-ac46-22471565af36", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d113b360-739e-4f29-b0db-273ba2d65e2a", + "metadata": {}, + "outputs": [], + "source": [ + "searchain = SearChainPack(\n", + " data_path=\"data\",\n", + " dprtokenizer_path=\"./model/dpr_reader_multi\",\n", + " dprmodel_path=\"./model/dpr_reader_multi\",\n", + " crossencoder_name_or_path=\"./model/Quora_cross_encoder\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "7a531d37-6832-40ab-b579-8dc007e1a1e2", + "metadata": {}, + "source": [ + "## Excute" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4dd90f1c-720a-4ef0-9a35-d1b54be8cd53", + "metadata": {}, + "outputs": [], + "source": [ + "start_idx = 0\n", + "while not start_idx == -1:\n", + " start_idx = execute(\n", + " \"/hotpotqa/hotpot_dev_fullwiki_v1_line.json\", start_idx=start_idx\n", + " )" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/llama-index-packs/llama-index-packs-searchain/llama_index/packs/searchain/BUILD b/llama-index-packs/llama-index-packs-searchain/llama_index/packs/searchain/BUILD new file mode 100644 index 0000000000..db46e8d6c9 --- /dev/null +++ b/llama-index-packs/llama-index-packs-searchain/llama_index/packs/searchain/BUILD @@ -0,0 +1 @@ +python_sources() diff --git a/llama-index-packs/llama-index-packs-searchain/llama_index/packs/searchain/__init__.py b/llama-index-packs/llama-index-packs-searchain/llama_index/packs/searchain/__init__.py new file mode 100644 index 0000000000..467b3ed046 --- /dev/null +++ b/llama-index-packs/llama-index-packs-searchain/llama_index/packs/searchain/__init__.py @@ -0,0 +1,4 @@ +from llama_index.packs.searchain.base import SearChainPack + + +__all__ = ["SearChainPack"] diff --git a/llama-index-packs/llama-index-packs-searchain/llama_index/packs/searchain/base.py b/llama-index-packs/llama-index-packs-searchain/llama_index/packs/searchain/base.py new file mode 100644 index 0000000000..f6bdf4c8ed --- /dev/null +++ b/llama-index-packs/llama-index-packs-searchain/llama_index/packs/searchain/base.py @@ -0,0 +1,244 @@ +import regex +import string +import json +from llama_index.core.llama_pack.base import BaseLlamaPack +from llama_index.core import VectorStoreIndex, SimpleDirectoryReader +from llama_index.llms.openai import OpenAI +from llama_index.core.llms import ChatMessage +from transformers import DPRReader, DPRReaderTokenizer +import torch +from typing import Any +from sentence_transformers import CrossEncoder +import time + + +def _normalize_answer(s): + def remove_articles(text): + return regex.sub(r"\b(a|an|the)\b", " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def _match_or_not(prediction, ground_truth): + norm_predict = _normalize_answer(prediction) + norm_answer = _normalize_answer(ground_truth) + return norm_answer in norm_predict + + +def _have_seen_or_not(model_cross_encoder, query_item, query_seen_list, query_type): + if "Unsolved" in query_type: + return False + for query_seen in query_seen_list: + with torch.no_grad(): + if model_cross_encoder.predict([(query_seen, query_item)]) > 0.5: + return True + return False + + +class SearChainPack(BaseLlamaPack): + """Simple short form SearChain pack.""" + + def __init__( + self, + data_path: str, + dprtokenizer_path: str = "/dpr_reader_multi", + dprmodel_path: str = "/dpr_reader_multi", + crossencoder_name_or_path: str = "/Quora_cross_encoder", + device: str = "cuda", + **kwargs: Any, + ) -> None: + """Init params.""" + self.device = device + self.crossencoder = CrossEncoder(crossencoder_name_or_path, device=self.device) + self.documents = SimpleDirectoryReader(data_path).load_data() + self.index = VectorStoreIndex.from_documents(self.documents) + self.query_engine = self.index.as_query_engine() + self.llm = OpenAI() + + self.dprtokenizer = DPRReaderTokenizer.from_pretrained(dprtokenizer_path) + self.dprmodel = DPRReader.from_pretrained(dprmodel_path) + self.dprmodel.eval() + self.dprmodel.to(self.device) + + def _get_answer(self, query, texts, title): + print("texts:" + texts) + encoded_inputs = self.dprtokenizer( + questions=[query], + titles=[title], + texts=[texts], + return_tensors="pt", + max_length=510, + ) + outputs = self.dprmodel(**encoded_inputs.to(self.device)) + start_logits = outputs.start_logits + end_logits = outputs.end_logits + relevance_logits = outputs.relevance_logits + + answer_start_index = outputs.start_logits.argmax() + answer_end_index = outputs.end_logits.argmax() + predict_answer_tokens = encoded_inputs.input_ids[ + 0, answer_start_index : answer_end_index + 1 + ] + answer = self.dprtokenizer.decode(predict_answer_tokens) + return answer, relevance_logits + + def _ir(self, query, query_seen_list): + flag_contibue_label = False + query_list = query.split("\n") + message = "" + for idx in range(len(query_list)): + query_item = query_list[idx] + if "Query" in query_item and "]:" in query_item: + temp = query_item.split("]") + if len(temp) < 2: + continue + query_type = temp[0] + query_item = temp[1] + if ":" in query_item: + query_item = query_item[1:] + if not _have_seen_or_not( + self.crossencoder, query_item, query_seen_list, query_type + ): + now_reference = {} + query_seen_list.append(query_item) + response = str(self.query_engine.query(query_item)) + + answer, relevance_score = self._get_answer( + query=query_item, texts="", title=response + ) + now_reference["query"] = query_item + now_reference["answer"] = answer + now_reference["reference"] = response + now_reference["ref_score"] = relevance_score + now_reference["idx"] = response + + if "Unsolved" in query_type: + message = "[Unsolved Query]:{}<SEP>[Answer]:{}<SEP>[Reference]:{}<SEP>".format( + query_item, answer, response + ) + flag_contibue_label = True + break + elif relevance_score > 1.5: + answer_start_idx = idx + 1 + predict_answer = "" + while answer_start_idx < len(query_list): + if "Answer" in query_list[answer_start_idx]: + predict_answer = query_list[answer_start_idx] + break + answer_start_idx += 1 + match_label = _match_or_not( + prediction=predict_answer, ground_truth=answer + ) + if match_label: + continue + else: + message = "[Query]:{}<SEP>[Answer]:{}<SEP>[Reference]:{}<SEP>".format( + query_item, answer, response + ) + flag_contibue_label = True + break + return message, flag_contibue_label, query_seen_list + + def _extract(self, message_keys_list): + text = message_keys_list + idx = len(text) + while idx > 0: + idx = idx - 1 + item = text[idx] + if item.role == "assistant" and "Final Content" in item.content: + list_item = item.content.split("\n") + for sp in list_item: + if "Final Content" in sp: + return item.content + return "Sorry, I still cannot solve this question!" + + def execute(self, data_path, start_idx): + data = open(data_path) + for k, example in enumerate(data): + if k < start_idx: + continue + example = json.loads(example) + q = example["question"] + round_count = 0 + message_keys_list = [ + ChatMessage( + role="user", + content="""Construct a global reasoning chain for this complex [Question] : " {} " You should generate a query to the search engine based on what you already know at each step of the reasoning chain, starting with [Query]. If you know the answer for [Query], generate it starting with [Answer]. You can try to generate the final answer for the [Question] by referring to the [Query]-[Answer] pairs, starting with [Final Content]. If you don't know the answer, generate a query to search engine based on what you already know and do not know, starting with [Unsolved Query]. + For example: + [Question]: "Where do greyhound buses that are in the birthplace of Spirit If...'s performer leave from? " + [Query 1]: Who is the performer of Spirit If... ? + If you don't know the answer: + [Unsolved Query]: Who is the performer of Spirit If... ? + If you know the answer: + [Answer 1]: The performer of Spirit If... is Kevin Drew. + [Query 2]: Where was Kevin Drew born? + If you don't know the answer: + [Unsolved Query]: Where was Kevin Drew born? + If you know the answer: + [Answer 2]: Toronto. + [Query 3]: Where do greyhound buses in Toronto leave from? + If you don't know the answer: + [Unsolved Query]: Where do greyhound buses in Toronto leave from? + If you know the answer: + [Answer 3]: Toronto Coach Terminal. + [Final Content]: The performer of Spirit If... is Kevin Drew [1]. Kevin Drew was born in Toronto [2]. Greyhound buses in Toronto leave from Toronto Coach Terminal [3]. So the final answer is Toronto Coach Terminal. + + [Question]:"Which magazine was started first Arthur’s Magazine or First for Women?" + [Query 1]: When was Arthur’s Magazine started? + [Answer 1]: 1844. + [Query 2]: When was First for Women started? + [Answer 2]: 1989 + [Final Content]: Arthur’s Magazine started in 1844 [1]. First for Women started in 1989 [2]. So Arthur’s Magazine was started first. So the answer is Arthur’s Magazi + [Question]: {}""".format( + q, q + ), + ) + ] + feedback_answer = "continue" + predict_answer = "" + query_seen_list = [] + while round_count < 5 and feedback_answer != "end": + time.sleep(0.5) + rsp = self.llm.chat(message_keys_list) + round_count += 1 + input_str = str(rsp.message.content) + message_keys_list.append( + ChatMessage(role="assistant", content=input_str) + ) + predict_answer += input_str + + message, flag_contibue_label, query_seen_list = self._ir( + input_str, query_seen_list + ) + if flag_contibue_label: + feedback = message + else: + feedback = "end" + + if feedback == "end": + break + # [Query]:xxxx<SEP>[Answer]:xxxx<SEP>[Reference]:xxxx<SEP> + feedback_list = feedback.split("<SEP>") + if "Unsolved Query" not in feedback: + new_prompt = """Reference: {} According to this Reference, the answer for "{}" should be "{}", you can change your answer based on the Reference and continue constructing the reasoning chain to give the final answer for [Question]:{}""".format( + feedback_list[0], feedback_list[1], q, feedback_list[2] + ) + else: + new_prompt = """Reference: {} According to this Reference, the answer for "{}" should be "{}", you can give your answer based on the Reference and continue constructing the reasoning chain to give the final answer for [Question]:{} """.format( + feedback_list[0], feedback_list[1], q, feedback_list[2] + ) + message_keys_list.append(ChatMessage(role="user", content=new_prompt)) + result = self._extract(message_keys_list) + print(result) + + return -1 diff --git a/llama-index-packs/llama-index-packs-searchain/pyproject.toml b/llama-index-packs/llama-index-packs-searchain/pyproject.toml new file mode 100644 index 0000000000..1a364cb9a0 --- /dev/null +++ b/llama-index-packs/llama-index-packs-searchain/pyproject.toml @@ -0,0 +1,59 @@ +[build-system] +build-backend = "poetry.core.masonry.api" +requires = ["poetry-core"] + +[tool.codespell] +check-filenames = true +check-hidden = true +# Feel free to un-skip examples, and experimental, you will just need to +# work through many typos (--write-changes and --interactive will help) +skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb" + +[tool.llamahub] +contains_example = false +import_path = "llama_index.packs.searchain" + +[tool.llamahub.class_authors] +SearChainPack = "DJC-GO-SOLO" + +[tool.mypy] +disallow_untyped_defs = true +# Remove venv skip when integrated with pre-commit +exclude = ["_static", "build", "examples", "notebooks", "venv"] +ignore_missing_imports = true +python_version = "3.8" + +[tool.poetry] +authors = ["Your Name <you@example.com>"] +description = "llama-index packs searchain integration" +license = "MIT" +name = "llama-index-packs-searchain" +packages = [{include = "llama_index/"}] +readme = "README.md" +version = "0.1.0" + +[tool.poetry.dependencies] +python = ">=3.8.1,<4.0" +llama-index-core = "^0.10.0" +torch = "^2.1.2" +transformers = "^4.38.1" +sentence_transformers = "^2.5.1" + +[tool.poetry.group.dev.dependencies] +black = {extras = ["jupyter"], version = "<=23.9.1,>=23.7.0"} +codespell = {extras = ["toml"], version = ">=v2.2.6"} +ipython = "8.10.0" +jupyter = "^1.0.0" +mypy = "0.991" +pre-commit = "3.2.0" +pylint = "2.15.10" +pytest = "7.2.1" +pytest-mock = "3.11.1" +ruff = "0.0.292" +tree-sitter-languages = "^1.8.0" +types-Deprecated = ">=0.1.0" +types-PyYAML = "^6.0.12.12" +types-protobuf = "^4.24.0.4" +types-redis = "4.5.5.0" +types-requests = "2.28.11.8" # TODO: unpin when mypy>0.991 +types-setuptools = "67.1.0.0" diff --git a/llama-index-packs/llama-index-packs-searchain/tests/BUILD b/llama-index-packs/llama-index-packs-searchain/tests/BUILD new file mode 100644 index 0000000000..dabf212d7e --- /dev/null +++ b/llama-index-packs/llama-index-packs-searchain/tests/BUILD @@ -0,0 +1 @@ +python_tests() diff --git a/llama-index-packs/llama-index-packs-searchain/tests/__init__.py b/llama-index-packs/llama-index-packs-searchain/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/llama-index-packs/llama-index-packs-searchain/tests/test_packs_searchain.py b/llama-index-packs/llama-index-packs-searchain/tests/test_packs_searchain.py new file mode 100644 index 0000000000..4d55fb1741 --- /dev/null +++ b/llama-index-packs/llama-index-packs-searchain/tests/test_packs_searchain.py @@ -0,0 +1,7 @@ +from llama_index.core.llama_pack import BaseLlamaPack +from llama_index.packs.searchain import SearChainPack + + +def test_class(): + names_of_base_classes = [b.__name__ for b in SearChainPack.__mro__] + assert BaseLlamaPack.__name__ in names_of_base_classes -- GitLab