add fused precision support for norm

pull/319/head
Qu Wenwen 2023-09-18 19:02:07 +08:00
parent ab513e1ddd
commit 98329da327
4 changed files with 128 additions and 4 deletions

View File

@ -3,7 +3,8 @@
# adopted from https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/amp # adopted from https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/amp
from typing import Any from functools import partial
from typing import Any, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -15,6 +16,14 @@ from internlm.core.context import ParallelMode
from internlm.core.context.parallel_context import global_context as gpc from internlm.core.context.parallel_context import global_context as gpc
def set_fp32_attr_to_module(module: nn.Module):
setattr(module, "is_fp32_module", True)
def module_has_fp32_attr(module: nn.Module):
return hasattr(module, "is_fp32_module") and getattr(module, "is_fp32_module")
class NaiveAMPModel(nn.Module): class NaiveAMPModel(nn.Module):
""" """
This is a wrapper class for a model that automatically casts the model, its inputs, and outputs into fp16. This is a wrapper class for a model that automatically casts the model, its inputs, and outputs into fp16.
@ -51,6 +60,9 @@ class NaiveAMPModel(nn.Module):
self._sync_buf = False self._sync_buf = False
self._first_eval_run = False self._first_eval_run = False
# register hook for fp32 module
self._register_fp32_parameters_hook()
@property @property
def sync_buffer(self): def sync_buffer(self):
"""Returns the current state of the buffer synchronization.""" """Returns the current state of the buffer synchronization."""
@ -134,3 +146,55 @@ class NaiveAMPModel(nn.Module):
if self._output_to_fp32: if self._output_to_fp32:
out = self.convert_to_fp32(out) out = self.convert_to_fp32(out)
return out return out
def _register_fp32_parameters_hook(self) -> None:
dtype = torch.float32
def _pre_forward_hook(model: nn.Module, inputs: tuple): # pylint: disable=W0613
inputs_fp32 = []
for input_data_ in inputs:
if isinstance(input_data_, Tensor) and input_data_.dtype is not dtype:
inputs_fp32.append(input_data_.to(dtype))
else:
inputs_fp32.append(input_data_)
return tuple(inputs_fp32)
def _post_forward_hook(model: nn.Module, inputs: tuple, outputs: Union[tuple, Tensor]): # pylint: disable=W0613
outputs_ = []
assert isinstance(outputs, (Tensor, tuple))
if isinstance(outputs, tuple):
for output_data_ in outputs:
if isinstance(output_data_, Tensor) and output_data_.dtype is not self.dtype:
outputs_.append(output_data_.to(self.dtype))
else:
outputs_.append(output_data_)
return tuple(outputs_)
else:
return outputs.to(self.dtype)
# just want to share same for loop for ModuleList and Module
if not isinstance(self.model, nn.ModuleList):
model = [self.model]
modules = []
# record the modules to transformer/embeding/head/norm block
for _chunk in model:
if isinstance(_chunk, NaiveAMPModel):
_chunk = _chunk.model
for _, sub_module in _chunk.named_modules():
# should be the transformer block definaton in modeling_xxx.py
if isinstance(sub_module, nn.ModuleList):
for _, module in enumerate(sub_module):
modules.append(module)
else:
# embedding, head, etc that out of the transformer block
modules.append(sub_module)
# register_forward_pre_hook for transformer/embeding/norm/xxx block
for sub_module in modules:
if module_has_fp32_attr(sub_module):
sub_module.to(dtype)
sub_module.register_forward_pre_hook(partial(_pre_forward_hook))
sub_module.register_forward_hook(partial(_post_forward_hook))

View File

@ -11,6 +11,7 @@ from torch import nn
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context.parallel_context import global_context as gpc from internlm.core.context.parallel_context import global_context as gpc
from internlm.core.naive_amp import set_fp32_attr_to_module
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
from internlm.model.embedding import Embedding1D from internlm.model.embedding import Embedding1D
from internlm.model.linear import ( from internlm.model.linear import (
@ -101,6 +102,8 @@ class PackedFlashBaseLayer1D(nn.Module):
else: else:
self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
set_fp32_attr_to_module(self.norm1)
set_fp32_attr_to_module(self.norm2)
if use_swiglu: if use_swiglu:
self.mlp = FeedForward( self.mlp = FeedForward(
@ -334,6 +337,7 @@ class PackedFlashInternLm1D(nn.Module):
self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon)
else: else:
self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
set_fp32_attr_to_module(self.norm)
self.head = head_cls( self.head = head_cls(
in_features=hidden_size, in_features=hidden_size,
out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size, out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,

View File

@ -31,7 +31,7 @@ from internlm.solver.beta2_scheduler import Beta2Scheduler
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
from internlm.solver.optimizer import HybridZeroOptimizer from internlm.solver.optimizer import HybridZeroOptimizer
from internlm.solver.optimizer.utils import ParamBcastSyncHandler from internlm.solver.optimizer.utils import ParamBcastSyncHandler
from internlm.utils.common import DummyProfile from internlm.utils.common import DummyProfile, create_param_groups
from internlm.utils.logger import get_logger from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.parallel import ( from internlm.utils.parallel import (
@ -109,8 +109,9 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
param_bcast_sync_handler = None param_bcast_sync_handler = None
adam_cfg = gpc.config.adam adam_cfg = gpc.config.adam
params = create_param_groups(model, adam_cfg.weight_decay)
naive_optimizer = torch.optim.AdamW( naive_optimizer = torch.optim.AdamW(
params=[{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}], params=params,
lr=adam_cfg.lr, lr=adam_cfg.lr,
betas=(adam_cfg.adam_beta1, adam_cfg.adam_beta2), betas=(adam_cfg.adam_beta1, adam_cfg.adam_beta2),
eps=adam_cfg.adam_eps, eps=adam_cfg.adam_eps,

View File

@ -7,7 +7,7 @@ import os
import random import random
from contextlib import contextmanager from contextlib import contextmanager
from datetime import datetime from datetime import datetime
from typing import Union from typing import Dict, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@ -236,3 +236,58 @@ class DummyProfile:
def step(self): def step(self):
pass pass
def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) -> Tuple[Dict]:
"""Split parameters into different MoE groups for optimizer
Compatiable with muiltiple param groups, each should have a name
Args:
param_groups (Tuple[Dict]):
The list of parameter groups to split
Returns:
Tuple[Dict]:
list of MoE/non-MoE groups for optimizer
"""
if isinstance(param_groups, tuple):
param_groups = list(param_groups) # Tuple cannot be modified
elif isinstance(param_groups, dict):
param_groups = [param_groups]
elif not isinstance(param_groups, list):
raise ValueError(f"Unknown param group type of {type(param_groups)}")
fp32_group = {}
# Create fp32 and moe groups and copy origin attribute
for param_group in param_groups:
# copy attribute for fp32 group
fp32_group["name"] = "fp32"
fp32_group["gate"] = True
for ori_key in param_group.keys():
if ori_key != "name":
if ori_key == "params":
fp32_group[ori_key] = []
else:
fp32_group[ori_key] = param_group[ori_key]
# Assign param
for param_group in param_groups:
new_params = []
for param in param_group["params"]:
if param.dtype == torch.float32:
fp32_group["params"].append(param)
else:
new_params.append(param)
# origin group without fp32 or moe parameter
param_group["params"] = new_params
# append to origin group
param_groups.append(fp32_group)
return tuple(param_groups)
def create_param_groups(model, weight_decay):
parameters = {"params": list(model.parameters()), "name": "default", "weight_decay": weight_decay}
return split_params_into_different_groups_for_optimizer(parameters)