import copy import csv import os from typing import Dict, List from colossalai.logging import DistributedLogger from .base import BaseDataset ceval_subject_mapping = { "computer_network": ["Computer Network", "计算机网络", "STEM"], "operating_system": ["Operating System", "操作系统", "STEM"], "computer_architecture": ["Computer Architecture", "计算机组成", "STEM"], "college_programming": ["College Programming", "大学编程", "STEM"], "college_physics": ["College Physics", "大学物理", "STEM"], "college_chemistry": ["College Chemistry", "大学化学", "STEM"], "advanced_mathematics": ["Advanced Mathematics", "高等数学", "STEM"], "probability_and_statistics": ["Probability and Statistics", "概率统计", "STEM"], "discrete_mathematics": ["Discrete Mathematics", "离散数学", "STEM"], "electrical_engineer": ["Electrical Engineer", "注册电气工程师", "STEM"], "metrology_engineer": ["Metrology Engineer", "注册计量师", "STEM"], "high_school_mathematics": ["High School Mathematics", "高中数学", "STEM"], "high_school_physics": ["High School Physics", "高中物理", "STEM"], "high_school_chemistry": ["High School Chemistry", "高中化学", "STEM"], "high_school_biology": ["High School Biology", "高中生物", "STEM"], "middle_school_mathematics": ["Middle School Mathematics", "初中数学", "STEM"], "middle_school_biology": ["Middle School Biology", "初中生物", "STEM"], "middle_school_physics": ["Middle School Physics", "初中物理", "STEM"], "middle_school_chemistry": ["Middle School Chemistry", "初中化学", "STEM"], "veterinary_medicine": ["Veterinary Medicine", "兽医学", "STEM"], "college_economics": ["College Economics", "大学经济学", "Social Science"], "business_administration": ["Business Administration", "工商管理", "Social Science"], "marxism": ["Marxism", "马克思主义基本原理", "Social Science"], "mao_zedong_thought": ["Mao Zedong Thought", "毛泽东思想和中国特色社会主义理论体系概论", "Social Science"], "education_science": ["Education Science", "教育学", "Social Science"], "teacher_qualification": ["Teacher Qualification", "教师资格", "Social Science"], "high_school_politics": ["High School Politics", "高中政治", "Social Science"], "high_school_geography": ["High School Geography", "高中地理", "Social Science"], "middle_school_politics": ["Middle School Politics", "初中政治", "Social Science"], "middle_school_geography": ["Middle School Geography", "初中地理", "Social Science"], "modern_chinese_history": ["Modern Chinese History", "近代史纲要", "Humanities"], "ideological_and_moral_cultivation": ["Ideological and Moral Cultivation", "思想道德修养与法律基础", "Humanities"], "logic": ["Logic", "逻辑学", "Humanities"], "law": ["Law", "法学", "Humanities"], "chinese_language_and_literature": ["Chinese Language and Literature", "中国语言文学", "Humanities"], "art_studies": ["Art Studies", "艺术学", "Humanities"], "professional_tour_guide": ["Professional Tour Guide", "导游资格", "Humanities"], "legal_professional": ["Legal Professional", "法律职业资格", "Humanities"], "high_school_chinese": ["High School Chinese", "高中语文", "Humanities"], "high_school_history": ["High School History", "高中历史", "Humanities"], "middle_school_history": ["Middle School History", "初中历史", "Humanities"], "civil_servant": ["Civil Servant", "公务员", "Other"], "sports_science": ["Sports Science", "体育学", "Other"], "plant_protection": ["Plant Protection", "植物保护", "Other"], "basic_medicine": ["Basic Medicine", "基础医学", "Other"], "clinical_medicine": ["Clinical Medicine", "临床医学", "Other"], "urban_and_rural_planner": ["Urban and Rural Planner", "注册城乡规划师", "Other"], "accountant": ["Accountant", "注册会计师", "Other"], "fire_engineer": ["Fire Engineer", "注册消防工程师", "Other"], "environmental_impact_assessment_engineer": [ "Environmental Impact Assessment Engineer", "环境影响评价工程师", "Other", ], "tax_accountant": ["Tax Accountant", "税务师", "Other"], "physician": ["Physician", "医师资格", "Other"], } default_inference_kwargs = { "calculate_loss": False, "all_classes": ["A", "B", "C", "D"], "language": "Chinese", "pretrain": False, "max_new_tokens": 32, } def get_few_shot_data(data: List[Dict], subject): few_shot_data = [f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。"] for i in data: few_shot_data.append(i["input"] + i["target"]) return few_shot_data class CEvalDataset(BaseDataset): """ Dataset class for CEval dataset. Data source: https://huggingface.co/datasets/ceval/ceval-exam This dataset class will convert the original dataset into the inference dataset. """ @staticmethod def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]: dataset = {"dev": {}, "test": {}} for split in ["dev", "test"]: files = os.listdir(os.path.join(path, split)) files.sort() for file in files: subject = file[0 : -len(f"_{split}.csv")] subject = ceval_subject_mapping[subject][1] file_dir = os.path.join(path, split, file) dataset[split][subject] = {"data": []} # It's been tested that each data sample in one subcategory have same inference arguments. dataset[split][subject]["inference_kwargs"] = copy.deepcopy(default_inference_kwargs) if split == "test" and few_shot: dataset[split][subject]["inference_kwargs"]["few_shot_data"] = get_few_shot_data( dataset["dev"][subject]["data"], subject ) with open(file_dir, encoding="utf-8") as f: reader = csv.reader(f) _ = next(reader) for row in reader: # Dev split have answer and explanation so len(row) is 8 # But test split doesn't contain answer and explanation, so len(row) is 6 assert len(row) >= 6 choices = f"A. {row[2]}\nB. {row[3]}\nC. {row[4]}\nD. {row[5]}" data_sample = { "dataset": "ceval", "split": split, "category": subject, "instruction": f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。", "input": f"题目:{row[1]}\n{choices}\n答案:", "output": "", "target": row[6] if split == "dev" else "", "id": int(row[0]), } dataset[split][subject]["data"].append(data_sample) return dataset