From a9d5ad1b5f66f84d36647ce79d89d3d9aaa0d16d Mon Sep 17 00:00:00 2001 From: yingtongxiong <974106207@qq.com> Date: Wed, 6 Dec 2023 11:01:07 +0800 Subject: [PATCH] replace the named_children by named_modules --- internlm/utils/parallel.py | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/internlm/utils/parallel.py b/internlm/utils/parallel.py index a2b150b..b90b9d6 100644 --- a/internlm/utils/parallel.py +++ b/internlm/utils/parallel.py @@ -119,22 +119,11 @@ def check_sequence_parallel(model): for _chunk in model: if isinstance(_chunk, NaiveAMPModel): _chunk = _chunk.model - for _, children in _chunk.named_children(): - if isinstance(children, (RMSNorm, nn.LayerNorm)): - for param in children.parameters(): + + for _, module in _chunk.named_modules(): + if isinstance(module, (RMSNorm, nn.LayerNorm)): + for param in module.parameters(): assert hasattr(param, IS_SEQUENCE_PARALLEL), ( "when the sequence parallel is True," "the params of norm module should have IS_SEQUENCE_PARALLEL attribute" ) - continue - elif not isinstance(children, nn.ModuleList): - continue - # transformer block - for _, block in enumerate(children): # iterate transformer blocks - for _, sub in block.named_children(): - if isinstance(sub, (RMSNorm, nn.LayerNorm)): - for param in sub.parameters(): - assert hasattr(param, IS_SEQUENCE_PARALLEL), ( - "when the sequence parallel is True," - "the params of norm module should have IS_SEQUENCE_PARALLEL attribute" - )