diff --git a/internlm/utils/parallel.py b/internlm/utils/parallel.py index a2b150b..b90b9d6 100644 --- a/internlm/utils/parallel.py +++ b/internlm/utils/parallel.py @@ -119,22 +119,11 @@ def check_sequence_parallel(model): for _chunk in model: if isinstance(_chunk, NaiveAMPModel): _chunk = _chunk.model - for _, children in _chunk.named_children(): - if isinstance(children, (RMSNorm, nn.LayerNorm)): - for param in children.parameters(): + + for _, module in _chunk.named_modules(): + if isinstance(module, (RMSNorm, nn.LayerNorm)): + for param in module.parameters(): assert hasattr(param, IS_SEQUENCE_PARALLEL), ( "when the sequence parallel is True," "the params of norm module should have IS_SEQUENCE_PARALLEL attribute" ) - continue - elif not isinstance(children, nn.ModuleList): - continue - # transformer block - for _, block in enumerate(children): # iterate transformer blocks - for _, sub in block.named_children(): - if isinstance(sub, (RMSNorm, nn.LayerNorm)): - for param in sub.parameters(): - assert hasattr(param, IS_SEQUENCE_PARALLEL), ( - "when the sequence parallel is True," - "the params of norm module should have IS_SEQUENCE_PARALLEL attribute" - )