From c581cc4c02dca892bb73d6b8e013c4d147cc9407 Mon Sep 17 00:00:00 2001 From: ytxiong <45058324+yingtongxiong@users.noreply.github.com> Date: Wed, 6 Dec 2023 12:06:22 +0800 Subject: [PATCH] fix(model): add IS_SEQUENCE_PARALLEL check for norm module (#528) * add IS_SEQUENCE_PARALLEL check for norm module * fix lint * remove comments * replace the named_children by named_modules * fix lint * fix the spell bug and move the sequence judge to training_internlm --- internlm/train/training_internlm.py | 5 +++++ internlm/utils/parallel.py | 28 +++++++++++++++++++++++++++- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 1e36a21..1f68af9 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -53,6 +53,7 @@ from internlm.utils.common import DummyProfile from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.parallel import ( + check_sequence_parallel, set_model_params_layer_name, sync_model_param, sync_model_param_within_tp, @@ -111,6 +112,10 @@ def initialize_model(): # if fsdp enabled, wrap the model model = wrap_FSDP_model(model) + # check whether the norm module has IS_SEQUENCE_PARALLEL attribute + if gpc.config.parallel.sequence_parallel is True: + check_sequence_parallel(model) + return model diff --git a/internlm/utils/parallel.py b/internlm/utils/parallel.py index 9b70fc8..e6bb18f 100644 --- a/internlm/utils/parallel.py +++ b/internlm/utils/parallel.py @@ -4,11 +4,14 @@ import torch.distributed as dist from torch import nn -from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode +from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import global_context as gpc from internlm.core.naive_amp import NaiveAMPModel +from internlm.model.utils import try_import_RMSNorm from internlm.solver.pipeline_utils import partition_uniform +RMSNorm = try_import_RMSNorm() + def is_model_parallel_parameter(p): return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL) @@ -98,3 +101,26 @@ def set_model_params_layer_name(model): layer_param_name = f"{layer_name}-{param_name}" param.__setattr__("layer_name", layer_name) param.__setattr__("param_name", f"{layer_name}-{param_name}") + + +def check_sequence_parallel(model): + """ + check whether the norm module has IS_SEQUENCE_PARALLEL attribute. + when the sequence_parallel is True, the norm module should have the IS_SEQUENCE_PARALLEL attribute + to illustrate the norm should conduct the all-reduce for its grad. + """ + + if not isinstance(model, nn.ModuleList): + model = [model] + + for _chunk in model: + if isinstance(_chunk, NaiveAMPModel): + _chunk = _chunk.model + + for _, module in _chunk.named_modules(): + if isinstance(module, (RMSNorm, nn.LayerNorm)): + for param in module.parameters(): + assert hasattr(param, IS_SEQUENCE_PARALLEL), ( + "when the gpc.config.parallel.sequence parallel is True," + "the params of norm module should have IS_SEQUENCE_PARALLEL attribute" + )