Feat(norm)/support fused precision (#319)

* add fused precision support for norm

* refactor code

* refactor code

* change the granularity of hook

* fix bugs if self.model is ModuleList

* add dtype condition for post hook

* refactor code for split group

* refactor code for pre/post hook

* refactor code for split group

* remove fp32 hook for norm

* unit tests for fused precision

* add doc for fused precision

* add doc for En. version

* reformat docs

* Update mixed_precision.rst

* Update mixed_precision.po

* update mixed_precision.po
pull/373/head
Wenwen Qu 2023-09-26 20:39:55 +08:00 committed by GitHub
parent 96b20cd43f
commit 655e9dae40
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 391 additions and 13 deletions

View File

@ -43,39 +43,42 @@ msgstr "Training API"
msgid "并行训练"
msgstr "Parallel Training"
#: ../../source/index.rst:51 9234725f3c464731993d73607608c874
#: ../../source/index.rst:51
msgid "混合精度"
msgstr "Mixed Precision"
#: ../../source/index.rst:59 9234725f3c464731993d73607608c874
msgid "模型备份"
msgstr "Model Checkpointing"
#: ../../source/index.rst:59 8e4ce037017f4510b2892a66003877fa
#: ../../source/index.rst:67 8e4ce037017f4510b2892a66003877fa
msgid "性能分析"
msgstr "Profiler"
#: ../../source/index.rst:67 a36e02819ecd4b448a8cb4ebbecb6600
#: ../../source/index.rst:75 a36e02819ecd4b448a8cb4ebbecb6600
msgid "训练监控"
msgstr "Monitor"
#: ../../source/index.rst:75 b912e292486f455c8b5cdd75962e8ac2
#: ../../source/index.rst:83 b912e292486f455c8b5cdd75962e8ac2
msgid "训练样例"
msgstr "Example"
#: ../../source/index.rst:83 ea9e9281720941a1830e5df7a2badf7a
#: ../../source/index.rst:91 ea9e9281720941a1830e5df7a2badf7a
msgid "常见问题"
msgstr "Q&A"
#: ../../source/index.rst:91 e08edc5aa1c74965b10084b393b88fae
#: ../../source/index.rst:99 e08edc5aa1c74965b10084b393b88fae
msgid "索引和表格"
msgstr "Indices and tables"
#: ../../source/index.rst:93 f3fdca059caa49dcad09aa44be7f02d6
#: ../../source/index.rst:101 f3fdca059caa49dcad09aa44be7f02d6
msgid ":ref:`genindex`"
msgstr ""
#: ../../source/index.rst:94 b3791e811315435097bb507edc3f4b9b
#: ../../source/index.rst:102 b3791e811315435097bb507edc3f4b9b
msgid ":ref:`modindex`"
msgstr ""
#: ../../source/index.rst:95 a164b772960f4ab8b18c7e8820f69f55
#: ../../source/index.rst:103 a164b772960f4ab8b18c7e8820f69f55
msgid ":ref:`search`"
msgstr ""

View File

@ -0,0 +1,85 @@
# SOME DESCRIPTIVE TITLE.
# Copyright (C) 2023, InternLM Team
# This file is distributed under the same license as the InternLM package.
# FIRST AUTHOR <EMAIL@ADDRESS>, 2023.
#
#, fuzzy
msgid ""
msgstr ""
"Project-Id-Version: InternLM \n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2023-09-26 17:04+0800\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language: en\n"
"Language-Team: en <LL@li.org>\n"
"Plural-Forms: nplurals=2; plural=(n != 1);\n"
"MIME-Version: 1.0\n"
"Content-Type: text/plain; charset=utf-8\n"
"Content-Transfer-Encoding: 8bit\n"
"Generated-By: Babel 2.12.1\n"
#: ../../source/mixed_precision.rst:2
msgid "混合精度"
msgstr "Mixed Precision"
#: ../../source/mixed_precision.rst:3
msgid ""
"混合精度是指在模型训练的过程中同时使用16位和32位浮点数类型是一种在最小化精度损失的前提下加速模型训练的方法。 "
"混合精度通过让模型的某些部分使用32位浮点数以保持数值稳定性并在其余部分利用半精度浮点数加速训练并可以减少内存使用在评估指标如准确率方面仍可以获得同等的训练效果。"
msgstr ""
"Mixed precision refers to using both 16-bit and 32-bit floating-point "
"types to train model, which can accelerate the model training while "
"minimizing the accuracy loss. Mixed precision training uses 32-bit "
"floating-point types in certain parts of the model to maintain numerical "
"stability, and accelerate training and reduce memory usage by using "
"16-bit floating-point types in other parts. Mixed precision can achieve "
"the same training effect in evaluating indicators such as accuracy."
#: internlm.core.naive_amp.NaiveAMPModel:1 of
msgid ""
"This is a wrapper class for a model that automatically casts the model, "
"its inputs, and outputs into fp16. It also provides options to cast the "
"output back to fp32 and to synchronize buffers."
msgstr ""
#: internlm.core.naive_amp.NaiveAMPModel of
msgid "参数"
msgstr ""
#: internlm.core.naive_amp.NaiveAMPModel:4 of
msgid "The model to be wrapped and cast into fp16."
msgstr ""
#: internlm.core.naive_amp.NaiveAMPModel:6 of
msgid "If True, the output of this module is cast into fp32. Defaults to True."
msgstr ""
#: internlm.core.naive_amp.NaiveAMPModel:8 of
msgid ""
"The parallel group mode used in this module. Defaults to "
"``ParallelMode.DATA``."
msgstr ""
#: internlm.core.naive_amp.NaiveAMPModel:11 of
msgid "If True, the buffers are synchronized. Defaults to True."
msgstr ""
#: ../../source/mixed_precision.rst:8
msgid "InternLM默认将模型转换为16位浮点数类型进行训练在配置文件中可以设置默认类型为其他数据类型。在使用混合精度时需要在构建模型时使用"
msgstr ""
"InternLM converts the model to 16-bit floating-point types for model "
"training by default (the default type can be set to other data types in "
"the configuration file). When using mixed precision, it is necessary to "
"use "
#: ../../source/mixed_precision.rst:14
msgid "将模型的某个子模块设置为32位浮点数类型进行训练InternLM会在模型训练时自动将数据类型转换成需要的精度。"
msgstr ""
"to set a sub-module of the model to 16-bit floating-point types for "
"training, and InternLM will automatically convert the data type to the "
"required precision during model training."
#: ../../source/mixed_precision.rst:16
msgid "例如:"
msgstr "For example:"

View File

@ -47,6 +47,14 @@ InternLM
parallel
混合精度
-------------------
.. toctree::
:maxdepth: 2
mixed_precision
模型备份
--------------------

View File

@ -0,0 +1,36 @@
混合精度
-----------------
混合精度是指在模型训练的过程中同时使用16位和32位浮点数类型是一种在最小化精度损失的前提下加速模型训练的方法。
混合精度通过让模型的某些部分使用32位浮点数以保持数值稳定性并在其余部分利用半精度浮点数加速训练并可以减少内存使用在评估指标如准确率方面仍可以获得同等的训练效果。
.. autoclass:: internlm.core.naive_amp.NaiveAMPModel
InternLM默认将模型转换为16位浮点数类型进行训练在配置文件中可以设置默认类型为其他数据类型。在使用混合精度时需要在构建模型时使用
.. code-block:: python
set_fp32_attr_to_module(/*fp32 module*/)
将模型的某个子模块设置为32位浮点数类型进行训练InternLM会在模型训练时自动将数据类型转换成需要的精度。
例如:
.. code-block:: python
class MlpModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(4, 1, bias=False)
self.linear2 = nn.Linear(1, 4, bias=False)
model = MlpModel()
# set model.linear2 as fp32 module
set_fp32_attr_to_module(model.linear2)
# apply mixed precision
model = NaiveAMPModel(
model=model,
output_to_fp32=True,
dtype=torch.bfloat16(),
sync_buffer=False,
)

View File

@ -133,7 +133,7 @@ ZeRO1.5 的实现使用了分层分片的概念,通过配置值 ``parallel.zer
hybrid_zero_optimizer = dict(
# Enable low_level_optimzer overlap_communication
overlap_sync_grad=True,
overlap_sync_grad=True,
overlap_sync_param=True,
# bucket size for nccl communication params
reduce_bucket_size=512 * 1024 * 1024,

View File

@ -3,7 +3,8 @@
# adopted from https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/amp
from typing import Any
from functools import partial
from typing import Any, Union
import torch
import torch.distributed as dist
@ -15,6 +16,14 @@ from internlm.core.context import ParallelMode
from internlm.core.context.parallel_context import global_context as gpc
def set_fp32_attr_to_module(module: nn.Module):
setattr(module, "is_fp32_module", True)
def module_has_fp32_attr(module: nn.Module):
return hasattr(module, "is_fp32_module") and getattr(module, "is_fp32_module")
class NaiveAMPModel(nn.Module):
"""
This is a wrapper class for a model that automatically casts the model, its inputs, and outputs into fp16.
@ -51,6 +60,9 @@ class NaiveAMPModel(nn.Module):
self._sync_buf = False
self._first_eval_run = False
# register hook for fp32 module
self._register_fp32_parameters_hook()
@property
def sync_buffer(self):
"""Returns the current state of the buffer synchronization."""
@ -134,3 +146,46 @@ class NaiveAMPModel(nn.Module):
if self._output_to_fp32:
out = self.convert_to_fp32(out)
return out
def _register_fp32_parameters_hook(self) -> None:
"""
Set module to fp32 and register automatic conversion hook in the forward pass.
The fp32 modules are marked by set_fp32_attr_to_module(.)
"""
dtype = torch.float32
def to_fp32(x, dtype=dtype):
if isinstance(x, Tensor) and x.dtype != dtype:
return x.to(dtype)
return x
def _pre_forward_hook_for_fp32(model: nn.Module, inputs: tuple): # pylint: disable=W0613
assert isinstance(inputs, tuple)
return tuple(map(to_fp32, inputs))
def _post_forward_hook_for_fp32(
model: nn.Module, inputs: tuple, outputs: Union[tuple, Tensor]
): # pylint: disable=W0613
assert isinstance(inputs, Union[tuple, Tensor])
if isinstance(outputs, tuple):
return tuple(map(to_fp32, outputs, self.dtype))
else:
return to_fp32(outputs, self.dtype)
# just want to share same for loop for ModuleList and Module
if isinstance(self.model, nn.ModuleList):
model = self.model
else:
model = [self.model]
modules = []
# record the modules to transformer/embeding/head/norm block
for _chunk in model:
modules.extend([sub_module for _, sub_module in _chunk.named_modules()])
# register_forward_pre_hook for transformer/embeding/norm/xxx block
for sub_module in modules:
if module_has_fp32_attr(sub_module):
sub_module.to(dtype)
sub_module.register_forward_pre_hook(partial(_pre_forward_hook_for_fp32))
sub_module.register_forward_hook(partial(_post_forward_hook_for_fp32))

View File

@ -31,6 +31,7 @@ from internlm.solver.beta2_scheduler import Beta2Scheduler
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
from internlm.solver.optimizer import HybridZeroOptimizer
from internlm.solver.optimizer.utils import ParamBcastSyncHandler
from internlm.train.utils import create_param_groups
from internlm.utils.common import DummyProfile
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
@ -109,8 +110,9 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
param_bcast_sync_handler = None
adam_cfg = gpc.config.adam
params = create_param_groups(model, adam_cfg.weight_decay)
naive_optimizer = torch.optim.AdamW(
params=[{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}],
params=params,
lr=adam_cfg.lr,
betas=(adam_cfg.adam_beta1, adam_cfg.adam_beta2),
eps=adam_cfg.adam_eps,

61
internlm/train/utils.py Normal file
View File

@ -0,0 +1,61 @@
from typing import Dict, Tuple
import torch
def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) -> Tuple[Dict]:
"""Split parameters into different groups for optimizer
Compatiable with muiltiple param groups, each should have a name
Args:
param_groups (Tuple[Dict]): The list of parameter groups to split
Input Example:
>>> (
>>> {'name': 'default', 'params': [tensor], 'weight_decay' :xxx},
>>> )
Returns:
Tuple[Dict]: list of params groups for optimizer
Output Example:
>>> (
>>> {'name': 'default','params': [tensor],'weight_decay' :xxx},
>>> {'name': 'default_fp32', 'params': [tensor],'weight_decay' :xxx},
>>> ...,
>>> )
Returns:
Tuple[Dict]: list of fp16/fp32 groups for optimizer
"""
if isinstance(param_groups, tuple):
param_groups = list(param_groups) # Tuple cannot be modified
elif isinstance(param_groups, dict):
param_groups = [param_groups]
elif not isinstance(param_groups, list):
raise ValueError(f"Unknown param group type of {type(param_groups)}")
fp32_group = {"name": "fp32", "params": []}
for pgroup in param_groups:
# copy attribute from origin group, we assume the input param_groups only
# have one group, so the attribute will not be copyed multiple times.
for ori_key in pgroup.keys():
if ori_key not in ("name", "params"):
fp32_group[ori_key] = pgroup[ori_key]
# Assign param
origin_params = []
for param in pgroup["params"]:
if param.dtype == torch.float32:
fp32_group["params"].append(param)
else:
origin_params.append(param)
# bf16 param group, the first group in the param_groups
pgroup["params"] = origin_params
param_groups.append(fp32_group)
return tuple(param_groups)
def create_param_groups(model, weight_decay):
parameters = {"params": list(model.parameters()), "name": "default", "weight_decay": weight_decay}
return split_params_into_different_groups_for_optimizer(parameters)

View File

@ -0,0 +1,128 @@
import multiprocessing as mp
from functools import partial
import pytest
import torch
from torch import nn
from internlm.core.naive_amp import NaiveAMPModel, set_fp32_attr_to_module
from internlm.model.modeling_internlm import PackedFlashBaseLayer1D
from internlm.train.utils import create_param_groups
from tests.test_model.test_model_internlm import build_environment, seed_all
def _pre_forward_hook_for_check(model, inputs): # pylint: disable=W0613
assert all(_.dtype == torch.float32 for _ in inputs)
def _post_forward_hook_for_check(model, inputs, outputs): # pylint: disable=W0613
if isinstance(outputs, tuple):
assert all(_.dtype == torch.half for _ in outputs)
else:
assert outputs.dtype == torch.half
def check_fused_precision(args):
# init
rank, world_size = args
device = torch.device("cuda")
build_environment(rank, world_size)
# fix seed
seed_all(1024)
# define model
model = PackedFlashBaseLayer1D(
hidden_size=16, # 768
num_attention_heads=2, # 12
mlp_ratio=2,
attn_drop_rate=0.0,
drop_rate=0.0,
dtype=torch.bfloat16,
layer_norm_epsilon=1e-5,
checkpoint=False,
layer_idx=0,
residual_in_fp32=False,
device=device,
norm_type="rmsnorm",
dropout_selective_checkpoint=True,
use_scaled_init=True,
use_swiglu=True,
)
model = model.to(device)
set_fp32_attr_to_module(model.norm1)
model = NaiveAMPModel(
model=model,
output_to_fp32=True,
dtype=torch.half,
sync_buffer=False,
)
model.model.norm1.register_forward_pre_hook(partial(_pre_forward_hook_for_check))
model.model.norm1.register_forward_hook(partial(_post_forward_hook_for_check))
hidden_states = torch.rand(1, 1, 16).to(device).requires_grad_()
# forward
model(hidden_states)
class MlpModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(4, 1, bias=False).half()
self.linear2 = nn.Linear(1, 4, bias=False).float()
def check_split_fused_group(args):
# init
rank, world_size = args
device = torch.device("cuda")
build_environment(rank, world_size)
rtol, atol = (1e-3, 5e-3)
# fix seed
seed_all(1024)
# define model
model = MlpModel().to(device)
groups = create_param_groups(model, weight_decay=0.05)
standard_group = (
{
"name": "default",
"params": [torch.Tensor([[0.3088, 0.2935, -0.2900, 0.4280]]).to(torch.float16).to(device).requires_grad_()],
"weight_decay": 0.05,
},
{
"name": "fp32",
"params": [torch.Tensor([[0.6273], [0.4844], [-0.0463], [-0.0090]]).to(device).requires_grad_()],
"weight_decay": 0.05,
},
)
# check groups params
for t1, t2 in zip(groups, standard_group):
# assert t1["name"] == t2["name"]
assert all(
torch.allclose(p1, p2, rtol=rtol, atol=atol, equal_nan=True) for p1, p2 in zip(t1["params"], t2["params"])
)
@pytest.mark.fused_precision
def test_fused_precision():
ctx = mp.get_context("spawn")
with ctx.Pool(processes=8) as pool:
pool.map(check_fused_precision, [[rank, 8] for rank in range(8)])
pool.close()
pool.join()
@pytest.mark.split_groups
def test_split_fused_groups():
ctx = mp.get_context("spawn")
with ctx.Pool(processes=8) as pool:
pool.map(check_split_fused_group, [[rank, 8] for rank in range(8)])
pool.close()
pool.join()
if __name__ == "__main__":
pytest.main(["-s", "-q", "test_norm.py"])