From 16f8ec2354b57585992fe4e84aae52d727e8f8af Mon Sep 17 00:00:00 2001 From: yingtongxiong <974106207@qq.com> Date: Wed, 6 Dec 2023 12:03:23 +0800 Subject: [PATCH] fix the spell bug and move the sequence judge to training_internlm --- internlm/train/training_internlm.py | 3 ++- internlm/utils/parallel.py | 9 +++------ 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 91d3780..1f68af9 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -113,7 +113,8 @@ def initialize_model(): model = wrap_FSDP_model(model) # check whether the norm module has IS_SEQUENCE_PARALLEL attribute - check_sequence_parallel(model) + if gpc.config.parallel.sequence_parallel is True: + check_sequence_parallel(model) return model diff --git a/internlm/utils/parallel.py b/internlm/utils/parallel.py index 908aa80..e6bb18f 100644 --- a/internlm/utils/parallel.py +++ b/internlm/utils/parallel.py @@ -105,14 +105,11 @@ def set_model_params_layer_name(model): def check_sequence_parallel(model): """ - check whether the norm module has IS_SEQUENC_PARALLEL attribute. - when the sequence_parallel is True, the norm module should have the IS_SEQUENC_PARALLEL attribute + check whether the norm module has IS_SEQUENCE_PARALLEL attribute. + when the sequence_parallel is True, the norm module should have the IS_SEQUENCE_PARALLEL attribute to illustrate the norm should conduct the all-reduce for its grad. """ - if gpc.config.parallel.sequence_parallel is False: - return - if not isinstance(model, nn.ModuleList): model = [model] @@ -124,6 +121,6 @@ def check_sequence_parallel(model): if isinstance(module, (RMSNorm, nn.LayerNorm)): for param in module.parameters(): assert hasattr(param, IS_SEQUENCE_PARALLEL), ( - "when the sequence parallel is True," + "when the gpc.config.parallel.sequence parallel is True," "the params of norm module should have IS_SEQUENCE_PARALLEL attribute" )