mirror of https://github.com/InternLM/InternLM
fix layer norm with pp
parent
9ac5ab3101
commit
a6051335b7
|
@ -7,6 +7,7 @@ from torch import nn
|
|||
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.naive_amp import NaiveAMPModel
|
||||
from internlm.solver.pipeline_utils import partition_uniform
|
||||
|
||||
|
||||
def is_model_parallel_parameter(p):
|
||||
|
@ -70,18 +71,24 @@ def set_model_params_layer_name(model):
|
|||
Args:
|
||||
model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency.
|
||||
"""
|
||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
all_parts = partition_uniform(gpc.config.model.num_layers, pipeline_size, gpc.config.model.num_chunks)
|
||||
parts = all_parts[pipeline_rank]
|
||||
|
||||
if not isinstance(model, nn.ModuleList):
|
||||
model = [model]
|
||||
|
||||
for _chunk in model:
|
||||
for chunk_idx, _chunk in enumerate(model):
|
||||
if isinstance(_chunk, NaiveAMPModel):
|
||||
_chunk = _chunk.model
|
||||
chunk_start = parts[chunk_idx][0]
|
||||
# Create a unique layer name based on the block's class name and index
|
||||
for _, children in _chunk.named_children():
|
||||
if isinstance(children, nn.ModuleList):
|
||||
for idx, block in enumerate(children):
|
||||
for param_name, param in block.named_parameters():
|
||||
layer_name = f"{block.__class__.__name__}Block{idx}"
|
||||
layer_name = f"{block.__class__.__name__}Block{idx + chunk_start}"
|
||||
layer_param_name = f"{layer_name}-{param_name}"
|
||||
param.__setattr__("layer_name", layer_name)
|
||||
param.__setattr__("param_name", layer_param_name)
|
||||
|
|
Loading…
Reference in New Issue