fix set layer name

pull/412/head
JiaoPL 2023-10-14 22:45:35 +08:00
parent 7d68509c4f
commit 7920168179
1 changed files with 18 additions and 16 deletions

View File

@ -6,6 +6,7 @@ from torch import nn
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import NaiveAMPModel
def is_model_parallel_parameter(p): def is_model_parallel_parameter(p):
@ -70,11 +71,12 @@ def set_model_params_layer_name(model):
Args: Args:
model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency.
""" """
if isinstance(model, nn.ModuleList): if not isinstance(model, nn.ModuleList):
_chunk = model[0] model = [model]
else:
_chunk = model
for _chunk in model:
if isinstance(_chunk, NaiveAMPModel):
_chunk = _chunk.model
# Create a unique layer name based on the block's class name and index # Create a unique layer name based on the block's class name and index
for name, children in _chunk.named_children(): for name, children in _chunk.named_children():
if isinstance(children, nn.ModuleList): if isinstance(children, nn.ModuleList):