diff --git a/internlm/data/packed_dataset.py b/internlm/data/packed_dataset.py index 576862e..af4c34a 100644 --- a/internlm/data/packed_dataset.py +++ b/internlm/data/packed_dataset.py @@ -15,7 +15,7 @@ from tqdm import tqdm from internlm.core.context import global_context as gpc from internlm.data.single_dataset import JsonlDataset -from internlm.data.utils import get_dataset_type_id +from internlm.data.utils import get_dataset_type_id, get_dataset_type_ids_map from internlm.utils.logger import get_logger DEFAULT_SEED = 1024 @@ -373,6 +373,8 @@ def get_packed_dataset_without_short_length( datasets = [] delete_samples = 0 + DATASET_TYPE_IDS_MAP = get_dataset_type_ids_map(folder) + if gpc.get_global_rank() == 0: triples = [list(os.walk(folder, followlinks=True))] else: @@ -400,7 +402,7 @@ def get_packed_dataset_without_short_length( len(catch_ml_keys) < 2 ), f"The file name `{fp}` matched the following resample keys:{catch_ml_keys}" - ds_type_id = get_dataset_type_id(path=fp) + ds_type_id = get_dataset_type_id(DATASET_TYPE_IDS_MAP, path=fp) ds = JsonlDataset(fp, ds_type_id, min_length=min_length_num) if hasattr(ds, "old_length"): diff --git a/internlm/data/utils.py b/internlm/data/utils.py index 724fb9f..fbcb6f7 100644 --- a/internlm/data/utils.py +++ b/internlm/data/utils.py @@ -1,21 +1,26 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import os +import re import torch from internlm.core.context import global_context as gpc -DATASET_TYPE_IDS_MAP = {"en": 0, "cn": 1, "code": 2} + +def get_dataset_type_ids_map(path): + dirlist = list(os.listdir(path)) + dirlist.sort() + return {key: idx for idx, key in enumerate(dirlist)} -def get_dataset_type_id(path): - import re - +def get_dataset_type_id(dataset_type_ids_map, path): match_idxes = [] - for key, idx in DATASET_TYPE_IDS_MAP.items(): + + for key, idx in dataset_type_ids_map.items(): if re.search(rf"/[z_]*{key}/", path): match_idxes.append(idx) - assert len(match_idxes) == 1, f"{path}, match_idxes should be 1, but got {match_idxes} from {DATASET_TYPE_IDS_MAP}" + assert len(match_idxes) == 1, f"{path}, match_idxes should be 1, but got {match_idxes} from {dataset_type_ids_map}" return match_idxes[0] diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index dd8e190..7010ca5 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -33,7 +33,7 @@ from internlm.data.packed_dataset import ( PackedDatasetWithoutCuSeqlen, get_packed_dataset_without_short_length, ) -from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data +from internlm.data.utils import get_dataset_type_ids_map, unpack_data from internlm.model.embedding import Embedding1D from internlm.model.linear import ( FeedForward, @@ -133,7 +133,7 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]): # wrap the model grp = gpc.get_group(ParallelMode.ZERO1) - model = FSDP( + model = FSDP( # pylint: disable=unexpected-keyword-arg module=model, process_group=grp, sharding_strategy=ShardingStrategy.FULL_SHARD, @@ -219,11 +219,11 @@ def get_train_data_loader( # Get the dataset types dataset_types = None - dataset_types = list(DATASET_TYPE_IDS_MAP.keys()) data_cfg = gpc.config.data # Get the sample weight dictionary train_folder = data_cfg.train_folder + dataset_types = list(get_dataset_type_ids_map(train_folder).keys()) if not train_folder: train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len)