mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
134 lines
7.2 KiB
134 lines
7.2 KiB
import copy |
|
import csv |
|
import os |
|
from typing import Dict, List |
|
|
|
from colossalai.logging import DistributedLogger |
|
|
|
from .base import BaseDataset |
|
|
|
ceval_subject_mapping = { |
|
"computer_network": ["Computer Network", "计算机网络", "STEM"], |
|
"operating_system": ["Operating System", "操作系统", "STEM"], |
|
"computer_architecture": ["Computer Architecture", "计算机组成", "STEM"], |
|
"college_programming": ["College Programming", "大学编程", "STEM"], |
|
"college_physics": ["College Physics", "大学物理", "STEM"], |
|
"college_chemistry": ["College Chemistry", "大学化学", "STEM"], |
|
"advanced_mathematics": ["Advanced Mathematics", "高等数学", "STEM"], |
|
"probability_and_statistics": ["Probability and Statistics", "概率统计", "STEM"], |
|
"discrete_mathematics": ["Discrete Mathematics", "离散数学", "STEM"], |
|
"electrical_engineer": ["Electrical Engineer", "注册电气工程师", "STEM"], |
|
"metrology_engineer": ["Metrology Engineer", "注册计量师", "STEM"], |
|
"high_school_mathematics": ["High School Mathematics", "高中数学", "STEM"], |
|
"high_school_physics": ["High School Physics", "高中物理", "STEM"], |
|
"high_school_chemistry": ["High School Chemistry", "高中化学", "STEM"], |
|
"high_school_biology": ["High School Biology", "高中生物", "STEM"], |
|
"middle_school_mathematics": ["Middle School Mathematics", "初中数学", "STEM"], |
|
"middle_school_biology": ["Middle School Biology", "初中生物", "STEM"], |
|
"middle_school_physics": ["Middle School Physics", "初中物理", "STEM"], |
|
"middle_school_chemistry": ["Middle School Chemistry", "初中化学", "STEM"], |
|
"veterinary_medicine": ["Veterinary Medicine", "兽医学", "STEM"], |
|
"college_economics": ["College Economics", "大学经济学", "Social Science"], |
|
"business_administration": ["Business Administration", "工商管理", "Social Science"], |
|
"marxism": ["Marxism", "马克思主义基本原理", "Social Science"], |
|
"mao_zedong_thought": ["Mao Zedong Thought", "毛泽东思想和中国特色社会主义理论体系概论", "Social Science"], |
|
"education_science": ["Education Science", "教育学", "Social Science"], |
|
"teacher_qualification": ["Teacher Qualification", "教师资格", "Social Science"], |
|
"high_school_politics": ["High School Politics", "高中政治", "Social Science"], |
|
"high_school_geography": ["High School Geography", "高中地理", "Social Science"], |
|
"middle_school_politics": ["Middle School Politics", "初中政治", "Social Science"], |
|
"middle_school_geography": ["Middle School Geography", "初中地理", "Social Science"], |
|
"modern_chinese_history": ["Modern Chinese History", "近代史纲要", "Humanities"], |
|
"ideological_and_moral_cultivation": ["Ideological and Moral Cultivation", "思想道德修养与法律基础", "Humanities"], |
|
"logic": ["Logic", "逻辑学", "Humanities"], |
|
"law": ["Law", "法学", "Humanities"], |
|
"chinese_language_and_literature": ["Chinese Language and Literature", "中国语言文学", "Humanities"], |
|
"art_studies": ["Art Studies", "艺术学", "Humanities"], |
|
"professional_tour_guide": ["Professional Tour Guide", "导游资格", "Humanities"], |
|
"legal_professional": ["Legal Professional", "法律职业资格", "Humanities"], |
|
"high_school_chinese": ["High School Chinese", "高中语文", "Humanities"], |
|
"high_school_history": ["High School History", "高中历史", "Humanities"], |
|
"middle_school_history": ["Middle School History", "初中历史", "Humanities"], |
|
"civil_servant": ["Civil Servant", "公务员", "Other"], |
|
"sports_science": ["Sports Science", "体育学", "Other"], |
|
"plant_protection": ["Plant Protection", "植物保护", "Other"], |
|
"basic_medicine": ["Basic Medicine", "基础医学", "Other"], |
|
"clinical_medicine": ["Clinical Medicine", "临床医学", "Other"], |
|
"urban_and_rural_planner": ["Urban and Rural Planner", "注册城乡规划师", "Other"], |
|
"accountant": ["Accountant", "注册会计师", "Other"], |
|
"fire_engineer": ["Fire Engineer", "注册消防工程师", "Other"], |
|
"environmental_impact_assessment_engineer": ["Environmental Impact Assessment Engineer", "环境影响评价工程师", "Other"], |
|
"tax_accountant": ["Tax Accountant", "税务师", "Other"], |
|
"physician": ["Physician", "医师资格", "Other"], |
|
} |
|
|
|
default_inference_kwargs = { |
|
"calculate_loss": False, |
|
"all_classes": ["A", "B", "C", "D"], |
|
"language": "Chinese", |
|
"pretrain": False, |
|
"max_new_tokens": 32, |
|
} |
|
|
|
|
|
def get_few_shot_data(data: List[Dict], subject): |
|
few_shot_data = [f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。"] |
|
for i in data: |
|
few_shot_data.append(i["input"] + i["target"]) |
|
return few_shot_data |
|
|
|
|
|
class CEvalDataset(BaseDataset): |
|
""" |
|
Dataset class for CEval dataset. |
|
Data source: https://huggingface.co/datasets/ceval/ceval-exam |
|
This dataset class will convert the original dataset into the inference dataset. |
|
""" |
|
|
|
@staticmethod |
|
def load( |
|
path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool |
|
) -> List[Dict]: |
|
dataset = {"dev": {}, "test": {}} |
|
for split in ["dev", "test"]: |
|
files = os.listdir(os.path.join(path, split)) |
|
files.sort() |
|
|
|
for file in files: |
|
subject = file[0 : -len(f"_{split}.csv")] |
|
subject = ceval_subject_mapping[subject][1] |
|
|
|
file_dir = os.path.join(path, split, file) |
|
|
|
dataset[split][subject] = {"data": []} |
|
|
|
# It's been tested that each data sample in one subcategory have same inference arguments. |
|
dataset[split][subject]["inference_kwargs"] = copy.deepcopy(default_inference_kwargs) |
|
|
|
if split == "test" and few_shot: |
|
dataset[split][subject]["inference_kwargs"]["few_shot_data"] = get_few_shot_data( |
|
dataset["dev"][subject]["data"], subject |
|
) |
|
|
|
with open(file_dir, encoding="utf-8") as f: |
|
reader = csv.reader(f) |
|
_ = next(reader) |
|
for row in reader: |
|
# Dev split have answer and explanation so len(row) is 8 |
|
# But test split doesn't contain answer and explanation, so len(row) is 6 |
|
assert len(row) >= 6 |
|
choices = f"A. {row[2]}\nB. {row[3]}\nC. {row[4]}\nD. {row[5]}" |
|
data_sample = { |
|
"dataset": "ceval", |
|
"split": split, |
|
"category": subject, |
|
"instruction": f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。", |
|
"input": f"题目:{row[1]}\n{choices}\n答案:", |
|
"output": "", |
|
"target": row[6] if split == "dev" else "", |
|
"id": int(row[0]), |
|
} |
|
|
|
dataset[split][subject]["data"].append(data_sample) |
|
|
|
return dataset
|
|
|