diff --git a/README.md b/README.md index 65a51ad6196664d6d24bdf26f25fa4f2f42ee0d8..2d2bc545912e07652c2393207126418e7d73d913 100644 --- a/README.md +++ b/README.md @@ -1,27 +1,34 @@  <p align="center"> - 🤗 <a href="https://huggingface.co/datasets/THUDM/LongBench" target="_blank">HF Repo</a> • 📃 Paper coming soon! + 🤗 <a href="https://huggingface.co/datasets/THUDM/LongBench" target="_blank">HF Repo</a> • 📃 <a href="https://arxiv.org/abs/2308.14508" target="_blank">Paper</a> </p> 阅读[中文版本](README_ZH.md). # 📖 LongBench: A Bilingual, Multitask Benchmark for Long Context Understanding -**LongBench** is the first benchmark for bilingual, multitask, and comprehensive assessment of **long context understanding** capabilities of large language models. LongBench includes different languages (Chinese and English) to provide a more comprehensive evaluation of the large models' multilingual capabilities on long contexts. In addition, LongBench is composed of six major categories and twenty different tasks, covering key long-text application scenarios such as multi-document QA, single-document QA, summarization, Few-shot learning, code completion, and synthesis tasks. +**LongBench** is the first benchmark for bilingual, multitask, and comprehensive assessment of **long context understanding** capabilities of large language models. LongBench includes different languages (Chinese and English) to provide a more comprehensive evaluation of the large models' multilingual capabilities on long contexts. In addition, LongBench is composed of six major categories and twenty one different tasks, covering key long-text application scenarios such as single-document QA, multi-document QA, summarization, few-shot learning, synthetic tasks and code completion. We are fully aware of the potentially high costs involved in the model evaluation process, especially in the context of long context scenarios (such as manual annotation costs or API call costs). Therefore, we adopt a fully automated evaluation method, aimed at measuring and evaluating the model's ability to understand long contexts at the lowest cost. -LongBench includes 13 English tasks, 5 Chinese tasks, and 2 code tasks, with the average length of most tasks ranging from 5k to 15k. For detailed statistics and construction methods of LongBench tasks, please refer [here](task.md). +LongBench includes 14 English tasks, 5 Chinese tasks, and 2 code tasks, with the average length of most tasks ranging from 5k to 15k, and a total of 4,750 test data. For detailed statistics and construction methods of LongBench tasks, please refer [here](task.md). In addition, we provide LongBench-E, a test set with a more uniform length distribution constructed by uniform sampling, with comparable amounts of data in the 0-4k, 4k-8k, and 8k+ length intervals to provide an analysis of the model's performance variations at different input lengths. + | Task Type | \#English Task | \#Chinese Task | \#Code Task | | :-------: | :--------------------: | :--------------------: | :------------------: | | Multi-document QA | 3 | 1 | - | | Single-document QA | 3 | 1 | - | -| Summarization | 2 | 1 | - | +| Summarization | 3 | 1 | - | | Few-shot learning | 3 | 1 | - | | Synthetic Tasks | 2 | 1 | - | | Code Completion | - | - | 2 | +## 🔥 Updates +**[2023/08/29]** The [LongBench paper](https://arxiv.org/abs/2308.14508) is released, along with several important updates to LongBench: +1. **More comprehensive datasets**: The MultiNews dataset for multi-document summarization is added to the summarization tasks, and the summarization task SAMSum is added to the Few-shot learning tasks, replacing the previous QA task NQ. TriviaQA and RepoBench-P are resampled to ensure a more appropriate data length; +2. **More uniformed length distribution**: LongBench-E is obtained by uniform sampling according to length, featuring a comparable amount of test data in the length intervals of 0-4k, 4-8k, and 8k+, which is more suitable for evaluating the model's ability in different input lengths variation; +3. **All evaluation codes made public**: The code for evaluating all models has been made public, and the code for retrieval-based and summarization-based long context compression strategies are also provided. + ## 🔍 Table of Contents - [🖥️ Leaderboard](#leaderboard) - [⚙️ How to evaluate on LongBench](#how-to-evaluate-on-LongBench) @@ -38,35 +45,35 @@ Here is the average scores (%) on the main task categories in both Chinese and E #### English | | Avg | Single-Doc QA | Multi-Doc QA | Summarization | Few-shot Learning | Code Completion | Synthetic Tasks | | ----------------- | :--: | :-----------: | :----------: | :-----------: | :---------------: | :-------------: | :-------------: | -| GPT-3.5-Turbo-16k | 45.5 | 39.8 | 38.7 | 26.5 | 76.0 | 54.5 | 37.8 | -| Llama2-7B-chat-4k | 29.0 | 24.8 | 21.4 | 23.9 | 50.5 | 47.3 | 5.9 | -| LongChat-7B-16k | 33.7 | 29.3 | 16.1 | 25.8 | 59.9 | 57.0 | 14.2 | -| XGen-7B-8k | 28.7 | 24.5 | 20.4 | 24.8 | 58.7 | 38.0 | 5.6 | -| InternLM-7B-8k | 24.7 | 17.1 | 20.8 | 13.3 | 52.7 | 39.7 | 4.7 | -| ChatGLM2-6B | 26.0 | 23.1 | 15.0 | 22.9 | 46.1 | 46.1 | 2.7 | -| ChatGLM2-6B-32k | 42.7 | 32.8 | 34.0 | 28.6 | 68.1 | 52.7 | 39.8 | +| GPT-3.5-Turbo-16k | 44.0 | 39.8 | 38.7 | 26.5 | 67.1 | 54.1 | 37.8 | +| Llama2-7B-chat-4k | 31.0 | 24.9 | 22.6 | 24.7 | 60.0 | 48.1 | 5.9 | +| LongChat-v1.5-7B-32k | 34.3 | 28.7 | 20.6 | 26.7 | 60.0 | 54.1 | 15.8 | +| XGen-7B-8k | 28.3 | 24.6 | 20.4 | 24.7 | 56.2 | 38.6 | 5.3 | +| InternLM-7B-8k | 24.2 | 17.4 | 20.2 | 16.1 | 50.3 | 36.4 | 4.5 | +| ChatGLM2-6B | 26.6 | 23.1 | 16.2 | 23.2 | 48.2 | 46.1 | 2.8 | +| ChatGLM2-6B-32k | 40.9 | 32.9 | 33.7 | 27.6 | 59.1 | 52.7 | 39.2 | +| Vicuna-v1.5-7B-16k | 31.9 | 28.0 | 18.6 | 26.0 | 66.2 | 47.3 | 5.5 | #### Chinese | | Avg | Single-Doc QA | Multi-Doc QA | Summarization | Few-shot Learning | Code Completion | Synthetic Tasks | | ----------------- | :--: | :-----------: | :----------: | :-----------: | :---------------: | :-------------: | :-------------: | -| GPT-3.5-Turbo-16k | 44.5 | 61.2 | 28.7 | 16.0 | 29.2 | 54.5 | 77.5 | -| Llama2-7B-chat-4k | 13.5 | 11.6 | 1.9 | 0.2 | 19.8 | 47.3 | 0.5 | -| LongChat-7B-16k | 23.7 | 26.6 | 19.1 | 14.0 | 20.8 | 57.0 | 4.8 | -| XGen-7B-8k | 14.5 | 14.2 | 9.1 | 1.5 | 20.0 | 38.0 | 4.2 | -| InternLM-7B-8k | 18.6 | 33.3 | 8.9 | 13.0 | 15.5 | 39.7 | 0.9 | -| ChatGLM2-6B | 22.5 | 33.0 | 15.2 | 14.6 | 20.5 | 46.1 | 5.5 | -| ChatGLM2-6B-32k | 41.3 | 52.0 | 34.3 | 16.3 | 29.9 | 52.7 | 62.5 | - -#### Radar Chart on Long Context Capability - +| GPT-3.5-Turbo-16k | 44.5 | 61.2 | 28.7 | 16.0 | 29.2 | 54.1 | 77.5 | +| Llama2-7B-chat-4k | 14.3 | 11.9 | 5.2 | 0.2 | 19.8 | 48.1 | 0.5 | +| LongChat-v1.5-7B-32k | 23.9 | 29.1 | 19.5 | 9.9 | 23.2 | 54.1 | 7.6 | +| XGen-7B-8k | 15.1 | 14.8 | 11.0 | 2.2 | 20.5 | 38.6 | 3.5 | +| InternLM-7B-8k | 18.3 | 33.6 | 11.1 | 12.4 | 15.2 | 36.4 | 0.9 | +| ChatGLM2-6B | 22.9 | 33.2 | 16.3 | 14.5 | 20.8 | 46.1 | 6.5 | +| ChatGLM2-6B-32k | 41.7 | 51.6 | 37.6 | 16.2 | 27.7 | 52.7 | 64.5 | +| Vicuna-v1.5-7B-16k | 26.4 | 43.0 | 19.3 | 15.1 | 28.8 | 47.3 | 5.0 | + +#### Radar Chart on Long Context Capability  #### Variation of Abilities under Different Context Lengths -To more specifically analyze the models' relative performance under different context lengths, the following chart shows the average relative scores on all tasks over different context length intervals. - +To specifically analyze the model's performance under different context lengths, the following chart shows the models' total scores averaged across all tasks by task category over different context length intervals in LongBench-E. -> Note: Assume that the model scores x on the data within a specific length range of a task, and y on all data of that task, then the model's **relative score** for that length range is (x/y-1). To better compare the trends of different models, we shift all the lines to 0 on 0-4k. + <a name="how-to-evaluate-on-LongBench"></a> ## ⚙️ How to evaluate on LongBench @@ -76,18 +83,28 @@ You can download and load the **LongBench** data through the Hugging Face datase ```python from datasets import load_dataset -datasets = ["hotpotqa", "2wikimqa", "musique", "dureader", "narrativeqa", "qasper", "multifieldqa_en", \ - "multifieldqa_zh", "gov_report", "qmsum", "vcsum", "trec", "nq", "triviaqa", "lsht", "passage_count", \ - "passage_retrieval_en", "passage_retrieval_zh", "lcc", "repobench-p"] +datasets = ["narrativeqa", "qasper", "multifieldqa_en", "multifieldqa_zh", "hotpotqa", "2wikimqa", "musique", \ + "dureader", "gov_report", "qmsum", "multi_news", "vcsum", "trec", "triviaqa", "samsum", "lsht", \ + "passage_count", "passage_retrieval_en", "passage_retrieval_zh", "lcc", "repobench-p"] for dataset in datasets: data = load_dataset('THUDM/LongBench', dataset, split='test') ``` +Similarly, you can load the **LongBench-E** data +```python +from datasets import load_dataset + +datasets = ["qasper", "multifieldqa_en", "hotpotqa", "2wikimqa", "gov_report", "multi_news", "trec", \ + "triviaqa", "samsum", "passage_count", "passage_retrieval_en", "lcc", "repobench-p"] + +for dataset in datasets: + data = load_dataset('THUDM/LongBench', f"{dataset}_e", split='test') +``` Alternatively, you can download the folder from [this link](https://huggingface.co/datasets/THUDM/LongBench/resolve/main/data.zip) to load the data. #### Data Format -All data in **LongBench** are standardized to the following format: +All data in **LongBench** (LongBench-E) are standardized to the following format: ```json { @@ -103,15 +120,25 @@ All data in **LongBench** are standardized to the following format: ``` #### Evaluation -Install the requirements with pip: `pip install -r requirements.txt`. We provide an evaluation code using ChatGLM2-6B as an example. First, run the [pred.py](pred.py) under the repository: +Install the requirements with pip: `pip install -r requirements.txt`. For Llama-2 based models, we recommend using Flash Attention for optimization and saving GPU memory The relevant dependencies can be installed according to the code base of [Flash Attention](https://github.com/Dao-AILab/flash-attention). + +First, run [pred.py](pred.py) and select the model you want to evaluate via `--model`. Let's take ChatGLM2-6B-32k as an example (HuggingFace model weight will be downloaded automatically according to the path in [model2path.json](config/model2path.json), you can change the path in this file to load the model weight from local): +```bash +CUDA_VISIBLE_DEVICES=0 python pred.py --model chatglm2-6b-32k +``` +You can obtain the output of the model under all LongBench datasets under the `pred/` folder corresponding to the model name. Similarly, with the `--e` command: ```bash -CUDA_VISIBLE_DEVICES=0 python pred.py +CUDA_VISIBLE_DEVICES=0 python pred.py --model chatglm2-6b-32k --e ``` -You can get the model outputs on all datasets in the `pred/` folder. After that, run the evaluation code in [eval.py](eval.py): +You can obtain the output on LongBench-E under the `pred_e/` folder. After that, run the evaluation code in [eval.py](eval.py): ```bash -python eval.py +python eval.py --model chatglm2-6b-32k ``` -You can get the evaluation results on all datasets in `result.json`. Please note that in `config/`, we provide the input format suitable for each dataset and the maximum output length. Feel free to modify them to better suit the model you want to evaluate. After modification, when evaluating with [pred.py](pred.py), the data will be automatically organized according to the new format to get the corresponding model output. +You can get the evaluation results on all datasets in `result.json`. The average score of the model over different length intervals in all LongBench-E datasets can be obtained with the `--e` command. + +Please note that in `config/`, we provide the input format suitable for each dataset and the maximum output length. Feel free to modify them to better suit the model you want to evaluate. After modification, when evaluating with [pred.py](pred.py), the data will be automatically organized according to the new format to get the corresponding model output. + +In addition we provide the code for the long context compression evaluation based on retrieval and summarization (see Section 4.2 in the LongBench paper for the implementation details) in the folders `retrieval/` and `summ/`, respectively. <a name="evaluation-result-on-each-dataset"></a> ## 📊 Evaluation Result on Each Dataset @@ -121,81 +148,90 @@ The following tables show the Zero-shot evaluation results (%) on all datasets, #### Single-Document QA | | NarrativeQA | Qasper | MultiFieldQA-en | MultiFieldQA-zh | | ----------------- | :---------: | :----: | :-------------: | :-------------: | -| GPT-3.5-Turbo-16k | 23.6 | 43.3 | 52.3 | 61.2 | -| Llama2-7B-chat-4k | 19.1 | 19.6 | 35.8 | 11.6 | -| LongChat-7B-16k | 21.6 | 21.6 | 44.6 | 26.6 | -| XGen-7B-8k | 17.9 | 18.3 | 37.2 | 14.2 | -| InternLM-7B-8k | 12.4 | 16.8 | 22.3 | 33.3 | -| ChatGLM2-6B | 11.2 | 23.7 | 34.2 | 33.0 | -| ChatGLM2-6B-32k | 20.4 | 32.2 | 45.7 | 52.0 | +| GPT-3.5-Turbo-16k | 23.6 | 43.3 | 52.3 | 61.2 | +| Llama2-7B-chat-4k | 18.7 | 19.2 | 36.8 | 11.9 | +| LongChat-v1.5-7B-32k | 16.9 | 27.7 | 41.4 | 29.1 | +| XGen-7B-8k | 18.0 | 18.1 | 37.7 | 14.8 | +| InternLM-7B-8k | 12.1 | 16.7 | 23.4 | 33.6 | +| ChatGLM2-6B | 11.8 | 22.5 | 35.0 | 33.2 | +| ChatGLM2-6B-32k | 21.1 | 31.5 | 46.2 | 51.6 | +| Vicuna-v1.5-7B-16k | 19.4 | 26.1 | 38.5 | 43.0 | #### Multi-Document QA - | | HotpotQA | 2WikiMQA | Musique | DuReader (zh) | | ----------------- | :------: | :------: | :-----: | :-----------: | -| GPT-3.5-Turbo-16k | 51.6 | 37.7 | 26.9 | 28.7 | -| Llama2-7B-chat-4k | 24.3 | 31.4 | 8.6 | 1.9 | -| LongChat-7B-16k | 22.4 | 16.8 | 9.1 | 19.1 | -| XGen-7B-8k | 28.3 | 21.5 | 11.5 | 9.1 | -| InternLM-7B-8k | 27.9 | 24.0 | 10.3 | 8.9 | -| ChatGLM2-6B | 20.2 | 19.6 | 5.3 | 15.2 | -| ChatGLM2-6B-32k | 44.9 | 34.9 | 22.2 | 34.3 | +| GPT-3.5-Turbo-16k | 51.6 | 37.7 | 26.9 | 28.7 | +| Llama2-7B-chat-4k | 25.4 | 32.8 | 9.4 | 5.2 | +| LongChat-v1.5-7B-32k | 31.5 | 20.6 | 9.7 | 19.5 | +| XGen-7B-8k | 29.7 | 21.1 | 10.3 | 11.0 | +| InternLM-7B-8k | 28.7 | 22.8 | 9.0 | 11.1 | +| ChatGLM2-6B | 22.4 | 20.1 | 6.1 | 16.3 | +| ChatGLM2-6B-32k | 45.1 | 34.0 | 21.9 | 37.6 | +| Vicuna-v1.5-7B-16k | 25.3 | 20.8 | 9.8 | 19.3 | #### Summarization - -| | GovReport | QMSum | VCSUM (zh) | -| :---------------- | :-------: | :---: | :--------: | -| GPT-3.5-Turbo-16k | 29.5 | 23.4 | 16.0 | -| Llama2-7B-chat-4k | 27.3 | 20.6 | 0.2 | -| LongChat-7B-16k | 28.4 | 23.2 | 14.0 | -| XGen-7B-8k | 27.8 | 21.7 | 1.5 | -| InternLM-7B-8k | 9.8 | 16.8 | 13.0 | -| ChatGLM2-6B | 23.7 | 22.2 | 14.6 | -| ChatGLM2-6B-32k | 33.3 | 23.9 | 16.3 | +| | GovReport | QMSum | MultiNews | VCSUM (zh) | +|:-----------|:---------:|:-----:|:-----:|:-----:| +| GPT-3.5-Turbo-16k | 29.5 | 23.4 | 26.7 | 16.0 | +| Llama2-7B-chat-4k | 27.3 | 20.8 | 25.8 | 0.2 | +| LongChat-v1.5-7B-32k | 30.8 | 22.7 | 26.4 | 9.9 | +| XGen-7B-8k | 27.3 | 20.5 | 26.2 | 2.2 | +| InternLM-7B-8k | 9.7 | 15.9 | 22.8 | 12.4 | +| ChatGLM2-6B | 23.2 | 21.1 | 25.2 | 14.5 | +| ChatGLM2-6B-32k | 32.4 | 24.0 | 26.5 | 16.2 | +| Vicuna-v1.5-7B-16k | 27.9 | 22.8 | 27.2 | 15.1 | #### Few-shot Learning - -| | TREC | NQ | TriviaQA | LSHT (zh) | -| ----------------- | :--: | :--: | :------: | :-------: | -| GPT-3.5-Turbo-16k | 68.0 | 73.0 | 87.1 | 29.2 | -| Llama2-7B-chat-4k | 60.5 | 31.4 | 59.7 | 19.8 | -| LongChat-7B-16k | 61.5 | 44.8 | 73.5 | 20.8 | -| XGen-7B-8k | 66.0 | 43.2 | 67.0 | 20.0 | -| InternLM-7B-8k | 49.0 | 47.6 | 61.6 | 15.5 | -| ChatGLM2-6B | 44.0 | 34.5 | 59.8 | 20.5 | -| ChatGLM2-6B-32k | 62.0 | 64.9 | 77.6 | 29.9 | - -#### Code Completion - -| | LCC | RepoBench-P | -| ----------------- | :--: | :---------: | -| GPT-3.5-Turbo-16k | 54.7 | 54.3 | -| Llama2-7B-chat-4k | 52.3 | 42.4 | -| LongChat-7B-16k | 59.2 | 54.7 | -| XGen-7B-8k | 38.8 | 37.3 | -| InternLM-7B-8k | 45.5 | 34.0 | -| ChatGLM2-6B | 48.4 | 43.7 | -| ChatGLM2-6B-32k | 55.4 | 50.0 | +| | TREC | TriviaQA | SAMSum | LSHT (zh) | +| --- | :-: | :-: | :-: | :-: | +| GPT-3.5-Turbo-16k | 68.0 | 91.4 | 41.7 | 29.2 | +| Llama2-7B-chat-4k | 61.5 | 77.8 | 40.7 | 19.8 | +| LongChat-v1.5-7B-32k | 63.5 | 82.3 | 34.2 | 23.2 | +| XGen-7B-8k | 65.5 | 77.8 | 25.3 | 20.5 | +| InternLM-7B-8k | 52.0 | 77.8 | 21.2 | 15.2 | +| ChatGLM2-6B | 44.5 | 70.6 | 29.5 | 20.8 | +| ChatGLM2-6B-32k | 62.5 | 78.7 | 36.3 | 27.7 | +| Vicuna-v1.5-7B-16k | 71.5 | 86.2 | 40.8 | 28.8 | #### Synthetic Tasks +| | Passage Count | PassageRetrieval-en | PassageRetrieval-zh | +| --- | :-: | :-: | :-: | +| GPT-3.5-Turbo-16k | 4.5 | 71.0 | 77.5 | +| Llama2-7B-chat-4k | 2.1 | 9.8 | 0.5 | +| LongChat-v1.5-7B-32k | 1.0 | 30.5 | 7.6 | +| XGen-7B-8k | 2.1 | 8.5 | 3.5 | +| InternLM-7B-8k | 3.0 | 6.0 | 0.9 | +| ChatGLM2-6B | 2.5 | 3.0 | 6.5 | +| ChatGLM2-6B-32k | 1.5 | 77.0 | 64.5 | +| Vicuna-v1.5-7B-16k | 6.5 | 4.5 | 5.0 | -| | PassageRetrieval-en | Passage Count | PassageRetrieval-zh | -| ----------------- | :-----------------: | :-----------: | :-----------------: | -| GPT-3.5-Turbo-16k | 71.0 | 4.5 | 77.5 | -| Llama2-7B-chat-4k | 9.2 | 2.5 | 0.5 | -| LongChat-7B-16k | 24.0 | 4.5 | 4.8 | -| XGen-7B-8k | 9.0 | 2.2 | 4.2 | -| InternLM-7B-8k | 6.5 | 2.9 | 0.9 | -| ChatGLM2-6B | 3.2 | 2.1 | 5.5 | -| ChatGLM2-6B-32k | 77.5 | 2.0 | 62.5 | +#### Code Completion +| | LCC | RepoBench-P | +| --- | :-: | :-: | +| GPT-3.5-Turbo-16k | 54.7 | 53.6 | +| Llama2-7B-chat-4k | 52.4 | 43.8 | +| LongChat-v1.5-7B-32k | 53.0 | 55.3 | +| XGen-7B-8k | 38.6 | 38.6 | +| InternLM-7B-8k | 44.1 | 28.8 | +| ChatGLM2-6B | 49.0 | 43.2 | +| ChatGLM2-6B-32k | 55.6 | 49.9 | +| Vicuna-v1.5-7B-16k | 51.0 | 43.5 | <a name="acknowledgement"></a> ## 📄 Acknowledgement -- Some of the tasks of **LongBench** are based on the datasets proposed by previous researchers, including [HotpotQA](https://hotpotqa.github.io/), [2WikiMultihopQA](https://aclanthology.org/2020.coling-main.580/), [Musique](https://arxiv.org/abs/2108.00573), [DuReader](https://github.com/baidu/DuReader), [NarrativeQA](https://arxiv.org/pdf/1712.07040.pdf), [Qasper](https://arxiv.org/pdf/2105.03011.pdf), [GovReport](https://arxiv.org/pdf/2104.02112.pdf), [QMSum](https://arxiv.org/pdf/2104.05938.pdf), [VCSUM](https://arxiv.org/abs/2305.05280), [TriviaQA](https://nlp.cs.washington.edu/triviaqa/), [NQ](https://ai.google.com/research/NaturalQuestions/), [TREC](https://aclanthology.org/C02-1150.pdf), [LSHT](http://tcci.ccf.org.cn/conference/2014/dldoc/evatask6.pdf), [LCC](https://arxiv.org/abs/2306.14893) and [RepoBench-P](https://arxiv.org/abs/2306.03091). +- Some of the tasks of LongBench are based on the datasets proposed by previous researchers, including [HotpotQA](https://hotpotqa.github.io/), [2WikiMultihopQA](https://aclanthology.org/2020.coling-main.580/), [MuSiQue](https://arxiv.org/abs/2108.00573), [DuReader](https://github.com/baidu/DuReader), [NarrativeQA](https://arxiv.org/pdf/1712.07040.pdf), [Qasper](https://arxiv.org/pdf/2105.03011.pdf), [GovReport](https://arxiv.org/pdf/2104.02112.pdf), [QMSum](https://arxiv.org/pdf/2104.05938.pdf), [MultiNews](https://aclanthology.org/P19-1102.pdf),[VCSUM](https://arxiv.org/abs/2305.05280), [TriviaQA](https://nlp.cs.washington.edu/triviaqa/), [TREC](https://aclanthology.org/C02-1150.pdf), [SAMSum](https://aclanthology.org/D19-5409.pdf),[LSHT](http://tcci.ccf.org.cn/conference/2014/dldoc/evatask6.pdf), [LCC](https://arxiv.org/abs/2306.14893) and [RepoBench-P](https://arxiv.org/abs/2306.03091). <a name="citation"></a> ## 📝 Citation -This is a joint work by **THU-KEG** and **Zhipu AI**. We are currently working on the paper, and the citation information will be updated when it's ready. Please stay tuned~ - -When citing our work, please cite all of the original dataset papers. The relevant citation information is listed [here](refs/ref.bib). +``` +@misc{bai2023longbench, + title={LongBench: A Bilingual, Multitask Benchmark for Long Context Understanding}, + author={Yushi Bai and Xin Lv and Jiajie Zhang and Hongchang Lyu and Jiankai Tang and Zhidian Huang and Zhengxiao Du and Xiao Liu and Aohan Zeng and Lei Hou and Yuxiao Dong and Jie Tang and Juanzi Li}, + year={2023}, + eprint={2308.14508}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +``` +When citing our work, please kindly consider citing the original dataset papers. The relevant citation information is listed [here](refs/ref.bib). diff --git a/README_ZH.md b/README_ZH.md index d0d71564dc810a4571ef41552d620827510e708a..d10c701d3a2678904e3a57a7a86788aba618771f 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -1,27 +1,34 @@  <p align="center"> - 🤗 <a href="https://huggingface.co/datasets/THUDM/LongBench" target="_blank">HF Repo</a> • 📃 Paper coming soon! + 🤗 <a href="https://huggingface.co/datasets/THUDM/LongBench" target="_blank">HF Repo</a> • 📃 <a href="https://arxiv.org/abs/2308.14508" target="_blank">Paper</a> </p> Read this in [English](README.md). # 📖 LongBench: 多任务中英双语长文本理解评测基准 -**LongBench**是第一个多任务、中英双语、针对大语言模型**长文本理解能力**的评测基准。在目前大模型多语言能力引起广泛关注的背景下,LongBench涵盖了不同的语言(中文和英文),以此来对大模型在长文本下的多语言能力进行更全面的评估。同时,LongBench由六大类、二十个不同的任务组成,覆盖了单文档QA、多文档QA、摘要、Few-shot学习、代码补全和合成任务等关键的长文本应用场景。 +**LongBench**是第一个多任务、中英双语、针对大语言模型**长文本理解能力**的评测基准。在目前大模型多语言能力引起广泛关注的背景下,LongBench涵盖了不同的语言(中文和英文),以此来对大模型在长文本下的多语言能力进行更全面的评估。同时,LongBench由六大类、二十一个不同的任务组成,覆盖了单文档QA、多文档QA、摘要、Few-shot学习、合成任务和代码补全等关键的长文本应用场景。 我们深知模型评测过程中可能产生的高昂成本,尤其是长文本场景下(如人工标注成本或API调用成本)。因此,我们采用了一种全自动的评测方式,旨在以最低的成本,最有效地衡量和评估模型的长文本理解能力。 -LongBench包含13个英文任务、5个中文任务和2个代码任务,多数任务的平均长度在5k-15k之间,共包含约4500条测试数据。关于LongBench数据集的具体统计及任务构造方式请参考[这里](task_zh.md)。 +LongBench包含14个英文任务、5个中文任务和2个代码任务,多数任务的平均长度在5k-15k之间,共包含4750条测试数据。关于LongBench数据集的具体统计及任务构造方式请参考[这里](task_zh.md)。此外,我们还通过均匀采样得到了长度分布更均匀的测试集合LongBench-E,在0-4k、4k-8k、8k+长度区间内的数据量相当,以提供模型在不同长度下性能变化的分析。 + | 任务类型 | 英文任务数 | 中文任务数 | 代码任务数 | | :----------: | :--------: | :--------: | :--------: | | 单文档QA | 3 | 1 | - | | 多文档QA | 3 | 1 | - | -| 摘要 | 2 | 1 | - | +| 摘要 | 3 | 1 | - | | Few-shot学习 | 3 | 1 | - | | 合成任务 | 2 | 1 | - | | 代码补全 | - | - | 2 | +## 🔥 更新信息 +**[2023/08/29]** [LongBench论文](https://arxiv.org/abs/2308.14508)发布,同时对LongBench进行了以下几项重要更新: +1. **更全面的数据集**:在摘要任务中增加了多文档摘要MultiNews数据集,在Few-shot学习任务中增加了摘要任务SAMSum,代替之前的QA任务NQ,并对TriviaQA, RepoBench-P进行重新采样以保证数据长度更加合适; +2. **更均匀的长度分布**:根据长度进行均匀采样得到了LongBench-E,其包含LongBench中的13个长度分布更加均匀的英文数据集,LongBench-E在0-4k,4-8k,8k+长度区间内均有数量相当的测试数据,更加适合评价模型在不同输入长度上的能力变化; +3. **全部评测代码公开**:评测所有模型的代码已公开,同时提供了基于检索、分段摘要的长文本压缩策略代码。 + ## 🔍 目录 - [🖥️ 排行榜](#排行榜) - [⚙️ 如何在LongBench上评测模型](#如何在LongBench上评测模型) @@ -38,33 +45,34 @@ LongBench包含13个英文任务、5个中文任务和2个代码任务,多数 #### 英文榜单 | | Avg | 单文档QA | 多文档QA | 摘要 | Few-shot学习 | 代码补全 | 合成任务 | | --- | :-: | :-: | :-: | :-: | :-: | :-: | :-: | -| GPT-3.5-Turbo-16k | 45.5 | 39.8 | 38.7 | 26.5 | 76.0 | 54.5 | 37.8 | -| Llama2-7B-chat-4k | 29.0 | 24.8 | 21.4 | 23.9 | 50.5 | 47.3 | 5.9 | -| LongChat-7B-16k | 33.7 | 29.3 | 16.1 | 25.8 | 59.9 | 57.0 | 14.2 | -| XGen-7B-8k | 28.7 | 24.5 | 20.4 | 24.8 | 58.7 | 38.0 | 5.6 | -| InternLM-7B-8k | 24.7 | 17.1 | 20.8 | 13.3 | 52.7 | 39.7 | 4.7 | -| ChatGLM2-6B | 26.0 | 23.1 | 15.0 | 22.9 | 46.1 | 46.1 | 2.7 | -| ChatGLM2-6B-32k | 42.7 | 32.8 | 34.0 | 28.6 | 68.1 | 52.7 | 39.8 | +| GPT-3.5-Turbo-16k | 44.0 | 39.8 | 38.7 | 26.5 | 67.1 | 54.1 | 37.8 | +| Llama2-7B-chat-4k | 31.0 | 24.9 | 22.6 | 24.7 | 60.0 | 48.1 | 5.9 | +| LongChat-v1.5-7B-32k | 34.3 | 28.7 | 20.6 | 26.7 | 60.0 | 54.1 | 15.8 | +| XGen-7B-8k | 28.3 | 24.6 | 20.4 | 24.7 | 56.2 | 38.6 | 5.3 | +| InternLM-7B-8k | 24.2 | 17.4 | 20.2 | 16.1 | 50.3 | 36.4 | 4.5 | +| ChatGLM2-6B | 26.6 | 23.1 | 16.2 | 23.2 | 48.2 | 46.1 | 2.8 | +| ChatGLM2-6B-32k | 40.9 | 32.9 | 33.7 | 27.6 | 59.1 | 52.7 | 39.2 | +| Vicuna-v1.5-7B-16k | 31.9 | 28.0 | 18.6 | 26.0 | 66.2 | 47.3 | 5.5 | #### 中文榜单 | | Avg | 单文档QA | 多文档QA | 摘要 | Few-shot学习 | 代码补全 | 合成任务 | |-------|:---:|:-------------:|:------------:|:-------------:|:-----------------:|:---------------:|:----------------:| -| GPT-3.5-Turbo-16k | 44.5 | 61.2 | 28.7 | 16.0 | 29.2 | 54.5 | 77.5 | -| Llama2-7B-chat-4k | 13.5 | 11.6 | 1.9 | 0.2 | 19.8 | 47.3 | 0.5 | -| LongChat-7B-16k | 23.7 | 26.6 | 19.1 | 14.0 | 20.8 | 57.0 | 4.8 | -| XGen-7B-8k | 14.5 | 14.2 | 9.1 | 1.5 | 20.0 | 38.0 | 4.2 | -| InternLM-7B-8k | 18.6 | 33.3 | 8.9 | 13.0 | 15.5 | 39.7 | 0.9 | -| ChatGLM2-6B | 22.5 | 33.0 | 15.2 | 14.6 | 20.5 | 46.1 | 5.5 | -| ChatGLM2-6B-32k | 41.3 | 52.0 | 34.3 | 16.3 | 29.9 | 52.7 | 62.5 | +| GPT-3.5-Turbo-16k | 44.5 | 61.2 | 28.7 | 16.0 | 29.2 | 54.1 | 77.5 | +| Llama2-7B-chat-4k | 14.3 | 11.9 | 5.2 | 0.2 | 19.8 | 48.1 | 0.5 | +| LongChat-v1.5-7B-32k | 23.9 | 29.1 | 19.5 | 9.9 | 23.2 | 54.1 | 7.6 | +| XGen-7B-8k | 15.1 | 14.8 | 11.0 | 2.2 | 20.5 | 38.6 | 3.5 | +| InternLM-7B-8k | 18.3 | 33.6 | 11.1 | 12.4 | 15.2 | 36.4 | 0.9 | +| ChatGLM2-6B | 22.9 | 33.2 | 16.3 | 14.5 | 20.8 | 46.1 | 6.5 | +| ChatGLM2-6B-32k | 41.7 | 51.6 | 37.6 | 16.2 | 27.7 | 52.7 | 64.5 | +| Vicuna-v1.5-7B-16k | 26.4 | 43.0 | 19.3 | 15.1 | 28.8 | 47.3 | 5.0 | #### 长文本任务能力雷达图  #### 不同长度文本下的能力变化 -为了更有针对性地分析模型在不同文本长度下的相对表现,下图展示了模型在不同文本长度区间上,所有任务上的平均相对分数。 - +为了更有针对性地分析模型在不同文本长度下的表现,下图展示了模型在LongBench-E中不同文本长度区间上,所有任务上按照任务类别进行平均的总分。 -> 注:假设模型在某个任务的特定长度范围内数据上得分为x,在该任务所有数据上得分为y,则模型在该长度范围的**相对分数**为(x/y-1)。为了更好比较不同模型的变化趋势,我们在0-4k将所有折线平移至0。 + <a name="如何在LongBench上评测模型"></a> ## ⚙️ 如何在LongBench上评测模型 @@ -74,17 +82,27 @@ LongBench包含13个英文任务、5个中文任务和2个代码任务,多数 ```python from datasets import load_dataset -datasets = ["hotpotqa", "2wikimqa", "musique", "dureader", "narrativeqa", "qasper", "multifieldqa_en", \ - "multifieldqa_zh", "gov_report", "qmsum", "vcsum", "trec", "nq", "triviaqa", "lsht", "passage_count", \ - "passage_retrieval_en", "passage_retrieval_zh", "lcc", "repobench-p"] +datasets = ["narrativeqa", "qasper", "multifieldqa_en", "multifieldqa_zh", "hotpotqa", "2wikimqa", "musique", \ + "dureader", "gov_report", "qmsum", "multi_news", "vcsum", "trec", "triviaqa", "samsum", "lsht", \ + "passage_count", "passage_retrieval_en", "passage_retrieval_zh", "lcc", "repobench-p"] for dataset in datasets: data = load_dataset('THUDM/LongBench', dataset, split='test') ``` +类似地,也可以载入**LongBench-E**的数据 +```python +from datasets import load_dataset + +datasets = ["qasper", "multifieldqa_en", "hotpotqa", "2wikimqa", "gov_report", "multi_news", "trec", \ + "triviaqa", "samsum", "passage_count", "passage_retrieval_en", "lcc", "repobench-p"] + +for dataset in datasets: + data = load_dataset('THUDM/LongBench', f"{dataset}_e", split='test') +``` 同样地,你也可以直接用这个[链接](https://huggingface.co/datasets/THUDM/LongBench/resolve/main/data.zip)下载所有的评测数据。 #### 数据格式 -**LongBench**中所有数据都统一为以下格式: +**LongBench**(LongBench-E)中所有数据都统一为以下格式: ```json { "input": "任务的输入/指令,通常较短,比如QA中的问题、Few-shot任务中的提问等", @@ -99,15 +117,25 @@ for dataset in datasets: ``` #### 评测 -通过pip安装依赖:`pip install -r requirements.txt`。我们以ChatGLM2-6B为例提供了一份评测代码。首先,运行仓库下的[pred.py](pred.py) +通过pip安装依赖:`pip install -r requirements.txt`。对于基于Llama-2的模型,我们推荐使用Flash Attention进行优化并节省显存,可以根据[Flash Attention](https://github.com/Dao-AILab/flash-attention)的代码库来安装相关依赖。 + +首先,运行仓库下的[pred.py](pred.py),并通过`--model`选择你想评测的模型,我们以ChatGLM2-6B-32k模型为例(代码将会根据[model2path.json](config/model2path.json)中的路径自动下载HuggingFace模型,你可以修改此文件中的路径以从本地载入模型参数): +```bash +CUDA_VISIBLE_DEVICES=0 python pred.py --model chatglm2-6b-32k +``` +可以在`pred/`对应模型名称的文件夹下得到模型在LongBench所有数据集下的输出,类似地,通过`--e`命令: ```bash -CUDA_VISIBLE_DEVICES=0 python pred.py +CUDA_VISIBLE_DEVICES=0 python pred.py --model chatglm2-6b-32k --e ``` -可以在`pred/`文件夹下得到模型在所有数据集下的输出,此后运行[eval.py](eval.py)的评测代码: +可以在`pred_e/`对应模型名称的文件夹下得到模型在LongBench-E所有数据集下的输出。此后运行[eval.py](eval.py)的评测代码: ```bash -python eval.py +python eval.py --model chatglm2-6b-32k ``` -可以在`result.json`中得到在各数据集上的评测结果。请注意,我们在`config/`下提供了我们总结出来的在各数据集上适合的输入格式和最大输出长度限制,在评测的时候可以进行修改以更好地适用你要评测的模型,修改后在[pred.py](pred.py)评测时会自动按照新的格式去整理数据并得到对应的模型输出。 +可以在存储模型输出文件夹下的`result.json`中得到模型在LongBench各数据集上的评测结果。通过`--e`命令也可以得到模型在LongBench-E所有数据集中不同长度区间内的平均得分。 + +请注意,我们在`config/`下提供了我们总结出来的在各数据集上适合的输入格式和最大输出长度限制,在评测的时候可以进行修改以更好地适用你要评测的模型,修改后在[pred.py](pred.py)评测时会自动按照新的格式去整理数据并得到对应的模型输出。 + +此外我们还提供了基于检索和分段摘要的长文本压缩评测代码(实现方式参考LongBench论文中的4.2节),分别在`retrieval/`和`summ/`两个文件夹下。 <a name="详细评测结果"></a> ## 📊 详细评测结果 @@ -117,74 +145,88 @@ python eval.py | | NarrativeQA | Qasper | MultiFieldQA-en | MultiFieldQA-zh | |-------------------|:-----------:|:------:|:---------------:|:---------------:| | GPT-3.5-Turbo-16k | 23.6 | 43.3 | 52.3 | 61.2 | -| Llama2-7B-chat-4k | 19.1 | 19.6 | 35.8 | 11.6 | -| LongChat-7B-16k | 21.6 | 21.6 | 44.6 | 26.6 | -| XGen-7B-8k | 17.9 | 18.3 | 37.2 | 14.2 | -| InternLM-7B-8k | 12.4 | 16.8 | 22.3 | 33.3 | -| ChatGLM2-6B | 11.2 | 23.7 | 34.2 | 33.0 | -| ChatGLM2-6B-32k | 20.4 | 32.2 | 45.7 | 52.0 | +| Llama2-7B-chat-4k | 18.7 | 19.2 | 36.8 | 11.9 | +| LongChat-v1.5-7B-32k | 16.9 | 27.7 | 41.4 | 29.1 | +| XGen-7B-8k | 18.0 | 18.1 | 37.7 | 14.8 | +| InternLM-7B-8k | 12.1 | 16.7 | 23.4 | 33.6 | +| ChatGLM2-6B | 11.8 | 22.5 | 35.0 | 33.2 | +| ChatGLM2-6B-32k | 21.1 | 31.5 | 46.2 | 51.6 | +| Vicuna-v1.5-7B-16k | 19.4 | 26.1 | 38.5 | 43.0 | #### 多文档QA | | HotpotQA | 2WikiMQA | Musique | DuReader (zh) | |----------------------|:--------:|:--------:|:-------:|:--------:| | GPT-3.5-Turbo-16k | 51.6 | 37.7 | 26.9 | 28.7 | -| Llama2-7B-chat-4k | 24.3 | 31.4 | 8.6 | 1.9 | -| LongChat-7B-16k | 22.4 | 16.8 | 9.1 | 19.1 | -| XGen-7B-8k | 28.3 | 21.5 | 11.5 | 9.1 | -| InternLM-7B-8k | 27.9 | 24.0 | 10.3 | 8.9 | -| ChatGLM2-6B | 20.2 | 19.6 | 5.3 | 15.2 | -| ChatGLM2-6B-32k | 44.9 | 34.9 | 22.2 | 34.3 | +| Llama2-7B-chat-4k | 25.4 | 32.8 | 9.4 | 5.2 | +| LongChat-v1.5-7B-32k | 31.5 | 20.6 | 9.7 | 19.5 | +| XGen-7B-8k | 29.7 | 21.1 | 10.3 | 11.0 | +| InternLM-7B-8k | 28.7 | 22.8 | 9.0 | 11.1 | +| ChatGLM2-6B | 22.4 | 20.1 | 6.1 | 16.3 | +| ChatGLM2-6B-32k | 45.1 | 34.0 | 21.9 | 37.6 | +| Vicuna-v1.5-7B-16k | 25.3 | 20.8 | 9.8 | 19.3 | #### 摘要 -| | GovReport | QMSum | VCSUM (zh) | -|:-----------|:---------:|:-----:|:-----:| -| GPT-3.5-Turbo-16k | 29.5 | 23.4 | 16.0 | -| Llama2-7B-chat-4k | 27.3 | 20.6 | 0.2 | -| LongChat-7B-16k | 28.4 | 23.2 | 14.0 | -| XGen-7B-8k | 27.8 | 21.7 | 1.5 | -| InternLM-7B-8k | 9.8 | 16.8 | 13.0 | -| ChatGLM2-6B | 23.7 | 22.2 | 14.6 | -| ChatGLM2-6B-32k | 33.3 | 23.9 | 16.3 | +| | GovReport | QMSum | MultiNews | VCSUM (zh) | +|:-----------|:---------:|:-----:|:-----:|:-----:| +| GPT-3.5-Turbo-16k | 29.5 | 23.4 | 26.7 | 16.0 | +| Llama2-7B-chat-4k | 27.3 | 20.8 | 25.8 | 0.2 | +| LongChat-v1.5-7B-32k | 30.8 | 22.7 | 26.4 | 9.9 | +| XGen-7B-8k | 27.3 | 20.5 | 26.2 | 2.2 | +| InternLM-7B-8k | 9.7 | 15.9 | 22.8 | 12.4 | +| ChatGLM2-6B | 23.2 | 21.1 | 25.2 | 14.5 | +| ChatGLM2-6B-32k | 32.4 | 24.0 | 26.5 | 16.2 | +| Vicuna-v1.5-7B-16k | 27.9 | 22.8 | 27.2 | 15.1 | #### Few-shot学习 -| | TREC | NQ | TriviaQA | LSHT (zh) | +| | TREC | TriviaQA | SAMSum | LSHT (zh) | | --- | :-: | :-: | :-: | :-: | -| GPT-3.5-Turbo-16k | 68.0 | 73.0 | 87.1 | 29.2 | -| Llama2-7B-chat-4k | 60.5 | 31.4 | 59.7 | 19.8 | -| LongChat-7B-16k | 61.5 | 44.8 | 73.5 | 20.8 | -| XGen-7B-8k | 66.0 | 43.2 | 67.0 | 20.0 | -| InternLM-7B-8k | 49.0 | 47.6 | 61.6 | 15.5 | -| ChatGLM2-6B | 44.0 | 34.5 | 59.8 | 20.5 | -| ChatGLM2-6B-32k | 62.0 | 64.9 | 77.6 | 29.9 | +| GPT-3.5-Turbo-16k | 68.0 | 91.4 | 41.7 | 29.2 | +| Llama2-7B-chat-4k | 61.5 | 77.8 | 40.7 | 19.8 | +| LongChat-v1.5-7B-32k | 63.5 | 82.3 | 34.2 | 23.2 | +| XGen-7B-8k | 65.5 | 77.8 | 25.3 | 20.5 | +| InternLM-7B-8k | 52.0 | 77.8 | 21.2 | 15.2 | +| ChatGLM2-6B | 44.5 | 70.6 | 29.5 | 20.8 | +| ChatGLM2-6B-32k | 62.5 | 78.7 | 36.3 | 27.7 | +| Vicuna-v1.5-7B-16k | 71.5 | 86.2 | 40.8 | 28.8 | + +#### 合成任务 +| | Passage Count | PassageRetrieval-en | PassageRetrieval-zh | +| --- | :-: | :-: | :-: | +| GPT-3.5-Turbo-16k | 4.5 | 71.0 | 77.5 | +| Llama2-7B-chat-4k | 2.1 | 9.8 | 0.5 | +| LongChat-v1.5-7B-32k | 1.0 | 30.5 | 7.6 | +| XGen-7B-8k | 2.1 | 8.5 | 3.5 | +| InternLM-7B-8k | 3.0 | 6.0 | 0.9 | +| ChatGLM2-6B | 2.5 | 3.0 | 6.5 | +| ChatGLM2-6B-32k | 1.5 | 77.0 | 64.5 | +| Vicuna-v1.5-7B-16k | 6.5 | 4.5 | 5.0 | #### 代码补全 | | LCC | RepoBench-P | | --- | :-: | :-: | -| GPT-3.5-Turbo-16k | 54.7 | 54.3 | -| Llama2-7B-chat-4k | 52.3 | 42.4 | -| LongChat-7B-16k | 59.2 | 54.7 | -| XGen-7B-8k | 38.8 | 37.3 | -| InternLM-7B-8k | 45.5 | 34.0 | -| ChatGLM2-6B | 48.4 | 43.7 | -| ChatGLM2-6B-32k | 55.4 | 50.0 | - -#### 合成任务 -| | PassageRetrieval-en | Passage Count | PassageRetrieval-zh | -| --- | :-: | :-: | :-: | -| GPT-3.5-Turbo-16k | 71.0 | 4.5 | 77.5 | -| Llama2-7B-chat-4k | 9.2 | 2.5 | 0.5 | -| LongChat-7B-16k | 24.0 | 4.5 | 4.8 | -| XGen-7B-8k | 9.0 | 2.2 | 4.2 | -| InternLM-7B-8k | 6.5 | 2.9 | 0.9 | -| ChatGLM2-6B | 3.2 | 2.1 | 5.5 | -| ChatGLM2-6B-32k | 77.5 | 2.0 | 62.5 | +| GPT-3.5-Turbo-16k | 54.7 | 53.6 | +| Llama2-7B-chat-4k | 52.4 | 43.8 | +| LongChat-v1.5-7B-32k | 53.0 | 55.3 | +| XGen-7B-8k | 38.6 | 38.6 | +| InternLM-7B-8k | 44.1 | 28.8 | +| ChatGLM2-6B | 49.0 | 43.2 | +| ChatGLM2-6B-32k | 55.6 | 49.9 | +| Vicuna-v1.5-7B-16k | 51.0 | 43.5 | <a name="致谢"></a> ## 📄 致谢 -- **LongBench**的部分任务基于之前的研究者提出的数据集构建,包括[HotpotQA](https://hotpotqa.github.io/),[2WikiMultihopQA](https://aclanthology.org/2020.coling-main.580/),[Musique](https://arxiv.org/abs/2108.00573),[DuReader](https://github.com/baidu/DuReader),[NarrativeQA](https://arxiv.org/pdf/1712.07040.pdf),[Qasper](https://arxiv.org/pdf/2105.03011.pdf),[GovReport](https://arxiv.org/pdf/2104.02112.pdf),[QMSum](https://arxiv.org/pdf/2104.05938.pdf),[VCSUM](https://arxiv.org/abs/2305.05280),[TriviaQA](https://nlp.cs.washington.edu/triviaqa/),[NQ](https://ai.google.com/research/NaturalQuestions/),[TREC](https://aclanthology.org/C02-1150.pdf),[LSHT](http://tcci.ccf.org.cn/conference/2014/dldoc/evatask6.pdf),[LCC](https://arxiv.org/abs/2306.14893)和[RepoBench-P](https://arxiv.org/abs/2306.03091)。 +- LongBench的部分任务基于之前的研究者提出的数据集构建,包括[HotpotQA](https://hotpotqa.github.io/),[2WikiMultihopQA](https://aclanthology.org/2020.coling-main.580/),[MuSiQue](https://arxiv.org/abs/2108.00573),[DuReader](https://github.com/baidu/DuReader),[NarrativeQA](https://arxiv.org/pdf/1712.07040.pdf),[Qasper](https://arxiv.org/pdf/2105.03011.pdf),[GovReport](https://arxiv.org/pdf/2104.02112.pdf),[QMSum](https://arxiv.org/pdf/2104.05938.pdf),[MultiNews](https://aclanthology.org/P19-1102.pdf),[VCSUM](https://arxiv.org/abs/2305.05280),[TriviaQA](https://nlp.cs.washington.edu/triviaqa/),[TREC](https://aclanthology.org/C02-1150.pdf),[SAMSum](https://aclanthology.org/D19-5409.pdf),[LSHT](http://tcci.ccf.org.cn/conference/2014/dldoc/evatask6.pdf),[LCC](https://arxiv.org/abs/2306.14893)和[RepoBench-P](https://arxiv.org/abs/2306.03091)。 <a name="引用"></a> ## 📝 引用 -本工作由**THU-KEG**和**Zhipu AI**共同完成,相关论文正在撰写中,届时将更新引用信息,敬请关注~ - -如果您使用Longbench,请一并引用LongBench所基于的数据集对应的论文,相关引用信息在[这里](refs/ref.bib)。 +``` +@misc{bai2023longbench, + title={LongBench: A Bilingual, Multitask Benchmark for Long Context Understanding}, + author={Yushi Bai and Xin Lv and Jiajie Zhang and Hongchang Lyu and Jiankai Tang and Zhidian Huang and Zhengxiao Du and Xiao Liu and Aohan Zeng and Lei Hou and Yuxiao Dong and Jie Tang and Juanzi Li}, + year={2023}, + eprint={2308.14508}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +``` +如果您使用Longbench,请考虑引用LongBench所基于的数据集对应的论文,相关引用信息在[这里](refs/ref.bib)。 diff --git a/config/dataset2maxlen.json b/config/dataset2maxlen.json index 07e5638d291a38061ebb081e17db8f931728e37d..79d0d9990e5799c845ebcf839c1ee1a4ff14873e 100644 --- a/config/dataset2maxlen.json +++ b/config/dataset2maxlen.json @@ -1,22 +1,23 @@ { - "passage_count": 32, - "trec": 64, - "nq": 32, - "triviaqa": 32, - "hotpotqa": 32, - "musique": 32, - "2wikimqa": 32, "narrativeqa": 128, "qasper": 128, + "multifieldqa_en": 64, + "multifieldqa_zh": 64, + "hotpotqa": 32, + "2wikimqa": 32, + "musique": 32, + "dureader": 128, "gov_report": 512, "qmsum": 512, - "passage_retrieval_zh": 32, - "passage_retrieval_en": 32, - "lsht": 64, - "dureader": 128, + "multi_news": 512, "vcsum": 512, - "multifieldqa_en": 64, - "multifieldqa_zh": 64, + "trec": 64, + "triviaqa": 32, + "samsum": 128, + "lsht": 64, + "passage_count": 32, + "passage_retrieval_en": 32, + "passage_retrieval_zh": 32, "lcc": 64, "repobench-p": 64 } \ No newline at end of file diff --git a/config/dataset2prompt.json b/config/dataset2prompt.json index a902636fd588c0e778aa3f442282a8cfc94dd7a5..1c85f6bc0f0df4e42131aa49a867797ee7043ecf 100644 --- a/config/dataset2prompt.json +++ b/config/dataset2prompt.json @@ -1,22 +1,23 @@ { - "passage_count": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n{context}\n\nPlease enter the final count of unique paragraphs after removing duplicates. The output format should only contain the number, such as 1, 2, 3, and so on.\n\nThe final answer is: ", - "passage_retrieval_en": "Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n{context}\n\nThe following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like \"Paragraph 1\", \"Paragraph 2\", etc.\n\nThe answer is: ", - "trec": "Please determine the type of the question below. Here are some examples of questions.\n\n{context}\n{input}", - "nq": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}\n\n{input}", - "triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}\n\n{input}", - "hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", - "musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", - "2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", "narrativeqa": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: {context}\n\nNow, answer the question based on the story asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:", "qasper": "You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nArticle: {context}\n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:", + "multifieldqa_en": "Read the following text and answer briefly.\n\n{context}\n\nNow, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", + "multifieldqa_zh": "阅读以下文字并用中文简短回答:\n\n{context}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{input}\n回答:", + "hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", + "2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", + "musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", + "dureader": "请基于给定的文章回答下述问题。\n\n文章:{context}\n\n请基于上述文章回答下面的问题。\n\n问题:{input}\n回答:", "gov_report": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{context}\n\nNow, write a one-page summary of the report.\n\nSummary:", "qmsum": "You are given a meeting transcript and a query containing a question or instruction. Answer the query in one or more sentences.\n\nTranscript:\n{context}\n\nNow, answer the query based on the above meeting transcript in one or more sentences.\n\nQuery: {input}\nAnswer:", - "passage_retrieval_zh": "以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n{context}\n\n下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是\"段落1\",\"段落2\"等格式\n\n答案是:", - "lsht": "请判断给定新闻的类别,下面是一些例子。\n\n{context}\n{input}", - "dureader": "请基于给定的文章回答下述问题。\n\n文章:{context}\n\n请基于上述文章回答下面的问题。\n\n问题:{input}\n回答:", + "multi_news": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\nSummary:", "vcsum": "下面有一段会议记录,请你阅读后,写一段总结,总结会议的内容。\n会议记录:\n{context}\n\n会议总结:", - "multifieldqa_en": "Read the following text and answer briefly.\n\n{context}\n\nNow, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", - "multifieldqa_zh": "阅读以下文字并用中文简短回答:\n\n{context}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{input}\n回答:", + "trec": "Please determine the type of the question below. Here are some examples of questions.\n\n{context}\n{input}", + "triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}\n\n{input}", + "samsum": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n{context}\n\n{input}", + "lsht": "请判断给定新闻的类别,下面是一些例子。\n\n{context}\n{input}", + "passage_count": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n{context}\n\nPlease enter the final count of unique paragraphs after removing duplicates. The output format should only contain the number, such as 1, 2, 3, and so on.\n\nThe final answer is: ", + "passage_retrieval_en": "Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n{context}\n\nThe following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like \"Paragraph 1\", \"Paragraph 2\", etc.\n\nThe answer is: ", + "passage_retrieval_zh": "以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n{context}\n\n下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是\"段落1\",\"段落2\"等格式\n\n答案是:", "lcc": "Please complete the code given below. \n{context}Next line of code:\n", "repobench-p": "Please complete the code given below. \n{context}{input}Next line of code:\n" } \ No newline at end of file diff --git a/config/model2maxlen.json b/config/model2maxlen.json new file mode 100644 index 0000000000000000000000000000000000000000..a13f42c9d783abf917acaf9b067ae8b4b52d0c9f --- /dev/null +++ b/config/model2maxlen.json @@ -0,0 +1,9 @@ +{ + "llama2-7b-chat-4k": 3500, + "longchat-v1.5-7b-32k": 31500, + "xgen-7b-8k": 7500, + "internlm-7b-8k": 7500, + "chatglm2-6b": 31500, + "chatglm2-6b-32k": 31500, + "vicuna-v1.5-7b-16k": 15500 +} \ No newline at end of file diff --git a/config/model2path.json b/config/model2path.json new file mode 100644 index 0000000000000000000000000000000000000000..c59ea2535e5c69e7814229679adf3581908b0a38 --- /dev/null +++ b/config/model2path.json @@ -0,0 +1,9 @@ +{ + "llama2-7b-chat-4k": "meta-llama/Llama-2-7b-chat-hf", + "longchat-v1.5-7b-32k": "lmsys/longchat-7b-v1.5-32k", + "xgen-7b-8k": "Salesforce/xgen-7b-8k-inst", + "internlm-7b-8k": "internlm/internlm-chat-7b-8k", + "chatglm2-6b": "THUDM/chatglm2-6b", + "chatglm2-6b-32k": "THUDM/chatglm2-6b-32k", + "vicuna-v1.5-7b-16k": "lmsys/vicuna-7b-v1.5-16k" +} \ No newline at end of file diff --git a/eval.py b/eval.py index 635d1319e41d9b388352c55b737f9e4554ca9711..0a1d559c48eccfaf6f814f54866840a1735e712a 100644 --- a/eval.py +++ b/eval.py @@ -1,5 +1,7 @@ import os import json +import argparse +import numpy as np from metrics import ( qa_f1_score, @@ -14,20 +16,21 @@ from metrics import ( ) dataset2metric = { - "hotpotqa": qa_f1_score, - "2wikimqa": qa_f1_score, - "musique": qa_f1_score, - "dureader": rouge_zh_score, "narrativeqa": qa_f1_score, "qasper": qa_f1_score, "multifieldqa_en": qa_f1_score, "multifieldqa_zh": qa_f1_zh_score, + "hotpotqa": qa_f1_score, + "2wikimqa": qa_f1_score, + "musique": qa_f1_score, + "dureader": rouge_zh_score, "gov_report": rouge_score, "qmsum": rouge_score, + "multi_news": rouge_score, "vcsum": rouge_zh_score, "trec": classification_score, - "nq": qa_f1_score, "triviaqa": qa_f1_score, + "samsum": rouge_score, "lsht": classification_score, "passage_retrieval_en": retrieval_score, "passage_count": count_score, @@ -36,28 +39,71 @@ dataset2metric = { "repobench-p": code_sim_score, } +def parse_args(args=None): + parser = argparse.ArgumentParser() + parser.add_argument('--model', type=str, default=None) + parser.add_argument('--e', action='store_true', help="Evaluate on LongBench-E") + return parser.parse_args(args) + +def scorer_e(dataset, predictions, answers, lengths, all_classes): + scores = {"0-4k": [], "4-8k": [], "8k+": []} + for (prediction, ground_truths, length) in zip(predictions, answers, lengths): + score = 0. + if dataset in ["trec", "triviaqa", "samsum", "lsht"]: + prediction = prediction.lstrip('\n').split('\n')[0] + for ground_truth in ground_truths: + score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes)) + if length < 4000: + scores["0-4k"].append(score) + elif length < 8000: + scores["4-8k"].append(score) + else: + scores["8k+"].append(score) + for key in scores.keys(): + scores[key] = round(100 * np.mean(scores[key]), 2) + return scores + def scorer(dataset, predictions, answers, all_classes): total_score = 0. for (prediction, ground_truths) in zip(predictions, answers): score = 0. + if dataset in ["trec", "triviaqa", "samsum", "lsht"]: + prediction = prediction.lstrip('\n').split('\n')[0] for ground_truth in ground_truths: score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes)) total_score += score return round(100 * total_score / len(predictions), 2) if __name__ == '__main__': + args = parse_args() scores = dict() - all_files = os.listdir("pred/") + if args.e: + path = f"pred_e/{args.model}/" + else: + path = f"pred/{args.model}/" + all_files = os.listdir(path) + print("Evaluating on:", all_files) for filename in all_files: - predictions, answers = [], [] + if not filename.endswith("jsonl"): + continue + predictions, answers, lengths = [], [], [] dataset = filename.split('.')[0] - with open(f"pred/{filename}", "r") as f: + with open(f"{path}{filename}", "r", encoding="utf-8") as f: for line in f: data = json.loads(line) predictions.append(data["pred"]) answers.append(data["answers"]) all_classes = data["all_classes"] - score = scorer(dataset, predictions, answers, all_classes) + if "length" in data: + lengths.append(data["length"]) + if args.e: + score = scorer_e(dataset, predictions, answers, lengths, all_classes) + else: + score = scorer(dataset, predictions, answers, all_classes) scores[dataset] = score - with open("result.json", "w") as f: + if args.e: + out_path = f"pred_e/{args.model}/result.json" + else: + out_path = f"pred/{args.model}/result.json" + with open(out_path, "w") as f: json.dump(scores, f, ensure_ascii=False, indent=4) diff --git a/misc/.DS_Store b/misc/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..0398f6cc86d9df8f7229ba9af0dda70fb99cc6e5 Binary files /dev/null and b/misc/.DS_Store differ diff --git a/misc/curve.png b/misc/curve.png index 63d64bb8049910566cbfd8a579b1d459f0286784..89d75cac38eb648f480b7108d967681201ef45dc 100644 Binary files a/misc/curve.png and b/misc/curve.png differ diff --git a/misc/overview.png b/misc/overview.png new file mode 100644 index 0000000000000000000000000000000000000000..8fb1d50ef60e210710a2ce196f91b2a20a10e1e3 Binary files /dev/null and b/misc/overview.png differ diff --git a/misc/radar.png b/misc/radar.png index baecc55c03a164cfb0732c65825e5b5639b85bf4..f0c11b4ac1de515ca65daa999da10d356ca11214 100644 Binary files a/misc/radar.png and b/misc/radar.png differ diff --git a/pred.py b/pred.py index d84dee8a8f2358d8bdbbbc16b50102f68d7b8166..467d38b90b16bdf872596145b02b6777bf02c2e7 100644 --- a/pred.py +++ b/pred.py @@ -2,14 +2,49 @@ import os from datasets import load_dataset import torch import json -from transformers import AutoTokenizer, AutoModel +from transformers import AutoTokenizer, LlamaTokenizer, LlamaForCausalLM, AutoModelForCausalLM from tqdm import tqdm +import numpy as np +import random +import argparse +from llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn -# This is the customized building prompt for chat models, here is an example for ChatGLM2 -def build_chat(tokenizer, prompt): - return tokenizer.build_prompt(prompt) +def parse_args(args=None): + parser = argparse.ArgumentParser() + parser.add_argument('--model', type=str, default=None, choices=["llama2-7b-chat-4k", "longchat-v1.5-7b-32k", "xgen-7b-8k", "internlm-7b-8k", "chatglm2-6b", "chatglm2-6b-32k", "vicuna-v1.5-7b-16k"]) + parser.add_argument('--e', action='store_true', help="Evaluate on LongBench-E") + return parser.parse_args(args) -def get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset, device): +# This is the customized building prompt for chat models +def build_chat(tokenizer, prompt, model_name): + if "chatglm" in model_name: + prompt = tokenizer.build_prompt(prompt) + elif "longchat" in model_name or "vicuna" in model_name: + from fastchat.model import get_conversation_template + conv = get_conversation_template("vicuna") + conv.append_message(conv.roles[0], prompt) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + elif "llama2" in model_name: + prompt = f"[INST]{prompt}[/INST]" + elif "xgen" in model_name: + header = ( + "A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n" + ) + prompt = header + f" ### Human: {prompt}\n###" + elif "internlm" in model_name: + prompt = f"<|User|>:{prompt}<eoh>\n<|Bot|>:" + return prompt + +def post_process(response, model_name): + if "xgen" in model_name: + response = response.strip().replace("Assistant:", "") + elif "internlm" in model_name: + response = response.split("<eoa>")[0] + return response + +def get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset, device, model_name): preds = [] for json_obj in tqdm(data): prompt = prompt_format.format(**json_obj) @@ -18,44 +53,107 @@ def get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset if len(tokenized_prompt) > max_length: half = int(max_length/2) prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True) - if dataset not in ["lcc", "repobench-p", "trec", "nq", "triviaqa", "lsht"]: # chat models are better off without build prompt on these tasks - prompt = build_chat(tokenizer, prompt) + if dataset not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]: # chat models are better off without build prompts on these tasks + prompt = build_chat(tokenizer, prompt, model_name) input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device) context_length = input.input_ids.shape[-1] - output = model.generate( - **input, - max_new_tokens=max_gen, - num_beams=1, - do_sample=False, - temperature=1.0, - )[0] + if dataset == "samsum": # prevent illegal output on samsum (model endlessly repeat "\nDialogue"), might be a prompting issue + output = model.generate( + **input, + max_new_tokens=max_gen, + num_beams=1, + do_sample=False, + temperature=1.0, + min_length=context_length+1, + eos_token_id=[tokenizer.eos_token_id, tokenizer.encode("\n", add_special_tokens=False)[-1]], + )[0] + else: + output = model.generate( + **input, + max_new_tokens=max_gen, + num_beams=1, + do_sample=False, + temperature=1.0, + )[0] pred = tokenizer.decode(output[context_length:], skip_special_tokens=True) - preds.append({"pred": pred, "answers": json_obj["answers"], "all_classes": json_obj["all_classes"]}) + pred = post_process(pred, model_name) + preds.append({"pred": pred, "answers": json_obj["answers"], "all_classes": json_obj["all_classes"], "length": json_obj["length"]}) return preds +def seed_everything(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.cuda.manual_seed_all(seed) + +def load_model_and_tokenizer(path, model_name, device): + if "chatglm" in model_name or "internlm" in model_name or "xgen" in model_name: + tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained(path, trust_remote_code=True, torch_dtype=torch.bfloat16).to(device) + elif "llama2" in model_name: + replace_llama_attn_with_flash_attn() + tokenizer = LlamaTokenizer.from_pretrained(path) + model = LlamaForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16).to(device) + elif "longchat" in model_name or "vicuna" in model_name: + from fastchat.model import load_model + replace_llama_attn_with_flash_attn() + model, _ = load_model( + path, + device='cpu', + num_gpus=0, + load_8bit=False, + cpu_offloading=False, + debug=False, + ) + model = model.to(device) + model = model.bfloat16() + tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False) + model = model.eval() + return model, tokenizer if __name__ == '__main__': - datasets = ["hotpotqa", "2wikimqa", "musique", "dureader", "narrativeqa", "qasper", "multifieldqa_en", "multifieldqa_zh", "gov_report", \ - "qmsum", "vcsum", "trec", "nq", "triviaqa", "lsht", "passage_count", "passage_retrieval_en", "passage_retrieval_zh", "lcc", "repobench-p"] + seed_everything(42) + args = parse_args() + model2path = json.load(open("config/model2path.json", "r")) + model2maxlen = json.load(open("config/model2maxlen.json", "r")) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - # define your model (ChatGLM2-6B, for instance) - tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) - model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device) - model = model.eval() - # define max_length - max_length = 31500 + model_name = args.model + # define your model + model, tokenizer = load_model_and_tokenizer(model2path[model_name], model_name, device) + max_length = model2maxlen[model_name] + if args.e: + datasets = ["qasper", "multifieldqa_en", "hotpotqa", "2wikimqa", "gov_report", "multi_news", \ + "trec", "triviaqa", "samsum", "passage_count", "passage_retrieval_en", "lcc", "repobench-p"] + else: + datasets = ["narrativeqa", "qasper", "multifieldqa_en", "multifieldqa_zh", "hotpotqa", "2wikimqa", "musique", \ + "dureader", "gov_report", "qmsum", "multi_news", "vcsum", "trec", "triviaqa", "samsum", "lsht", \ + "passage_count", "passage_retrieval_en", "passage_retrieval_zh", "lcc", "repobench-p"] # we design specific prompt format and max generation length for each task, feel free to modify them to optimize model output dataset2prompt = json.load(open("config/dataset2prompt.json", "r")) dataset2maxlen = json.load(open("config/dataset2maxlen.json", "r")) # predict on each dataset if not os.path.exists("pred"): os.makedirs("pred") + if not os.path.exists("pred_e"): + os.makedirs("pred_e") for dataset in datasets: - data = load_dataset('THUDM/LongBench', dataset, split='test') + if args.e: + data = load_dataset('THUDM/LongBench', f"{dataset}_e", split='test') + if not os.path.exists(f"pred_e/{model_name}"): + os.makedirs(f"pred_e/{model_name}") + out_path = f"pred_e/{model_name}/{dataset}.jsonl" + else: + data = load_dataset('THUDM/LongBench', dataset, split='test') + if not os.path.exists(f"pred/{model_name}"): + os.makedirs(f"pred/{model_name}") + out_path = f"pred/{model_name}/{dataset}.jsonl" prompt_format = dataset2prompt[dataset] max_gen = dataset2maxlen[dataset] - preds = get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset, device) - with open(f"pred/{dataset}.jsonl", "w", encoding="utf-8") as f: + preds = get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset, device, model_name) + with open(out_path, "w", encoding="utf-8") as f: for pred in preds: json.dump(pred, f, ensure_ascii=False) f.write('\n') \ No newline at end of file diff --git a/refs/ref.bib b/refs/ref.bib index b037c7326a8d16e50a8b373966323ab93025bbc2..d2eac4c549eda31b75644e5b6c342bae6ac0b565 100644 --- a/refs/ref.bib +++ b/refs/ref.bib @@ -81,14 +81,20 @@ year={2017} } -@article{kwiatkowski2019natural, - title={Natural questions: a benchmark for question answering research}, - author={Kwiatkowski, Tom and Palomaki, Jennimaria and Redfield, Olivia and Collins, Michael and Parikh, Ankur and Alberti, Chris and Epstein, Danielle and Polosukhin, Illia and Devlin, Jacob and Lee, Kenton and others}, - journal={Transactions of the Association for Computational Linguistics}, - volume={7}, - pages={453--466}, - year={2019}, - publisher={MIT Press One Rogers Street, Cambridge, MA 02142-1209, USA journals-info~…} +@article{gliwa2019samsum, + title={SAMSum Corpus: A Human-annotated Dialogue Dataset for Abstractive Summarization}, + author={Gliwa, Bogdan and Mochol, Iwona and Biesek, Maciej and Wawer, Aleksander}, + journal={EMNLP-IJCNLP 2019}, + pages={70}, + year={2019} +} + +@inproceedings{fabbri2019multi, + title={Multi-News: A Large-Scale Multi-Document Summarization Dataset and Abstractive Hierarchical Model}, + author={Fabbri, Alexander Richard and Li, Irene and She, Tianwei and Li, Suyi and Radev, Dragomir}, + booktitle={Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics}, + pages={1074--1084}, + year={2019} } @inproceedings{li2002learning, diff --git a/requirements.txt b/requirements.txt index 653e553a4bfb7ab5ed82a1ccb960d4399a98f189..c137186bbeed2c3c5c21c4ad3fdf5ff63a7c9b95 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ rouge jieba fuzzywuzzy torch -transformers==4.31.0 \ No newline at end of file +transformers==4.31.0 +einops \ No newline at end of file diff --git a/retrieval/BM25/BM25.sh b/retrieval/BM25/BM25.sh new file mode 100644 index 0000000000000000000000000000000000000000..022baf760294bc4574495444c57ce71ae8a08381 --- /dev/null +++ b/retrieval/BM25/BM25.sh @@ -0,0 +1,38 @@ +#!/bin/bash +chunk_size=500 +work_dir="../../LongBench" # dir for storing data + +source_dir="${work_dir}/data" # source LongBench dir +dest_dir=""${work_dir}/B${chunk_size}/data"" + +file_names=() +allowed_files=("multifieldqa_en.jsonl" "qasper.jsonl" "2wikimqa.jsonl" "dureader.jsonl" "hotpotqa.jsonl" "narrativeqa.jsonl" "musique.jsonl" "multifieldqa_zh.jsonl") +# store all jsonl files +while IFS= read -r -d '' file; do + base_name=$(basename "$file") + # Check if the file name is in the allowed_files list + if [[ " ${allowed_files[@]} " =~ " ${base_name} " ]]; then + file_names+=("$base_name") + fi +done < <(find "$source_dir" -type f -name "*.jsonl" -print0) + +# concurrent execution +group_size=3 + +for ((start=0; start<${#file_names[@]}; start+=group_size)); do + end=$((start + group_size - 1)) + echo "Index Range:$start ~ $end" + current_group=("${file_names[@]:start:group_size}") + for file in "${current_group[@]}"; do + fileName=$(basename "${file}") + python generate_BM25.py \ + --file_name $fileName \ + --source_dir $source_dir \ + --dest_dir $dest_dir \ + --chunk_size $chunk_size \ + & + done + wait +done + +cp ../LongBench.py "${work_dir}/B${chunk_size}" \ No newline at end of file diff --git a/retrieval/BM25/generate_BM25.py b/retrieval/BM25/generate_BM25.py new file mode 100644 index 0000000000000000000000000000000000000000..49f2edf20c422604cda907bcd0fd0377aa30584e --- /dev/null +++ b/retrieval/BM25/generate_BM25.py @@ -0,0 +1,77 @@ +from rank_bm25 import BM25Okapi + +import os +import json +from concurrent.futures import ThreadPoolExecutor, wait +from tqdm import tqdm +import argparse +import sys +sys.path.append('..') +from splitter import split_long_sentence, get_word_len, regex +# DEBUG +# os.chdir(os.path.dirname(os.path.abspath(__file__))) + +def retriveDoc(query: str, document: str, chunk_size, file_name:str, + js, output_list, idx, pbar=None, maxLen=1500): + # 1. Splits the context into pieces + texts = split_long_sentence(document, regex, chunk_size=chunk_size, filename=file_name) + # 2. Creates retriver, adds texts + retriever = BM25Okapi(texts) + # 3. Retrive and merge + retrieved_texts = retriever.get_top_n(query=query, documents=texts, + n=len(texts)) + retrieved_texts = [retrieved_texts] if type(retrieved_texts) == str else retrieved_texts + context = '' + for text in retrieved_texts: + if get_word_len(context) < maxLen: + context += text + js['retrieved'] = retrieved_texts if type(retrieved_texts) == list else [retrieved_texts] + js['context'] = context + js['length'] = get_word_len(context) + output_list[index] = js + if pbar: + pbar.update() + return + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + # parser.add_argument("--file_name", default='dureader.jsonl') + parser.add_argument("--file_name", default='2wikimqa.jsonl') + parser.add_argument("--source_dir", default='../../LongBench/data') + parser.add_argument("--dest_dir", default='./test') + parser.add_argument("--chunk_size", type=int, default=200) + args = parser.parse_args() + file_name = args.file_name + + print(f"------ {file_name} ------") + with open(os.path.join(args.source_dir, file_name), 'r', encoding='utf-8') as file: + file_contents = file.readlines() + output_data = [{}] * len(file_contents) + if (os.path.exists(os.path.join(args.dest_dir, file_name))): + with open(os.path.join(args.dest_dir, file_name), 'r', encoding='utf-8') as f: + lines = f.readlines() + lines = [line for line in lines] + output_data = [json.loads(line) for line in lines] + loop = tqdm(enumerate(file_contents), total=len(file_contents), desc=f'{file_name}') + exe_list = [] + with ThreadPoolExecutor(max_workers=10) as executor: + # for index, line in loop: + for index, line in enumerate(file_contents): + if (output_data[index] != {} or + "context" in output_data[index].keys() and len(output_data[index]['context']) != 0): + loop.update() + continue + line_js = json.loads(line) + retriveDoc(query=line_js['input'], document=line_js['context'], + chunk_size=args.chunk_size, file_name=file_name, + js=line_js, output_list=output_data, idx=index, pbar=loop) + # exe_list.append(executor.submit(retriveDoc, query=line_js['input'], document=line_js['context'], + # chunk_size=args.chunk_size, file_name=file_name, + # js=line_js, output_list=output_data, idx=index, pbar=loop)) + # loop.set_description(f'{file_name}') + wait(exe_list) + # saving + os.makedirs(args.dest_dir, exist_ok=True) + with open(os.path.join(args.dest_dir, file_name), 'w', encoding='utf-8') as output_file: + for item in output_data: + output_file.write(json.dumps(item, ensure_ascii=False) + '\n') diff --git a/retrieval/README.md b/retrieval/README.md new file mode 100644 index 0000000000000000000000000000000000000000..024f6242843e396a0750d6e2b70013f1c2dd011c --- /dev/null +++ b/retrieval/README.md @@ -0,0 +1,47 @@ +## Introduction +This folder is to conduct retrieval-based context compression on LongBench using 3 retrievers. +- BM25 +- [Contriever](https://github.com/facebookresearch/contriever) +- OpenAI Embedding ([text-embedding-ada-002](https://openai.com/blog/new-and-improved-embedding-model)) + +First, download the LongBench dataset from HuggingFace and save them in `../LongBench/`, resulting in the folder structure: +``` +LongBench/ + LongBench/ + data/ + Put raw LongBench data here. + 2wikimqa.jsonl + ... + retrieval/ + BM25/ + contriever/ + contriever/: github + mcontriever/: huggingface + embedding/ + README.md: This file. +``` +## Usage + +Install the requirements with pip: `pip install -r requirements.txt` + +### Retrieval + +We take contriever method as an example. +1. Clone contriever from https://github.com/facebookresearch/contriever +2. Replace the files in contriever directory with `contriever/passage_retrieval.py` and `contriever/generate_passage_embeddings.py` +3. Get mcontriever model from https://huggingface.co/facebook/mcontriever +4. run `mContriever.sh` +5. Each line within the JSONL file is expanded by adding a new item "retrieved", which represents the retrieval outcomes of the original context. These results are sorted according to the retriever's criteria. + +### Evaluation + +We take ChatGLM2-6B-32k as an example. First run [pred.py](pred.py): +```bash +python pred.py --model chatglm2-6b-32k --data C200 --top_k 7 +``` +Then evaluate via [eval.py](eval.py): +```bash +python eval.py --model chatglm2-6b-32k --data C200_7 +``` + +Then the evaluation files are in `result_chatglm2-6b-32k`. \ No newline at end of file diff --git a/retrieval/contriever/LB2mC.py b/retrieval/contriever/LB2mC.py new file mode 100644 index 0000000000000000000000000000000000000000..fa866b763f51051795c21d9e3b61b008235f7f08 --- /dev/null +++ b/retrieval/contriever/LB2mC.py @@ -0,0 +1,74 @@ +import os +import json +import pandas as pd +import argparse +import re +from tqdm import tqdm + +import sys +sys.path.append('..') +from splitter import split_long_sentence, regex +import concurrent.futures + +# DEBUG +# os.chdir(os.path.dirname(os.path.abspath(__file__))) + +parser = argparse.ArgumentParser() +parser.add_argument("--input_folder", type=str, default='../source/docqa_only') +parser.add_argument("--chunk_size", type=int, default=200) +parser.add_argument("--output_folder", type=str, default='../datasets/C200_t/split') +args = parser.parse_args() + + + +def process_jsonl_file(input_file, output_folder, chunk_size=100, filename='Unknown'): + with open(input_file, 'r', encoding='utf-8') as f_in: + lines = f_in.readlines() + # for idx, line in enumerate(lines): + loop = tqdm(lines, desc=filename) + for line in loop: + data = json.loads(line) + context = data.get('context', '') + chunks = split_long_sentence(context, regex, chunk_size, filename) + output_folder_name = os.path.join(output_folder, os.path.splitext(os.path.basename(input_file))[0]) + if not os.path.exists(output_folder_name): + os.makedirs(output_folder_name) + output_data = [] + for i, chunk in enumerate(chunks): + output_datum = { + 'id': data['_id'] + '_' + str(i), + 'text': chunk, + 'title': '' + } + output_data.append(output_datum) + output_data = pd.DataFrame(output_data, index=range(len(output_data))) + output_tsv_file = os.path.join(output_folder_name, data['_id'] + '.tsv') + output_data.to_csv(output_tsv_file, sep='\t', index=False) + + output_jsonl_file = os.path.join(output_folder_name, data['_id'] + '.jsonl') + output_data = { + 'id': data['_id'], + # 'lang': 'zh' if "_zh" in input_file else 'en', + 'lang' : 'zh' if 'zh' in data.get('context', '') else 'en', + 'question': data.get('input', ''), + 'answers': [] + } + with open(output_jsonl_file, 'w', encoding='utf-8') as f_out: + f_out.write(json.dumps(output_data, ensure_ascii=False) + '\n') + +def process_all_jsonl_files(input_folder, output_folder, chunk_size=1700): + if not os.path.exists(output_folder): + os.makedirs(output_folder) + + loop = tqdm(os.listdir(input_folder)) + allowed_files = ["multifieldqa_en.jsonl", "qasper.jsonl", "2wikimqa.jsonl", "dureader.jsonl", "hotpotqa.jsonl", "narrativeqa.jsonl", "musique.jsonl", "multifieldqa_zh.jsonl"] + for filename in loop: + if filename.endswith('.jsonl') and filename in allowed_files: + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + input_file = os.path.join(input_folder, filename) + loop.set_description(f"totalFile") + # process_jsonl_file(input_file, output_folder, chunk_size, filename) + executor.submit(process_jsonl_file, input_file, output_folder, chunk_size, filename) + # print("split {} done!".format(filename)) + +process_all_jsonl_files(args.input_folder, args.output_folder, chunk_size=args.chunk_size) diff --git a/retrieval/contriever/generate_passage_embeddings.py b/retrieval/contriever/generate_passage_embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..455152a7918caa99bc0cd4ab93d455d55d3915d3 --- /dev/null +++ b/retrieval/contriever/generate_passage_embeddings.py @@ -0,0 +1,128 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os + +import argparse +import csv +import logging +import pickle + +import numpy as np +import torch + +import transformers + +import src.slurm +import src.contriever +import src.utils +import src.data +import src.normalize_text + +def embed_passages(args, passages, model, tokenizer): + total = 0 + allids, allembeddings = [], [] + batch_ids, batch_text = [], [] + with torch.no_grad(): + for k, p in enumerate(passages): + batch_ids.append(p["id"]) + if args.no_title or not "title" in p: + text = p["text"] + else: + text = p["title"] + " " + p["text"] + if args.lowercase: + text = text.lower() + if args.normalize_text: + text = src.normalize_text.normalize(text) + batch_text.append(text) + + if len(batch_text) == args.per_gpu_batch_size or k == len(passages) - 1: + + encoded_batch = tokenizer.batch_encode_plus( + batch_text, + return_tensors="pt", + max_length=args.passage_maxlength, + padding=True, + truncation=True, + ) + + encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()} + embeddings = model(**encoded_batch) + + embeddings = embeddings.cpu() + total += len(batch_ids) + allids.extend(batch_ids) + allembeddings.append(embeddings) + + batch_text = [] + batch_ids = [] + if k % 100000 == 0 and k > 0: + print(f"Encoded passages {total}") + if [] != allembeddings: + allembeddings = torch.cat(allembeddings, dim=0).numpy() + return allids, allembeddings + + +def main(args): + model, tokenizer, _ = src.contriever.load_retriever(args.model_name_or_path) + print(f"Model loaded from {args.model_name_or_path}.", flush=True) + model.eval() + model = model.cuda() + if not args.no_fp16: + model = model.half() + for psg in args.psgs_list: + passages = src.data.load_passages(psg) + + shard_size = len(passages) // args.num_shards + start_idx = args.shard_id * shard_size + end_idx = start_idx + shard_size + if args.shard_id == args.num_shards - 1: + end_idx = len(passages) + + passages = passages[start_idx:end_idx] + # print(f"Embedding generation for {len(passages)} passages from idx {start_idx} to {end_idx}.") + + allids, allembeddings = embed_passages(args, passages, model, tokenizer) + + # save_file = os.path.join(args.output_dir, args.prefix + f"_{args.shard_id:02d}") + def get_file_name_without_extension(file_path): + base_name = os.path.basename(file_path) # 获取文件名 + file_name_without_extension = os.path.splitext(base_name)[0] # 去除后缀 + return file_name_without_extension + fileName = get_file_name_without_extension(psg) + save_file = os.path.join(args.output_dir, fileName) + os.makedirs(args.output_dir, exist_ok=True) + print(f"Saving {len(allids)} passage embeddings to {save_file}.") + with open(save_file, mode="wb") as f: + pickle.dump((allids, allembeddings), f) + + print(f"Total passages processed {len(allids)}. Written to {save_file}.") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--psgs_list", nargs='+', required=True) + # parser.add_argument("--passages", type=str, default=None, help="Path to passages (.tsv file)") + parser.add_argument("--output_dir", type=str, default="wikipedia_embeddings", help="dir path to save embeddings") + parser.add_argument("--prefix", type=str, default="passages", help="prefix path to save embeddings") + parser.add_argument("--shard_id", type=int, default=0, help="Id of the current shard") + parser.add_argument("--num_shards", type=int, default=1, help="Total number of shards") + parser.add_argument( + "--per_gpu_batch_size", type=int, default=512, help="Batch size for the passage encoder forward pass" + ) + parser.add_argument("--passage_maxlength", type=int, default=512, help="Maximum number of tokens in a passage") + parser.add_argument( + "--model_name_or_path", type=str, help="path to directory containing model weights and config file" + ) + parser.add_argument("--no_fp16", action="store_true", help="inference in fp32") + parser.add_argument("--no_title", action="store_true", help="title not added to the passage body") + parser.add_argument("--lowercase", action="store_true", help="lowercase text before encoding") + parser.add_argument("--normalize_text", action="store_true", help="lowercase text before encoding") + + args = parser.parse_args() + + src.slurm.init_distributed_mode(args) + + main(args) diff --git a/retrieval/contriever/mContriever.sh b/retrieval/contriever/mContriever.sh new file mode 100644 index 0000000000000000000000000000000000000000..ca8cfe3a790dbce52e20dff3ade177b7a573633e --- /dev/null +++ b/retrieval/contriever/mContriever.sh @@ -0,0 +1,81 @@ +#!/bin/bash +chunk_size=200 +work_dir="../../LongBench" # dir for storing data + +source_dir="${work_dir}/data" # source LongBench dir +chunk_dir="C${chunk_size}" +split_dir="${work_dir}/${chunk_dir}/split" +embed_dir="${work_dir}/${chunk_dir}/embed" +retrieved_dir="${work_dir}/${chunk_dir}/output" +python LB2mC.py \ + --chunk_size ${chunk_size} \ + --output_folder ${split_dir}\ + --input_folder ${source_dir} +folder_names=() +# Traverse all subfolders under `split` dir +for folder in "$split_dir"/*; do + if [ -d "$folder" ]; then + # get the name of subfolder + folder_name=$(basename "$folder") + # concat + folder_path="$split_dir/$folder_name" + echo "$folder_path" + folder_names+=("$folder_name") + fi +done + +# Traverse all subfolders under `split` dir +for folder in "${folder_names[@]}"; do + file_paths=() + # Traverse all files in a subfolder + for file in "$split_dir"/"$folder"/*.tsv; do + if [ -f "$file" ]; then + fileName=$(basename "${file%.*}") + file_paths+=("${split_dir}/${folder}/${fileName}.tsv") + fi + done + # Converts an array to a ' ' separated string + files_str=$(IFS=' '; echo "${file_paths[*]}") + # generate embeddings + python ./contriever/generate_passage_embeddings.py \ + --model_name_or_path ./contriever/mcontriever \ + --output_dir ${embed_dir}/${folder} \ + --psgs_list $files_str\ + --shard_id 0 --num_shards 1 \ + --lowercase --normalize_text + + # generate results of retrieval + tsv_files=("$split_dir/$folder"/*.tsv) + # concurrent execution + group_size=5 + + for ((start=0; start<${#tsv_files[@]}; start+=group_size)); do + end=$((start + group_size - 1)) + echo "Index Range:$start ~ $end" + current_group=("${tsv_files[@]:start:group_size}") + + for ((index=0; index<${#current_group[@]}; index+=1)); do + file=${current_group[index]} + fileName=$(basename "${file%.*}") + python ./contriever/passage_retrieval.py \ + --model_name_or_path ./contriever/mcontriever \ + --passages ${split_dir}/${folder}/${fileName}.tsv \ + --passages_embeddings ${embed_dir}/${folder}/${fileName} \ + --data ${split_dir}/${folder}/${fileName}.jsonl \ + --output_dir ${retrieved_dir}/${folder} \ + --lowercase --normalize_text \ + --device "cuda" \ + & + # --device "cuda:$(expr 4 + $index % 4)" \ + done + wait + done + + python merge_output.py \ + --input_folder "${retrieved_dir}/${folder}" \ + --output_file "${work_dir}/${chunk_dir}/mc2LB/${folder}.jsonl" \ + --input_dataFile "${source_dir}/${folder}.jsonl" \ + --output_dataFile "${work_dir}/${chunk_dir}/data/${folder}.jsonl" +done + +cp ../LongBench.py "${work_dir}/${chunk_dir}" \ No newline at end of file diff --git a/retrieval/contriever/merge_output.py b/retrieval/contriever/merge_output.py new file mode 100644 index 0000000000000000000000000000000000000000..e027359d1663a622916d7d803b1ea7290501ebea --- /dev/null +++ b/retrieval/contriever/merge_output.py @@ -0,0 +1,68 @@ +import os +import json +import argparse +from tqdm import tqdm +import sys +sys.path.append('..') +from splitter import get_word_len + +# os.chdir(os.path.dirname(os.path.abspath(__file__))) +parser = argparse.ArgumentParser() +parser.add_argument('--input_folder', type=str, default='mcontriever_output', + help='Path to the input folder containing jsonl files.') +parser.add_argument('--output_file', type=str, default='CONTENT.jsonl', + help='Output jsonl file name.') +parser.add_argument('--input_dataFile', type=str, default='inputData.jsonl', + help='Input datum jsonl file name.') +parser.add_argument('--output_dataFile', type=str, default='DATA.jsonl', + help='Output datum jsonl file name.') +args = parser.parse_args() + + +def merge_text(jsonl_file, maxLen=1500): + with open(jsonl_file, 'r', encoding='utf-8') as f: + data_list = json.load(f) + context_list = data_list['ctxs'] + merged_text = '' + retrieved = [] + for item in context_list: + if get_word_len(merged_text) < maxLen: + merged_text += item['text'] + '\n\n' + retrieved += [item['text']] + output_data = { + 'context': merged_text, + 'id': data_list['id'], + 'retrieved': retrieved + } + + return output_data + +def process_all_jsonl_files(args): + input_folder = args.input_folder + output_file = args.output_file + # data_name = os.path.basename(os.path.normpath(input_folder)) + os.makedirs(os.path.dirname(output_file), exist_ok=True) + output_data_list = [] + with open(output_file, 'w', encoding='utf-8') as f_out: + # print("input_folder", input_folder) + loop = tqdm(os.listdir(input_folder), desc="merge") + for filename in loop: + if filename.endswith('.jsonl'): + jsonl_file = os.path.join(input_folder, filename) + output_data = merge_text(jsonl_file) + output_data_list += [output_data] + f_out.write(json.dumps(output_data, ensure_ascii=False) + '\n') + os.makedirs(os.path.dirname(args.output_dataFile), exist_ok=True) + with open(args.input_dataFile, 'r', encoding='utf-8') as in_data: + with open(args.output_dataFile, 'w', encoding='utf-8') as out_data: + for line in in_data: + data_l = json.loads(line) + for modified_data in output_data_list: + if data_l['_id'] == modified_data['id']: + data_l['context'] = modified_data['context'] + data_l['length'] = get_word_len(data_l['context']) + data_l['retrieved'] = modified_data['retrieved'] + break + out_data.write(json.dumps(data_l, ensure_ascii=False) + '\n') + +process_all_jsonl_files(args) diff --git a/retrieval/contriever/passage_retrieval.py b/retrieval/contriever/passage_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..c77ed94b1e0b3639357993e25c35c0ca370f05fc --- /dev/null +++ b/retrieval/contriever/passage_retrieval.py @@ -0,0 +1,254 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import argparse +import csv +import json +import logging +import pickle +import time +import glob +from pathlib import Path + +import numpy as np +import torch +import transformers + +import src.index +import src.contriever +import src.utils +import src.slurm +import src.data +from src.evaluation import calculate_matches +import src.normalize_text + +os.environ["TOKENIZERS_PARALLELISM"] = "true" +# os.chdir(os.path.dirname(os.path.abspath(__file__))) + +def embed_queries(args, queries, model, tokenizer): + model.eval() + embeddings, batch_question = [], [] + with torch.no_grad(): + + for k, q in enumerate(queries): + if args.lowercase: + q = q.lower() + if args.normalize_text: + q = src.normalize_text.normalize(q) + batch_question.append(q) + + if len(batch_question) == args.per_gpu_batch_size or k == len(queries) - 1: + + encoded_batch = tokenizer.batch_encode_plus( + batch_question, + return_tensors="pt", + max_length=args.question_maxlength, + padding=True, + truncation=True, + ) + encoded_batch = {k: v.to(args.device) for k, v in encoded_batch.items()} + output = model(**encoded_batch) + embeddings.append(output.cpu()) + + batch_question = [] + + embeddings = torch.cat(embeddings, dim=0) + print(f"Questions embeddings shape: {embeddings.size()}") + + return embeddings.numpy() + + +def index_encoded_data(index, embedding_files, indexing_batch_size): + allids = [] + allembeddings = np.array([]) + for i, file_path in enumerate(embedding_files): + print(f"Loading file {file_path}") + with open(file_path, "rb") as fin: + ids, embeddings = pickle.load(fin) + + allembeddings = np.vstack((allembeddings, embeddings)) if allembeddings.size else embeddings + allids.extend(ids) + while allembeddings.shape[0] > indexing_batch_size: + allembeddings, allids = add_embeddings(index, allembeddings, allids, indexing_batch_size) + + while allembeddings.shape[0] > 0: + allembeddings, allids = add_embeddings(index, allembeddings, allids, indexing_batch_size) + + # print("Data indexing completed.") + + +def add_embeddings(index, embeddings, ids, indexing_batch_size): + end_idx = min(indexing_batch_size, embeddings.shape[0]) + ids_toadd = ids[:end_idx] + embeddings_toadd = embeddings[:end_idx] + ids = ids[end_idx:] + embeddings = embeddings[end_idx:] + index.index_data(ids_toadd, embeddings_toadd) + return embeddings, ids + + +def validate(data, workers_num): + match_stats = calculate_matches(data, workers_num) + top_k_hits = match_stats.top_k_hits + + print("Validation results: top k documents hits %s", top_k_hits) + top_k_hits = [v / len(data) for v in top_k_hits] + message = "" + for k in [5, 10, 20, 100]: + if k <= len(top_k_hits): + message += f"R@{k}: {top_k_hits[k-1]} " + print(message) + return match_stats.questions_doc_hits + + +def add_passages(data, passages, top_passages_and_scores): + # add passages to original data + merged_data = [] + assert len(data) == len(top_passages_and_scores) + for i, d in enumerate(data): + results_and_scores = top_passages_and_scores[i] + docs = [passages[doc_id] for doc_id in results_and_scores[0]] + scores = [str(score) for score in results_and_scores[1]] + ctxs_num = len(docs) + d["ctxs"] = [ + { + "id": results_and_scores[0][c], + "title": docs[c]["title"], + "text": docs[c]["text"], + "score": scores[c], + } + for c in range(ctxs_num) + ] + + +def add_hasanswer(data, hasanswer): + # add hasanswer to data + for i, ex in enumerate(data): + for k, d in enumerate(ex["ctxs"]): + d["hasanswer"] = hasanswer[i][k] + +def load_data(data_path): + if data_path.endswith(".json"): + with open(data_path, "r") as fin: + data = json.load(fin) + elif data_path.endswith(".jsonl"): + data = [] + with open(data_path, "r") as fin: + for k, example in enumerate(fin): + example = json.loads(example) + data.append(example) + return data + + +def main(args): + + print(f"Loading model from: {args.model_name_or_path}") + model, tokenizer, _ = src.contriever.load_retriever(args.model_name_or_path) + model.eval() + model = model.to(args.device) + if not args.no_fp16: + model = model.half() + + index = src.index.Indexer(args.projection_size, args.n_subquantizers, args.n_bits) + + # index all passages + input_paths = glob.glob(args.passages_embeddings) + input_paths = sorted(input_paths) + embeddings_dir = os.path.dirname(input_paths[0]) + index_path = os.path.join(embeddings_dir, "index.faiss") + if args.save_or_load_index and os.path.exists(index_path): + index.deserialize_from(embeddings_dir) + else: + # print(f"Indexing passages from files {input_paths}") + start_time_indexing = time.time() + index_encoded_data(index, input_paths, args.indexing_batch_size) + print(f"Indexing time: {time.time()-start_time_indexing:.1f} s.") + if args.save_or_load_index: + index.serialize(embeddings_dir) + + # load passages + passages = src.data.load_passages(args.passages) + passage_id_map = {x["id"]: x for x in passages} + + data_paths = glob.glob(args.data) + alldata = [] + for path in data_paths: + data = load_data(path) + output_path = os.path.join(args.output_dir, os.path.basename(path)) + + queries = [ex["question"] for ex in data] + questions_embedding = embed_queries(args, queries, model, tokenizer) + + # get top k results + start_time_retrieval = time.time() + # top_ids_and_scores = index.search_knn(questions_embedding, args.n_docs) + top_ids_and_scores = index.search_knn(questions_embedding, len(passages)) + print(f"{len(passages)} psgs: Search time: {time.time()-start_time_retrieval:.1f} s.") + + add_passages(data, passage_id_map, top_ids_and_scores) + # hasanswer = validate(data, args.validation_workers) + # add_hasanswer(data, hasanswer) + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, "w") as fout: + for ex in data: + json.dump(ex, fout, ensure_ascii=False) + fout.write("\n") + print(f"Saved results to {output_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--data", + # required=True, + type=str, + + default='./data_longbench/split/2wikimqa/0a64d8873482d91efc595a508218c6ce881c13c95028039e.jsonl', + help=".json file containing question and answers, similar format to reader data", + ) + parser.add_argument("--passages", type=str, default='./data_longbench/split/2wikimqa/0a64d8873482d91efc595a508218c6ce881c13c95028039e.tsv', + help="Path to passages (.tsv file)") + parser.add_argument("--passages_embeddings", type=str, default='./data_longbench/mEmbeddings/2wikimqa/0a64d8873482d91efc595a508218c6ce881c13c95028039e.jsonl', + help="Glob path to encoded passages") + parser.add_argument( + "--output_dir", type=str, default=None, help="Results are written to outputdir with data suffix" + ) + parser.add_argument("--n_docs", type=int, default=100, help="Number of documents to retrieve per questions") + parser.add_argument( + "--validation_workers", type=int, default=32, help="Number of parallel processes to validate results" + ) + parser.add_argument("--per_gpu_batch_size", type=int, default=64, help="Batch size for question encoding") + parser.add_argument( + "--save_or_load_index", action="store_true", help="If enabled, save index and load index if it exists" + ) + parser.add_argument( + "--model_name_or_path", type=str, default='./../mcontriever', + help="path to directory containing model weights and config file" + ) + parser.add_argument("--no_fp16", action="store_true", help="inference in fp32") + parser.add_argument("--question_maxlength", type=int, default=512, help="Maximum number of tokens in a question") + parser.add_argument( + "--indexing_batch_size", type=int, default=1000000, help="Batch size of the number of passages indexed" + ) + parser.add_argument("--projection_size", type=int, default=768) + parser.add_argument( + "--n_subquantizers", + type=int, + default=0, + help="Number of subquantizer used for vector quantization, if 0 flat index is used", + ) + parser.add_argument("--n_bits", type=int, default=8, help="Number of bits per subquantizer") + parser.add_argument("--lang", nargs="+") + parser.add_argument("--dataset", type=str, default="none") + parser.add_argument("--lowercase", action="store_true", default=True, help="lowercase text before encoding") + parser.add_argument("--normalize_text", action="store_true", default=True, help="normalize text") + parser.add_argument("--device", type=str, default='cuda', help="normalize text") + + args = parser.parse_args() + src.slurm.init_distributed_mode(args) + main(args) diff --git a/retrieval/embedding/generate_openai_embedding.py b/retrieval/embedding/generate_openai_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..421d2f42afe71f0edabd495091afe9099155ff28 --- /dev/null +++ b/retrieval/embedding/generate_openai_embedding.py @@ -0,0 +1,99 @@ +import openai +from openai.embeddings_utils import cosine_similarity +openai.api_key="KEY" +openai.proxy="" + +import os +import json +from concurrent.futures import ThreadPoolExecutor, wait +from tqdm import tqdm +import argparse +import sys +sys.path.append("..") +from splitter import split_long_sentence, get_word_len, regex +# os.chdir(os.path.dirname(os.path.abspath(__file__))) + +def retriveDoc(query: str, document: str, chunk_size, file_name:str, + js, output_list, idx, pbar=None, maxLen=1500): + # 1. Splits the context into pieces + texts = split_long_sentence(document, regex, chunk_size=chunk_size, filename=file_name) + # 2. Creates retriver, adds texts + # 3. Retrives and merges + # https://platform.openai.com/docs/api-reference/embeddings/object?lang=python + texts_embeddings = openai.Embedding.create( + model="text-embedding-ada-002", + input=texts + ) + query_embeddings = openai.Embedding.create( + model="text-embedding-ada-002", + input=query + ) + similarity = [] + for emb in texts_embeddings['data']: + similarity.append(cosine_similarity(emb['embedding'], query_embeddings['data'][0]['embedding'])) + sorted_pairs=sorted(zip(similarity, texts), reverse=True) + retrieved_texts = [pair[1] for pair in sorted_pairs] + retrieved_texts = [retrieved_texts] if type(retrieved_texts) == str else retrieved_texts + context = '' + for text in retrieved_texts: + if get_word_len(context) < maxLen: + context += text + js['retrieved'] = retrieved_texts if type(retrieved_texts) == list else [retrieved_texts] + js['context'] = context + js['length'] = get_word_len(context) + output_list[index] = js + if pbar: + pbar.update() + return + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + # parser.add_argument("--file_name", default='dureader.jsonl') + parser.add_argument("--file_name", default='musique.jsonl') + parser.add_argument("--source_dir", default='../source/docqa_only') + parser.add_argument("--dest_dir", default='./test') + parser.add_argument("--chunk_size", type=int, default=200) + args = parser.parse_args() + file_name = args.file_name + + print(f"------ {file_name} ------") + with open(os.path.join(args.source_dir, file_name), 'r', encoding='utf-8') as file: + file_contents = file.readlines() + # DEBUG + # file_contents = file_contents[:10] + # with tqdm(total=len(file_contents)) as pbar, ThreadPoolExecutor(max_workers=1) as executor: + output_data = [{}] * len(file_contents) + if (os.path.exists(os.path.join(args.dest_dir, file_name))): + with open(os.path.join(args.dest_dir, file_name), 'r', encoding='utf-8') as f: + lines = f.readlines() + lines = [line for line in lines] + output_data = [json.loads(line) for line in lines] + def saving(): + os.makedirs(args.dest_dir, exist_ok=True) + with open(os.path.join(args.dest_dir, file_name), 'w', encoding='utf-8') as output_file: + for item in output_data: + output_file.write(json.dumps(item, ensure_ascii=False) + '\n') + loop = tqdm(enumerate(file_contents), total=len(file_contents), desc=f'{file_name}') + exe_list = [] + with ThreadPoolExecutor(max_workers=3) as executor: + # for index, line in loop: + for index, line in enumerate(file_contents): + if (output_data[index] != {} or + "context" in output_data[index].keys() and len(output_data[index]['context']) != 0): + loop.update() + continue + line_js = json.loads(line) + + try: + # retriveDoc(query=line_js['input'], document=line_js['context'], + # chunk_size=args.chunk_size, file_name=file_name, + # js=line_js, output_list=output_data, idx=index, pbar=loop) + exe_list.append(executor.submit(retriveDoc, query=line_js['input'], document=line_js['context'], + chunk_size=args.chunk_size, file_name=file_name, + js=line_js, output_list=output_data, idx=index, pbar=loop)) + except Exception as e: + saving() + print(e) + wait(exe_list) + saving() + diff --git a/retrieval/embedding/openai_embedding.sh b/retrieval/embedding/openai_embedding.sh new file mode 100644 index 0000000000000000000000000000000000000000..7af5564c7d81b7cd71a3c1d1ff1f0e8807fb9e5e --- /dev/null +++ b/retrieval/embedding/openai_embedding.sh @@ -0,0 +1,39 @@ +#!/bin/bash +chunk_size=500 +work_dir="../../LongBench" # dir for storing data + +source_dir="${work_dir}/data" # source LongBench dir +dest_dir=""${work_dir}/E${chunk_size}/data"" + +file_names=() + +allowed_files=("multifieldqa_en.jsonl" "qasper.jsonl" "2wikimqa.jsonl" "dureader.jsonl" "hotpotqa.jsonl" "narrativeqa.jsonl" "musique.jsonl" "multifieldqa_zh.jsonl") +# store all jsonl files +while IFS= read -r -d '' file; do + base_name=$(basename "$file") + # Check if the file name is in the allowed_files list + if [[ " ${allowed_files[@]} " =~ " ${base_name} " ]]; then + file_names+=("$base_name") + fi +done < <(find "$source_dir" -type f -name "*.jsonl" -print0) + +# concurrent execution +group_size=3 + +for ((start=0; start<${#file_names[@]}; start+=group_size)); do + end=$((start + group_size - 1)) + echo "Index Range:$start ~ $end" + current_group=("${file_names[@]:start:group_size}") + for file in "${current_group[@]}"; do + fileName=$(basename "${file}") + python generate_openai_embedding.py \ + --file_name $fileName \ + --source_dir $source_dir \ + --dest_dir $dest_dir \ + --chunk_size $chunk_size \ + & + done + wait +done + +cp ../LongBench.py "${work_dir}/E${chunk_size}" \ No newline at end of file diff --git a/retrieval/pred.py b/retrieval/pred.py new file mode 100644 index 0000000000000000000000000000000000000000..9874df89d45477eb97a6dd4c38b3b25631f9e3bb --- /dev/null +++ b/retrieval/pred.py @@ -0,0 +1,158 @@ +import os +from datasets import load_dataset +import torch +import json +from transformers import AutoTokenizer, AutoModel, LlamaTokenizer, LlamaForCausalLM, AutoModelForCausalLM +from tqdm import tqdm +import argparse +# DEBUG +# os.chdir(os.path.dirname(os.path.abspath(__file__))) + +def parse_args(args=None): + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default="chatglm2-6b") + parser.add_argument("--top_k", type=int, default=3) + parser.add_argument("--data", type=str, default="B500") + return parser.parse_args(args) + +# This is the customized building prompt for chat models, here is an example for ChatGLM2 +def build_chat(tokenizer, prompt, model_name): + if "chatglm" in model_name: + prompt = tokenizer.build_prompt(prompt) + elif "longchat" in model_name or "vicuna" in model_name: + from fastchat.model import get_conversation_template + conv = get_conversation_template("vicuna") + conv.append_message(conv.roles[0], prompt) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + elif "llama2" in model_name: + prompt = f"[INST]{prompt}[/INST]" + elif "xgen" in model_name: + header = ( + "A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n" + ) + prompt = header + f" ### Human: {prompt}\n###" + elif "internlm" in model_name: + prompt = f"<|User|>:{prompt}<eoh>\n<|Bot|>:" + return prompt + +def get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset, device, model_name, args): + preds = [{}] * len(data) + if os.path.exists(f"{args.model}_pred_{args.data}_{args.top_k}/{dataset}.jsonl"): + with open(f"{args.model}_pred_{args.data}_{args.top_k}/{dataset}.jsonl", "r", encoding="utf-8") as f: + for index, item in enumerate(f): + preds[index] = json.loads(item) + for index, json_obj in enumerate(tqdm(data, desc=f"{dataset}")): + if preds[index] != {}: + continue + if args.top_k != 0: + json_obj['context'] = "".join(json_obj['retrieved'][:args.top_k]) + prompt = prompt_format.format(**json_obj) + prompt = build_chat(tokenizer, prompt, model_name) + if "chatgpt" in model_name: + output = openai.ChatCompletion.create(model="gpt-3.5-turbo-16k", + messages=[{"role": "user", "content": prompt}], max_tokens=max_gen, + temperature=1.0) + pred = output['choices'][0]['message']['content'] + context_length = output['usage']['prompt_tokens'] + else: + # truncate to fit max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions) + tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0] + if len(tokenized_prompt) > max_length: + half = int(max_length/2) + prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True) + if dataset not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]: # chat models are better off without build prompt on these tasks + prompt = build_chat(tokenizer, prompt, model_name) + context_length = input.input_ids.shape[-1] + + input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device) + if dataset == "samsum": # prevent illegal output on samsum (model endlessly repeat "\nDialogue"), might be a prompting issue + output = model.generate( + **input, + max_new_tokens=max_gen, + num_beams=1, + do_sample=False, + temperature=1.0, + min_length=context_length+1, + eos_token_id=[tokenizer.eos_token_id, tokenizer.encode("\n", add_special_tokens=False)[-1]], + )[0] + else: + output = model.generate( + **input, + max_new_tokens=max_gen, + num_beams=1, + do_sample=False, + temperature=1.0, + )[0] + pred = tokenizer.decode(output[context_length:], skip_special_tokens=True) + pred = post_process(pred, model_name) + preds[index] = {"pred": pred, "answers": json_obj["answers"], "all_classes": json_obj["all_classes"], + "context_length": context_length} + with open(f"{args.model}_pred_{args.data}_{args.top_k}/{dataset}.jsonl", "w", encoding="utf-8") as f: + for pred in preds: + json.dump(pred, f, ensure_ascii=False) + f.write('\n') + return preds + +def post_process(response, model_name): + if "xgen" in model_name: + response = response.strip().replace("Assistant:", "") + elif "internlm" in model_name: + response = response.split("<eoa>")[0] + return response + +def load_model_and_tokenizer(model2path, model_name, device): + if "chatgpt" in model_name: + return model_name, model_name + else: + if "chatglm" in model_name or "internlm" in model_name or "xgen" in model_name: + tokenizer = AutoTokenizer.from_pretrained(model2path[model_name], trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained(model2path[model_name], trust_remote_code=True, torch_dtype=torch.bfloat16).to(device) + elif "llama2" in model_name: + tokenizer = LlamaTokenizer.from_pretrained(model2path[model_name]) + model = LlamaForCausalLM.from_pretrained(model2path[model_name], torch_dtype=torch.bfloat16).to(device) + elif "longchat" in model_name or "vicuna" in model_name: + from fastchat.model import load_model + model, _ = load_model( + model2path[model_name], + device='cpu', + num_gpus=0, + load_8bit=False, + cpu_offloading=False, + debug=False, + ) + model = model.to(device) + model = model.bfloat16() + tokenizer = AutoTokenizer.from_pretrained(model2path[model_name], trust_remote_code=True, use_fast=False) + model = model.eval() + return model, tokenizer + +if __name__ == '__main__': + args = parse_args() + model_name = args.model + if "chatgpt" in model_name: + import openai + # openai.api_base="" + openai.api_key = "YOUR_KEY" + # Retrieval is fit for these datasets + datasets = ["multifieldqa_en", "qasper", "2wikimqa", "dureader", \ + "hotpotqa", "narrativeqa", "musique", "multifieldqa_zh"] + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + # load configs + model2path = json.load(open("../config/model2path.json", "r")) + model2maxlen = json.load(open("../config/model2maxlen.json", "r")) + # we design specific prompt format and max generation length for each task, feel free to modify them to optimize model output + dataset2prompt = json.load(open("../config/dataset2prompt.json", "r")) + dataset2maxlen = json.load(open("../config/dataset2maxlen.json", "r")) + # define your model + model, tokenizer = load_model_and_tokenizer(model2path, model_name, device) + max_length = model2maxlen[model_name] + # predict on each dataset + os.makedirs(f"{args.model}_pred_{args.data}_{args.top_k}", exist_ok=True) + for dataset in datasets: + data = load_dataset(f'../LongBench/{args.data}/LongBench.py', dataset, split='test', + download_mode='force_redownload') # force to load from dir + prompt_format = dataset2prompt[dataset] + max_gen = dataset2maxlen[dataset] + preds = get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset, device, model_name, args) \ No newline at end of file diff --git a/retrieval/requirements.txt b/retrieval/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..882720937931c2ec5c884d6314f042fca6dcb7a0 --- /dev/null +++ b/retrieval/requirements.txt @@ -0,0 +1,3 @@ +rank_bm25 +openai +beir \ No newline at end of file diff --git a/retrieval/splitter.py b/retrieval/splitter.py new file mode 100644 index 0000000000000000000000000000000000000000..a29ef753210f3f4b2226023628603266f5817b66 --- /dev/null +++ b/retrieval/splitter.py @@ -0,0 +1,45 @@ +import re +def split_long_sentence(sentence, regex, chunk_size=200, filename='Unknown'): + chunks = [] + sentences = re.split(regex, sentence) + current_chunk = "" + for s in sentences: + if current_chunk and get_word_len(current_chunk) + get_word_len(s) <= chunk_size: + current_chunk += ' ' if s == '' else s + else: + if current_chunk: + chunks.append(current_chunk) + # if (len(current_chunk) > chunk_size*5): + current_len = get_word_len(current_chunk) + if (current_len > chunk_size * 1.5): + print(f"\n{filename}-{len(chunks)-1} Chunk size: {current_len}") + + current_chunk = s + + if current_chunk: + chunks.append(current_chunk) + + return chunks + +def get_word_list(s1): + # Separate sentences by word, Chinese by word, English by word, numbers by space + regEx = re.compile('[\W]') + res = re.compile(r"([\u4e00-\u9fa5])") # [\u4e00-\u9fa5] for Chinese + + p1 = regEx.split(s1.lower()) + str1_list = [] + for str in p1: + if res.split(str) == None: + str1_list.append(str) + else: + ret = res.split(str) + for ch in ret: + str1_list.append(ch) + + list_word1 = [w for w in str1_list if len(w.strip()) > 0] + + return list_word1 +def get_word_len(s1): + return len(get_word_list(s1)) + +regex = r'([。?!;\n.!?;]\s*)' diff --git a/summ/.DS_Store b/summ/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..e5d473b0db5c1d6cd9af846974e8630679a815d7 Binary files /dev/null and b/summ/.DS_Store differ diff --git a/summ/README.md b/summ/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d60c37e05f67b6cd8505c9a43eaba61e414d4863 --- /dev/null +++ b/summ/README.md @@ -0,0 +1,30 @@ +## Introduction +The `compress.py` script is designed to address the challenge of processing excessively long texts using large language models. It provides a solution to reduce both the runtime and inference costs of running these models on long texts. The script works by segmenting the lengthy input into smaller segments, processing them individually through a chosen model, and then concatenating the outputs to reconstruct the final summarized result. + +## Applicability +The `compress.py` script is particularly useful for summarization tasks involving the following datasets: +- `qmsum.jsonl` +- `gov_report.jsonl` +- `vcsum.jsonl` +- `multinews.jsonl` + +## Motivation +The motivation behind this script is to tackle the challenges posed by processing extremely lengthy texts using large language models. Running these models on overly long texts can result in excessive execution times and high inference costs. The script addresses this by dividing the input text into manageable segments, summarizing them using a model, and finally merging the summaries to reconstruct the original long text. + +## Usage +1. Clone this repository. +2. Open the `compress.py` script in a text editor. +3. Replace the keys or paths for the models (`glm2`, `gpt-16k`, `Llama2`) with your desired paths. +4. Define the paths for your raw data folder (`raw_data_folder`), the new data folder (`new_data_folder`), and the folder for storing compressed context texts (`compressed_context_path`). +5. Uncomment the code related to `flash_attn` if you wish to accelerate Llama2's inference using flash attention. If not needed, you can comment out this section. +6. Save the changes. + +## Features +- Supports checkpoint resumption, allowing you to resume processing from where it was last paused. +- Automatically loads files from checkpoints. +- Utilizes `flash_attn` to speed up Llama2 inference (can be commented out if not applicable). + +## Example +Here's an example of how to use the script: +```shell +python compress.py --model glm2 --max_len 500 diff --git a/summ/compress.py b/summ/compress.py new file mode 100644 index 0000000000000000000000000000000000000000..b7180ad8e58f3290a427ec99c28bdf7ff8ec5655 --- /dev/null +++ b/summ/compress.py @@ -0,0 +1,342 @@ +import requests +import json +import jsonlines +from transformers import AutoTokenizer, AutoModel +import matplotlib.pyplot as plt +from tqdm import tqdm +import concurrent.futures +from concurrent.futures import ThreadPoolExecutor +import multiprocessing +import argparse +from transformers import LlamaForCausalLM, LlamaTokenizer +import torch +import os +from threading import Lock +import time +import re +from llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn +parser = argparse.ArgumentParser() +parser.add_argument('--max_len', type=int, default=500) +parser.add_argument('--model', type=str, default="Llama2") # glm2, gpt-16k, Llama2 +args = parser.parse_args() +print(args) +GPT_key = "" #openai key +GPT_MODEL = 'gpt-3.5-turbo-16k' +GLM_MODEL = "THUDM/chatglm2-6b-32k" +LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf" +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +jsonl_files = ['qmsum.jsonl', 'gov_report.jsonl', 'vcsum.jsonl','multinews.jsonl'] +raw_data_folder = '../LongBench/data' #raw data folder +new_data_folder = args.model+'_'+str(args.max_len)+'/data' #compressed data folder +compressed_context_path ='../LongBench/compressed_data_'+str(args.max_len)+'/data' #compressed context folder +if not os.path.exists(new_data_folder): + os.makedirs(new_data_folder) +if not os.path.exists(compressed_context_path): + os.makedirs(compressed_context_path) + +def build_chat(tokenizer, prompt, model_name): + if "glm2" in model_name: + prompt = tokenizer.build_prompt(prompt) + elif "Llama2" in model_name: + prompt = f"[INST]{prompt}[/INST]" + elif "xgen" in model_name: + header = ( + "A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n" + ) + prompt = header + f" ### Human: {prompt}\n###" + elif "internlm" in model_name: + prompt = f"<|User|>:{prompt}<eoh>\n<|Bot|>:" + return prompt + +if args.model=="glm2": + model_path = GLM_MODEL + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True,max_length=1024) + # model = load_model_on_gpus(model_path, num_gpus=4) + model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16).to(device) + model = model.eval() + def generate_response(prompt): + tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0] + max_length = 31500 + if len(tokenized_prompt) > max_length: + half = int(max_length/2) + prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True) + prompt = build_chat(tokenizer, prompt, 'glm2') + input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device) + context_length = input.input_ids.shape[-1] + output = model.generate( + **input, + max_new_tokens=200, + num_beams=1, + do_sample=False, + temperature=1.0, + )[0] + pred = tokenizer.decode(output[context_length:], skip_special_tokens=True) + return pred + +elif args.model=="gpt-16k": + #using GPT API and tokenizer + def query(messages, force_commit=False): + tries = 0 + while tries < 5: + tries += 1 + try: + headers = { + 'Authorization': GPT_key + } + resp = requests.post("https://api.openai.com/v1/chat/completions", json = { + "model": GPT_MODEL, + "messages": messages, + "temperature": 1.0 + }, headers=headers, timeout=120) + if resp.status_code != 200: + raise Exception(resp.text) + resp = resp.json() + break + except KeyboardInterrupt as e: + raise e + except Exception as e: + if "maximum context length" in str(e): + raise e + print("Error Occurs: \"%s\" Retry ..."%(str(e))) + else: + print("Max tries. Failed.") + return + return resp["choices"][0]["message"]["content"] + + def generate_response(prompt): + msg = [{"role": "user", "content": prompt}] + result = query(msg) + return result + +elif args.model=="Llama2": + replace_llama_attn_with_flash_attn() + tokenizer = LlamaTokenizer.from_pretrained(LLAMA_MODEL) + model =LlamaForCausalLM.from_pretrained(LLAMA_MODEL, torch_dtype=torch.bfloat16).to(device) + model.eval() + def generate_response(prompt): + tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0] + max_length = 3500 + if len(tokenized_prompt) > max_length: + half = int(max_length/2) + prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True) + prompt = build_chat(tokenizer, prompt, 'Llama2') + input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device) + context_length = input.input_ids.shape[-1] + output = model.generate( + **input, + max_new_tokens=200, + num_beams=1, + do_sample=False, + temperature=1.0, + )[0] + pred = tokenizer.decode(output[context_length:], skip_special_tokens=True) + return pred + +def get_word_list(s1): + + regEx = re.compile('[\W]') + res = re.compile(r"([\u4e00-\u9fa5])") # [\u4e00-\u9fa5] for zh + + p1 = regEx.split(s1.lower()) + str1_list = [] + for str in p1: + if res.split(str) == None: + str1_list.append(str) + else: + ret = res.split(str) + for ch in ret: + str1_list.append(ch) + + list_word1 = [w for w in str1_list if len(w.strip()) > 0] + return list_word1 + +def get_word_len(s1): + return len(get_word_list(s1)) + +def data_spilt(data_test,max_len=args.max_len): + + data_len=len(data_test) + #split data_test to n parts averagely according to data_len + data_words_len = get_word_len(data_test) + split_len = int(data_words_len/max_len) + + text_len = int(data_len/split_len) + #start position of each part + text = data_test + text_list=[] + text_num = 0 + while text_num <= split_len: + text_num += 1 + try: + for i in range(text_len,text_len+1500): + + if text[i] in ['\n', '.', '。', '?', '?', '!', '!']: + # cut off the text until the end of the line, using the length of the text to decide when to stop + text_list.append(text[:i+1]) + text = text[i+1:] + break + except: + if text not in text_list: + text_list.append(text) + if text!='' and text not in text_list: + text_list.append(text) + + return text_list + +def compress(data_test, max_len=args.max_len, language='en',_id=None,dataset_type=None): + try: + text_list = data_spilt(data_test) + responses = [] + compressed_data_list = [] # List to store compressed data + # compressed_context_path = + compressed_id = 0 # Counter for compressed_id + prompt_en = 'Summerize the context above. The max length of the summary is 200 words.' + prompt_zh = '请对上面的文本进行总结。最大长度为200个字。' + for text_num, text in enumerate(text_list): + # Your existing code to generate prompt based on language + if language == 'zh': + prompt = text + '\n' + prompt_zh + elif language == 'en': + prompt = text + '\n' + prompt_en + else: + prompt = text + '\n' + prompt_en + print('language error') + response = generate_response(prompt) + responses.append(response) + compressed_context = response # Change this based on your response generation logic + # Construct the compressed data entry + compressed_entry = { + "input": "", # Fill in the input/command here + "raw_context": text, # Original text + "compressed_context": compressed_context, # Compressed text + "compressed_id": compressed_id, + "answers": [], # Fill in the answers + "length": get_word_len(text), # Length of the original text + "dataset": str(dataset_type), # Fill in the dataset name + "language": language, + "all_classes": None, # Fill in the categories + "_id": _id # Fill in the ID + } + + compressed_data_list.append(compressed_entry) + compressed_id += 1 + + # Save the compressed data to a file + path = os.path.join(compressed_context_path, str(args.model)+'_'+str(max_len)+'_'+str(dataset_type)+'.jsonl') + with jsonlines.open(path, 'a') as writer: + writer.write_all(compressed_data_list) + #save repsonses to file + if language == 'en': + for i in range(len(responses)): + responses[i] = 'Paragraph Summary'+str(i+1)+" : "+responses[i]+'\n' + elif language == 'zh': + for i in range(len(responses)): + responses[i] = '段落摘要'+str(i+1)+" : "+responses[i]+'\n' + + new_text_words_len =get_word_len(new_text) + print('new_text_words_len :',new_text_words_len) + except Exception as e: + print(str(e), sep=" | ") + new_text = data_test + return new_text + +def handle_item(item, max_len): + try: + context = item['context'] + language = item['language'] + _id = item['_id'] + dataset_type = item['dataset'] + compressd_context = compress(context, max_len,language,_id,dataset_type) + item['context'] = compressd_context + if 'all_classes' not in item.keys(): + item['all_classes'] = 'None' + #count the length of compressd context basen on language + item['length'] = get_word_len(compressd_context) + + except Exception as e: + print('Compress Fail:',str(e), sep=" | ") + return item + +def save_data(data, file_name): + with jsonlines.open(file_name, 'a') as writer: + writer.write_all(data) + return data + +def load_data(file_name): + with jsonlines.open(file_name) as reader: + data = list(reader) + return data + +def parallel_process_data(data, start_index, handle_item, workers=20, callback=None, checkpoint_interval=20): + save_lock = Lock() + with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor: + futures = [] + id_set = set() + for i in range(start_index): + id_set.add(data[i]['_id']) + + for i in range(start_index, len(data)): + item = data[i] + future = executor.submit(handle_item, item, args.max_len) + futures.append(future) + if i > 0 and i % checkpoint_interval == 0: + + processed_data = [] + + with tqdm(total=len(futures)) as pbar: + for future in concurrent.futures.as_completed(futures): + + result = future.result() + if result['_id'] not in id_set: + id_set.add(result['_id']) + processed_data.append(result) + if callback: + callback(result) + pbar.update(1) + + with save_lock: + save_data(processed_data, new_file_path) + save_data(processed_data, checkpoint_file) + + processed_data = [] + print('final save data') + with tqdm(total=len(futures)) as pbar: + for future in concurrent.futures.as_completed(futures): + result = future.result() + if result['_id'] not in id_set: + id_set.add(result['_id']) + processed_data.append(result) + if callback: + callback(result) + pbar.update(1) + + with save_lock: + try: + save_data(processed_data, new_file_path) + save_data(processed_data, checkpoint_file) + except Exception as e: + time.sleep(1) + print('Save data:',str(e), sep=" | ") + save_data(processed_data, new_file_path) + save_data(processed_data, checkpoint_file) + +for jsonl_file in jsonl_files: + raw_file_path = os.path.join(raw_data_folder, jsonl_file) + print('load data from', raw_file_path) + data = load_data(raw_file_path) + print('start compress') + checkpoint_folder = os.path.join(new_data_folder, 'checkpoint') + if not os.path.exists(checkpoint_folder): + os.makedirs(checkpoint_folder) + checkpoint_file = os.path.join(checkpoint_folder, jsonl_file + '_checkpoint.jsonl') + start_index = 0 + if os.path.exists(checkpoint_file): + processed_data = load_data(checkpoint_file) + start_index = len(processed_data) + print('load checkpoint from', checkpoint_file, 'start_index:', start_index) + new_file_path = os.path.join(new_data_folder, jsonl_file) + parallel_process_data(data, start_index, handle_item, workers=19, checkpoint_interval=19) + print('end compress') + print('save data to', new_file_path) \ No newline at end of file diff --git a/task.md b/task.md index 121cd2fa94552988333ae65d8a562883b8b76a03..e42221d7c0949c1d2f1177b14f17a20e78965640 100644 --- a/task.md +++ b/task.md @@ -1,27 +1,28 @@ -# Task statistics +# LongBench statistics | Task | Task Type | Eval metric | Avg len |Language | \#Sample | | :-------- | :-----------:| :-----------: |:-------: | :-----------: |:--------: | -| HotpotQA | Multi-doc QA | F1 |9,149 |EN |200 | -| 2WikiMultihopQA| Multi-doc QA | F1 |4,885 |EN |200 | -| MuSiQue| Multi-doc QA | F1 |11,018 |EN |200 | +| HotpotQA | Multi-doc QA | F1 |9,151 |EN |200 | +| 2WikiMultihopQA| Multi-doc QA | F1 |4,887 |EN |200 | +| MuSiQue| Multi-doc QA | F1 |11,214 |EN |200 | | DuReader| Multi-doc QA | Rouge-L |15,768 |ZH |200 | | MultiFieldQA-en| Single-doc QA | F1 |4,559 |EN |150 | -| MultiFieldQA-zh| Single-doc QA | F1 |6,771 |ZH |200 | -| NarrativeQA| Single-doc QA | F1 |18,405 |EN |200 | +| MultiFieldQA-zh| Single-doc QA | F1 |6,701 |ZH |200 | +| NarrativeQA| Single-doc QA | F1 |18,409 |EN |200 | | Qasper| Single-doc QA | F1 |3,619 |EN |200 | -| GovReport| Summarization | Rouge-L |8,169 |EN |200 | -| QMSum| Summarization | Rouge-L |10,546 |EN |200 | -| VCSUM| Summarization | Rouge-L |15,147 |ZH |200 | -| TriviaQA| Few shot | F1 |8,015 |EN |200 | -| NQ| Few shot | F1 |8,210 |EN |200 | -| TREC| Few shot | Accuracy |5,176 |EN |200 | -| LSHT| Few shot | Accuracy |22,333 |ZH |200 | -| PassageRetrieval-en| Synthetic | Accuracy |9,288 |EN |200 | +| GovReport| Summarization | Rouge-L |8,734 |EN |200 | +| QMSum| Summarization | Rouge-L |10,614 |EN |200 | +| MultiNews| Summarization | Rouge-L |2,113 |EN |200 | +| VCSUM| Summarization | Rouge-L |15,380 |ZH |200 | +| TriviaQA| Few shot | F1 |8,209 |EN |200 | +| SAMSum| Few shot | Rouge-L |6,258 |EN |200 | +| TREC| Few shot | Accuracy |5,177 |EN |200 | +| LSHT| Few shot | Accuracy |22,337 |ZH |200 | +| PassageRetrieval-en| Synthetic | Accuracy |9,289 |EN |200 | | PassageCount| Synthetic | Accuracy |11,141 |EN |200 | | PassageRetrieval-zh | Synthetic | Accuracy |6,745 |ZH |200 | | LCC| Code | Edit Sim |1,235 |Python/C#/Java |500 | -| RepoBench-P| Code | Edit Sim |5,622 |Python/Java |500 | +| RepoBench-P| Code | Edit Sim |4,206 |Python/Java |500 | > Note: In order to avoid discrepancies caused by different tokenizers, we use the word count (using Python's split function) to calculate the average length of English datasets and code datasets, and use the character count to calculate the average length of Chinese datasets. @@ -35,11 +36,13 @@ | DuReader | Answer related Chinese questions based on multiple retrieved documents | | MultiFieldQA-en | Answer English questions based on a long article, which comes from a relatively diverse field | | MultiFieldQA-zh | Answer Chinese questions based on a long article, which comes from a relatively diverse field | -| NarrativeQA | Ask questions based on stories or scripts, including understanding of important elements such as characters, plots, themes, etc. | -| Qasper | Ask questions based on a NLP research paper, questions proposed and answered by NLP practitioners | +| NarrativeQA | Answer questions based on stories or scripts, including understanding of important elements such as characters, plots, themes, etc. | +| Qasper | Answer questions based on a NLP research paper, questions proposed and answered by NLP practitioners | | GovReport | A summarization task that requires summarizing government work reports | +| MultiNews | A multi-doc summarization that requires summarizing over multiple news | | QMSum | A summarization task that requires summarizing meeting records based on user queries | | VCSUM | A summarization task that requires summarizing Chinese meeting records | +| SAMSum | A dialogue summarization task, providing several few-shot examples | | TriviaQA | Single document question answering task, providing several few-shot examples | | NQ | Single document question answering task, providing several few-shot examples | | TREC | A classification task that requires categorizing questions, includes 50 categories in total | @@ -57,11 +60,28 @@ - The tasks of [HotpotQA](https://hotpotqa.github.io/), [2WikiMultihopQA](https://aclanthology.org/2020.coling-main.580/), [MuSiQue](https://arxiv.org/abs/2108.00573), and [DuReader](https://github.com/baidu/DuReader) are built based on the original datasets and processed to be suitable for long context evaluation. Specifically, for questions in the validation set, we select the evidence passage that contains the answer and several distracting articles. These articles together with the original question constitute the input of the tasks. - The tasks of MultiFiedQA-zh and MultiFieldQA-en consist of long artical data from about 10 sources, including Latex papers, judicial documents, government work reports, and PDF documents indexed by Google. For each long artical, we invite several PhD and master students to annotate, i.e., to ask questions based on the long artical and give the correct answers. To better automate evaluation, we ask the annotators to propose questions with definitive answers as much as possible. -- The tasks of [NarrativeQA](https://arxiv.org/pdf/1712.07040.pdf), [Qasper](https://arxiv.org/pdf/2105.03011.pdf), [GovReport](https://arxiv.org/pdf/2104.02112.pdf), and [QMSum](https://arxiv.org/pdf/2104.05938.pdf) directly use the data provided by the original papers. In the specific construction, we use the template provided by [ZeroSCROLLS](https://www.zero.scrolls-benchmark.com/) to convert the corresponding data into pure text input. +- The tasks of [NarrativeQA](https://arxiv.org/pdf/1712.07040.pdf), [Qasper](https://arxiv.org/pdf/2105.03011.pdf), [GovReport](https://arxiv.org/pdf/2104.02112.pdf), [QMSum](https://arxiv.org/pdf/2104.05938.pdf) and [MultiNews](https://aclanthology.org/P19-1102.pdf) directly use the data provided by the original papers. In the specific construction, we use the template provided by [ZeroSCROLLS](https://www.zero.scrolls-benchmark.com/) to convert the corresponding data into pure text input. - The [VCSUM](https://arxiv.org/abs/2305.05280) task is built based on the original dataset, and we design a corresponding template to convert the corresponding data into pure text input. -- The tasks of [TriviaQA](https://nlp.cs.washington.edu/triviaqa/) and [NQ](https://ai.google.com/research/NaturalQuestions/) are constructed in the manner of [CoLT5](https://arxiv.org/abs/2303.09752), which provides several examples of question and answering based on documents, and requires the language model to answer related questions based on new documents. -- The tasks of [TREC](https://aclanthology.org/C02-1150.pdf) and [LSHT](http://tcci.ccf.org.cn/conference/2014/dldoc/evatask6.pdf) are built based on the original datasets. For each question in the validation set, we sample several data from the training set to form few-shot examples. These examples together with the questions in the validation set constitute the input for this task. +- The [TriviaQA](https://nlp.cs.washington.edu/triviaqa/) task is constructed in the manner of [CoLT5](https://arxiv.org/abs/2303.09752), which provides several examples of question and answering based on documents, and requires the language model to answer related questions based on new documents. +- The tasks of [SAMSum](https://aclanthology.org/D19-5409.pdf), [TREC](https://aclanthology.org/C02-1150.pdf) and [LSHT](http://tcci.ccf.org.cn/conference/2014/dldoc/evatask6.pdf) are built based on the original datasets. For each question in the validation set, we sample several data from the training set to form few-shot examples. These examples together with the questions in the validation set constitute the input for this task. - The PassageRetrieval-en task is constructed based on English Wikipedia. For each piece of data, we randomly sample 30 paragraphs from English Wikipedia and select one for summarization (using GPT-3.5-Turbo). This task requires the model to give the original paragraph name to which the summary corresponds. - The PassageCount task is constructed based on the English wiki. For each piece of data, we randomly sample several passages from English Wikipedia, repeat each paragraph at random several times, and finally shuffle the paragraphs. This task requires the model to determine the total number of different paragraphs in the given context. - The PasskeyRetrieval-zh task is constructed based on [C4](https://arxiv.org/abs/1910.10683). For each piece of data, we randomly sample several Chinese paragraphs from C4 and select one of them for summarization (using GPT-3.5-Turbo). This task requires the model to give the original paragraph name to which the summary corresponds. -- For the [LCC](https://arxiv.org/abs/2306.14893) task, we sample from the original code completion dataset. In the [RepoBench-P](https://arxiv.org/abs/2306.03091) task, we select the most challenging XF-F (Cross-File-First) setting from the original dataset and refer to the Oracle-Filled scenario in the paper. For each original piece of data, we randomly extract multiple cross-file code snippets, including the gold cross-file code snippet, and concatenate them as input, requiring the model to effectively use cross-file code for completion. \ No newline at end of file +- For the [LCC](https://arxiv.org/abs/2306.14893) task, we sample from the original code completion dataset. In the [RepoBench-P](https://arxiv.org/abs/2306.03091) task, we select the most challenging XF-F (Cross-File-First) setting from the original dataset and refer to the Oracle-Filled scenario in the paper. For each original piece of data, we randomly extract multiple cross-file code snippets, including the gold cross-file code snippet, and concatenate them as input, requiring the model to effectively use cross-file code for completion. + +# LongBench-E statistics +| Task | Task Type | \#data in 0-4k | \#data in 4-8k | \#data in 8k+| +| :--------- | :-----------:| :-----------: |:---------: | :-------------: | +| HotpotQA | Multi-doc QA | 100 |100 |100 | +| 2WikiMultihopQA| Multi-doc QA | 100 |100 |100 | +| MultiFieldQA-en| Single-doc QA | 67 |70 |13 | +| Qasper| Single-doc QA | 100 |100 |24 | +| GovReport| Summarization | 100 |100 |100 | +| MultiNews| Summarization | 100 |100 |94 | +| TriviaQA| Few shot | 100 |100 |100 | +| SAMSum| Few shot | 100 |100 |100 | +| TREC| Few shot | 100 |100 |100 | +| PassageRetrieval-en| Synthetic | 100 |100 |100 | +| PassageCount| Synthetic | 100 |100 |100 | +| LCC| Code | 100 |100 |100 | +| RepoBench-P| Code | 100 |100 |100 | diff --git a/task_zh.md b/task_zh.md index af2a1ae15a1fb0633fde314976db3215289c09ee..35e4c8fa45d56c829b3d371fe18dcde26f80a876 100644 --- a/task_zh.md +++ b/task_zh.md @@ -1,27 +1,28 @@ -# 任务统计 +# LongBench任务统计 | 任务 | 任务类型 | 评价指标 | 平均长度 |语言 | Sample数量| | :--------- | :-----------:| :-----------: |:---------: | :-------------: |:---------: | -| HotpotQA | 多文档QA | F1 |9,149 |英文 |200 | -| 2WikiMultihopQA| 多文档QA | F1 |4,885 |英文 |200 | -| MuSiQue| 多文档QA | F1 |11,018 |英文 |200 | +| HotpotQA | 多文档QA | F1 |9,151 |英文 |200 | +| 2WikiMultihopQA| 多文档QA | F1 |4,887 |英文 |200 | +| MuSiQue| 多文档QA | F1 |11,214 |英文 |200 | | DuReader| 多文档QA | Rouge-L |15,768 |中文 |200 | | MultiFieldQA-en| 单文档QA | F1 |4,559 |英文 |150 | -| MultiFieldQA-zh| 单文档QA | F1 |6,771 |中文 |200 | -| NarrativeQA| 单文档QA | F1 |18,405 |英文 |200 | +| MultiFieldQA-zh| 单文档QA | F1 |6,701 |中文 |200 | +| NarrativeQA| 单文档QA | F1 |18,409 |英文 |200 | | Qasper| 单文档QA | F1 |3,619 |英文 |200 | -| GovReport| 摘要 | Rouge-L |8,169 |英文 |200 | -| QMSum| 摘要 | Rouge-L |10,546 |英文 |200 | -| VCSUM| 摘要 | Rouge-L |15,147 |中文 |200 | -| TriviaQA| Few shot | F1 |8,015 |英文 |200 | -| NQ| Few shot | F1 |8,210 |英文 |200 | -| TREC| Few shot | Accuracy |5,176 |英文 |200 | -| LSHT| Few shot | Accuracy |22,333 |中文 |200 | -| PassageRetrieval-en| 合成任务 | Accuracy |9,288 |英文 |200 | +| GovReport| 摘要 | Rouge-L |8,734 |英文 |200 | +| QMSum| 摘要 | Rouge-L |10,614 |英文 |200 | +| MultiNews| 摘要 | Rouge-L |2,113 |英文 |200 | +| VCSUM| 摘要 | Rouge-L |15,380 |中文 |200 | +| TriviaQA| Few shot | F1 |8,209 |英文 |200 | +| SAMSum| Few shot | Rouge-L |6,258 |英文 |200 | +| TREC| Few shot | Accuracy |5,177 |英文 |200 | +| LSHT| Few shot | Accuracy |22,337 |中文 |200 | +| PassageRetrieval-en| 合成任务 | Accuracy |9,289 |英文 |200 | | PassageCount| 合成任务 | Accuracy |11,141 |英文 |200 | | PassageRetrieval-zh | 合成任务 | Accuracy |6,745 |中文 |200 | | LCC| 代码 | Edit Sim |1,235 |Python/C#/Java |500 | -| RepoBench-P| 代码 | Edit Sim |5,622 |Python/Java |500 | +| RepoBench-P| 代码 | Edit Sim |4,206 |Python/Java |500 | > 注:为了避免不同Tokenizer统计的差距,我们使用单词数(Python的split函数)来统计英文数据集和代码数据集的平均长度,使用汉字数来统计中文数据集的平均长度。 @@ -39,9 +40,10 @@ | Qasper | 基于单篇论文的提出,问题由NLP的读者提出,并由NLP从业者回答 | | GovReport | 摘要任务,要求对政府的工作报告进行总结摘要 | | QMSum | 摘要任务,要求基于用户的查询对会议记录进行摘要 | +| MultiNews | 多文档摘要任务,要求基于多篇新闻进行摘要 | | VCSUM | 摘要任务,要求对中文会议记录进行总结摘要 | | TriviaQA | 单文档问答任务,提供若干的Few Shot样例 | -| NQ | 单文档问答任务,提供若干的Few Shot样例 | +| SAMSum | 对话摘要任务,提供若干的Few Shot样例 | | TREC | 分类任务,要求对问题进行分类,一共包含50个类别 | | LSHT | 中文分类任务,要求对新闻进行分类,一共包含24个类别 | | PassageRetrieval-en | 给定30个英文维基的段落,判断给定的摘要属于哪个段落 | @@ -56,11 +58,28 @@ - [HotpotQA](https://hotpotqa.github.io/), [2WikiMultihopQA](https://aclanthology.org/2020.coling-main.580/), [MuSiQue](https://arxiv.org/abs/2108.00573)和[DuReader](https://github.com/baidu/DuReader)任务基于原始的数据集构建,并进行相关处理使其适用于长文本评测。具体地,对于验证集中的问题,我们会选取包含答案的evidence passage和若干干扰的文章,这些文章和原始的问题共同组成了相关任务的输入。 - MultiFiedQA-zh和MultiFieldQA-en任务由约10种来源的长文本数据组成,包含Latex论文、裁判文书、政府工作报告和谷歌索引的PDF文档等。对于每篇长文本,我们邀请了若干博士生和硕士生来进行标注,即基于长文本提问,并给出正确的答案。为了更好地进行自动化评测,我们要求标注员尽可能提出有确定性答案的问题。 -- [NarrativeQA](https://arxiv.org/pdf/1712.07040.pdf), [Qasper](https://arxiv.org/pdf/2105.03011.pdf), [GovReport](https://arxiv.org/pdf/2104.02112.pdf)和[QMSum](https://arxiv.org/pdf/2104.05938.pdf)任务直接使用原论文提供的数据。在具体的构建中,我们使用[ZeroSCROLLS](https://www.zero.scrolls-benchmark.com/)提供的模板来将对应的数据转换为纯文本的输入。 +- [NarrativeQA](https://arxiv.org/pdf/1712.07040.pdf), [Qasper](https://arxiv.org/pdf/2105.03011.pdf), [GovReport](https://arxiv.org/pdf/2104.02112.pdf),[QMSum](https://arxiv.org/pdf/2104.05938.pdf)和[MultiNews](https://aclanthology.org/P19-1102.pdf)任务直接使用原论文提供的数据。在具体的构建中,我们使用[ZeroSCROLLS](https://www.zero.scrolls-benchmark.com/)提供的模板来将对应的数据转换为纯文本的输入。 - [VCSUM](https://arxiv.org/abs/2305.05280)任务基于原始的数据集构建,我们针对该数据设计了相应的模板将对应的数据转换为纯文本的输入。 -- [TriviaQA](https://nlp.cs.washington.edu/triviaqa/)和[NQ](https://ai.google.com/research/NaturalQuestions/)任务参考[CoLT5](https://arxiv.org/abs/2303.09752)的方式进行构建,即会提供若干基于文档进行问答的样例,并要求语言模型基于新的文档回答相关问题。 -- [TREC](https://aclanthology.org/C02-1150.pdf)和[LSHT](http://tcci.ccf.org.cn/conference/2014/dldoc/evatask6.pdf)任务基于原始的数据集构建。对于验证集中的每个问题,我们采样训练集中的若干数据组成Few-shot样例。这些样例会和验证集中的问题共同组成该任务的输入。 +- [TriviaQA](https://nlp.cs.washington.edu/triviaqa/)任务参考[CoLT5](https://arxiv.org/abs/2303.09752)的方式进行构建,即会提供若干基于文档进行问答的样例,并要求语言模型基于新的文档回答相关问题。 +- [SAMSum](https://aclanthology.org/D19-5409.pdf),[TREC](https://aclanthology.org/C02-1150.pdf)和[LSHT](http://tcci.ccf.org.cn/conference/2014/dldoc/evatask6.pdf)任务基于原始的数据集构建。对于验证集中的每个问题,我们采样训练集中的若干数据组成Few-shot样例。这些样例会和验证集中的问题共同组成该任务的输入。 - PassageRetrieval-en任务基于英文维基进行构造。对于每条数据,我们随机采样30段英文维基的段落,并选取其中一段进行摘要(使用GPT-3.5-Turbo)。该任务要求模型给出摘要应该对应哪个的原始段落。 - PassageCount任务基于英文维基进行构造。对于每条数据,我们随机采样若干英文维基的段落,并将其中的每个段落随机重复若干次,最后将段落随机打乱。该任务要求模型判断给定的若干的段落中不重复的段落一共有几个。 - PassageRetrieval-zh任务基于[C4](https://arxiv.org/abs/1910.10683)进行构造。对于每条数据,我们随机采样若干段来自于C4的中文段落,并选取其中一段进行摘要(使用GPT-3.5-Turbo)。该任务要求模型给出摘要对应的那个原始段落名称。 -- [LCC](https://arxiv.org/abs/2306.14893)任务我们基于原始的代码补全数据集采样构建。[RepoBench-P](https://arxiv.org/abs/2306.03091)任务中我们选取了原数据集最具挑战性的XF-F(Cross-File-First)设定,并且参考原文中的Oracle-Filled场景,对于每一条原始数据我们随机抽取包括有效跨文件代码片段(gold snippet)在内的多个跨文件代码片段,将其拼接后作为输入,要求模型从其中利用有效的跨文件代码以补全当前文件中的代码。 \ No newline at end of file +- [LCC](https://arxiv.org/abs/2306.14893)任务我们基于原始的代码补全数据集采样构建。[RepoBench-P](https://arxiv.org/abs/2306.03091)任务中我们选取了原数据集最具挑战性的XF-F(Cross-File-First)设定,并且参考原文中的Oracle-Filled场景,对于每一条原始数据我们随机抽取包括有效跨文件代码片段(gold snippet)在内的多个跨文件代码片段,将其拼接后作为输入,要求模型从其中利用有效的跨文件代码以补全当前文件中的代码。 + +# LongBench-E数据统计 +| 任务 | 任务类型 | 0-4k数据量 | 4-8k数据量 |8k+数据量| +| :--------- | :-----------:| :-----------: |:---------: | :-------------: | +| HotpotQA | 多文档QA | 100 |100 |100 | +| 2WikiMultihopQA| 多文档QA | 100 |100 |100 | +| MultiFieldQA-en| 单文档QA | 67 |70 |13 | +| Qasper| 单文档QA | 100 |100 |24 | +| GovReport| 摘要 | 100 |100 |100 | +| MultiNews| 摘要 | 100 |100 |94 | +| TriviaQA| Few shot | 100 |100 |100 | +| SAMSum| Few shot | 100 |100 |100 | +| TREC| Few shot | 100 |100 |100 | +| PassageRetrieval-en| 合成任务 | 100 |100 |100 | +| PassageCount| 合成任务 | 100 |100 |100 | +| LCC| 代码 | 100 |100 |100 | +| RepoBench-P| 代码 | 100 |100 |100 |