diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 1e36a21..1f68af9 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -53,6 +53,7 @@ from internlm.utils.common import DummyProfile from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.parallel import ( + check_sequence_parallel, set_model_params_layer_name, sync_model_param, sync_model_param_within_tp, @@ -111,6 +112,10 @@ def initialize_model(): # if fsdp enabled, wrap the model model = wrap_FSDP_model(model) + # check whether the norm module has IS_SEQUENCE_PARALLEL attribute + if gpc.config.parallel.sequence_parallel is True: + check_sequence_parallel(model) + return model diff --git a/internlm/utils/parallel.py b/internlm/utils/parallel.py index 9b70fc8..e6bb18f 100644 --- a/internlm/utils/parallel.py +++ b/internlm/utils/parallel.py @@ -4,11 +4,14 @@ import torch.distributed as dist from torch import nn -from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode +from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import global_context as gpc from internlm.core.naive_amp import NaiveAMPModel +from internlm.model.utils import try_import_RMSNorm from internlm.solver.pipeline_utils import partition_uniform +RMSNorm = try_import_RMSNorm() + def is_model_parallel_parameter(p): return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL) @@ -98,3 +101,26 @@ def set_model_params_layer_name(model): layer_param_name = f"{layer_name}-{param_name}" param.__setattr__("layer_name", layer_name) param.__setattr__("param_name", f"{layer_name}-{param_name}") + + +def check_sequence_parallel(model): + """ + check whether the norm module has IS_SEQUENCE_PARALLEL attribute. + when the sequence_parallel is True, the norm module should have the IS_SEQUENCE_PARALLEL attribute + to illustrate the norm should conduct the all-reduce for its grad. + """ + + if not isinstance(model, nn.ModuleList): + model = [model] + + for _chunk in model: + if isinstance(_chunk, NaiveAMPModel): + _chunk = _chunk.model + + for _, module in _chunk.named_modules(): + if isinstance(module, (RMSNorm, nn.LayerNorm)): + for param in module.parameters(): + assert hasattr(param, IS_SEQUENCE_PARALLEL), ( + "when the gpc.config.parallel.sequence parallel is True," + "the params of norm module should have IS_SEQUENCE_PARALLEL attribute" + )