pull/528/head
yingtongxiong 2023-12-05 21:03:00 +08:00
parent 62d193c763
commit e6c0d7bf62
1 changed files with 8 additions and 8 deletions

View File

@ -123,10 +123,10 @@ def check_sequence_parallel(model):
# import pdb; pdb.set_trace()
if isinstance(children, (RMSNorm, nn.LayerNorm)):
for param in children.parameters():
assert hasattr(
param, IS_SEQUENCE_PARALLEL
), ("when the sequence parallel is True,"
"the params of norm module should have IS_SEQUENCE_PARALLEL attribute")
assert hasattr(param, IS_SEQUENCE_PARALLEL), (
"when the sequence parallel is True,"
"the params of norm module should have IS_SEQUENCE_PARALLEL attribute"
)
continue
elif not isinstance(children, nn.ModuleList):
continue
@ -135,7 +135,7 @@ def check_sequence_parallel(model):
for _, sub in block.named_children():
if isinstance(sub, (RMSNorm, nn.LayerNorm)):
for param in sub.parameters():
assert hasattr(
param, IS_SEQUENCE_PARALLEL
), ("when the sequence parallel is True,"
"the params of norm module should have IS_SEQUENCE_PARALLEL attribute")
assert hasattr(param, IS_SEQUENCE_PARALLEL), (
"when the sequence parallel is True,"
"the params of norm module should have IS_SEQUENCE_PARALLEL attribute"
)