set layer name to parameters after init_model

pull/412/head
JiaoPL 2023-10-14 22:32:10 +08:00
parent 646f1b45fa
commit 7d68509c4f
4 changed files with 39 additions and 14 deletions

View File

@ -150,6 +150,7 @@ class ParallelContext(metaclass=SingletonMeta):
self.virtual_pipeline_parallel_size = None
self.virtual_pipeline_parallel_rank = None
self._expert_parallel_group_names = []
self.layer_names = {"unknown", "embedding", "norm", "head"}
@property
def config(self):

View File

@ -1,7 +1,6 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import copy
import math
from abc import ABC, abstractmethod
from collections import OrderedDict
@ -32,7 +31,6 @@ except (ModuleNotFoundError, ImportError):
APEX_AVAILABLE = False
inf = math.inf
global_layer_norms = {"unknown": 0.0, "embedding": 0.0, "norm": 0.0, "head": 0.0}
def flatten(input_):
@ -228,7 +226,7 @@ def compute_norm(
enable_cuda_kernels = gradients[0].device.type == "cuda"
# Norm parameters.
norm_type = float(norm_type)
total_layer_norms = copy.deepcopy(global_layer_norms)
total_layer_norms = {layer_name: 0.0 for layer_name in gpc.layer_names}
layer_grads = {}
# Calculate norm.
if norm_type == inf:
@ -249,7 +247,7 @@ def compute_norm(
total_layer_norms[key] = max(value, total_layer_norms[key])
total_layer_norms_values = move_norm_to_cuda(torch.Tensor(list(total_layer_norms.values())))
total_layer_norms_keys = list(global_layer_norms.keys())
total_layer_norms_keys = list(total_layer_norms.keys())
# Take max across all model-parallel GPUs.
if gpc.is_initialized(ParallelMode.MODEL):
@ -523,24 +521,17 @@ class ParamBcastSyncHandler:
for _chunk in model:
if isinstance(_chunk, NaiveAMPModel):
_chunk = _chunk.model
for name, children in _chunk.named_children():
for _, children in _chunk.named_children():
# should be the transformer block definaton in modeling_xxx.py
if isinstance(children, nn.ModuleList):
# record the block that a parameter belongs to
for idx, block in enumerate(children):
for _, block in enumerate(children):
# self._block_to_param[f"{name}.{idx}"] = list(block.parameters())
self._block_to_param[block] = list(block.parameters())
for parameter in self._block_to_param[block]:
layer_name = f"{block.__class__.__name__}.{idx}"
global_layer_norms[layer_name] = 0.0
parameter.__setattr__("layer_name", layer_name)
else:
# record the block that a parameter belongs to
# self._block_to_param[name] = list(children.parameters())
self._block_to_param[children] = list(children.parameters())
for parameter in self._block_to_param[children]:
layer_name = f"{children.__class__.__name__}"
parameter.__setattr__("layer_name", name)
alloc_num = 0
rank_to_go = 0

View File

@ -52,7 +52,11 @@ from internlm.train.utils import create_param_groups
from internlm.utils.common import DummyProfile
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.parallel import sync_model_param, sync_model_param_within_tp
from internlm.utils.parallel import (
set_model_params_layer_name,
sync_model_param,
sync_model_param_within_tp,
)
from internlm.utils.registry import MODEL_INITIALIZER
from internlm.utils.timeout import llm_timeout
@ -107,6 +111,9 @@ def initialize_model():
# if fsdp enabled, wrap the model
model = wrap_FSDP_model(model)
# set the layer name as an attribute of the model parameters
set_model_params_layer_name(model)
return model

View File

@ -2,6 +2,7 @@
# -*- encoding: utf-8 -*-
import torch.distributed as dist
from torch import nn
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context import global_context as gpc
@ -61,3 +62,28 @@ def get_parallel_log_file_name():
f"tp={gpc.get_local_rank(ParallelMode.TENSOR)}_pp={gpc.get_local_rank(ParallelMode.PIPELINE)}"
)
return log_file_name
def set_model_params_layer_name(model):
r"""Set the layer name as an attribute of the model parameters.
Args:
model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency.
"""
if isinstance(model, nn.ModuleList):
_chunk = model[0]
else:
_chunk = model
# Create a unique layer name based on the block's class name and index
for name, children in _chunk.named_children():
if isinstance(children, nn.ModuleList):
for idx, block in enumerate(children):
for param in block.parameters():
layer_name = f"{block.__class__.__name__}.{idx}"
gpc.layer_names.add(layer_name)
param.__setattr__("layer_name", layer_name)
else:
for param in children.parameters():
gpc.layer_names.add(name)
param.__setattr__("layer_name", name)