fix layer norm with pp

pull/449/head
JiaoPL 2023-10-26 14:50:54 +08:00
parent 9ac5ab3101
commit a6051335b7
1 changed files with 9 additions and 2 deletions

View File

@ -7,6 +7,7 @@ from torch import nn
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import NaiveAMPModel from internlm.core.naive_amp import NaiveAMPModel
from internlm.solver.pipeline_utils import partition_uniform
def is_model_parallel_parameter(p): def is_model_parallel_parameter(p):
@ -70,18 +71,24 @@ def set_model_params_layer_name(model):
Args: Args:
model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency.
""" """
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
all_parts = partition_uniform(gpc.config.model.num_layers, pipeline_size, gpc.config.model.num_chunks)
parts = all_parts[pipeline_rank]
if not isinstance(model, nn.ModuleList): if not isinstance(model, nn.ModuleList):
model = [model] model = [model]
for _chunk in model: for chunk_idx, _chunk in enumerate(model):
if isinstance(_chunk, NaiveAMPModel): if isinstance(_chunk, NaiveAMPModel):
_chunk = _chunk.model _chunk = _chunk.model
chunk_start = parts[chunk_idx][0]
# Create a unique layer name based on the block's class name and index # Create a unique layer name based on the block's class name and index
for _, children in _chunk.named_children(): for _, children in _chunk.named_children():
if isinstance(children, nn.ModuleList): if isinstance(children, nn.ModuleList):
for idx, block in enumerate(children): for idx, block in enumerate(children):
for param_name, param in block.named_parameters(): for param_name, param in block.named_parameters():
layer_name = f"{block.__class__.__name__}Block{idx}" layer_name = f"{block.__class__.__name__}Block{idx + chunk_start}"
layer_param_name = f"{layer_name}-{param_name}" layer_param_name = f"{layer_name}-{param_name}"
param.__setattr__("layer_name", layer_name) param.__setattr__("layer_name", layer_name)
param.__setattr__("param_name", layer_param_name) param.__setattr__("param_name", layer_param_name)