mirror of https://github.com/InternLM/InternLM
fix the spell bug and move the sequence judge to training_internlm
parent
bffb515d30
commit
16f8ec2354
|
@ -113,6 +113,7 @@ def initialize_model():
|
||||||
model = wrap_FSDP_model(model)
|
model = wrap_FSDP_model(model)
|
||||||
|
|
||||||
# check whether the norm module has IS_SEQUENCE_PARALLEL attribute
|
# check whether the norm module has IS_SEQUENCE_PARALLEL attribute
|
||||||
|
if gpc.config.parallel.sequence_parallel is True:
|
||||||
check_sequence_parallel(model)
|
check_sequence_parallel(model)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
|
@ -105,14 +105,11 @@ def set_model_params_layer_name(model):
|
||||||
|
|
||||||
def check_sequence_parallel(model):
|
def check_sequence_parallel(model):
|
||||||
"""
|
"""
|
||||||
check whether the norm module has IS_SEQUENC_PARALLEL attribute.
|
check whether the norm module has IS_SEQUENCE_PARALLEL attribute.
|
||||||
when the sequence_parallel is True, the norm module should have the IS_SEQUENC_PARALLEL attribute
|
when the sequence_parallel is True, the norm module should have the IS_SEQUENCE_PARALLEL attribute
|
||||||
to illustrate the norm should conduct the all-reduce for its grad.
|
to illustrate the norm should conduct the all-reduce for its grad.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if gpc.config.parallel.sequence_parallel is False:
|
|
||||||
return
|
|
||||||
|
|
||||||
if not isinstance(model, nn.ModuleList):
|
if not isinstance(model, nn.ModuleList):
|
||||||
model = [model]
|
model = [model]
|
||||||
|
|
||||||
|
@ -124,6 +121,6 @@ def check_sequence_parallel(model):
|
||||||
if isinstance(module, (RMSNorm, nn.LayerNorm)):
|
if isinstance(module, (RMSNorm, nn.LayerNorm)):
|
||||||
for param in module.parameters():
|
for param in module.parameters():
|
||||||
assert hasattr(param, IS_SEQUENCE_PARALLEL), (
|
assert hasattr(param, IS_SEQUENCE_PARALLEL), (
|
||||||
"when the sequence parallel is True,"
|
"when the gpc.config.parallel.sequence parallel is True,"
|
||||||
"the params of norm module should have IS_SEQUENCE_PARALLEL attribute"
|
"the params of norm module should have IS_SEQUENCE_PARALLEL attribute"
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue