mirror of https://github.com/InternLM/InternLM
add fused precision support for norm
parent
ab513e1ddd
commit
98329da327
|
@ -3,7 +3,8 @@
|
|||
|
||||
# adopted from https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/amp
|
||||
|
||||
from typing import Any
|
||||
from functools import partial
|
||||
from typing import Any, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -15,6 +16,14 @@ from internlm.core.context import ParallelMode
|
|||
from internlm.core.context.parallel_context import global_context as gpc
|
||||
|
||||
|
||||
def set_fp32_attr_to_module(module: nn.Module):
|
||||
setattr(module, "is_fp32_module", True)
|
||||
|
||||
|
||||
def module_has_fp32_attr(module: nn.Module):
|
||||
return hasattr(module, "is_fp32_module") and getattr(module, "is_fp32_module")
|
||||
|
||||
|
||||
class NaiveAMPModel(nn.Module):
|
||||
"""
|
||||
This is a wrapper class for a model that automatically casts the model, its inputs, and outputs into fp16.
|
||||
|
@ -51,6 +60,9 @@ class NaiveAMPModel(nn.Module):
|
|||
self._sync_buf = False
|
||||
self._first_eval_run = False
|
||||
|
||||
# register hook for fp32 module
|
||||
self._register_fp32_parameters_hook()
|
||||
|
||||
@property
|
||||
def sync_buffer(self):
|
||||
"""Returns the current state of the buffer synchronization."""
|
||||
|
@ -134,3 +146,55 @@ class NaiveAMPModel(nn.Module):
|
|||
if self._output_to_fp32:
|
||||
out = self.convert_to_fp32(out)
|
||||
return out
|
||||
|
||||
def _register_fp32_parameters_hook(self) -> None:
|
||||
dtype = torch.float32
|
||||
|
||||
def _pre_forward_hook(model: nn.Module, inputs: tuple): # pylint: disable=W0613
|
||||
inputs_fp32 = []
|
||||
for input_data_ in inputs:
|
||||
if isinstance(input_data_, Tensor) and input_data_.dtype is not dtype:
|
||||
inputs_fp32.append(input_data_.to(dtype))
|
||||
else:
|
||||
inputs_fp32.append(input_data_)
|
||||
return tuple(inputs_fp32)
|
||||
|
||||
def _post_forward_hook(model: nn.Module, inputs: tuple, outputs: Union[tuple, Tensor]): # pylint: disable=W0613
|
||||
outputs_ = []
|
||||
assert isinstance(outputs, (Tensor, tuple))
|
||||
if isinstance(outputs, tuple):
|
||||
for output_data_ in outputs:
|
||||
if isinstance(output_data_, Tensor) and output_data_.dtype is not self.dtype:
|
||||
outputs_.append(output_data_.to(self.dtype))
|
||||
else:
|
||||
outputs_.append(output_data_)
|
||||
return tuple(outputs_)
|
||||
else:
|
||||
return outputs.to(self.dtype)
|
||||
|
||||
# just want to share same for loop for ModuleList and Module
|
||||
if not isinstance(self.model, nn.ModuleList):
|
||||
model = [self.model]
|
||||
|
||||
modules = []
|
||||
# record the modules to transformer/embeding/head/norm block
|
||||
for _chunk in model:
|
||||
if isinstance(_chunk, NaiveAMPModel):
|
||||
_chunk = _chunk.model
|
||||
|
||||
for _, sub_module in _chunk.named_modules():
|
||||
# should be the transformer block definaton in modeling_xxx.py
|
||||
if isinstance(sub_module, nn.ModuleList):
|
||||
for _, module in enumerate(sub_module):
|
||||
modules.append(module)
|
||||
|
||||
else:
|
||||
# embedding, head, etc that out of the transformer block
|
||||
modules.append(sub_module)
|
||||
|
||||
# register_forward_pre_hook for transformer/embeding/norm/xxx block
|
||||
for sub_module in modules:
|
||||
if module_has_fp32_attr(sub_module):
|
||||
sub_module.to(dtype)
|
||||
sub_module.register_forward_pre_hook(partial(_pre_forward_hook))
|
||||
sub_module.register_forward_hook(partial(_post_forward_hook))
|
||||
|
|
|
@ -11,6 +11,7 @@ from torch import nn
|
|||
|
||||
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
||||
from internlm.core.context.parallel_context import global_context as gpc
|
||||
from internlm.core.naive_amp import set_fp32_attr_to_module
|
||||
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
|
||||
from internlm.model.embedding import Embedding1D
|
||||
from internlm.model.linear import (
|
||||
|
@ -101,6 +102,8 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
else:
|
||||
self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
set_fp32_attr_to_module(self.norm1)
|
||||
set_fp32_attr_to_module(self.norm2)
|
||||
|
||||
if use_swiglu:
|
||||
self.mlp = FeedForward(
|
||||
|
@ -334,6 +337,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
else:
|
||||
self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
set_fp32_attr_to_module(self.norm)
|
||||
self.head = head_cls(
|
||||
in_features=hidden_size,
|
||||
out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,
|
||||
|
|
|
@ -31,7 +31,7 @@ from internlm.solver.beta2_scheduler import Beta2Scheduler
|
|||
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
|
||||
from internlm.solver.optimizer import HybridZeroOptimizer
|
||||
from internlm.solver.optimizer.utils import ParamBcastSyncHandler
|
||||
from internlm.utils.common import DummyProfile
|
||||
from internlm.utils.common import DummyProfile, create_param_groups
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
from internlm.utils.parallel import (
|
||||
|
@ -109,8 +109,9 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
|
|||
param_bcast_sync_handler = None
|
||||
|
||||
adam_cfg = gpc.config.adam
|
||||
params = create_param_groups(model, adam_cfg.weight_decay)
|
||||
naive_optimizer = torch.optim.AdamW(
|
||||
params=[{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}],
|
||||
params=params,
|
||||
lr=adam_cfg.lr,
|
||||
betas=(adam_cfg.adam_beta1, adam_cfg.adam_beta2),
|
||||
eps=adam_cfg.adam_eps,
|
||||
|
|
|
@ -7,7 +7,7 @@ import os
|
|||
import random
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from typing import Union
|
||||
from typing import Dict, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -236,3 +236,58 @@ class DummyProfile:
|
|||
|
||||
def step(self):
|
||||
pass
|
||||
|
||||
|
||||
def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) -> Tuple[Dict]:
|
||||
"""Split parameters into different MoE groups for optimizer
|
||||
Compatiable with muiltiple param groups, each should have a name
|
||||
|
||||
Args:
|
||||
param_groups (Tuple[Dict]):
|
||||
The list of parameter groups to split
|
||||
|
||||
Returns:
|
||||
Tuple[Dict]:
|
||||
list of MoE/non-MoE groups for optimizer
|
||||
"""
|
||||
if isinstance(param_groups, tuple):
|
||||
param_groups = list(param_groups) # Tuple cannot be modified
|
||||
elif isinstance(param_groups, dict):
|
||||
param_groups = [param_groups]
|
||||
elif not isinstance(param_groups, list):
|
||||
raise ValueError(f"Unknown param group type of {type(param_groups)}")
|
||||
|
||||
fp32_group = {}
|
||||
# Create fp32 and moe groups and copy origin attribute
|
||||
for param_group in param_groups:
|
||||
# copy attribute for fp32 group
|
||||
fp32_group["name"] = "fp32"
|
||||
fp32_group["gate"] = True
|
||||
for ori_key in param_group.keys():
|
||||
if ori_key != "name":
|
||||
if ori_key == "params":
|
||||
fp32_group[ori_key] = []
|
||||
else:
|
||||
fp32_group[ori_key] = param_group[ori_key]
|
||||
|
||||
# Assign param
|
||||
for param_group in param_groups:
|
||||
new_params = []
|
||||
for param in param_group["params"]:
|
||||
if param.dtype == torch.float32:
|
||||
fp32_group["params"].append(param)
|
||||
else:
|
||||
new_params.append(param)
|
||||
# origin group without fp32 or moe parameter
|
||||
param_group["params"] = new_params
|
||||
|
||||
# append to origin group
|
||||
param_groups.append(fp32_group)
|
||||
|
||||
return tuple(param_groups)
|
||||
|
||||
|
||||
def create_param_groups(model, weight_decay):
|
||||
parameters = {"params": list(model.parameters()), "name": "default", "weight_decay": weight_decay}
|
||||
|
||||
return split_params_into_different_groups_for_optimizer(parameters)
|
||||
|
|
Loading…
Reference in New Issue