mirror of https://github.com/hpcaitech/ColossalAI
parent
f71e63b0f3
commit
239cd92eff
|
@ -6,6 +6,7 @@ from .colossalai import ColossalDataset
|
|||
from .gaokaobench import GaoKaoBenchDataset
|
||||
from .longbench import LongBenchDataset
|
||||
from .mmlu import MMLUDataset
|
||||
from .mtbench import MTBenchDataset
|
||||
|
||||
__all__ = [
|
||||
"AGIEvalDataset",
|
||||
|
@ -16,4 +17,5 @@ __all__ = [
|
|||
"LongBenchDataset",
|
||||
"MMLUDataset",
|
||||
"ColossalDataset",
|
||||
"MTBenchDataset",
|
||||
]
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
import copy
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List
|
||||
|
||||
from colossal_eval.utils import get_json_list
|
||||
|
||||
from colossalai.logging import DistributedLogger
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
default_inference_kwargs = {
|
||||
"calculate_loss": False,
|
||||
"all_classes": None,
|
||||
"language": "English",
|
||||
"pretrain": False,
|
||||
"max_new_tokens": 1024,
|
||||
"turns": 2,
|
||||
}
|
||||
|
||||
|
||||
class MTBenchDataset(BaseDataset):
|
||||
"""
|
||||
Dataset class for mt_bench dataset.
|
||||
Data source: https://github.com/lm-sys/FastChat/blob/main/fastchat/llm_judge/data/mt_bench/question.jsonl
|
||||
This dataset class will convert the original dataset into the inference dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, path, logger, few_shot):
|
||||
self.multiturn = True
|
||||
self.dataset = self.load(path, logger, few_shot)
|
||||
|
||||
@staticmethod
|
||||
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
|
||||
dataset = {"test": defaultdict(dict)}
|
||||
|
||||
file_path = os.path.join(path, "question.jsonl")
|
||||
ref_path = os.path.join(path, "reference_answer/gpt-4.jsonl")
|
||||
|
||||
reference = defaultdict(list)
|
||||
ref_origin = get_json_list(ref_path)
|
||||
for ref in ref_origin:
|
||||
reference[ref["question_id"]] = ref["choices"][0]["turns"]
|
||||
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
for line in file:
|
||||
question = json.loads(line)
|
||||
category = question["category"]
|
||||
turn_number = len(question["turns"])
|
||||
data_point = {
|
||||
"id": question["question_id"],
|
||||
"dataset": "mtbench",
|
||||
"split": "test",
|
||||
"category": category,
|
||||
"instruction": question["turns"],
|
||||
"input": "",
|
||||
"output": [],
|
||||
"target": [""] * turn_number
|
||||
if question["question_id"] not in reference
|
||||
else reference[question["question_id"]],
|
||||
}
|
||||
|
||||
if category in dataset["test"]:
|
||||
dataset["test"][category]["data"].append(data_point)
|
||||
else:
|
||||
dataset["test"][category] = {
|
||||
"data": [data_point],
|
||||
"inference_kwargs": copy.deepcopy(default_inference_kwargs),
|
||||
}
|
||||
|
||||
return dataset
|
|
@ -1,12 +1,15 @@
|
|||
import os
|
||||
from typing import Dict, List
|
||||
|
||||
import colossal_eval.evaluate.dataset_evaluator.metrics as metric_helper
|
||||
import numpy as np
|
||||
import tqdm
|
||||
from colossal_eval.utils import jdump
|
||||
|
||||
LabelBasedMetrics = ["first_token_accuracy", "matthews_correlation"]
|
||||
LossBasedMetrics = ["perplexity", "ppl_score", "ppl_score_over_choices", "per_byte_perplexity", "per_byte_ppl_score"]
|
||||
CombinedMetrics = ["combined_single_choice_accuracy"]
|
||||
GPTMetrics = ["mtbench_single_judge"]
|
||||
OtherMetrics = [
|
||||
"f1_score",
|
||||
"f1_zh_score",
|
||||
|
@ -29,8 +32,9 @@ class DatasetEvaluator(object):
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self, config_path: str, save_path: str):
|
||||
self.config_path = config_path
|
||||
self.save_path = save_path
|
||||
|
||||
def _calculate_label_metrics(self, metric: str, category: str):
|
||||
"""Calculate label-based metrics."""
|
||||
|
@ -156,6 +160,24 @@ class DatasetEvaluator(object):
|
|||
self.evaluation_results[metric][category] = (total_score, len(self.data[category]["data"]))
|
||||
self.evaluation_results[metric]["ALL"] += total_score * weight
|
||||
|
||||
def _calculate_gpt_metrics(self, metric: str, category: str):
|
||||
"""Calculate gpt metrics."""
|
||||
weight = len(self.data[category]["data"]) / self.metric_total_length[metric]
|
||||
|
||||
metric_method = eval("gpt_helper." + metric)
|
||||
|
||||
judgements, avg_ratings = metric_method(self.data[category]["data"], self.config_path)
|
||||
self.judgements[category] = judgements
|
||||
|
||||
self.evaluation_results[metric][category] = (np.mean(avg_ratings), len(self.data[category]["data"]))
|
||||
self.evaluation_results[metric]["ALL"] += np.mean(avg_ratings) * weight
|
||||
|
||||
for i in range(avg_ratings.shape[0]):
|
||||
if f"{metric}_{i+1}" not in self.evaluation_results:
|
||||
self.evaluation_results[f"{metric}_{i+1}"] = {cat: 0 for cat in (["ALL"] + self.categories)}
|
||||
self.evaluation_results[f"{metric}_{i+1}"][category] = (avg_ratings[i], len(self.data[category]["data"]))
|
||||
self.evaluation_results[f"{metric}_{i+1}"]["ALL"] += avg_ratings[i] * weight
|
||||
|
||||
def _calculate_loss_metrics(self, metric: str, category: str):
|
||||
"""Calculate perplexity."""
|
||||
if metric == "perplexity":
|
||||
|
@ -217,10 +239,20 @@ class DatasetEvaluator(object):
|
|||
for category in self.suggested_categories[metric]:
|
||||
self._calculate_combined_metrics(metric, category)
|
||||
pbar.update(1)
|
||||
elif metric in GPTMetrics:
|
||||
for category in self.suggested_categories[metric]:
|
||||
self._calculate_gpt_metrics(metric, category)
|
||||
pbar.update(1)
|
||||
elif metric in OtherMetrics:
|
||||
for category in self.suggested_categories[metric]:
|
||||
self._calculate_other_metrics(metric, category)
|
||||
pbar.update(1)
|
||||
else:
|
||||
raise Exception(f"{metric} not supported.")
|
||||
|
||||
if self.judgements:
|
||||
judgement_path = os.path.join(self.save_path, f"{self.model_name}_judgements.json")
|
||||
jdump(self.judgements, judgement_path)
|
||||
|
||||
return self.evaluation_results
|
||||
|
||||
|
@ -240,6 +272,7 @@ class DatasetEvaluator(object):
|
|||
self.model_name = model_name
|
||||
self.categories = list(data.keys())
|
||||
self.metrics = metrics
|
||||
self.judgements = {}
|
||||
|
||||
self.evaluation_results = {
|
||||
metric: {category: 0 for category in (["ALL"] + self.categories)} for metric in self.metrics
|
||||
|
|
|
@ -0,0 +1,151 @@
|
|||
# Code adapted from https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge
|
||||
|
||||
import ast
|
||||
import concurrent.futures
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import numpy as np
|
||||
import openai
|
||||
import tqdm
|
||||
|
||||
MODEL = "gpt-4"
|
||||
|
||||
API_MAX_RETRY = 16
|
||||
API_RETRY_SLEEP = 10
|
||||
API_ERROR_OUTPUT = "$ERROR$"
|
||||
|
||||
NEED_REF_CATS = ["math", "reasoning", "coding"]
|
||||
|
||||
one_score_pattern = re.compile("\[\[(\d+\.?\d*)\]\]")
|
||||
one_score_pattern_backup = re.compile("\[(\d+\.?\d*)\]")
|
||||
|
||||
|
||||
def load_mt_prompts(prompt_file: str):
|
||||
prompts = {}
|
||||
with open(prompt_file) as fin:
|
||||
for line in fin:
|
||||
line = json.loads(line)
|
||||
prompts[line["name"]] = line
|
||||
return prompts
|
||||
|
||||
|
||||
def get_mt_prompt(prompts: Dict[str, str], multiturn: bool, math: bool):
|
||||
if math and multiturn:
|
||||
return prompts["single-math-v1-multi-turn"]
|
||||
elif math and not multiturn:
|
||||
return prompts["single-math-v1"]
|
||||
elif not math and multiturn:
|
||||
return prompts["single-v1-multi-turn"]
|
||||
elif not math and not multiturn:
|
||||
return prompts["single-v1"]
|
||||
|
||||
|
||||
def chat_compeletion_openai(messages: List[Dict], temperature: float = 0.0, max_tokens: int = 2048):
|
||||
output = API_ERROR_OUTPUT
|
||||
model = MODEL
|
||||
for _ in range(API_MAX_RETRY):
|
||||
try:
|
||||
response = openai.ChatCompletion.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
n=1,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
output = response["choices"][0]["message"]["content"]
|
||||
break
|
||||
except openai.error.OpenAIError as e:
|
||||
print(type(e), e)
|
||||
time.sleep(API_RETRY_SLEEP)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def get_mtbench_judgements(question: Dict[str, Any], prompts: Dict[str, str]):
|
||||
id = question["id"]
|
||||
judgement = {"id": id, "judgements": [], "ratings": []}
|
||||
category = question["category"]
|
||||
math = category in NEED_REF_CATS
|
||||
turn_number = len(question["instruction"])
|
||||
|
||||
for num in range(turn_number):
|
||||
assert (len(question["target"]) >= 1 and math) or not math
|
||||
kwargs = {}
|
||||
if num >= 1:
|
||||
prompt = get_mt_prompt(prompts, multiturn=True, math=math)
|
||||
if len(question["target"]) >= 1 and math:
|
||||
kwargs = {f"ref_answer_{i+1}": question["target"][i] for i in range(len(question["target"]))}
|
||||
user_prompt = prompt["prompt_template"].format(
|
||||
question_1=question["instruction"][0],
|
||||
question_2=question["instruction"][1],
|
||||
answer_1=question["output"][0],
|
||||
answer_2=question["output"][1],
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
prompt = get_mt_prompt(prompts, multiturn=False, math=math)
|
||||
if len(question["target"]) >= 1 and math:
|
||||
kwargs = {"ref_answer_1": question["target"][0]}
|
||||
user_prompt = prompt["prompt_template"].format(
|
||||
question=question["instruction"][0],
|
||||
answer=question["output"][0],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
rating = -1
|
||||
sys_prompt = prompt["system_prompt"]
|
||||
messages = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": user_prompt}]
|
||||
|
||||
judgement_str = chat_compeletion_openai(messages, temperature=0.0, max_tokens=2048)
|
||||
match = re.search(one_score_pattern, judgement_str)
|
||||
if not match:
|
||||
match = re.search(one_score_pattern_backup, judgement_str)
|
||||
if match:
|
||||
rating = ast.literal_eval(match.groups()[0])
|
||||
else:
|
||||
rating = -1
|
||||
|
||||
judgement["judgements"].append(judgement_str)
|
||||
judgement["ratings"].append(rating)
|
||||
|
||||
return judgement
|
||||
|
||||
|
||||
def mtbench_single_judge(data: List[Dict], config_path: str):
|
||||
judgements = []
|
||||
|
||||
prompt_dir = os.path.dirname(config_path)
|
||||
prompts = load_mt_prompts(os.path.join(prompt_dir, "mtbench_judge_prompts.jsonl"))
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
|
||||
futures = []
|
||||
for i, question in enumerate(data):
|
||||
future = executor.submit(get_mtbench_judgements, question, prompts)
|
||||
futures.append(future)
|
||||
|
||||
for future in tqdm.tqdm(
|
||||
concurrent.futures.as_completed(futures),
|
||||
desc=f"MTBench single judge for {data[0]['category']}",
|
||||
total=len(futures),
|
||||
):
|
||||
judgements.append(future.result())
|
||||
|
||||
judgements.sort(key=lambda x: x["id"])
|
||||
|
||||
judgements_by_id = {j["id"]: j for j in judgements}
|
||||
|
||||
data_to_dump = copy.deepcopy(data)
|
||||
|
||||
for d in data_to_dump:
|
||||
id = d["id"]
|
||||
d["judgements"] = judgements_by_id[id]["judgements"]
|
||||
d["ratings"] = judgements_by_id[id]["ratings"]
|
||||
|
||||
avg_ratings = np.mean([j["ratings"] for j in judgements], axis=0)
|
||||
|
||||
return data_to_dump, avg_ratings
|
|
@ -185,6 +185,7 @@ metrics4subcategory = {
|
|||
"ppl_score_over_choices": ["ALL"],
|
||||
"ppl_score": ["ALL"],
|
||||
},
|
||||
"mtbench": {"mtbench_single_judge": ["ALL"]},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -333,9 +333,12 @@ class HuggingFaceModel(BaseModel):
|
|||
|
||||
self.str_label_map = {choice: idx for idx, choice in enumerate(self.choices)}
|
||||
|
||||
turn = 0 if not isinstance(data[0]["output"], list) else len(data[0]["output"]) + 1
|
||||
turn_desc = "" if turn == 0 else f"-turn{turn}"
|
||||
|
||||
bar = tqdm(
|
||||
range(math.ceil(len(data) / self.batch_size)),
|
||||
desc=f"{data[0]['dataset']}-{data[0]['category']} Inference steps",
|
||||
desc=f"{data[0]['dataset']}-{data[0]['category']}{turn_desc} Inference steps",
|
||||
disable=not is_rank_0(),
|
||||
)
|
||||
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||
|
@ -384,7 +387,10 @@ class HuggingFaceModel(BaseModel):
|
|||
|
||||
for j in range(len(batch_prompt)):
|
||||
if not pretrain:
|
||||
answers[i + j]["output"] = batch_decodes[j].strip()
|
||||
if isinstance(answers[i + j]["output"], list):
|
||||
answers[i + j]["output"].append(batch_decodes[j].strip())
|
||||
else:
|
||||
answers[i + j]["output"] = batch_decodes[j].strip()
|
||||
|
||||
if isinstance(scores, torch.Tensor):
|
||||
answers[i + j]["softmax_over_choices"] = probs[j]
|
||||
|
|
|
@ -171,6 +171,9 @@ def get_batch_prompt(
|
|||
for b in batch:
|
||||
few_shot_prefix = ""
|
||||
if few_shot_data is not None:
|
||||
assert not isinstance(b["instruction"], list), print(
|
||||
f"When performing few-shot, {b['dataset']} shouldn't be a multiturn dataset."
|
||||
)
|
||||
# For few-shot, only need input. Otherwise use instruction (in AGIEval).
|
||||
query_text = b["input"] if b.get("input", "") != "" else b["instruction"]
|
||||
|
||||
|
@ -181,11 +184,24 @@ def get_batch_prompt(
|
|||
raise Exception("When using few-shot, target answer should be a string.")
|
||||
|
||||
few_shot_prefix = get_few_shot_prefix(conv, few_shot_data, tokenizer, language, max_tokens)
|
||||
else:
|
||||
query_text = b["instruction"] + "\n\n" + b["input"] if b.get("input", "") != "" else b["instruction"]
|
||||
|
||||
conv.append_message(conv.roles[0], few_shot_prefix + query_text)
|
||||
conv.append_message(conv.roles[1], None)
|
||||
conv.append_message(conv.roles[0], few_shot_prefix + query_text)
|
||||
conv.append_message(conv.roles[1], None)
|
||||
else:
|
||||
if not isinstance(b["instruction"], list):
|
||||
query_text = (
|
||||
b["instruction"] + "\n\n" + b["input"] if b.get("input", "") != "" else b["instruction"]
|
||||
)
|
||||
conv.append_message(conv.roles[0], query_text)
|
||||
conv.append_message(conv.roles[1], None)
|
||||
else:
|
||||
assert len(b["instruction"]) >= len(b["output"]) + 1
|
||||
cur_turns = len(b["output"])
|
||||
for turn in range(cur_turns):
|
||||
conv.append_message(conv.roles[0], b["instruction"][turn])
|
||||
conv.append_message(conv.roles[1], b["output"][turn])
|
||||
conv.append_message(conv.roles[0], b["instruction"][cur_turns])
|
||||
conv.append_message(conv.roles[1], None)
|
||||
|
||||
batch_prompt.append(conv.get_prompt())
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ def main(args):
|
|||
|
||||
evaluation_results = {dataset["name"]: {} for dataset in config["dataset"]}
|
||||
evaluation_results_table = {dataset["name"]: {} for dataset in config["dataset"]}
|
||||
evaluator = DatasetEvaluator()
|
||||
evaluator = DatasetEvaluator(args.config, args.evaluation_results_save_path)
|
||||
|
||||
for dataset_parameter in config["dataset"]:
|
||||
dataset_name = dataset_parameter["name"]
|
||||
|
@ -26,6 +26,8 @@ def main(args):
|
|||
results = evaluator.get_evaluation_results(data, dataset_name, model_name, metrics)
|
||||
|
||||
for metric, score in results.items():
|
||||
if metric not in results_metric_model:
|
||||
results_metric_model[metric] = {model["name"]: None for model in config["model"]}
|
||||
results_metric_model[metric][model_name] = score["ALL"]
|
||||
|
||||
evaluation_results[dataset_name][model_name] = results
|
||||
|
|
|
@ -71,6 +71,7 @@ def main(args):
|
|||
inference_data = {}
|
||||
debug_args = {}
|
||||
few_shot_args = {}
|
||||
multiturn_args = {}
|
||||
|
||||
config = utils.jload(args.config)
|
||||
|
||||
|
@ -102,6 +103,13 @@ def main(args):
|
|||
dataset_ = dataset_class(path, logger, dataset_parameter["few_shot"])
|
||||
|
||||
dataset_.save(save_path)
|
||||
|
||||
if hasattr(dataset_, "multiturn") and dataset_.multiturn:
|
||||
multiturn_args[dataset_name] = True
|
||||
logger.info(f"{dataset_parameter['dataset_class']} is a multiturn dataset.")
|
||||
else:
|
||||
multiturn_args[dataset_name] = False
|
||||
|
||||
inference_data[dataset_name] = dataset_.dataset["test"]
|
||||
|
||||
for model_parameter in model_parameters:
|
||||
|
@ -117,7 +125,10 @@ def main(args):
|
|||
|
||||
for dataset_name, split_data in inference_data.items():
|
||||
start = 0
|
||||
prev_questions = None
|
||||
for category, category_data in split_data.items():
|
||||
num_turn = category_data["inference_kwargs"].get("turns", 1)
|
||||
|
||||
if few_shot_args[dataset_name] and category_data["inference_kwargs"].get("few_shot_data", None) is None:
|
||||
raise Exception(f"Dataset {dataset_name} doesn't have few-shot data for category {category}!")
|
||||
|
||||
|
@ -132,11 +143,16 @@ def main(args):
|
|||
|
||||
start = (start + redundant) % world_size
|
||||
|
||||
questions = category_data["data"][sum(lengths[0:rank]) : sum(lengths[0:rank]) + lengths[rank]]
|
||||
for turn in range(num_turn):
|
||||
if turn == 0:
|
||||
questions = category_data["data"][sum(lengths[0:rank]) : sum(lengths[0:rank]) + lengths[rank]]
|
||||
else:
|
||||
questions = prev_questions
|
||||
|
||||
answers_per_rank = model_.inference(
|
||||
questions, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name]
|
||||
)
|
||||
answers_per_rank = model_.inference(
|
||||
questions, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name]
|
||||
)
|
||||
prev_questions = answers_per_rank
|
||||
|
||||
answers_to_dump["data"] = answers_per_rank
|
||||
|
||||
|
|
Loading…
Reference in New Issue