diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index dd8e190..7010ca5 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -33,7 +33,7 @@ from internlm.data.packed_dataset import ( PackedDatasetWithoutCuSeqlen, get_packed_dataset_without_short_length, ) -from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data +from internlm.data.utils import get_dataset_type_ids_map, unpack_data from internlm.model.embedding import Embedding1D from internlm.model.linear import ( FeedForward, @@ -133,7 +133,7 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]): # wrap the model grp = gpc.get_group(ParallelMode.ZERO1) - model = FSDP( + model = FSDP( # pylint: disable=unexpected-keyword-arg module=model, process_group=grp, sharding_strategy=ShardingStrategy.FULL_SHARD, @@ -219,11 +219,11 @@ def get_train_data_loader( # Get the dataset types dataset_types = None - dataset_types = list(DATASET_TYPE_IDS_MAP.keys()) data_cfg = gpc.config.data # Get the sample weight dictionary train_folder = data_cfg.train_folder + dataset_types = list(get_dataset_type_ids_map(train_folder).keys()) if not train_folder: train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len)