mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
99 lines
4.1 KiB
99 lines
4.1 KiB
import argparse
|
|
import json
|
|
import os
|
|
|
|
import openai
|
|
from evaluator import Evaluator
|
|
from utils import jload
|
|
|
|
|
|
def main(args):
|
|
assert len(args.answer_file_list) == len(
|
|
args.model_name_list), "The number of answer files and model names should be equal!"
|
|
|
|
# load config
|
|
config = jload(args.config_file)
|
|
|
|
if config["language"] == "cn":
|
|
# get metric settings for all categories
|
|
metrics_per_category = {}
|
|
for category in config["category"].keys():
|
|
metrics_all = {}
|
|
for metric_type, metrics in config["category"][category].items():
|
|
metrics_all[metric_type] = metrics
|
|
metrics_per_category[category] = metrics_all
|
|
|
|
battle_prompt = None
|
|
if args.battle_prompt_file:
|
|
battle_prompt = jload(args.battle_prompt_file)
|
|
|
|
gpt_evaluation_prompt = None
|
|
if args.gpt_evaluation_prompt_file:
|
|
gpt_evaluation_prompt = jload(args.gpt_evaluation_prompt_file)
|
|
|
|
if len(args.model_name_list) == 2 and not battle_prompt:
|
|
raise Exception("No prompt file for battle provided. Please specify the prompt file for battle!")
|
|
|
|
if len(args.model_name_list) == 1 and not gpt_evaluation_prompt:
|
|
raise Exception(
|
|
"No prompt file for gpt evaluation provided. Please specify the prompt file for gpt evaluation!")
|
|
|
|
# initialize evaluator
|
|
evaluator = Evaluator(metrics_per_category, battle_prompt, gpt_evaluation_prompt)
|
|
if len(args.model_name_list) == 2:
|
|
answers1 = jload(args.answer_file_list[0])
|
|
answers2 = jload(args.answer_file_list[1])
|
|
|
|
assert len(answers1) == len(answers2), "The number of answers for two models should be equal!"
|
|
|
|
evaluator.battle(answers1=answers1, answers2=answers2)
|
|
evaluator.save(args.save_path, args.model_name_list)
|
|
elif len(args.model_name_list) == 1:
|
|
targets = jload(args.target_file)
|
|
answers = jload(args.answer_file_list[0])
|
|
|
|
assert len(targets) == len(answers), "The number of target answers and model answers should be equal!"
|
|
|
|
evaluator.evaluate(answers=answers, targets=targets)
|
|
evaluator.save(args.save_path, args.model_name_list)
|
|
else:
|
|
raise ValueError("Unsupported number of answer files and model names!")
|
|
else:
|
|
raise ValueError(f'Unsupported language {config["language"]}!')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description='ColossalAI LLM evaluation pipeline.')
|
|
parser.add_argument('--config_file',
|
|
type=str,
|
|
default=None,
|
|
required=True,
|
|
help='path to the file of target results')
|
|
parser.add_argument('--battle_prompt_file', type=str, default=None, help='path to the prompt file for battle')
|
|
parser.add_argument('--gpt_evaluation_prompt_file',
|
|
type=str,
|
|
default=None,
|
|
help='path to the prompt file for gpt evaluation')
|
|
parser.add_argument('--target_file', type=str, default=None, help='path to the target answer (ground truth) file')
|
|
parser.add_argument('--answer_file_list',
|
|
type=str,
|
|
nargs='+',
|
|
default=[],
|
|
required=True,
|
|
help='path to the answer files of at most 2 models')
|
|
parser.add_argument('--model_name_list',
|
|
type=str,
|
|
nargs='+',
|
|
default=[],
|
|
required=True,
|
|
help='the names of at most 2 models')
|
|
parser.add_argument('--save_path', type=str, default="results", help='path to save evaluation results')
|
|
parser.add_argument('--openai_key', type=str, default=None, required=True, help='Your openai key')
|
|
args = parser.parse_args()
|
|
|
|
if args.openai_key is not None:
|
|
os.environ["OPENAI_API_KEY"] = args.openai_key
|
|
openai.api_key = os.getenv("OPENAI_API_KEY")
|
|
|
|
main(args)
|