Browse Source

support UniEval and add CHRF metric (#3924)

Co-authored-by: Yuanchen Xu <yuanchen.xu00@gmail.com>
pull/3942/head
Yuanchen 1 year ago committed by GitHub
parent
commit
21c4c0b1a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 105
      applications/Chat/evaluate/README.md
  2. 12
      applications/Chat/evaluate/config/config_cn.json
  3. 73
      applications/Chat/evaluate/config/config_en.json
  4. 2
      applications/Chat/evaluate/eval.py
  5. 142
      applications/Chat/evaluate/evaluator.py
  6. 2
      applications/Chat/evaluate/gpt_evaluate.py
  7. 22
      applications/Chat/evaluate/metrics.py
  8. 12
      applications/Chat/evaluate/unieval/__init__.py
  9. 330
      applications/Chat/evaluate/unieval/evaluator.py
  10. 101
      applications/Chat/evaluate/unieval/scorer.py
  11. 248
      applications/Chat/evaluate/unieval/utils.py
  12. 2
      applications/Chat/evaluate/utils.py

105
applications/Chat/evaluate/README.md

@ -12,12 +12,13 @@ pip install -r requirements.txt
## Evaluation Pipeline
The whole evaluation pipeline consists of two methods:
The whole evaluation pipeline consists of three methods:
1. `GPT Evaluation`: evaluates model predictions using GPT models.
* Compare the performance of two different models (battle).
* Rate the model according to pre-defined metrics using prompting design.
2. `Automatic Evaluation`: evaluates model predictions using automatic metrics.
3. `UniEval`: evaluates model predictions using UniEval models(English only).
### Evaluation Category
@ -75,7 +76,9 @@ GPT evaluation uses GPT models to evaluate the prediction of different models an
GPT models evaluate the quality of model predictions based on the given prompt words and gives a score between 1-5.
> **NOTE:** Even for the same metric, the details of its prompt words and CoT(Chain-of-Thought) can differ based on which category you want to evaluate. For example, prompt words for metric `correctness` showed here is "The answer should be in line with common sense, life experience, etc."(this is for category `brainstorming`), but for category `extraction`, prompt words can be "Answers should extract the required information accurately and should not contain any incorrect or misleading information." You can find all the prompt words and CoT(Chain-of-Thought) in `prompt/evaluation_prompt`.
> **NOTE 1:** Even for the same metric, the details of its prompt words and CoT(Chain-of-Thought) can differ based on which category you want to evaluate. For example, prompt words for metric `correctness` showed here is "The answer should be in line with common sense, life experience, etc."(this is for category `brainstorming`), but for category `extraction`, prompt words can be "Answers should extract the required information accurately and should not contain any incorrect or misleading information." You can find all the prompt words and CoT(Chain-of-Thought) in `prompt/evaluation_prompt`.
> **NOTE 2:** To add customized metrics, you can refer to [FAQ](#faq).
#### Automatic Evaluation
@ -85,7 +88,7 @@ There are two ways to obtain reference answers:
* For instruction coming from human-designed problems, the reference answers are generated by GPT-3.5, such as roleplay, chat.
* For instruction related with classic NLP problems, the reference answers are collected from open-sourced dataset with target answers, such as classification, extraction, summarization.
There are 5 types of automatic evaluation metrics listed in the table below:
There are 6 types of automatic evaluation metrics listed in the table below:
| Automatic Evaluation Metric | Description |
| :---------------------------------: | :----------------------------------------------------------- |
@ -94,6 +97,25 @@ There are 5 types of automatic evaluation metrics listed in the table below:
| Distinct | Measure the diversity of generation text by counting the unique n-grams. |
| BERTScore | Measure the semantic similarity between tokens of predictions and references with BERT. |
| Precision<br/> Recall<br/> F1 Score | Measure the number of overlaps between prediction and reference (design for classification and extraction categories). |
| CHRF | Measure the similarity of character n-grams between prediction and reference. |
#### UniEval Evaluation
UniEval converts all evaluation tasks of different dimensions(metrics) into Boolean QA problems and utilize the model to answer with “Yes” or “No”. Compared with similarity-based metrics such as ROUGE and BLEU, UniEval can achieve a more comprehensive evaluation. In addition, UniEval also demonstrates its ability to transfer to unseen dimensions and tasks.
In our evaluation pipeline, two pre-trained UniEval evaluators are used. One is [unieval-sum](https://huggingface.co/MingZhong/unieval-sum) and the other is [unieval-dialog](https://huggingface.co/MingZhong/unieval-dialog). The two models can be used for the 3 tasks, `summarization`, `dialogue` and `data2text`. Each task has different evaluation dimensions.
| UniEval Model | Task | Dimension(Metric) |
| :------------: | :----------------- | :--- |
| unieval-sum | summarization | coherence: whether the summary is coherent<br/>consistency: whether the claim is consistent with the given document<br/>fluency: whether the paragraph is fluent<br/>relevance: whether the summary is relevant to the reference |
| unieval-sum | data2text | naturalness: whether the utterance is fluent<br/>informativeness: whether the utterance is informative according to the reference |
| unieval-dialog | dialogue | naturalness: whether the response is natural in the dialogue<br/>coherence: whether the response is coherent in the dialogue history<br/>understandability: whether the response is understandable in the dialogue |
> **NOTE 1:** Task "data2text" uses the same model as task "summarization".
> **NOTE 2:** In UniEval paper, the `unieval-sum` model demonstrates the best transfer ability and so you can evaluate your customized metric with this model. Details of adding customized metrics can be found in [FAQ](#faq).
> **NOTE 3:** We consider not including all metrics provided in UniEval in our pipeline because the data structure and content of the instructions we want to evaluate are not suitable for direct use of some UniEval metrics.
## Evaluation Process
@ -215,19 +237,26 @@ The following is an example of a Chinese GPT evaluation prompt. In an evaluation
#### Configuration
The following is an example of a Chinese config file. The configuration file can control how the pipeline evaluates the model. You need to specify GPT evaluation metrics and automatic metrics in key `GPT` and `Metrics`. You can find an example Chinese config file in `config`.
The following is an example of a Chinese config file. The configuration file can control how the pipeline evaluates the model. You need to specify GPT evaluation metrics, automatic metrics and UniEval metrics in key `GPT`, `Metrics` and `UniEval`(English only). You can find an example English config file in `config`.
```json
{
"language": "cn",
"language": "en",
"path_for_UniEval": {
"summarization": "path to unieval-sum model",
"dialogue": "path to unieval-dialog model",
"data2text": "path to unieval-sum model"
},
"category": {
"brainstorming": {
"GPT": ["relevance", "creativity", "practicality", "correctness"],
"Metrics": ["Distinct"]
"Metrics": ["Distinct"],
"UniEval": ["summarization-fluency", "data2text-naturalness", "data2text-informativeness"]
},
"chat": {
"GPT": [ "relevance", "naturalness", "engagingness", "reasonableness"],
"Metrics": ["Distinct"]
"Metrics": ["Distinct"],
"UniEval": ["dialogue-naturalness", "dialogue-coherence", "dialogue-understandability"]
}
}
}
@ -235,27 +264,33 @@ The following is an example of a Chinese config file. The configuration file can
`"language"`: the language used to evaluate the model capability. We only support Chinese `"cn"` for now.
`"path_for_UniEval"`: path to the UniEval model.
`"category"`: the category/categories needed to evaluate the model capability.
`"GPT"`: the metrics you want to use for GPT evaluation.
`"Metrics"`: the metrics you want to use for automatic metrics evaluation.
`"UniEval"`: the metrics you want to use for UniEval metrics evaluation. The metric has to be in the `"{task}-{metric}"` format because different tasks have same metrics such as naturalness and coherence.
You can remove the key such as `"Metrics"` to skip evaluating answers using its corresponding evaluation metrics.
You can create your config file based on available settings listed in following table.
| "category" | "GPT" | "Metrics" |
| :--------------: | :---------------------: | :---------: |
| "brainstorming" | "language organization" | "BLEU" |
| "chat" | "relevance" | "ROUGE" |
| "classification" | "creativity" | "Distinct" |
| "closed_qa" | "practicality" | "BERTScore" |
| "extraction" | "correctness" | "Precision" |
| "generation" | "naturalness" | "Recall" |
| "open_qa" | "engagingness" | "F1 score" |
| "rewriting" | "reasonableness" | |
| "roleplay" | "diversity" | |
| "summarization" | "fidelity" | |
| | "conciseness" | |
| "category" | "GPT" | "Metrics" | "UniEval" |
| :--------------: | :---------------------: | :---------: | :--------------------------: |
| "brainstorming" | "language organization" | "BLEU" | "dialogue-naturalness" |
| "chat" | "relevance" | "ROUGE" | "dialogue-coherence" |
| "classification" | "creativity" | "Distinct" | "dialogue-understandability" |
| "closed_qa" | "practicality" | "BERTScore" | "data2text-naturalness" |
| "extraction" | "correctness" | "Precision" | "data2text-informativeness" |
| "generation" | "naturalness" | "Recall" | "summarization-coherence" |
| "open_qa" | "engagingness" | "F1 score" | "summarization-consistency" |
| "rewriting" | "reasonableness" | "CHRF" | "summarization-fluency" |
| "roleplay" | "diversity" | | "summarization-relevance" |
| "summarization" | "fidelity" | | |
| | "conciseness" | | |
> **NOTE:** For categories which don't have standard answers such as `brainstorming`, you should avoid using automatic metrics such as `BLEU` and `ROUGE` which are based on similarity measures and you should use `Distinct` instead in your config file.
@ -290,23 +325,36 @@ For example, if you want to add a new metric `persuasiveness` into category `bra
"id": 1,
"category": "brainstorming",
"metrics": {
"persuasiveness": "说服力(1-5):XXX"
"persuasiveness": "persuasiveness(1-5):a short description for persuasiveness"
},
"CoT": {
"persuasiveness": "XXX\n\n说服力:"
"persuasiveness": "CoT for persuasiveness\n\npersuasiveness:"
},
"prompt": "你是一个好助手。请你为下面“头脑风暴”问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}"
"prompt": "You are a good assistant. Please rate the given answer to the \"brainstorming\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}"
}
}
```
</details>
<details><summary><b>How can I add a new UniEval evaluation metric?</b></summary>
For example, if you want to add a new metric `persuasiveness` into task `data2text`, you should add a Boolean QA question about the metric in function `add_question` in `unieval/utils.py`. Please do note that how effectively the model would evaluate this metric is unknown and you may need some experiments to test whether the model is capable of evaluating this metric.
```python
if task == 'data2text':
if dimension == 'persuasiveness':
cur_input = 'question: Is this a persuasive utterence </s> utterance: ' + output[i]
```
</details>
## To Do
- [x] Add evaluation for English capability
- [ ] Support UniEval
- [x] Support UniEval
- [x] Support GPT-4 evaluation
- [ ] Support GPT evaluation with reference in the prompt
## Citations
@ -327,4 +375,13 @@ For example, if you want to add a new metric `persuasiveness` into category `bra
archivePrefix={arXiv},
primaryClass={cs.CL}
}
@misc{zhong2022unified,
title={Towards a Unified Multi-Dimensional Evaluator for Text Generation},
author={Ming Zhong and Yang Liu and Da Yin and Yuning Mao and Yizhu Jiao and Pengfei Liu and Chenguang Zhu and Heng Ji and Jiawei Han},
year={2022},
eprint={2210.07197},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```

12
applications/Chat/evaluate/config/config_cn.json

@ -34,7 +34,8 @@
"Metrics": [
"Precision",
"Recall",
"F1 score"
"F1 score",
"CHRF"
]
},
"closed_qa": {
@ -46,7 +47,8 @@
"Metrics": [
"BLEU",
"ROUGE",
"BERTScore"
"BERTScore",
"CHRF"
]
},
"extraction": {
@ -58,7 +60,8 @@
"Metrics": [
"Precision",
"Recall",
"F1 score"
"F1 score",
"CHRF"
]
},
"generation": {
@ -116,7 +119,8 @@
"Metrics": [
"BLEU",
"ROUGE",
"BERTScore"
"BERTScore",
"CHRF"
]
}
}

73
applications/Chat/evaluate/config/config_en.json

@ -1,5 +1,10 @@
{
"language": "en",
"path_for_UniEval": {
"summarization": "path to unieval-sum",
"dialogue": "path to unieval-dialog",
"data2text": "path to unieval-sum"
},
"category": {
"brainstorming": {
"GPT": [
@ -11,6 +16,11 @@
],
"Metrics": [
"Distinct"
],
"UniEval": [
"summarization-fluency",
"data2text-naturalness",
"data2text-informativeness"
]
},
"chat": {
@ -23,6 +33,14 @@
],
"Metrics": [
"Distinct"
],
"UniEval": [
"summarization-fluency",
"dialogue-naturalness",
"dialogue-coherence",
"dialogue-understandability",
"data2text-naturalness",
"data2text-informativeness"
]
},
"classification": {
@ -34,7 +52,13 @@
"Metrics": [
"Precision",
"Recall",
"F1 score"
"F1 score",
"CHRF"
],
"UniEval": [
"summarization-fluency",
"data2text-naturalness",
"data2text-informativeness"
]
},
"closed_qa": {
@ -46,7 +70,13 @@
"Metrics": [
"BLEU",
"ROUGE",
"BERTScore"
"BERTScore",
"CHRF"
],
"UniEval": [
"summarization-fluency",
"data2text-naturalness",
"data2text-informativeness"
]
},
"extraction": {
@ -58,7 +88,13 @@
"Metrics": [
"Precision",
"Recall",
"F1 score"
"F1 score",
"CHRF"
],
"UniEval": [
"summarization-fluency",
"data2text-naturalness",
"data2text-informativeness"
]
},
"generation": {
@ -71,6 +107,11 @@
"BLEU",
"ROUGE",
"BERTScore"
],
"UniEval": [
"summarization-fluency",
"data2text-naturalness",
"data2text-informativeness"
]
},
"open_qa": {
@ -81,6 +122,11 @@
],
"Metrics": [
"Distinct"
],
"UniEval": [
"summarization-fluency",
"data2text-naturalness",
"data2text-informativeness"
]
},
"rewriting": {
@ -93,6 +139,11 @@
"BLEU",
"ROUGE",
"BERTScore"
],
"UniEval": [
"summarization-fluency",
"data2text-naturalness",
"data2text-informativeness"
]
},
"roleplay": {
@ -104,6 +155,11 @@
],
"Metrics": [
"Distinct"
],
"UniEval": [
"summarization-fluency",
"data2text-naturalness",
"data2text-informativeness"
]
},
"summarization": {
@ -116,7 +172,16 @@
"Metrics": [
"BLEU",
"ROUGE",
"BERTScore"
"BERTScore",
"CHRF"
],
"UniEval": [
"summarization-coherence",
"summarization-consistency",
"summarization-fluency",
"summarization-relevance",
"data2text-naturalness",
"data2text-informativeness"
]
}
}

2
applications/Chat/evaluate/eval.py

@ -40,7 +40,7 @@ def main(args):
# initialize evaluator
evaluator = Evaluator(metrics_per_category, battle_prompt, gpt_evaluation_prompt, args.gpt_model,
config["language"])
config["language"], config.get("path_for_UniEval", None))
if len(args.model_name_list) == 2:
answers1 = jload(args.answer_file_list[0])
answers2 = jload(args.answer_file_list[1])

142
applications/Chat/evaluate/evaluator.py

@ -4,6 +4,7 @@ from typing import Any, Dict, List
import gpt_evaluate
import metrics
import pandas as pd
import unieval
from utils import analyze_automatic_results, get_data_per_category, save_automatic_results
@ -15,13 +16,15 @@ class Evaluator(object):
"""
def __init__(self, params: Dict[str, Any], battle_prompt: Dict[str, Any], gpt_evaluation_prompt: Dict[str, Any],
gpt_model: str, language: str) -> None:
gpt_model: str, language: str, path_for_UniEval: Dict[str, str]) -> None:
self.params = params
self.battle_prompt = battle_prompt
self.gpt_evaluation_prompt = gpt_evaluation_prompt
self.gpt_model = gpt_model
self.language = language
self.path_for_UniEval = path_for_UniEval
self.automatic_metric_stats = dict()
self.unieval_metric_stats = dict()
self.gpt_evaluation_results = dict()
self.battle_results = []
@ -47,16 +50,18 @@ class Evaluator(object):
return metrics.bleu_score(preds=predicts_list, targets=targets_list, language=language)
elif metric == "ROUGE":
return metrics.rouge_score(preds=predicts_list, targets=targets_list, language=language)
elif (metric == "Distinct"):
elif metric == "Distinct":
return metrics.distinct_score(preds=predicts_list, language=language)
elif (metric == "BERTScore"):
elif metric == "BERTScore":
return metrics.bert_score(preds=predicts_list, targets=targets_list, language=language)
elif (metric == "Precision"):
elif metric == "Precision":
return metrics.precision(preds=predicts_list, targets=targets_list, language=language)
elif (metric == "Recall"):
elif metric == "Recall":
return metrics.recall(preds=predicts_list, targets=targets_list, language=language)
elif (metric == "F1 score"):
elif metric == "F1 score":
return metrics.F1_score(preds=predicts_list, targets=targets_list, language=language)
elif metric == "CHRF":
return metrics.chrf_score(preds=predicts_list, targets=targets_list, language=language)
else:
raise ValueError(f"Unexpected metric")
@ -69,6 +74,9 @@ class Evaluator(object):
print(f"Category {category} specified in your config doesn't have corresponding answers!")
continue
if self.params[category].get("Metrics", None) is None:
continue
category_metrics = self.params[category]["Metrics"]
self.automatic_metric_stats[category] = {}
@ -80,12 +88,68 @@ class Evaluator(object):
for metric in category_metrics:
self.automatic_metric_stats[category].update(switch(metric=metric, language=self.language))
# UniEval evaluation
# self.unieval_metric_stats's key is "task" instead of "category".
# Iterating "task" first will avoid repeated loading models because one task corresponds to one UniEval model.
# If key is "category", different models will be loaded for multiple times across categories because the user may require different task(models) to evaluate one category.
for category in self.params:
if len(answers_per_category[category]) == 0:
print(f"Category {category} specified in your config doesn't have corresponding answers!")
continue
if self.params[category].get("UniEval", None) is None:
continue
if self.params[category]["UniEval"] and self.language == "cn":
raise Exception(
"UniEval doesn't support Chinese! Please remove UniEval config in your Chinese config file.")
category_metrics = self.params[category]["UniEval"]
for task, metric in [tuple(category_metric.split("-")) for category_metric in category_metrics]:
if self.unieval_metric_stats.get(task, None) is None:
self.unieval_metric_stats[task] = {category: {metric: 0}}
elif self.unieval_metric_stats[task].get(category, None) is None:
self.unieval_metric_stats[task][category] = {metric: 0}
else:
self.unieval_metric_stats[task][category][metric] = 0
for task in self.unieval_metric_stats:
if self.path_for_UniEval is None:
raise Exception(f"Please specify the path for UniEval model in the config file!")
if self.path_for_UniEval.get(task, None) is None:
raise Exception(f"Please specify the model path for task {task} in the config file!")
print(f"Load UniEval model for task {task}.")
uni_evaluator = unieval.get_evaluator(task, model_name_or_path=self.path_for_UniEval[task])
for category in self.unieval_metric_stats[task]:
targets_list = [
target["target"] if target["target"] else target["output"]
for target in targets_per_category[category]
]
predicts_list = [answer["output"] for answer in answers_per_category[category]]
sources_list = [answer["instruction"] + answer["input"] for answer in answers_per_category[category]]
data = unieval.convert_data_to_unieval_format(predicts_list, sources_list, targets_list)
scores = uni_evaluator.evaluate(data,
category,
dims=list(self.unieval_metric_stats[task][category].keys()),
overall=False)
avg_scores = unieval.calculate_average_score(scores)
self.unieval_metric_stats[task][category].update(avg_scores)
# gpt evaluation
for category in self.params:
if len(answers_per_category[category]) == 0:
print(f"Category {category} specified in your config doesn't have corresponding answers!")
continue
if self.params[category].get("GPT", None) is None:
continue
category_metrics = self.params[category]["GPT"]
prompt = self.gpt_evaluation_prompt.get(category, None)
@ -106,29 +170,43 @@ class Evaluator(object):
save_path = os.path.join(path, "gpt_evaluate", "battle_results")
gpt_evaluate.save_battle_results(self.battle_results, model_name_list[0], model_name_list[1], save_path)
else:
# Save evaluation results for automatic metrics
automatic_base_save_path = os.path.join(path, "automatic_results")
automatic_results_save_path = os.path.join(automatic_base_save_path, "evaluation_results")
save_automatic_results(model_name_list[0], self.automatic_metric_stats, automatic_results_save_path)
# Save charts and csv.
automatic_analyses_save_path = os.path.join(automatic_base_save_path, "evaluation_analyses")
analyze_automatic_results(automatic_results_save_path, automatic_analyses_save_path)
# Save evaluation results for GPT evaluation metrics.
gpt_base_save_path = os.path.join(path, "gpt_evaluate", "gpt_evaluate_results")
gpt_evaluation_results_save_path = os.path.join(gpt_base_save_path, "evaluation_results")
all_evaluations = gpt_evaluate.save_gpt_evaluation_results(model_name_list[0], self.gpt_evaluation_results,
gpt_evaluation_results_save_path)
# Start to calculate scores and save statistics.
gpt_evaluation_statistics_save_path = os.path.join(gpt_base_save_path, "evaluation_statistics")
gpt_evaluate.save_gpt_evaluation_statistics(model_name_list[0], all_evaluations,
gpt_evaluation_statistics_save_path)
# Save charts and csv.
gpt_evaluation_analyses_save_path = os.path.join(gpt_base_save_path, "evaluation_analyses")
gpt_evaluate.analyze_gpt_evaluation_statistics(gpt_evaluation_statistics_save_path,
gpt_evaluation_analyses_save_path)
if self.automatic_metric_stats:
# Save evaluation results for automatic metrics
automatic_base_save_path = os.path.join(path, "automatic_results")
automatic_results_save_path = os.path.join(automatic_base_save_path, "evaluation_results")
save_automatic_results(model_name_list[0], self.automatic_metric_stats, automatic_results_save_path)
# Save charts and csv.
automatic_analyses_save_path = os.path.join(automatic_base_save_path, "evaluation_analyses")
analyze_automatic_results(automatic_results_save_path, automatic_analyses_save_path)
if self.unieval_metric_stats:
# Save evaluation results for UniEval metrics
unieval_base_save_path = os.path.join(path, "unieval_results")
unieval_results_save_path = os.path.join(unieval_base_save_path, "evaluation_results")
unieval.save_unieval_results(model_name_list[0], self.unieval_metric_stats, unieval_results_save_path)
# Save charts and csv.
unieval_analyses_save_path = os.path.join(unieval_base_save_path, "evaluation_analyses")
unieval.analyze_unieval_results(unieval_results_save_path, unieval_analyses_save_path)
if self.gpt_evaluation_results:
# Save evaluation results for GPT evaluation metrics.
gpt_base_save_path = os.path.join(path, "gpt_evaluate", "gpt_evaluate_results")
gpt_evaluation_results_save_path = os.path.join(gpt_base_save_path, "evaluation_results")
all_evaluations = gpt_evaluate.save_gpt_evaluation_results(model_name_list[0],
self.gpt_evaluation_results,
gpt_evaluation_results_save_path)
# Start to calculate scores and save statistics.
gpt_evaluation_statistics_save_path = os.path.join(gpt_base_save_path, "evaluation_statistics")
gpt_evaluate.save_gpt_evaluation_statistics(model_name_list[0], all_evaluations,
gpt_evaluation_statistics_save_path)
# Save charts and csv.
gpt_evaluation_analyses_save_path = os.path.join(gpt_base_save_path, "evaluation_analyses")
gpt_evaluate.analyze_gpt_evaluation_statistics(gpt_evaluation_statistics_save_path,
gpt_evaluation_analyses_save_path)

2
applications/Chat/evaluate/gpt_evaluate.py

@ -599,7 +599,7 @@ def analyze_gpt_evaluation_statistics(statistics_path: str, save_path: str) -> N
for category in tqdm.tqdm(
frame_per_category.keys(),
desc=f"category: ",
desc=f"GPT evaluation: ",
total=len(frame_per_category.keys()),
):
data = pd.DataFrame(frame_per_category[category])

22
applications/Chat/evaluate/metrics.py

@ -4,6 +4,7 @@ from typing import Dict, List
import jieba
from bert_score import score
from nltk.translate.bleu_score import sentence_bleu
from nltk.translate.chrf_score import sentence_chrf
from rouge_chinese import Rouge as Rouge_cn
from rouge_score import rouge_scorer as Rouge_en
from sklearn.metrics import f1_score, precision_score, recall_score
@ -40,6 +41,27 @@ def bleu_score(preds: List[str], targets: List[str], language: str) -> Dict[str,
return bleu_scores
def chrf_score(preds: List[str], targets: List[str], language: str) -> Dict[str, float]:
"""Calculate CHRF Score Metric in sentence level.
"""
chrf_score = {"chrf": 0}
cumulative_chrf = []
for pred, target in zip(preds, targets):
if language == "cn":
pred_list = ' '.join(jieba.cut(preprocessing_text(pred))).split()
target_list = ' '.join(jieba.cut(preprocessing_text(target))).split()
elif language == "en":
pred_list = preprocessing_text(pred).split()
target_list = preprocessing_text(target).split()
cumulative_chrf.append(sentence_chrf(target_list, pred_list))
chrf_score["chrf"] = statistics.mean(cumulative_chrf)
return chrf_score
def rouge_cn_score(preds: List[str], targets: List[str]) -> Dict[str, float]:
"""Calculate Chinese ROUGE Score Metric

12
applications/Chat/evaluate/unieval/__init__.py

@ -0,0 +1,12 @@
from .evaluator import get_evaluator
from .utils import (
analyze_unieval_results,
calculate_average_score,
convert_data_to_unieval_format,
save_unieval_results,
)
__all__ = [
'get_evaluator', 'convert_data_to_unieval_format', 'calculate_average_score', 'save_unieval_results',
'analyze_unieval_results'
]

330
applications/Chat/evaluate/unieval/evaluator.py

@ -0,0 +1,330 @@
# MIT License
# Copyright (c) 2022 Ming Zhong
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import numpy as np
from nltk import sent_tokenize
from .scorer import UniEvaluator
from .utils import add_question
class SumEvaluator:
def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None):
""" Set up evaluator for text summarization """
self.scorer = UniEvaluator(
model_name_or_path='MingZhong/unieval-sum' if model_name_or_path == "" else model_name_or_path,
max_length=max_length,
device=device,
cache_dir=cache_dir)
self.task = 'summarization'
self.dimensions = ['coherence', 'consistency', 'fluency', 'relevance']
def evaluate(self, data, category, dims=None, overall=True):
"""
Get the scores of all the given dimensions
category: The category to be evaluated.
dims: A list of dimensions to be evaluated. If dims is None, SumEvaluator will evaluate
four dimensions: coherence, consistency, fluency, relevance.
overall: indicates whether the overall score is to be calculated.
Overall score can be customized to a combination of scores based on different
dimensions. The default here is the average score of all the given dimensions.
"""
n_data = len(data)
eval_scores = [{} for _ in range(n_data)]
if dims == None:
eval_dims = self.dimensions
else:
assert isinstance(dims, list)
eval_dims = dims
for dim in eval_dims:
# Calculate average sentence-level scores for 'consistency' and 'fluency'
if dim == 'consistency' or dim == 'fluency':
src_list, output_list = [], []
n_sents = [] # the number of sentences in each generated summary
for i in range(n_data):
source = data[i]['source']
system_outputs = sent_tokenize(data[i]['system_output'])
n_sents.append(len(system_outputs))
for j in range(len(system_outputs)):
src_list.append(source)
output_list.append(system_outputs[j])
input_list = add_question(dimension=dim, output=output_list, src=src_list, task=self.task)
sent_score = self.scorer.score(input_list, self.task, category, dim)
# Get average score for each sample
start_idx = 0
score = []
for cur_n_sent in n_sents:
score.append(sum(sent_score[start_idx:start_idx + cur_n_sent]) / cur_n_sent)
start_idx += cur_n_sent
# Calculate summary-level score for 'coherence' and 'relevance'
elif dim == 'coherence' or dim == 'relevance':
src_list, output_list, ref_list = [], [], []
for i in range(n_data):
src_list.append(data[i]['source'])
output_list.append(data[i]['system_output'])
if dim == 'relevance':
ref_list.append(data[i]['reference'])
input_list = add_question(dimension=dim, output=output_list, src=src_list, ref=ref_list, task=self.task)
score = self.scorer.score(input_list, self.task, category, dim)
# Please customize other dimensions here for summarization
else:
raise NotImplementedError('The input format for this dimension is still undefined. \
Please customize it first.')
for i in range(n_data):
eval_scores[i][dim] = score[i]
# Customize your overall score here.
if overall == True:
for i in range(n_data):
eval_scores[i]['overall'] = np.mean(list(eval_scores[i].values()))
return eval_scores
class DialogEvaluator:
def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None):
""" Set up evaluator for dialogues """
self.scorer = UniEvaluator(
model_name_or_path='MingZhong/unieval-dialog' if model_name_or_path == "" else model_name_or_path,
max_length=max_length,
device=device,
cache_dir=cache_dir)
self.task = 'dialogue'
self.dimensions = ['naturalness', 'coherence', 'engagingness', 'groundedness', 'understandability']
def evaluate(self, data, category, dims=None, overall=True):
"""
Get the scores of all the given dimensions
category: The category to be evaluated.
dims: A list of dimensions to be evaluated. If dims is None, DialogEvaluator will evaluate
five dimensions: naturalness, coherence, engagingness, groundedness and understandability.
overall: indicates whether the overall score is to be calculated.
Overall score can be customized to a combination of scores based on different
dimensions. The default here is the average score of all the given dimensions.
"""
n_data = len(data)
eval_scores = [{} for _ in range(n_data)]
if dims == None:
eval_dims = self.dimensions
else:
assert isinstance(dims, list)
eval_dims = dims
for dim in eval_dims:
# Calculate summation score for 'engagingness'
if dim == 'engagingness':
src_list, output_list, context_list = [], [], []
n_sents = [] # the number of sentences in each generated response
for i in range(n_data):
source = data[i]['source']
context = data[i]['context']
system_outputs = sent_tokenize(data[i]['system_output'])
n_sents.append(len(system_outputs))
for j in range(len(system_outputs)):
src_list.append(source)
context_list.append(context)
output_list.append(system_outputs[j])
input_list = add_question(dimension=dim,
output=output_list,
src=src_list,
context=context_list,
task=self.task)
sent_score = self.scorer.score(input_list, self.task, category, dim)
# Get the summation score for each sample
start_idx = 0
score = []
for cur_n_sent in n_sents:
score.append(sum(sent_score[start_idx:start_idx + cur_n_sent]))
start_idx += cur_n_sent
# Calculate turn-level score for other dimensions
elif dim in ['naturalness', 'coherence', 'groundedness', 'understandability']:
src_list, output_list, context_list = [], [], []
for i in range(n_data):
src_list.append(data[i]['source'])
output_list.append(data[i]['system_output'])
context_list.append(data[i]['context'])
input_list = add_question(dimension=dim,
output=output_list,
src=src_list,
context=context_list,
task=self.task)
score = self.scorer.score(input_list, self.task, category, dim)
# Please customize other dimensions here for summarization
else:
raise NotImplementedError('The input format for this dimension is still undefined. \
Please customize it first.')
for i in range(n_data):
eval_scores[i][dim] = score[i]
# Customize your overall score here.
if overall == True:
for i in range(n_data):
eval_scores[i]['overall'] = np.mean(list(eval_scores[i].values()))
return eval_scores
class D2tEvaluator:
def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None):
""" Set up evaluator for data-to-text """
self.scorer = UniEvaluator(
model_name_or_path='MingZhong/unieval-sum' if model_name_or_path == "" else model_name_or_path,
max_length=max_length,
device=device,
cache_dir=cache_dir)
self.task = 'data2text'
self.dimensions = ['naturalness', 'informativeness']
def evaluate(self, data, category, dims=None, overall=True):
"""
Get the scores of all the given dimensions
category: The category to be evaluated.
dims: A list of dimensions to be evaluated. If dims is None, D2tEvaluator will evaluate
two dimensions: naturalness and informativeness.
overall: indicates whether the overall score is to be calculated.
Overall score can be customized to a combination of scores based on different
dimensions. The default here is the average score of all the given dimensions.
"""
n_data = len(data)
eval_scores = [{} for _ in range(n_data)]
if dims == None:
eval_dims = self.dimensions
else:
assert isinstance(dims, list)
eval_dims = dims
for dim in eval_dims:
output_list, ref_list = [], []
for i in range(n_data):
output_list.append(data[i]['system_output'])
ref_list.append(data[i]['reference'])
input_list = add_question(dimension=dim, output=output_list, ref=ref_list, task=self.task)
score = self.scorer.score(input_list, self.task, category, dim)
for i in range(n_data):
eval_scores[i][dim] = score[i]
# Customize your overall score here.
if overall == True:
for i in range(n_data):
eval_scores[i]['overall'] = np.mean(list(eval_scores[i].values()))
return eval_scores
class FactEvaluator:
def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None):
""" Set up evaluator for factual consistency detection """
self.scorer = UniEvaluator(
model_name_or_path='MingZhong/unieval-fact' if model_name_or_path == "" else model_name_or_path,
max_length=max_length,
device=device,
cache_dir=cache_dir)
self.task = 'fact'
self.dim = 'consistency'
def evaluate(self, data, category):
"""
Get the factual consistency score (only 1 dimension for this task)
category: The category to be evaluated.
"""
n_data = len(data)
eval_scores = [{} for _ in range(n_data)]
# Calculate average sentence-level scores for facutal consistency
src_list, output_list = [], []
n_sents = [] # the number of sentences in the claim
for i in range(n_data):
source = data[i]['source']
system_outputs = sent_tokenize(data[i]['system_output'])
n_sents.append(len(system_outputs))
for j in range(len(system_outputs)):
src_list.append(source)
output_list.append(system_outputs[j])
input_list = add_question(dimension=self.dim, output=output_list, src=src_list, task=self.task)
sent_score = self.scorer.score(input_list, self.task, category, dim)
# Get average score for each sample
start_idx = 0
score = []
for cur_n_sent in n_sents:
score.append(sum(sent_score[start_idx:start_idx + cur_n_sent]) / cur_n_sent)
start_idx += cur_n_sent
for i in range(n_data):
eval_scores[i][self.dim] = score[i]
return eval_scores
def get_evaluator(task, model_name_or_path="", max_length=1024, device='cuda:0', cache_dir=None):
assert task in ['summarization', 'dialogue', 'data2text', 'fact']
if task == 'summarization':
return SumEvaluator(model_name_or_path=model_name_or_path,
max_length=max_length,
device=device,
cache_dir=cache_dir)
elif task == 'dialogue':
return DialogEvaluator(model_name_or_path=model_name_or_path,
max_length=max_length,
device=device,
cache_dir=cache_dir)
elif task == 'data2text':
return D2tEvaluator(model_name_or_path=model_name_or_path,
max_length=max_length,
device=device,
cache_dir=cache_dir)
elif task == 'fact':
return FactEvaluator(model_name_or_path=model_name_or_path,
max_length=max_length,
device=device,
cache_dir=cache_dir)
else:
raise NotImplementedError('Other tasks are not implemented, \
please customize specific tasks here.')

101
applications/Chat/evaluate/unieval/scorer.py

@ -0,0 +1,101 @@
# MIT License
# Copyright (c) 2022 Ming Zhong
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import torch
import torch.nn as nn
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer
class UniEvaluator:
def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None):
""" Set up model """
self.device = device
self.max_length = max_length
self.config = AutoConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir)
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, config=self.config, cache_dir=cache_dir)
self.model.eval()
self.model.to(device)
self.softmax = nn.Softmax(dim=1)
self.pos_id = self.tokenizer("Yes")["input_ids"][0]
self.neg_id = self.tokenizer("No")["input_ids"][0]
def score(self, inputs, task, category, dim, batch_size=8):
"""
Get scores for the given samples.
final_score = postive_score / (postive_score + negative_score)
"""
# The implementation of "forward" in T5 still requires decoder_input_ids.
# Therefore, we construct a random one-word target sequence.
# The content of the target has no effect on the final scores.
tgts = ["No" for _ in range(len(inputs))]
pos_score_list, neg_score_list = [], []
for i in tqdm(range(0, len(inputs), batch_size), desc=f"{category}-({dim}-{task}): "):
src_list = inputs[i:i + batch_size]
tgt_list = tgts[i:i + batch_size]
try:
with torch.no_grad():
encoded_src = self.tokenizer(src_list,
max_length=self.max_length,
truncation=True,
padding=True,
return_tensors='pt')
encoded_tgt = self.tokenizer(tgt_list,
max_length=self.max_length,
truncation=True,
padding=True,
return_tensors='pt')
src_tokens = encoded_src['input_ids'].to(self.device)
src_mask = encoded_src['attention_mask'].to(self.device)
tgt_tokens = encoded_tgt['input_ids'].to(self.device)[:, 0].unsqueeze(-1)
output = self.model(input_ids=src_tokens, attention_mask=src_mask, labels=tgt_tokens)
logits = output.logits.view(-1, self.model.config.vocab_size)
pos_score = self.softmax(logits)[:, self.pos_id] # Yes
neg_score = self.softmax(logits)[:, self.neg_id] # No
cur_pos_score = [x.item() for x in pos_score]
cur_neg_score = [x.item() for x in neg_score]
pos_score_list += cur_pos_score
neg_score_list += cur_neg_score
except RuntimeError:
print(f'source: {src_list}')
print(f'target: {tgt_list}')
exit(0)
score_list = []
for i in range(len(pos_score_list)):
score_list.append(pos_score_list[i] / (pos_score_list[i] + neg_score_list[i]))
return score_list

248
applications/Chat/evaluate/unieval/utils.py

@ -0,0 +1,248 @@
# MIT License
# Copyright (c) 2022 Ming Zhong
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import os
from typing import Dict
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import tqdm
def add_question(dimension, output, src=None, ref=None, context=None, task=None):
"""
Add questions to generate input in Bool-QA format for UniEval.
dimension: specific dimension to be evaluated
src: source input for different NLG tasks. For example, source document for summarization
and dialogue history for dialogue response generation.
output: output text generated by the models
ref: human-annotataed groundtruth
context: the context needed to evaluate several specific dimension. For example,
additional factual information when evaluating engagingness and groundedness in dialogues.
"""
input_with_question = []
for i in range(len(output)):
# For summarization
if task == 'summarization':
if dimension == 'fluency':
cur_input = 'question: Is this a fluent paragraph? </s> paragraph: ' + output[i]
elif dimension == 'coherence':
cur_input = 'question: Is this a coherent summary to the document? </s> summary: ' + output[
i] + ' </s> document: ' + src[i]
elif dimension == 'consistency':
cur_input = 'question: Is this claim consistent with the document? </s> claim: ' + output[
i] + ' </s> document: ' + src[i]
elif dimension == 'relevance':
cur_input = 'question: Is this summary relevant to the reference? </s> summary: ' + output[
i] + ' </s> reference: ' + ref[i]
else:
raise NotImplementedError(
'The input format for this dimension is still undefined. Please customize it first.')
# For dialogues
elif task == 'dialogue':
if dimension == 'naturalness':
cur_input = 'question: Is this a natural response in the dialogue? </s> response: ' + output[i]
elif dimension == 'coherence':
cur_input = 'question: Is this a coherent response given the dialogue history? </s> response: '\
+ output[i] + ' </s> dialogue history: ' + src[i]
elif dimension == 'engagingness':
cur_input = 'question: Is this an engaging and informative response according to the dialogue history and fact? </s> response: '\
+ output[i] + ' </s> dialogue history: ' + src[i] + ' </s> fact: ' + context[i]
elif dimension == 'groundedness':
cur_input = 'question: Is this response consistent with knowledge in the fact? </s> response: '\
+ output[i] + ' </s> fact: ' + context[i]
elif dimension == 'understandability':
cur_input = 'question: Is this an understandable response in the dialogue? </s> response: ' + output[i]
else:
raise NotImplementedError(
'The input format for this dimension is still undefined. Please customize it first.')
# For data-to-text
elif task == 'data2text':
if dimension == 'naturalness':
cur_input = 'question: Is this a fluent utterance? </s> utterance: ' + output[i]
elif dimension == 'informativeness':
cur_input = 'question: Is this sentence informative according to the reference? </s> sentence: '\
+ output[i] + ' </s> reference: ' + ref[i]
else:
raise NotImplementedError(
'The input format for this dimension is still undefined. Please customize it first.')
# For factual consistency detection
elif task == 'fact':
if dimension == 'consistency':
cur_input = 'question: Is this claim consistent with the document? </s> claim: ' + output[
i] + ' </s> document: ' + src[i]
else:
raise NotImplementedError('No other dimensions for the factual consistency detection task.')
# For new customized tasks
else:
raise NotImplementedError('Other tasks are not implemented, please customize specific tasks here.')
input_with_question.append(cur_input)
return input_with_question
def convert_data_to_unieval_format(output_list, src_list=None, ref_list=None):
"""
Convert the data into the unieval's format.
output_list: a list of model output
src_list: source input for different NLG tasks. For example, source document for summarization
and dialogue history for dialogue response generation
ref_list: human-annotated groundtruth
"""
json_data = []
for i in range(len(output_list)):
cur = {}
cur['system_output'] = output_list[i]
if src_list is not None:
cur['source'] = src_list[i]
if ref_list is not None:
cur['reference'] = ref_list[i]
cur['context'] = ""
json_data.append(cur)
return json_data
def calculate_average_score(scores):
"""
Calculate average scores for different metrics
scores: a list of scores for different metrics for each answer
"""
metrics = {metric: 0 for metric in scores[0]}
for score in scores:
for metric in score:
metrics[metric] += score[metric]
for metric in metrics:
metrics[metric] /= len(scores)
return metrics
def save_unieval_results(model_name: str, unieval_metric_stats: Dict[str, Dict], save_path: str) -> None:
"""
Save UniEval evaluation results of different categories for one model.
"""
if not os.path.exists(save_path):
os.makedirs(save_path)
unieval_metric_stats_per_category = {}
for task, category_stat in unieval_metric_stats.items():
for category, metric_stat in category_stat.items():
if unieval_metric_stats_per_category.get(category, None) is None:
unieval_metric_stats_per_category[category] = {}
for metric, score in metric_stat.items():
unieval_metric_stats_per_category[category][f"{metric}-{task}"] = score
automatic_df = pd.DataFrame(unieval_metric_stats_per_category)
automatic_df.to_csv(os.path.join(save_path, f"{model_name}_results.csv"), index=True)
def read_unieval_results(results_path: str, file_name: str) -> Dict[str, Dict]:
"""
Read a csv file and return a dictionary which stores scores per metric.
"""
results = pd.read_csv(os.path.join(results_path, file_name), index_col=0)
results_dict = {metric: {} for metric in list(results.index)}
for i, metric in enumerate(results_dict.keys()):
for j, category in enumerate(list(results.columns)):
if pd.isnull(results.iloc[i][j]):
continue
results_dict[metric][category] = results.iloc[i][j]
return results_dict
def analyze_unieval_results(results_path: str, save_path: str) -> None:
"""
Analyze and visualize all csv files in the given folder.
"""
if not os.path.exists(results_path):
raise Exception(f'The given directory "{results_path}" doesn\'t exist! No results found!')
all_statistics = {}
for file_name in os.listdir(results_path):
if file_name.endswith("_results.csv"):
model_name = file_name.split("_results.csv")[0]
all_statistics[model_name] = read_unieval_results(results_path, file_name)
if len(list(all_statistics.keys())) == 0:
raise Exception(f'There are no csv files in the given directory "{results_path}"!')
frame_all = {"model": [], "category": [], "metric": [], "score": []}
frame_per_metric = {}
for model_name, model_statistics in all_statistics.items():
for metric, metric_statistics in model_statistics.items():
if frame_per_metric.get(metric) is None:
frame_per_metric[metric] = {"model": [], "category": [], "score": []}
for category, category_score in metric_statistics.items():
frame_all["model"].append(model_name)
frame_all["category"].append(category)
frame_all["metric"].append(metric)
frame_all["score"].append(category_score)
frame_per_metric[metric]["model"].append(model_name)
frame_per_metric[metric]["category"].append(category)
frame_per_metric[metric]["score"].append(category_score)
if not os.path.exists(save_path):
os.makedirs(save_path)
frame_all = pd.DataFrame(frame_all)
frame_all.to_csv(os.path.join(save_path, "unieval_statistics.csv"))
for metric in tqdm.tqdm(
frame_per_metric.keys(),
desc=f"UniEval metrics: ",
total=len(frame_per_metric.keys()),
):
data = pd.DataFrame(frame_per_metric[metric])
sns.set()
fig = plt.figure(figsize=(16, 10))
fig = sns.barplot(x="category", y="score", hue="model", data=data, dodge=True)
fig.set_title(
f"Comparison between Different Models for Metric {metric.split('-')[0].title()} in Task {metric.split('-')[1].title()}"
)
plt.xlabel("Evaluation Category")
plt.ylabel("Score")
figure = fig.get_figure()
figure.savefig(os.path.join(save_path, f"{metric}.png"), dpi=400)
plt.close()

2
applications/Chat/evaluate/utils.py

@ -199,7 +199,7 @@ def analyze_automatic_results(results_path: str, save_path: str) -> None:
for metric in tqdm.tqdm(
frame_per_metric.keys(),
desc=f"metric: ",
desc=f"automatic metrics: ",
total=len(frame_per_metric.keys()),
):
data = pd.DataFrame(frame_per_metric[metric])

Loading…
Cancel
Save