From 61f953bb7bd264c6bc78d39ff490e38ad360f5b2 Mon Sep 17 00:00:00 2001 From: gaoyang07 Date: Tue, 7 Nov 2023 17:13:45 +0800 Subject: [PATCH] walk folder to get dataset_type_ids_map --- internlm/data/packed_dataset.py | 6 ++++-- internlm/data/utils.py | 17 +++++++++++------ 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/internlm/data/packed_dataset.py b/internlm/data/packed_dataset.py index 576862e..af4c34a 100644 --- a/internlm/data/packed_dataset.py +++ b/internlm/data/packed_dataset.py @@ -15,7 +15,7 @@ from tqdm import tqdm from internlm.core.context import global_context as gpc from internlm.data.single_dataset import JsonlDataset -from internlm.data.utils import get_dataset_type_id +from internlm.data.utils import get_dataset_type_id, get_dataset_type_ids_map from internlm.utils.logger import get_logger DEFAULT_SEED = 1024 @@ -373,6 +373,8 @@ def get_packed_dataset_without_short_length( datasets = [] delete_samples = 0 + DATASET_TYPE_IDS_MAP = get_dataset_type_ids_map(folder) + if gpc.get_global_rank() == 0: triples = [list(os.walk(folder, followlinks=True))] else: @@ -400,7 +402,7 @@ def get_packed_dataset_without_short_length( len(catch_ml_keys) < 2 ), f"The file name `{fp}` matched the following resample keys:{catch_ml_keys}" - ds_type_id = get_dataset_type_id(path=fp) + ds_type_id = get_dataset_type_id(DATASET_TYPE_IDS_MAP, path=fp) ds = JsonlDataset(fp, ds_type_id, min_length=min_length_num) if hasattr(ds, "old_length"): diff --git a/internlm/data/utils.py b/internlm/data/utils.py index 724fb9f..fbcb6f7 100644 --- a/internlm/data/utils.py +++ b/internlm/data/utils.py @@ -1,21 +1,26 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import os +import re import torch from internlm.core.context import global_context as gpc -DATASET_TYPE_IDS_MAP = {"en": 0, "cn": 1, "code": 2} + +def get_dataset_type_ids_map(path): + dirlist = list(os.listdir(path)) + dirlist.sort() + return {key: idx for idx, key in enumerate(dirlist)} -def get_dataset_type_id(path): - import re - +def get_dataset_type_id(dataset_type_ids_map, path): match_idxes = [] - for key, idx in DATASET_TYPE_IDS_MAP.items(): + + for key, idx in dataset_type_ids_map.items(): if re.search(rf"/[z_]*{key}/", path): match_idxes.append(idx) - assert len(match_idxes) == 1, f"{path}, match_idxes should be 1, but got {match_idxes} from {DATASET_TYPE_IDS_MAP}" + assert len(match_idxes) == 1, f"{path}, match_idxes should be 1, but got {match_idxes} from {dataset_type_ids_map}" return match_idxes[0]