mirror of https://github.com/InternLM/InternLM
fix(model): add IS_SEQUENCE_PARALLEL check for norm module (#528)
* add IS_SEQUENCE_PARALLEL check for norm module * fix lint * remove comments * replace the named_children by named_modules * fix lint * fix the spell bug and move the sequence judge to training_internlmpull/530/head
parent
2dbbab7418
commit
c581cc4c02
|
@ -53,6 +53,7 @@ from internlm.utils.common import DummyProfile
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||||
from internlm.utils.parallel import (
|
from internlm.utils.parallel import (
|
||||||
|
check_sequence_parallel,
|
||||||
set_model_params_layer_name,
|
set_model_params_layer_name,
|
||||||
sync_model_param,
|
sync_model_param,
|
||||||
sync_model_param_within_tp,
|
sync_model_param_within_tp,
|
||||||
|
@ -111,6 +112,10 @@ def initialize_model():
|
||||||
# if fsdp enabled, wrap the model
|
# if fsdp enabled, wrap the model
|
||||||
model = wrap_FSDP_model(model)
|
model = wrap_FSDP_model(model)
|
||||||
|
|
||||||
|
# check whether the norm module has IS_SEQUENCE_PARALLEL attribute
|
||||||
|
if gpc.config.parallel.sequence_parallel is True:
|
||||||
|
check_sequence_parallel(model)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,11 +4,14 @@
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.core.naive_amp import NaiveAMPModel
|
from internlm.core.naive_amp import NaiveAMPModel
|
||||||
|
from internlm.model.utils import try_import_RMSNorm
|
||||||
from internlm.solver.pipeline_utils import partition_uniform
|
from internlm.solver.pipeline_utils import partition_uniform
|
||||||
|
|
||||||
|
RMSNorm = try_import_RMSNorm()
|
||||||
|
|
||||||
|
|
||||||
def is_model_parallel_parameter(p):
|
def is_model_parallel_parameter(p):
|
||||||
return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL)
|
return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL)
|
||||||
|
@ -98,3 +101,26 @@ def set_model_params_layer_name(model):
|
||||||
layer_param_name = f"{layer_name}-{param_name}"
|
layer_param_name = f"{layer_name}-{param_name}"
|
||||||
param.__setattr__("layer_name", layer_name)
|
param.__setattr__("layer_name", layer_name)
|
||||||
param.__setattr__("param_name", f"{layer_name}-{param_name}")
|
param.__setattr__("param_name", f"{layer_name}-{param_name}")
|
||||||
|
|
||||||
|
|
||||||
|
def check_sequence_parallel(model):
|
||||||
|
"""
|
||||||
|
check whether the norm module has IS_SEQUENCE_PARALLEL attribute.
|
||||||
|
when the sequence_parallel is True, the norm module should have the IS_SEQUENCE_PARALLEL attribute
|
||||||
|
to illustrate the norm should conduct the all-reduce for its grad.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not isinstance(model, nn.ModuleList):
|
||||||
|
model = [model]
|
||||||
|
|
||||||
|
for _chunk in model:
|
||||||
|
if isinstance(_chunk, NaiveAMPModel):
|
||||||
|
_chunk = _chunk.model
|
||||||
|
|
||||||
|
for _, module in _chunk.named_modules():
|
||||||
|
if isinstance(module, (RMSNorm, nn.LayerNorm)):
|
||||||
|
for param in module.parameters():
|
||||||
|
assert hasattr(param, IS_SEQUENCE_PARALLEL), (
|
||||||
|
"when the gpc.config.parallel.sequence parallel is True,"
|
||||||
|
"the params of norm module should have IS_SEQUENCE_PARALLEL attribute"
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue