mirror of https://github.com/InternLM/InternLM
fix a bug
parent
61f953bb7b
commit
2f1812e8c7
|
@ -33,7 +33,7 @@ from internlm.data.packed_dataset import (
|
||||||
PackedDatasetWithoutCuSeqlen,
|
PackedDatasetWithoutCuSeqlen,
|
||||||
get_packed_dataset_without_short_length,
|
get_packed_dataset_without_short_length,
|
||||||
)
|
)
|
||||||
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
|
from internlm.data.utils import get_dataset_type_ids_map, unpack_data
|
||||||
from internlm.model.embedding import Embedding1D
|
from internlm.model.embedding import Embedding1D
|
||||||
from internlm.model.linear import (
|
from internlm.model.linear import (
|
||||||
FeedForward,
|
FeedForward,
|
||||||
|
@ -133,7 +133,7 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
||||||
|
|
||||||
# wrap the model
|
# wrap the model
|
||||||
grp = gpc.get_group(ParallelMode.ZERO1)
|
grp = gpc.get_group(ParallelMode.ZERO1)
|
||||||
model = FSDP(
|
model = FSDP( # pylint: disable=unexpected-keyword-arg
|
||||||
module=model,
|
module=model,
|
||||||
process_group=grp,
|
process_group=grp,
|
||||||
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
||||||
|
@ -219,11 +219,11 @@ def get_train_data_loader(
|
||||||
|
|
||||||
# Get the dataset types
|
# Get the dataset types
|
||||||
dataset_types = None
|
dataset_types = None
|
||||||
dataset_types = list(DATASET_TYPE_IDS_MAP.keys())
|
|
||||||
data_cfg = gpc.config.data
|
data_cfg = gpc.config.data
|
||||||
|
|
||||||
# Get the sample weight dictionary
|
# Get the sample weight dictionary
|
||||||
train_folder = data_cfg.train_folder
|
train_folder = data_cfg.train_folder
|
||||||
|
dataset_types = list(get_dataset_type_ids_map(train_folder).keys())
|
||||||
|
|
||||||
if not train_folder:
|
if not train_folder:
|
||||||
train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len)
|
train_ds = RandomDataset(num_samples=1000000, max_len=data_cfg.seq_len)
|
||||||
|
|
Loading…
Reference in New Issue