From e6c0d7bf623c844ff9ee1563e08ecf386f8ce8bf Mon Sep 17 00:00:00 2001 From: yingtongxiong <974106207@qq.com> Date: Tue, 5 Dec 2023 21:03:00 +0800 Subject: [PATCH] fix lint --- internlm/utils/parallel.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/internlm/utils/parallel.py b/internlm/utils/parallel.py index 7726c77..14fb2dc 100644 --- a/internlm/utils/parallel.py +++ b/internlm/utils/parallel.py @@ -123,10 +123,10 @@ def check_sequence_parallel(model): # import pdb; pdb.set_trace() if isinstance(children, (RMSNorm, nn.LayerNorm)): for param in children.parameters(): - assert hasattr( - param, IS_SEQUENCE_PARALLEL - ), ("when the sequence parallel is True," - "the params of norm module should have IS_SEQUENCE_PARALLEL attribute") + assert hasattr(param, IS_SEQUENCE_PARALLEL), ( + "when the sequence parallel is True," + "the params of norm module should have IS_SEQUENCE_PARALLEL attribute" + ) continue elif not isinstance(children, nn.ModuleList): continue @@ -135,7 +135,7 @@ def check_sequence_parallel(model): for _, sub in block.named_children(): if isinstance(sub, (RMSNorm, nn.LayerNorm)): for param in sub.parameters(): - assert hasattr( - param, IS_SEQUENCE_PARALLEL - ), ("when the sequence parallel is True," - "the params of norm module should have IS_SEQUENCE_PARALLEL attribute") + assert hasattr(param, IS_SEQUENCE_PARALLEL), ( + "when the sequence parallel is True," + "the params of norm module should have IS_SEQUENCE_PARALLEL attribute" + )