From 62d193c763360f27d031242a1da0f3696afcdfcc Mon Sep 17 00:00:00 2001 From: yingtongxiong <974106207@qq.com> Date: Tue, 5 Dec 2023 20:58:26 +0800 Subject: [PATCH] add IS_SEQUENCE_PARALLEL check for norm module --- internlm/train/training_internlm.py | 4 +++ internlm/utils/parallel.py | 43 ++++++++++++++++++++++++++++- 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 1e36a21..91d3780 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -53,6 +53,7 @@ from internlm.utils.common import DummyProfile from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.parallel import ( + check_sequence_parallel, set_model_params_layer_name, sync_model_param, sync_model_param_within_tp, @@ -111,6 +112,9 @@ def initialize_model(): # if fsdp enabled, wrap the model model = wrap_FSDP_model(model) + # check whether the norm module has IS_SEQUENCE_PARALLEL attribute + check_sequence_parallel(model) + return model diff --git a/internlm/utils/parallel.py b/internlm/utils/parallel.py index 9b70fc8..7726c77 100644 --- a/internlm/utils/parallel.py +++ b/internlm/utils/parallel.py @@ -4,11 +4,14 @@ import torch.distributed as dist from torch import nn -from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode +from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import global_context as gpc from internlm.core.naive_amp import NaiveAMPModel +from internlm.model.utils import try_import_RMSNorm from internlm.solver.pipeline_utils import partition_uniform +RMSNorm = try_import_RMSNorm() + def is_model_parallel_parameter(p): return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL) @@ -98,3 +101,41 @@ def set_model_params_layer_name(model): layer_param_name = f"{layer_name}-{param_name}" param.__setattr__("layer_name", layer_name) param.__setattr__("param_name", f"{layer_name}-{param_name}") + + +def check_sequence_parallel(model): + """ + check whether the norm module has IS_SEQUENC_PARALLEL attribute. + when the sequence_parallel is True, the norm module should have the IS_SEQUENC_PARALLEL attribute + to illustrate the norm should conduct the all-reduce for its grad. + """ + + if gpc.config.parallel.sequence_parallel is False: + return + + if not isinstance(model, nn.ModuleList): + model = [model] + + for _chunk in model: + if isinstance(_chunk, NaiveAMPModel): + _chunk = _chunk.model + for _, children in _chunk.named_children(): + # import pdb; pdb.set_trace() + if isinstance(children, (RMSNorm, nn.LayerNorm)): + for param in children.parameters(): + assert hasattr( + param, IS_SEQUENCE_PARALLEL + ), ("when the sequence parallel is True," + "the params of norm module should have IS_SEQUENCE_PARALLEL attribute") + continue + elif not isinstance(children, nn.ModuleList): + continue + # transformer block + for _, block in enumerate(children): # iterate transformer blocks + for _, sub in block.named_children(): + if isinstance(sub, (RMSNorm, nn.LayerNorm)): + for param in sub.parameters(): + assert hasattr( + param, IS_SEQUENCE_PARALLEL + ), ("when the sequence parallel is True," + "the params of norm module should have IS_SEQUENCE_PARALLEL attribute")