diff --git a/internlm/utils/parallel.py b/internlm/utils/parallel.py index ba4a713..001af22 100644 --- a/internlm/utils/parallel.py +++ b/internlm/utils/parallel.py @@ -6,6 +6,7 @@ from torch import nn from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import global_context as gpc +from internlm.core.naive_amp import NaiveAMPModel def is_model_parallel_parameter(p): @@ -70,20 +71,21 @@ def set_model_params_layer_name(model): Args: model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. """ - if isinstance(model, nn.ModuleList): - _chunk = model[0] - else: - _chunk = model + if not isinstance(model, nn.ModuleList): + model = [model] - # Create a unique layer name based on the block's class name and index - for name, children in _chunk.named_children(): - if isinstance(children, nn.ModuleList): - for idx, block in enumerate(children): - for param in block.parameters(): - layer_name = f"{block.__class__.__name__}.{idx}" - gpc.layer_names.add(layer_name) - param.__setattr__("layer_name", layer_name) - else: - for param in children.parameters(): - gpc.layer_names.add(name) - param.__setattr__("layer_name", name) + for _chunk in model: + if isinstance(_chunk, NaiveAMPModel): + _chunk = _chunk.model + # Create a unique layer name based on the block's class name and index + for name, children in _chunk.named_children(): + if isinstance(children, nn.ModuleList): + for idx, block in enumerate(children): + for param in block.parameters(): + layer_name = f"{block.__class__.__name__}.{idx}" + gpc.layer_names.add(layer_name) + param.__setattr__("layer_name", layer_name) + else: + for param in children.parameters(): + gpc.layer_names.add(name) + param.__setattr__("layer_name", name)