fix a bug

pull/477/head
gaoyang07 2023-11-07 17:38:46 +08:00
parent 61f953bb7b
commit 2f1812e8c7
1 changed files with 3 additions and 3 deletions

View File

@ -33,7 +33,7 @@ from internlm.data.packed_dataset import (
PackedDatasetWithoutCuSeqlen,
get_packed_dataset_without_short_length,
)
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
from internlm.data.utils import get_dataset_type_ids_map, unpack_data
from internlm.model.embedding import Embedding1D
from internlm.model.linear import (
FeedForward,
@ -133,7 +133,7 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
# wrap the model
grp = gpc.get_group(ParallelMode.ZERO1)
model = FSDP(
model = FSDP( # pylint: disable=unexpected-keyword-arg
module=model,
process_group=grp,
sharding_strategy=ShardingStrategy.FULL_SHARD,
@ -219,11 +219,11 @@ def get_train_data_loader(
# Get the dataset types
dataset_types = None
dataset_types = list(DATASET_TYPE_IDS_MAP.keys())
data_cfg = gpc.config.data
# Get the sample weight dictionary
train_folder = data_cfg.train_folder
dataset_types = list(get_dataset_type_ids_map(train_folder).keys())
if not train_folder:
train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len)