From e89b127d8ec9c14fc34ff9a1208b630069eb026f Mon Sep 17 00:00:00 2001 From: Michelle <97082656+MichelleMa8@users.noreply.github.com> Date: Mon, 26 Jun 2023 15:26:07 +0800 Subject: [PATCH] [chat]: fix chat evaluation possible bug (#4064) * fix chat eval * fix utils * fix utils * add comment --------- Co-authored-by: Qianran Ma --- applications/Chat/evaluate/metrics.py | 4 ++-- applications/Chat/evaluate/unieval/evaluator.py | 3 ++- applications/Chat/evaluate/utils.py | 13 +------------ 3 files changed, 5 insertions(+), 15 deletions(-) diff --git a/applications/Chat/evaluate/metrics.py b/applications/Chat/evaluate/metrics.py index e220226ec..77f9b6e98 100644 --- a/applications/Chat/evaluate/metrics.py +++ b/applications/Chat/evaluate/metrics.py @@ -141,8 +141,8 @@ def distinct_score(preds: List[str], language: str) -> Dict[str, float]: count_segs = len(pred_seg_list) unique_segs = set(pred_seg_list) count_unique_chars = len(unique_segs) - - cumulative_distinct.append(count_unique_chars / count_segs) + # prevent denominator from being 0 + cumulative_distinct.append(count_unique_chars / (count_segs + 1e-6)) elif language == "en": # calculate distinct 1-gram, 2-gram, 3-gram unique_ngram = [set() for _ in range(0, 3)] diff --git a/applications/Chat/evaluate/unieval/evaluator.py b/applications/Chat/evaluate/unieval/evaluator.py index d7f2f87f8..56cc6d2f9 100644 --- a/applications/Chat/evaluate/unieval/evaluator.py +++ b/applications/Chat/evaluate/unieval/evaluator.py @@ -80,7 +80,8 @@ class SumEvaluator: start_idx = 0 score = [] for cur_n_sent in n_sents: - score.append(sum(sent_score[start_idx:start_idx + cur_n_sent]) / cur_n_sent) + # prevent denominator from being 0 + score.append(sum(sent_score[start_idx:start_idx + cur_n_sent]) / (cur_n_sent + 1e-6)) start_idx += cur_n_sent # Calculate summary-level score for 'coherence' and 'relevance' diff --git a/applications/Chat/evaluate/utils.py b/applications/Chat/evaluate/utils.py index fefe25f5e..406e43db9 100644 --- a/applications/Chat/evaluate/utils.py +++ b/applications/Chat/evaluate/utils.py @@ -72,17 +72,6 @@ def get_data_per_category(data, categories): return data_per_category -def remove_articles(text: str) -> str: - """ - Remove articles "a, an, the" in the given text. - It is used in evaluation of automatic metrics. - - """ - - pattern = re.compile(r"\b(a|an|the)\b", re.UNICODE) - return re.sub(pattern, " ", text) - - def remove_punctuations(text: str) -> str: """ Remove punctuations in the given text. @@ -121,7 +110,7 @@ def preprocessing_text(text: str) -> str: """ - return remove_redundant_space(remove_articles(remove_punctuations(text.lower()))) + return remove_redundant_space(remove_punctuations(text.lower())) def save_automatic_results(model_name: str, automatic_metric_stats: Dict[str, Dict], save_path: str) -> None: