Skip to content
Snippets Groups Projects
metrics.py 5.47 KiB
Newer Older
  • Learn to ignore specific revisions
  • bys0318's avatar
    bys0318 committed
    import re
    import string
    
    import jieba
    from fuzzywuzzy import fuzz
    import difflib
    
    from typing import List
    from collections import Counter
    from rouge import Rouge
    
    def normalize_answer(s):
        """Lower text and remove punctuation, articles and extra whitespace."""
    
        def remove_articles(text):
            return re.sub(r"\b(a|an|the)\b", " ", text)
    
        def white_space_fix(text):
            return " ".join(text.split())
    
        def remove_punc(text):
            exclude = set(string.punctuation)
            return "".join(ch for ch in text if ch not in exclude)
    
        def lower(text):
            return text.lower()
    
        return white_space_fix(remove_articles(remove_punc(lower(s))))
    
    
    def normalize_zh_answer(s):
        """Lower text and remove punctuation, extra whitespace."""
    
        def white_space_fix(text):
            return "".join(text.split())
    
        def remove_punc(text):
            cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
            all_punctuation = set(string.punctuation + cn_punctuation)
            return "".join(ch for ch in text if ch not in all_punctuation)
    
        def lower(text):
            return text.lower()
    
        return white_space_fix(remove_punc(lower(s)))
    
    def count_score(prediction, ground_truth, **kwargs):
        numbers = re.findall(r"\d+", prediction)
        right_num = 0
        for number in numbers:
            if str(number) == str(ground_truth):
                right_num += 1
        final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
        return float(final_score)
    
    def retrieval_score(prediction, ground_truth, **kwargs):
        pattern = r'Paragraph (\d+)'
        matches = re.findall(pattern, ground_truth)
        ground_truth_id = matches[0]
        numbers = re.findall(r"\d+", prediction)
        right_num = 0
        for number in numbers:
            if str(number) == str(ground_truth_id):
                right_num += 1
        final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
        return float(final_score)
    
    def retrieval_zh_score(prediction, ground_truth, **kwargs):
        pattern = r'段落(\d+)'
        matches = re.findall(pattern, ground_truth)
        ground_truth_id = matches[0]
        numbers = re.findall(r"\d+", prediction)
        right_num = 0
        for number in numbers:
            if str(number) == str(ground_truth_id):
                right_num += 1
        final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
        return float(final_score)
    
    def code_sim_score(prediction, ground_truth, **kwargs):
        all_lines = prediction.lstrip('\n').split('\n')
        prediction = ""
        for line in all_lines:
            if ('`' not in line) and ('#' not in line) and ('//' not in line):
                prediction = line
                break
        return (fuzz.ratio(prediction, ground_truth) / 100)
    
    def classification_score(prediction, ground_truth, **kwargs):
        em_match_list = []
        all_classes = kwargs["all_classes"]
        for class_name in all_classes:
            if class_name in prediction:
                em_match_list.append(class_name)
        for match_term in em_match_list:
            if match_term in ground_truth and match_term != ground_truth:
                em_match_list.remove(match_term)
        if em_match_list != 0:
            if ground_truth in em_match_list:
                score = (1.0 / len(em_match_list))
            else:
                score = 0.0
        else:
            best_match = None
            highest_similarity = 0
            for string in all_classes:
                similarity = difflib.SequenceMatcher(None, string, prediction).ratio()
                if similarity > highest_similarity:
                    highest_similarity = similarity
                    best_match = string
            score = float(best_match == ground_truth)
        return score
        
    def rouge_score(prediction, ground_truth, **kwargs):
        rouge = Rouge()
        try:
            scores = rouge.get_scores([prediction], [ground_truth], avg=True)
        except:
            return 0.0
        return scores["rouge-l"]["f"]
    
    def rouge_zh_score(prediction, ground_truth, **kwargs):
        prediction = " ".join(list(jieba.cut(prediction, cut_all=False)))
        ground_truth = " ".join(list(jieba.cut(ground_truth, cut_all=False))) 
        score = rouge_score(prediction, ground_truth)
        return score
    
    def f1_score(prediction, ground_truth, **kwargs):
        common = Counter(prediction) & Counter(ground_truth)
        num_same = sum(common.values())
        if num_same == 0:
            return 0
        precision = 1.0 * num_same / len(prediction)
        recall = 1.0 * num_same / len(ground_truth)
        f1 = (2 * precision * recall) / (precision + recall)
        return f1
    
    def qa_f1_score(prediction, ground_truth, **kwargs):
        normalized_prediction = normalize_answer(prediction)
        normalized_ground_truth = normalize_answer(ground_truth)
    
        prediction_tokens = normalized_prediction.split()
        ground_truth_tokens = normalized_ground_truth.split()
        return f1_score(prediction_tokens, ground_truth_tokens)
    
    
    def qa_f1_zh_score(prediction, ground_truth, **kwargs):
        prediction_tokens = list(jieba.cut(prediction, cut_all=False))
        ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False))
        prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens]
        ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens]
        prediction_tokens = [token for token in prediction_tokens if len(token) > 0]
        ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0]
        return f1_score(prediction_tokens, ground_truth_tokens)