From 7d68509c4f482c97474e35fe4ecdf3beb956e8d1 Mon Sep 17 00:00:00 2001 From: JiaoPL Date: Sat, 14 Oct 2023 22:32:10 +0800 Subject: [PATCH] set layer name to parameters after init_model --- internlm/core/context/parallel_context.py | 1 + internlm/solver/optimizer/utils.py | 17 ++++----------- internlm/train/training_internlm.py | 9 +++++++- internlm/utils/parallel.py | 26 +++++++++++++++++++++++ 4 files changed, 39 insertions(+), 14 deletions(-) diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index 997bd46..5e44dc9 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -150,6 +150,7 @@ class ParallelContext(metaclass=SingletonMeta): self.virtual_pipeline_parallel_size = None self.virtual_pipeline_parallel_rank = None self._expert_parallel_group_names = [] + self.layer_names = {"unknown", "embedding", "norm", "head"} @property def config(self): diff --git a/internlm/solver/optimizer/utils.py b/internlm/solver/optimizer/utils.py index 12575a3..71a1323 100644 --- a/internlm/solver/optimizer/utils.py +++ b/internlm/solver/optimizer/utils.py @@ -1,7 +1,6 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -import copy import math from abc import ABC, abstractmethod from collections import OrderedDict @@ -32,7 +31,6 @@ except (ModuleNotFoundError, ImportError): APEX_AVAILABLE = False inf = math.inf -global_layer_norms = {"unknown": 0.0, "embedding": 0.0, "norm": 0.0, "head": 0.0} def flatten(input_): @@ -228,7 +226,7 @@ def compute_norm( enable_cuda_kernels = gradients[0].device.type == "cuda" # Norm parameters. norm_type = float(norm_type) - total_layer_norms = copy.deepcopy(global_layer_norms) + total_layer_norms = {layer_name: 0.0 for layer_name in gpc.layer_names} layer_grads = {} # Calculate norm. if norm_type == inf: @@ -249,7 +247,7 @@ def compute_norm( total_layer_norms[key] = max(value, total_layer_norms[key]) total_layer_norms_values = move_norm_to_cuda(torch.Tensor(list(total_layer_norms.values()))) - total_layer_norms_keys = list(global_layer_norms.keys()) + total_layer_norms_keys = list(total_layer_norms.keys()) # Take max across all model-parallel GPUs. if gpc.is_initialized(ParallelMode.MODEL): @@ -523,24 +521,17 @@ class ParamBcastSyncHandler: for _chunk in model: if isinstance(_chunk, NaiveAMPModel): _chunk = _chunk.model - for name, children in _chunk.named_children(): + for _, children in _chunk.named_children(): # should be the transformer block definaton in modeling_xxx.py if isinstance(children, nn.ModuleList): # record the block that a parameter belongs to - for idx, block in enumerate(children): + for _, block in enumerate(children): # self._block_to_param[f"{name}.{idx}"] = list(block.parameters()) self._block_to_param[block] = list(block.parameters()) - for parameter in self._block_to_param[block]: - layer_name = f"{block.__class__.__name__}.{idx}" - global_layer_norms[layer_name] = 0.0 - parameter.__setattr__("layer_name", layer_name) else: # record the block that a parameter belongs to # self._block_to_param[name] = list(children.parameters()) self._block_to_param[children] = list(children.parameters()) - for parameter in self._block_to_param[children]: - layer_name = f"{children.__class__.__name__}" - parameter.__setattr__("layer_name", name) alloc_num = 0 rank_to_go = 0 diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 1451dc5..943c916 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -52,7 +52,11 @@ from internlm.train.utils import create_param_groups from internlm.utils.common import DummyProfile from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer -from internlm.utils.parallel import sync_model_param, sync_model_param_within_tp +from internlm.utils.parallel import ( + set_model_params_layer_name, + sync_model_param, + sync_model_param_within_tp, +) from internlm.utils.registry import MODEL_INITIALIZER from internlm.utils.timeout import llm_timeout @@ -107,6 +111,9 @@ def initialize_model(): # if fsdp enabled, wrap the model model = wrap_FSDP_model(model) + # set the layer name as an attribute of the model parameters + set_model_params_layer_name(model) + return model diff --git a/internlm/utils/parallel.py b/internlm/utils/parallel.py index 3029af5..ba4a713 100644 --- a/internlm/utils/parallel.py +++ b/internlm/utils/parallel.py @@ -2,6 +2,7 @@ # -*- encoding: utf-8 -*- import torch.distributed as dist +from torch import nn from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import global_context as gpc @@ -61,3 +62,28 @@ def get_parallel_log_file_name(): f"tp={gpc.get_local_rank(ParallelMode.TENSOR)}_pp={gpc.get_local_rank(ParallelMode.PIPELINE)}" ) return log_file_name + + +def set_model_params_layer_name(model): + r"""Set the layer name as an attribute of the model parameters. + + Args: + model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. + """ + if isinstance(model, nn.ModuleList): + _chunk = model[0] + else: + _chunk = model + + # Create a unique layer name based on the block's class name and index + for name, children in _chunk.named_children(): + if isinstance(children, nn.ModuleList): + for idx, block in enumerate(children): + for param in block.parameters(): + layer_name = f"{block.__class__.__name__}.{idx}" + gpc.layer_names.add(layer_name) + param.__setattr__("layer_name", layer_name) + else: + for param in children.parameters(): + gpc.layer_names.add(name) + param.__setattr__("layer_name", name)