You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/applications/ColossalEval/colossal_eval/dataset/ceval.py

133 lines
7.0 KiB

import copy
import csv
import os
from typing import Dict, List
from colossalai.logging import DistributedLogger
from .base import BaseDataset
ceval_subject_mapping = {
"computer_network": ["Computer Network", "计算机网络", "STEM"],
"operating_system": ["Operating System", "操作系统", "STEM"],
"computer_architecture": ["Computer Architecture", "计算机组成", "STEM"],
"college_programming": ["College Programming", "大学编程", "STEM"],
"college_physics": ["College Physics", "大学物理", "STEM"],
"college_chemistry": ["College Chemistry", "大学化学", "STEM"],
"advanced_mathematics": ["Advanced Mathematics", "高等数学", "STEM"],
"probability_and_statistics": ["Probability and Statistics", "概率统计", "STEM"],
"discrete_mathematics": ["Discrete Mathematics", "离散数学", "STEM"],
"electrical_engineer": ["Electrical Engineer", "注册电气工程师", "STEM"],
"metrology_engineer": ["Metrology Engineer", "注册计量师", "STEM"],
"high_school_mathematics": ["High School Mathematics", "高中数学", "STEM"],
"high_school_physics": ["High School Physics", "高中物理", "STEM"],
"high_school_chemistry": ["High School Chemistry", "高中化学", "STEM"],
"high_school_biology": ["High School Biology", "高中生物", "STEM"],
"middle_school_mathematics": ["Middle School Mathematics", "初中数学", "STEM"],
"middle_school_biology": ["Middle School Biology", "初中生物", "STEM"],
"middle_school_physics": ["Middle School Physics", "初中物理", "STEM"],
"middle_school_chemistry": ["Middle School Chemistry", "初中化学", "STEM"],
"veterinary_medicine": ["Veterinary Medicine", "兽医学", "STEM"],
"college_economics": ["College Economics", "大学经济学", "Social Science"],
"business_administration": ["Business Administration", "工商管理", "Social Science"],
"marxism": ["Marxism", "马克思主义基本原理", "Social Science"],
"mao_zedong_thought": ["Mao Zedong Thought", "毛泽东思想和中国特色社会主义理论体系概论", "Social Science"],
"education_science": ["Education Science", "教育学", "Social Science"],
"teacher_qualification": ["Teacher Qualification", "教师资格", "Social Science"],
"high_school_politics": ["High School Politics", "高中政治", "Social Science"],
"high_school_geography": ["High School Geography", "高中地理", "Social Science"],
"middle_school_politics": ["Middle School Politics", "初中政治", "Social Science"],
"middle_school_geography": ["Middle School Geography", "初中地理", "Social Science"],
"modern_chinese_history": ["Modern Chinese History", "近代史纲要", "Humanities"],
"ideological_and_moral_cultivation": ["Ideological and Moral Cultivation", "思想道德修养与法律基础", "Humanities"],
"logic": ["Logic", "逻辑学", "Humanities"],
"law": ["Law", "法学", "Humanities"],
"chinese_language_and_literature": ["Chinese Language and Literature", "中国语言文学", "Humanities"],
"art_studies": ["Art Studies", "艺术学", "Humanities"],
"professional_tour_guide": ["Professional Tour Guide", "导游资格", "Humanities"],
"legal_professional": ["Legal Professional", "法律职业资格", "Humanities"],
"high_school_chinese": ["High School Chinese", "高中语文", "Humanities"],
"high_school_history": ["High School History", "高中历史", "Humanities"],
"middle_school_history": ["Middle School History", "初中历史", "Humanities"],
"civil_servant": ["Civil Servant", "公务员", "Other"],
"sports_science": ["Sports Science", "体育学", "Other"],
"plant_protection": ["Plant Protection", "植物保护", "Other"],
"basic_medicine": ["Basic Medicine", "基础医学", "Other"],
"clinical_medicine": ["Clinical Medicine", "临床医学", "Other"],
"urban_and_rural_planner": ["Urban and Rural Planner", "注册城乡规划师", "Other"],
"accountant": ["Accountant", "注册会计师", "Other"],
"fire_engineer": ["Fire Engineer", "注册消防工程师", "Other"],
"environmental_impact_assessment_engineer": ["Environmental Impact Assessment Engineer", "环境影响评价工程师", "Other"],
"tax_accountant": ["Tax Accountant", "税务师", "Other"],
"physician": ["Physician", "医师资格", "Other"],
}
default_inference_kwargs = {
"calculate_loss": False,
"all_classes": ["A", "B", "C", "D"],
"language": "Chinese",
"pretrain": False,
"max_new_tokens": 32,
}
def get_few_shot_data(data: List[Dict]):
few_shot_data = []
for i in data:
few_shot_data.append(i["input"] + i["target"])
return few_shot_data
class CEvalDataset(BaseDataset):
"""
Dataset class for CEval dataset.
Data source: https://huggingface.co/datasets/ceval/ceval-exam
This dataset class will convert the original dataset into the inference dataset.
"""
@staticmethod
def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
dataset = {"dev": {}, "test": {}}
for split in ["dev", "test"]:
files = os.listdir(os.path.join(path, split))
files.sort()
for file in files:
subject = file[0 : -len(f"_{split}.csv")]
subject = ceval_subject_mapping[subject][1]
file_dir = os.path.join(path, split, file)
dataset[split][subject] = {"data": []}
# It's been tested that each data sample in one subcategory have same inference arguments.
dataset[split][subject]["inference_kwargs"] = copy.deepcopy(default_inference_kwargs)
if split == "test" and few_shot:
dataset[split][subject]["inference_kwargs"]["few_shot_data"] = get_few_shot_data(
dataset["dev"][subject]["data"]
)
with open(file_dir, encoding="utf-8") as f:
reader = csv.reader(f)
_ = next(reader)
for row in reader:
# Dev split have answer and explanation so len(row) is 8
# But test split doesn't contain answer and explanation, so len(row) is 6
assert len(row) >= 6
choices = f"A. {row[2]}\nB. {row[3]}\nC. {row[4]}\nD. {row[5]}"
data_sample = {
"dataset": "ceval",
"split": split,
"category": subject,
"instruction": f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。",
"input": f"题目:{row[1]}\n{choices}\n答案:",
"output": "",
"target": row[6] if split == "dev" else "",
"id": int(row[0]),
}
dataset[split][subject]["data"].append(data_sample)
return dataset