pull/528/head
yingtongxiong 2023-12-05 21:03:00 +08:00
parent 62d193c763
commit e6c0d7bf62
1 changed files with 8 additions and 8 deletions

View File

@ -123,10 +123,10 @@ def check_sequence_parallel(model):
# import pdb; pdb.set_trace() # import pdb; pdb.set_trace()
if isinstance(children, (RMSNorm, nn.LayerNorm)): if isinstance(children, (RMSNorm, nn.LayerNorm)):
for param in children.parameters(): for param in children.parameters():
assert hasattr( assert hasattr(param, IS_SEQUENCE_PARALLEL), (
param, IS_SEQUENCE_PARALLEL "when the sequence parallel is True,"
), ("when the sequence parallel is True," "the params of norm module should have IS_SEQUENCE_PARALLEL attribute"
"the params of norm module should have IS_SEQUENCE_PARALLEL attribute") )
continue continue
elif not isinstance(children, nn.ModuleList): elif not isinstance(children, nn.ModuleList):
continue continue
@ -135,7 +135,7 @@ def check_sequence_parallel(model):
for _, sub in block.named_children(): for _, sub in block.named_children():
if isinstance(sub, (RMSNorm, nn.LayerNorm)): if isinstance(sub, (RMSNorm, nn.LayerNorm)):
for param in sub.parameters(): for param in sub.parameters():
assert hasattr( assert hasattr(param, IS_SEQUENCE_PARALLEL), (
param, IS_SEQUENCE_PARALLEL "when the sequence parallel is True,"
), ("when the sequence parallel is True," "the params of norm module should have IS_SEQUENCE_PARALLEL attribute"
"the params of norm module should have IS_SEQUENCE_PARALLEL attribute") )