# Code adapted from https://github.com/THUDM/LongBench/blob/main/metrics.py # Code adapted from https://github.com/hendrycks/math/blob/main/modeling/math_equivalence.py # Code adapted from https://github.com/ruixiangcui/AGIEval/blob/main/src/evaluation.py import difflib import re import string from collections import Counter import jieba from fuzzywuzzy import fuzz from rouge import Rouge metrics4subcategory = { "pretrain": { "perplexity": ["ALL"], "ppl_score": ["ALL"], "per_byte_perplexity": ["ALL"], "per_byte_ppl_score": ["ALL"], }, # The commented are non 4-choice questions. "agieval": { "combined_single_choice_accuracy": [ # "lsat-ar", # "lsat-lr", # "lsat-rc", "logiqa-en", "sat-math", "sat-en", # "aqua-rat", "sat-en-without-passage", "gaokao-english", "logiqa-zh", "gaokao-chinese", "gaokao-geography", "gaokao-history", "gaokao-biology", "gaokao-chemistry", ], "first_token_accuracy": [ # "lsat-ar", # "lsat-lr", # "lsat-rc", "logiqa-en", "sat-math", "sat-en", # "aqua-rat", "sat-en-without-passage", "gaokao-english", "logiqa-zh", "gaokao-chinese", "gaokao-geography", "gaokao-history", "gaokao-biology", "gaokao-chemistry", ], "single_choice_accuracy": [ # "lsat-ar", # "lsat-lr", # "lsat-rc", "logiqa-en", "sat-math", "sat-en", # "aqua-rat", "sat-en-without-passage", "gaokao-english", "logiqa-zh", "gaokao-chinese", "gaokao-geography", "gaokao-history", "gaokao-biology", "gaokao-chemistry", ], "multi_choice_accuracy": ["jec-qa-kd", "jec-qa-ca", "gaokao-physics", "gaokao-mathqa"], "math_equivalence": ["gaokao-mathcloze", "math"], "perplexity": ["ALL"], "ppl_score_over_choices": [ "lsat-ar", "lsat-lr", "lsat-rc", "logiqa-en", "sat-math", "sat-en", "aqua-rat", "sat-en-without-passage", "gaokao-english", "logiqa-zh", "jec-qa-kd", "jec-qa-ca", "gaokao-chinese", "gaokao-geography", "gaokao-history", "gaokao-biology", "gaokao-chemistry", "gaokao-physics", "gaokao-mathqa", ], "ppl_score": ["ALL"], }, "cmmlu": { "first_token_accuracy": ["ALL"], "single_choice_accuracy": ["ALL"], "perplexity": ["ALL"], "ppl_score_over_choices": ["ALL"], "ppl_score": ["ALL"], }, "gaokaobench": { "combined_single_choice_accuracy": [ "English MCQs", "Biology MCQs", "Chemistry MCQs", "History MCQs", "Math I MCQs", "Math II MCQs", "Political Science MCQs", ], "first_token_accuracy": [ "English MCQs", "Biology MCQs", "Chemistry MCQs", "History MCQs", "Math I MCQs", "Math II MCQs", "Political Science MCQs", ], "single_choice_accuracy": [ "English MCQs", "Biology MCQs", "Chemistry MCQs", "History MCQs", "Math I MCQs", "Math II MCQs", "Political Science MCQs", ], "multi_choice_accuracy": [ "Chinese Lang and Usage MCQs", "Chinese Modern Lit", "English Fill in Blanks", "English Reading Comp", "Geography MCQs", "Physics MCQs", "English Cloze Test", ], "math_equivalence": ["Math I Fill-in-the-Blank", "Math II Fill-in-the-Blank"], "rouge_score": ["English Language Cloze Passage"], "rouge_zh_score": [ "Chinese Language Famous Passages and Sentences Dictation", "Chemistry Open-ended Questions", "History Open-ended Questions", "Biology Open-ended Questions", "Political Science Open-ended Questions", "English Language Error Correction", "Chinese Language Language and Writing Skills Open-ended Questions", "Math II Open-ended Questions", "Chinese Language Literary Text Reading", "Chinese Language Ancient Poetry Reading", "Chinese Language Classical Chinese Reading", "Physics Open-ended Questions", "Math I Open-ended Questions", "Geography Open-ended Questions", "Chinese Language Practical Text Reading", ], "perplexity": ["ALL"], "ppl_score_over_choices": ["ALL"], "ppl_score": ["ALL"], }, "longbench": { "f1_score": ["hotpotqa", "2wikimqa", "musique", "narrativeqa", "qasper", "multifieldqa_en", "triviaqa"], "f1_zh_score": ["multifieldqa_zh"], "rouge_score": ["gov_report", "qmsum", "multi_news", "samsum"], "rouge_zh_score": ["dureader", "vcsum"], "retrieval_score": ["passage_retrieval_en"], "retrieval_zh_score": ["passage_retrieval_zh"], "classification_score": ["trec", "lsht"], "code_sim_score": ["lcc", "repobench-p"], "count_score": ["passage_count"], "perplexity": ["ALL"], "ppl_score": ["ALL"], }, "mmlu": { "first_token_accuracy": ["ALL"], "single_choice_accuracy": ["ALL"], "accuracy": ["ALL"], "perplexity": ["ALL"], "ppl_score_over_choices": ["ALL"], "ppl_score": ["ALL"], }, "mtbench": {"mtbench_single_judge": ["ALL"]}, "cvalues": {"first_token_accuracy": ["ALL"]}, "safetybench_zh": {"first_token_accuracy": ["ALL"]}, "safetybench_en": {"first_token_accuracy": ["ALL"]}, } def _fix_fracs(string): substrs = string.split("\\frac") new_str = substrs[0] if len(substrs) > 1: substrs = substrs[1:] for substr in substrs: new_str += "\\frac" if substr[0] == "{": new_str += substr else: try: assert len(substr) >= 2 except: return string a = substr[0] b = substr[1] if b != "{": if len(substr) > 2: post_substr = substr[2:] new_str += "{" + a + "}{" + b + "}" + post_substr else: new_str += "{" + a + "}{" + b + "}" else: if len(substr) > 2: post_substr = substr[2:] new_str += "{" + a + "}" + b + post_substr else: new_str += "{" + a + "}" + b string = new_str return string def _fix_a_slash_b(string): if len(string.split("/")) != 2: return string a = string.split("/")[0] b = string.split("/")[1] try: a = int(a) b = int(b) assert string == "{}/{}".format(a, b) new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" return new_string except: return string def _remove_right_units(string): # "\\text{ " only ever occurs (at least in the val set) when describing units if "\\text{ " in string: splits = string.split("\\text{ ") assert len(splits) == 2 return splits[0] else: return string def _fix_sqrt(string): if "\\sqrt" not in string: return string splits = string.split("\\sqrt") new_string = splits[0] for split in splits[1:]: if split[0] != "{": a = split[0] new_substr = "\\sqrt{" + a + "}" + split[1:] else: new_substr = "\\sqrt" + split new_string += new_substr return new_string def _strip_string(string): # linebreaks string = string.replace("\n", "") # print(string) # remove inverse spaces string = string.replace("\\!", "") # print(string) # replace \\ with \ string = string.replace("\\\\", "\\") # print(string) # replace tfrac and dfrac with frac string = string.replace("tfrac", "frac") string = string.replace("dfrac", "frac") # print(string) # remove \left and \right string = string.replace("\\left", "") string = string.replace("\\right", "") # print(string) # Remove circ (degrees) string = string.replace("^{\\circ}", "") string = string.replace("^\\circ", "") # remove dollar signs string = string.replace("\\$", "") # remove units (on the right) string = _remove_right_units(string) # remove percentage string = string.replace("\\%", "") string = string.replace("\%", "") # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string string = string.replace(" .", " 0.") string = string.replace("{.", "{0.") # if empty, return empty string if len(string) == 0: return string if string[0] == ".": string = "0" + string # to consider: get rid of e.g. "k = " or "q = " at beginning if len(string.split("=")) == 2: if len(string.split("=")[0]) <= 2: string = string.split("=")[1] # fix sqrt3 --> sqrt{3} string = _fix_sqrt(string) # remove spaces string = string.replace(" ", "") # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} string = _fix_fracs(string) # manually change 0.5 --> \frac{1}{2} if string == "0.5": string = "\\frac{1}{2}" # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y string = _fix_a_slash_b(string) return string def parse_math_answer(raw_string): def remove_boxed(s): left = "\\boxed{" try: assert s[: len(left)] == left assert s[-1] == "}" answer = s[len(left) : -1] if "=" in answer: answer = answer.split("=")[-1].lstrip(" ") return answer except: return None def last_boxed_only_string(string): idx = string.rfind("\\boxed") if idx < 0: idx = string.rfind("\\fbox") if idx < 0: return None i = idx right_brace_idx = None num_left_braces_open = 0 while i < len(string): if string[i] == "{": num_left_braces_open += 1 if string[i] == "}": num_left_braces_open -= 1 if num_left_braces_open == 0: right_brace_idx = i break i += 1 if right_brace_idx == None: retval = None else: retval = string[idx : right_brace_idx + 1] return retval def get_answer_with_dollar_sign(s): first_pattern = "\$(.*)\$" last_match = None matches = re.findall(first_pattern, s) if matches: last_match = matches[-1] if "=" in last_match: last_match = last_match.split("=")[-1].lstrip(" ") return last_match def get_answer_without_dollar_sign(s): last_match = None if "=" in s: last_match = s.split("=")[-1].lstrip(" ").rstrip(".") if "\\n" in last_match: last_match = last_match.split("\\n")[0] else: pattern = "(?:\\$)?\d+(?:\.\d+)?(?![\w\d])" matches = re.findall(pattern, s) if matches: last_match = matches[-1] return last_match if "\\boxed" in raw_string: answer = remove_boxed(last_boxed_only_string(raw_string)) else: answer = get_answer_with_dollar_sign(raw_string) if not answer: answer = get_answer_without_dollar_sign(raw_string) return answer def math_equivalence(prediction, reference, **kwargs): prediction = parse_math_answer(prediction) if prediction is None and reference is None: print("WARNING: Both None") return False if prediction is None or reference is None: return False try: ss1 = _strip_string(prediction) ss2 = _strip_string(reference) return ss1 == ss2 except: return prediction == reference def multi_choice_accuracy(prediction, reference, **kwargs): # Only find uppercase letters not surrounded by lowercase letters all_classes = kwargs.get("all_classes", None) if all_classes: pattern = f"(? highest_similarity: highest_similarity = similarity best_match = string score = float(best_match == reference) return score def rouge_score(prediction, reference, **kwargs): rouge = Rouge() try: scores = rouge.get_scores([prediction], [reference], avg=True) except: return 0.0 return scores["rouge-l"]["f"] def rouge_zh_score(prediction, reference, **kwargs): prediction = " ".join(list(jieba.cut(prediction, cut_all=False))) reference = " ".join(list(jieba.cut(reference, cut_all=False))) score = rouge_score(prediction, reference) return score def _f1_score(prediction, reference, **kwargs): common = Counter(prediction) & Counter(reference) num_same = sum(common.values()) if num_same == 0: return 0 precision = 1.0 * num_same / len(prediction) recall = 1.0 * num_same / len(reference) f1 = (2 * precision * recall) / (precision + recall) return f1 def f1_score(prediction, reference, **kwargs): normalized_prediction = normalize_answer(prediction) normalized_ground_truth = normalize_answer(reference) prediction_tokens = normalized_prediction.split() ground_truth_tokens = normalized_ground_truth.split() return _f1_score(prediction_tokens, ground_truth_tokens) def f1_zh_score(prediction, reference, **kwargs): prediction_tokens = list(jieba.cut(prediction, cut_all=False)) ground_truth_tokens = list(jieba.cut(reference, cut_all=False)) prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens] ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens] prediction_tokens = [token for token in prediction_tokens if len(token) > 0] ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0] return _f1_score(prediction_tokens, ground_truth_tokens)