replace the named_children by named_modules

pull/528/head
yingtongxiong 2023-12-06 11:01:07 +08:00
parent 2b28923949
commit a9d5ad1b5f
1 changed files with 4 additions and 15 deletions

View File

@ -119,22 +119,11 @@ def check_sequence_parallel(model):
for _chunk in model:
if isinstance(_chunk, NaiveAMPModel):
_chunk = _chunk.model
for _, children in _chunk.named_children():
if isinstance(children, (RMSNorm, nn.LayerNorm)):
for param in children.parameters():
for _, module in _chunk.named_modules():
if isinstance(module, (RMSNorm, nn.LayerNorm)):
for param in module.parameters():
assert hasattr(param, IS_SEQUENCE_PARALLEL), (
"when the sequence parallel is True,"
"the params of norm module should have IS_SEQUENCE_PARALLEL attribute"
)
continue
elif not isinstance(children, nn.ModuleList):
continue
# transformer block
for _, block in enumerate(children): # iterate transformer blocks
for _, sub in block.named_children():
if isinstance(sub, (RMSNorm, nn.LayerNorm)):
for param in sub.parameters():
assert hasattr(param, IS_SEQUENCE_PARALLEL), (
"when the sequence parallel is True,"
"the params of norm module should have IS_SEQUENCE_PARALLEL attribute"
)