diff --git a/internlm/core/naive_amp.py b/internlm/core/naive_amp.py index 7470659..04f3ee2 100644 --- a/internlm/core/naive_amp.py +++ b/internlm/core/naive_amp.py @@ -3,7 +3,8 @@ # adopted from https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/amp -from typing import Any +from functools import partial +from typing import Any, Union import torch import torch.distributed as dist @@ -15,6 +16,14 @@ from internlm.core.context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc +def set_fp32_attr_to_module(module: nn.Module): + setattr(module, "is_fp32_module", True) + + +def module_has_fp32_attr(module: nn.Module): + return hasattr(module, "is_fp32_module") and getattr(module, "is_fp32_module") + + class NaiveAMPModel(nn.Module): """ This is a wrapper class for a model that automatically casts the model, its inputs, and outputs into fp16. @@ -51,6 +60,9 @@ class NaiveAMPModel(nn.Module): self._sync_buf = False self._first_eval_run = False + # register hook for fp32 module + self._register_fp32_parameters_hook() + @property def sync_buffer(self): """Returns the current state of the buffer synchronization.""" @@ -134,3 +146,55 @@ class NaiveAMPModel(nn.Module): if self._output_to_fp32: out = self.convert_to_fp32(out) return out + + def _register_fp32_parameters_hook(self) -> None: + dtype = torch.float32 + + def _pre_forward_hook(model: nn.Module, inputs: tuple): # pylint: disable=W0613 + inputs_fp32 = [] + for input_data_ in inputs: + if isinstance(input_data_, Tensor) and input_data_.dtype is not dtype: + inputs_fp32.append(input_data_.to(dtype)) + else: + inputs_fp32.append(input_data_) + return tuple(inputs_fp32) + + def _post_forward_hook(model: nn.Module, inputs: tuple, outputs: Union[tuple, Tensor]): # pylint: disable=W0613 + outputs_ = [] + assert isinstance(outputs, (Tensor, tuple)) + if isinstance(outputs, tuple): + for output_data_ in outputs: + if isinstance(output_data_, Tensor) and output_data_.dtype is not self.dtype: + outputs_.append(output_data_.to(self.dtype)) + else: + outputs_.append(output_data_) + return tuple(outputs_) + else: + return outputs.to(self.dtype) + + # just want to share same for loop for ModuleList and Module + if not isinstance(self.model, nn.ModuleList): + model = [self.model] + + modules = [] + # record the modules to transformer/embeding/head/norm block + for _chunk in model: + if isinstance(_chunk, NaiveAMPModel): + _chunk = _chunk.model + + for _, sub_module in _chunk.named_modules(): + # should be the transformer block definaton in modeling_xxx.py + if isinstance(sub_module, nn.ModuleList): + for _, module in enumerate(sub_module): + modules.append(module) + + else: + # embedding, head, etc that out of the transformer block + modules.append(sub_module) + + # register_forward_pre_hook for transformer/embeding/norm/xxx block + for sub_module in modules: + if module_has_fp32_attr(sub_module): + sub_module.to(dtype) + sub_module.register_forward_pre_hook(partial(_pre_forward_hook)) + sub_module.register_forward_hook(partial(_post_forward_hook)) diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 64ff4de..858b6f0 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -11,6 +11,7 @@ from torch import nn from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context.parallel_context import global_context as gpc +from internlm.core.naive_amp import set_fp32_attr_to_module from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal from internlm.model.embedding import Embedding1D from internlm.model.linear import ( @@ -101,6 +102,8 @@ class PackedFlashBaseLayer1D(nn.Module): else: self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) + set_fp32_attr_to_module(self.norm1) + set_fp32_attr_to_module(self.norm2) if use_swiglu: self.mlp = FeedForward( @@ -334,6 +337,7 @@ class PackedFlashInternLm1D(nn.Module): self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) else: self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) + set_fp32_attr_to_module(self.norm) self.head = head_cls( in_features=hidden_size, out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size, diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index e08d4ec..416c8f2 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -31,7 +31,7 @@ from internlm.solver.beta2_scheduler import Beta2Scheduler from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR from internlm.solver.optimizer import HybridZeroOptimizer from internlm.solver.optimizer.utils import ParamBcastSyncHandler -from internlm.utils.common import DummyProfile +from internlm.utils.common import DummyProfile, create_param_groups from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.parallel import ( @@ -109,8 +109,9 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]): param_bcast_sync_handler = None adam_cfg = gpc.config.adam + params = create_param_groups(model, adam_cfg.weight_decay) naive_optimizer = torch.optim.AdamW( - params=[{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}], + params=params, lr=adam_cfg.lr, betas=(adam_cfg.adam_beta1, adam_cfg.adam_beta2), eps=adam_cfg.adam_eps, diff --git a/internlm/utils/common.py b/internlm/utils/common.py index f3b58c0..4b9b047 100644 --- a/internlm/utils/common.py +++ b/internlm/utils/common.py @@ -7,7 +7,7 @@ import os import random from contextlib import contextmanager from datetime import datetime -from typing import Union +from typing import Dict, Tuple, Union import numpy as np import torch @@ -236,3 +236,58 @@ class DummyProfile: def step(self): pass + + +def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) -> Tuple[Dict]: + """Split parameters into different MoE groups for optimizer + Compatiable with muiltiple param groups, each should have a name + + Args: + param_groups (Tuple[Dict]): + The list of parameter groups to split + + Returns: + Tuple[Dict]: + list of MoE/non-MoE groups for optimizer + """ + if isinstance(param_groups, tuple): + param_groups = list(param_groups) # Tuple cannot be modified + elif isinstance(param_groups, dict): + param_groups = [param_groups] + elif not isinstance(param_groups, list): + raise ValueError(f"Unknown param group type of {type(param_groups)}") + + fp32_group = {} + # Create fp32 and moe groups and copy origin attribute + for param_group in param_groups: + # copy attribute for fp32 group + fp32_group["name"] = "fp32" + fp32_group["gate"] = True + for ori_key in param_group.keys(): + if ori_key != "name": + if ori_key == "params": + fp32_group[ori_key] = [] + else: + fp32_group[ori_key] = param_group[ori_key] + + # Assign param + for param_group in param_groups: + new_params = [] + for param in param_group["params"]: + if param.dtype == torch.float32: + fp32_group["params"].append(param) + else: + new_params.append(param) + # origin group without fp32 or moe parameter + param_group["params"] = new_params + + # append to origin group + param_groups.append(fp32_group) + + return tuple(param_groups) + + +def create_param_groups(model, weight_decay): + parameters = {"params": list(model.parameters()), "name": "default", "weight_decay": weight_decay} + + return split_params_into_different_groups_for_optimizer(parameters)