mirror of https://github.com/InternLM/InternLM
replace the named_children by named_modules
parent
2b28923949
commit
a9d5ad1b5f
|
@ -119,22 +119,11 @@ def check_sequence_parallel(model):
|
||||||
for _chunk in model:
|
for _chunk in model:
|
||||||
if isinstance(_chunk, NaiveAMPModel):
|
if isinstance(_chunk, NaiveAMPModel):
|
||||||
_chunk = _chunk.model
|
_chunk = _chunk.model
|
||||||
for _, children in _chunk.named_children():
|
|
||||||
if isinstance(children, (RMSNorm, nn.LayerNorm)):
|
for _, module in _chunk.named_modules():
|
||||||
for param in children.parameters():
|
if isinstance(module, (RMSNorm, nn.LayerNorm)):
|
||||||
|
for param in module.parameters():
|
||||||
assert hasattr(param, IS_SEQUENCE_PARALLEL), (
|
assert hasattr(param, IS_SEQUENCE_PARALLEL), (
|
||||||
"when the sequence parallel is True,"
|
"when the sequence parallel is True,"
|
||||||
"the params of norm module should have IS_SEQUENCE_PARALLEL attribute"
|
"the params of norm module should have IS_SEQUENCE_PARALLEL attribute"
|
||||||
)
|
)
|
||||||
continue
|
|
||||||
elif not isinstance(children, nn.ModuleList):
|
|
||||||
continue
|
|
||||||
# transformer block
|
|
||||||
for _, block in enumerate(children): # iterate transformer blocks
|
|
||||||
for _, sub in block.named_children():
|
|
||||||
if isinstance(sub, (RMSNorm, nn.LayerNorm)):
|
|
||||||
for param in sub.parameters():
|
|
||||||
assert hasattr(param, IS_SEQUENCE_PARALLEL), (
|
|
||||||
"when the sequence parallel is True,"
|
|
||||||
"the params of norm module should have IS_SEQUENCE_PARALLEL attribute"
|
|
||||||
)
|
|
||||||
|
|
Loading…
Reference in New Issue