mirror of https://github.com/InternLM/InternLM
fix lint
parent
62d193c763
commit
e6c0d7bf62
|
@ -123,10 +123,10 @@ def check_sequence_parallel(model):
|
||||||
# import pdb; pdb.set_trace()
|
# import pdb; pdb.set_trace()
|
||||||
if isinstance(children, (RMSNorm, nn.LayerNorm)):
|
if isinstance(children, (RMSNorm, nn.LayerNorm)):
|
||||||
for param in children.parameters():
|
for param in children.parameters():
|
||||||
assert hasattr(
|
assert hasattr(param, IS_SEQUENCE_PARALLEL), (
|
||||||
param, IS_SEQUENCE_PARALLEL
|
"when the sequence parallel is True,"
|
||||||
), ("when the sequence parallel is True,"
|
"the params of norm module should have IS_SEQUENCE_PARALLEL attribute"
|
||||||
"the params of norm module should have IS_SEQUENCE_PARALLEL attribute")
|
)
|
||||||
continue
|
continue
|
||||||
elif not isinstance(children, nn.ModuleList):
|
elif not isinstance(children, nn.ModuleList):
|
||||||
continue
|
continue
|
||||||
|
@ -135,7 +135,7 @@ def check_sequence_parallel(model):
|
||||||
for _, sub in block.named_children():
|
for _, sub in block.named_children():
|
||||||
if isinstance(sub, (RMSNorm, nn.LayerNorm)):
|
if isinstance(sub, (RMSNorm, nn.LayerNorm)):
|
||||||
for param in sub.parameters():
|
for param in sub.parameters():
|
||||||
assert hasattr(
|
assert hasattr(param, IS_SEQUENCE_PARALLEL), (
|
||||||
param, IS_SEQUENCE_PARALLEL
|
"when the sequence parallel is True,"
|
||||||
), ("when the sequence parallel is True,"
|
"the params of norm module should have IS_SEQUENCE_PARALLEL attribute"
|
||||||
"the params of norm module should have IS_SEQUENCE_PARALLEL attribute")
|
)
|
||||||
|
|
Loading…
Reference in New Issue