fix the spell bug and move the sequence judge to training_internlm

pull/528/head
yingtongxiong 2023-12-06 12:03:23 +08:00
parent bffb515d30
commit 16f8ec2354
2 changed files with 5 additions and 7 deletions

View File

@ -113,6 +113,7 @@ def initialize_model():
model = wrap_FSDP_model(model) model = wrap_FSDP_model(model)
# check whether the norm module has IS_SEQUENCE_PARALLEL attribute # check whether the norm module has IS_SEQUENCE_PARALLEL attribute
if gpc.config.parallel.sequence_parallel is True:
check_sequence_parallel(model) check_sequence_parallel(model)
return model return model

View File

@ -105,14 +105,11 @@ def set_model_params_layer_name(model):
def check_sequence_parallel(model): def check_sequence_parallel(model):
""" """
check whether the norm module has IS_SEQUENC_PARALLEL attribute. check whether the norm module has IS_SEQUENCE_PARALLEL attribute.
when the sequence_parallel is True, the norm module should have the IS_SEQUENC_PARALLEL attribute when the sequence_parallel is True, the norm module should have the IS_SEQUENCE_PARALLEL attribute
to illustrate the norm should conduct the all-reduce for its grad. to illustrate the norm should conduct the all-reduce for its grad.
""" """
if gpc.config.parallel.sequence_parallel is False:
return
if not isinstance(model, nn.ModuleList): if not isinstance(model, nn.ModuleList):
model = [model] model = [model]
@ -124,6 +121,6 @@ def check_sequence_parallel(model):
if isinstance(module, (RMSNorm, nn.LayerNorm)): if isinstance(module, (RMSNorm, nn.LayerNorm)):
for param in module.parameters(): for param in module.parameters():
assert hasattr(param, IS_SEQUENCE_PARALLEL), ( assert hasattr(param, IS_SEQUENCE_PARALLEL), (
"when the sequence parallel is True," "when the gpc.config.parallel.sequence parallel is True,"
"the params of norm module should have IS_SEQUENCE_PARALLEL attribute" "the params of norm module should have IS_SEQUENCE_PARALLEL attribute"
) )