diff --git a/internlm/utils/parallel.py b/internlm/utils/parallel.py index b90b9d6..908aa80 100644 --- a/internlm/utils/parallel.py +++ b/internlm/utils/parallel.py @@ -119,7 +119,7 @@ def check_sequence_parallel(model): for _chunk in model: if isinstance(_chunk, NaiveAMPModel): _chunk = _chunk.model - + for _, module in _chunk.named_modules(): if isinstance(module, (RMSNorm, nn.LayerNorm)): for param in module.parameters():