mirror of https://github.com/InternLM/InternLM
set layer name to parameters after init_model
parent
646f1b45fa
commit
7d68509c4f
|
@ -150,6 +150,7 @@ class ParallelContext(metaclass=SingletonMeta):
|
|||
self.virtual_pipeline_parallel_size = None
|
||||
self.virtual_pipeline_parallel_rank = None
|
||||
self._expert_parallel_group_names = []
|
||||
self.layer_names = {"unknown", "embedding", "norm", "head"}
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import copy
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import OrderedDict
|
||||
|
@ -32,7 +31,6 @@ except (ModuleNotFoundError, ImportError):
|
|||
APEX_AVAILABLE = False
|
||||
|
||||
inf = math.inf
|
||||
global_layer_norms = {"unknown": 0.0, "embedding": 0.0, "norm": 0.0, "head": 0.0}
|
||||
|
||||
|
||||
def flatten(input_):
|
||||
|
@ -228,7 +226,7 @@ def compute_norm(
|
|||
enable_cuda_kernels = gradients[0].device.type == "cuda"
|
||||
# Norm parameters.
|
||||
norm_type = float(norm_type)
|
||||
total_layer_norms = copy.deepcopy(global_layer_norms)
|
||||
total_layer_norms = {layer_name: 0.0 for layer_name in gpc.layer_names}
|
||||
layer_grads = {}
|
||||
# Calculate norm.
|
||||
if norm_type == inf:
|
||||
|
@ -249,7 +247,7 @@ def compute_norm(
|
|||
total_layer_norms[key] = max(value, total_layer_norms[key])
|
||||
|
||||
total_layer_norms_values = move_norm_to_cuda(torch.Tensor(list(total_layer_norms.values())))
|
||||
total_layer_norms_keys = list(global_layer_norms.keys())
|
||||
total_layer_norms_keys = list(total_layer_norms.keys())
|
||||
|
||||
# Take max across all model-parallel GPUs.
|
||||
if gpc.is_initialized(ParallelMode.MODEL):
|
||||
|
@ -523,24 +521,17 @@ class ParamBcastSyncHandler:
|
|||
for _chunk in model:
|
||||
if isinstance(_chunk, NaiveAMPModel):
|
||||
_chunk = _chunk.model
|
||||
for name, children in _chunk.named_children():
|
||||
for _, children in _chunk.named_children():
|
||||
# should be the transformer block definaton in modeling_xxx.py
|
||||
if isinstance(children, nn.ModuleList):
|
||||
# record the block that a parameter belongs to
|
||||
for idx, block in enumerate(children):
|
||||
for _, block in enumerate(children):
|
||||
# self._block_to_param[f"{name}.{idx}"] = list(block.parameters())
|
||||
self._block_to_param[block] = list(block.parameters())
|
||||
for parameter in self._block_to_param[block]:
|
||||
layer_name = f"{block.__class__.__name__}.{idx}"
|
||||
global_layer_norms[layer_name] = 0.0
|
||||
parameter.__setattr__("layer_name", layer_name)
|
||||
else:
|
||||
# record the block that a parameter belongs to
|
||||
# self._block_to_param[name] = list(children.parameters())
|
||||
self._block_to_param[children] = list(children.parameters())
|
||||
for parameter in self._block_to_param[children]:
|
||||
layer_name = f"{children.__class__.__name__}"
|
||||
parameter.__setattr__("layer_name", name)
|
||||
|
||||
alloc_num = 0
|
||||
rank_to_go = 0
|
||||
|
|
|
@ -52,7 +52,11 @@ from internlm.train.utils import create_param_groups
|
|||
from internlm.utils.common import DummyProfile
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
from internlm.utils.parallel import sync_model_param, sync_model_param_within_tp
|
||||
from internlm.utils.parallel import (
|
||||
set_model_params_layer_name,
|
||||
sync_model_param,
|
||||
sync_model_param_within_tp,
|
||||
)
|
||||
from internlm.utils.registry import MODEL_INITIALIZER
|
||||
from internlm.utils.timeout import llm_timeout
|
||||
|
||||
|
@ -107,6 +111,9 @@ def initialize_model():
|
|||
# if fsdp enabled, wrap the model
|
||||
model = wrap_FSDP_model(model)
|
||||
|
||||
# set the layer name as an attribute of the model parameters
|
||||
set_model_params_layer_name(model)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
|
||||
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
|
@ -61,3 +62,28 @@ def get_parallel_log_file_name():
|
|||
f"tp={gpc.get_local_rank(ParallelMode.TENSOR)}_pp={gpc.get_local_rank(ParallelMode.PIPELINE)}"
|
||||
)
|
||||
return log_file_name
|
||||
|
||||
|
||||
def set_model_params_layer_name(model):
|
||||
r"""Set the layer name as an attribute of the model parameters.
|
||||
|
||||
Args:
|
||||
model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency.
|
||||
"""
|
||||
if isinstance(model, nn.ModuleList):
|
||||
_chunk = model[0]
|
||||
else:
|
||||
_chunk = model
|
||||
|
||||
# Create a unique layer name based on the block's class name and index
|
||||
for name, children in _chunk.named_children():
|
||||
if isinstance(children, nn.ModuleList):
|
||||
for idx, block in enumerate(children):
|
||||
for param in block.parameters():
|
||||
layer_name = f"{block.__class__.__name__}.{idx}"
|
||||
gpc.layer_names.add(layer_name)
|
||||
param.__setattr__("layer_name", layer_name)
|
||||
else:
|
||||
for param in children.parameters():
|
||||
gpc.layer_names.add(name)
|
||||
param.__setattr__("layer_name", name)
|
||||
|
|
Loading…
Reference in New Issue