Skip to content
Snippets Groups Projects
generate_BM25.py 3.57 KiB
Newer Older
  • Learn to ignore specific revisions
  • bys0318's avatar
    bys0318 committed
    from rank_bm25 import BM25Okapi
    
    import os
    import json
    from concurrent.futures import ThreadPoolExecutor, wait
    from tqdm import tqdm
    import argparse
    import sys
    sys.path.append('..')
    from splitter import split_long_sentence, get_word_len, regex
    # DEBUG
    # os.chdir(os.path.dirname(os.path.abspath(__file__)))
    
    def retriveDoc(query: str, document: str, chunk_size, file_name:str,
                   js, output_list, idx, pbar=None, maxLen=1500):
        # 1. Splits the context into pieces
        texts = split_long_sentence(document, regex, chunk_size=chunk_size, filename=file_name)
        # 2. Creates retriver, adds texts
        retriever = BM25Okapi(texts)
        # 3. Retrive and merge
        retrieved_texts = retriever.get_top_n(query=query, documents=texts,
            n=len(texts))
        retrieved_texts = [retrieved_texts] if type(retrieved_texts) == str else retrieved_texts
        context = ''
        for text in retrieved_texts:
            if get_word_len(context) < maxLen:
                context += text
        js['retrieved'] = retrieved_texts if type(retrieved_texts) == list else [retrieved_texts]
        js['context'] = context
        js['length'] = get_word_len(context)
        output_list[index] = js
        if pbar:
            pbar.update()
        return 
    
    if __name__ == '__main__':
        parser = argparse.ArgumentParser()
        # parser.add_argument("--file_name", default='dureader.jsonl')
        parser.add_argument("--file_name", default='2wikimqa.jsonl')
        parser.add_argument("--source_dir", default='../../LongBench/data')
        parser.add_argument("--dest_dir", default='./test')
        parser.add_argument("--chunk_size", type=int, default=200)
        args = parser.parse_args()
        file_name = args.file_name
    
        print(f"------  {file_name}  ------")
        with open(os.path.join(args.source_dir, file_name), 'r', encoding='utf-8') as file:
            file_contents = file.readlines()
            output_data = [{}] * len(file_contents)
            if (os.path.exists(os.path.join(args.dest_dir, file_name))):
                with open(os.path.join(args.dest_dir, file_name), 'r', encoding='utf-8') as f:
                    lines = f.readlines()
                    lines = [line for line in lines]
                    output_data = [json.loads(line) for line in lines]
            loop = tqdm(enumerate(file_contents), total=len(file_contents), desc=f'{file_name}')
            exe_list = []
            with ThreadPoolExecutor(max_workers=10) as executor:
                # for index, line in loop:
                for index, line in enumerate(file_contents):
                    if (output_data[index] != {} or
                        "context" in output_data[index].keys() and len(output_data[index]['context']) != 0):
                        loop.update()
                        continue
                    line_js = json.loads(line)
                    retriveDoc(query=line_js['input'], document=line_js['context'],
                                            chunk_size=args.chunk_size, file_name=file_name,
                                            js=line_js, output_list=output_data, idx=index, pbar=loop)
                    # exe_list.append(executor.submit(retriveDoc, query=line_js['input'], document=line_js['context'],
                    #                         chunk_size=args.chunk_size, file_name=file_name,
                    #                         js=line_js, output_list=output_data, idx=index, pbar=loop))
                    # loop.set_description(f'{file_name}')
                    wait(exe_list)
        # saving
        os.makedirs(args.dest_dir, exist_ok=True)
        with open(os.path.join(args.dest_dir, file_name), 'w', encoding='utf-8') as output_file:
            for item in output_data:
                output_file.write(json.dumps(item, ensure_ascii=False) + '\n')