diff --git a/internlm/utils/parallel.py b/internlm/utils/parallel.py index 7726c77..14fb2dc 100644 --- a/internlm/utils/parallel.py +++ b/internlm/utils/parallel.py @@ -123,10 +123,10 @@ def check_sequence_parallel(model): # import pdb; pdb.set_trace() if isinstance(children, (RMSNorm, nn.LayerNorm)): for param in children.parameters(): - assert hasattr( - param, IS_SEQUENCE_PARALLEL - ), ("when the sequence parallel is True," - "the params of norm module should have IS_SEQUENCE_PARALLEL attribute") + assert hasattr(param, IS_SEQUENCE_PARALLEL), ( + "when the sequence parallel is True," + "the params of norm module should have IS_SEQUENCE_PARALLEL attribute" + ) continue elif not isinstance(children, nn.ModuleList): continue @@ -135,7 +135,7 @@ def check_sequence_parallel(model): for _, sub in block.named_children(): if isinstance(sub, (RMSNorm, nn.LayerNorm)): for param in sub.parameters(): - assert hasattr( - param, IS_SEQUENCE_PARALLEL - ), ("when the sequence parallel is True," - "the params of norm module should have IS_SEQUENCE_PARALLEL attribute") + assert hasattr(param, IS_SEQUENCE_PARALLEL), ( + "when the sequence parallel is True," + "the params of norm module should have IS_SEQUENCE_PARALLEL attribute" + )