add IS_SEQUENCE_PARALLEL check for norm module

pull/528/head
yingtongxiong 2023-12-05 20:58:26 +08:00
parent 2dbbab7418
commit 62d193c763
2 changed files with 46 additions and 1 deletions

View File

@ -53,6 +53,7 @@ from internlm.utils.common import DummyProfile
from internlm.utils.logger import get_logger from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.parallel import ( from internlm.utils.parallel import (
check_sequence_parallel,
set_model_params_layer_name, set_model_params_layer_name,
sync_model_param, sync_model_param,
sync_model_param_within_tp, sync_model_param_within_tp,
@ -111,6 +112,9 @@ def initialize_model():
# if fsdp enabled, wrap the model # if fsdp enabled, wrap the model
model = wrap_FSDP_model(model) model = wrap_FSDP_model(model)
# check whether the norm module has IS_SEQUENCE_PARALLEL attribute
check_sequence_parallel(model)
return model return model

View File

@ -4,11 +4,14 @@
import torch.distributed as dist import torch.distributed as dist
from torch import nn from torch import nn
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import NaiveAMPModel from internlm.core.naive_amp import NaiveAMPModel
from internlm.model.utils import try_import_RMSNorm
from internlm.solver.pipeline_utils import partition_uniform from internlm.solver.pipeline_utils import partition_uniform
RMSNorm = try_import_RMSNorm()
def is_model_parallel_parameter(p): def is_model_parallel_parameter(p):
return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL) return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL)
@ -98,3 +101,41 @@ def set_model_params_layer_name(model):
layer_param_name = f"{layer_name}-{param_name}" layer_param_name = f"{layer_name}-{param_name}"
param.__setattr__("layer_name", layer_name) param.__setattr__("layer_name", layer_name)
param.__setattr__("param_name", f"{layer_name}-{param_name}") param.__setattr__("param_name", f"{layer_name}-{param_name}")
def check_sequence_parallel(model):
"""
check whether the norm module has IS_SEQUENC_PARALLEL attribute.
when the sequence_parallel is True, the norm module should have the IS_SEQUENC_PARALLEL attribute
to illustrate the norm should conduct the all-reduce for its grad.
"""
if gpc.config.parallel.sequence_parallel is False:
return
if not isinstance(model, nn.ModuleList):
model = [model]
for _chunk in model:
if isinstance(_chunk, NaiveAMPModel):
_chunk = _chunk.model
for _, children in _chunk.named_children():
# import pdb; pdb.set_trace()
if isinstance(children, (RMSNorm, nn.LayerNorm)):
for param in children.parameters():
assert hasattr(
param, IS_SEQUENCE_PARALLEL
), ("when the sequence parallel is True,"
"the params of norm module should have IS_SEQUENCE_PARALLEL attribute")
continue
elif not isinstance(children, nn.ModuleList):
continue
# transformer block
for _, block in enumerate(children): # iterate transformer blocks
for _, sub in block.named_children():
if isinstance(sub, (RMSNorm, nn.LayerNorm)):
for param in sub.parameters():
assert hasattr(
param, IS_SEQUENCE_PARALLEL
), ("when the sequence parallel is True,"
"the params of norm module should have IS_SEQUENCE_PARALLEL attribute")