mirror of https://github.com/InternLM/InternLM
Feat(norm)/support fused precision (#319)
* add fused precision support for norm * refactor code * refactor code * change the granularity of hook * fix bugs if self.model is ModuleList * add dtype condition for post hook * refactor code for split group * refactor code for pre/post hook * refactor code for split group * remove fp32 hook for norm * unit tests for fused precision * add doc for fused precision * add doc for En. version * reformat docs * Update mixed_precision.rst * Update mixed_precision.po * update mixed_precision.popull/373/head
parent
96b20cd43f
commit
655e9dae40
|
@ -43,39 +43,42 @@ msgstr "Training API"
|
|||
msgid "并行训练"
|
||||
msgstr "Parallel Training"
|
||||
|
||||
#: ../../source/index.rst:51 9234725f3c464731993d73607608c874
|
||||
#: ../../source/index.rst:51
|
||||
msgid "混合精度"
|
||||
msgstr "Mixed Precision"
|
||||
|
||||
#: ../../source/index.rst:59 9234725f3c464731993d73607608c874
|
||||
msgid "模型备份"
|
||||
msgstr "Model Checkpointing"
|
||||
|
||||
#: ../../source/index.rst:59 8e4ce037017f4510b2892a66003877fa
|
||||
#: ../../source/index.rst:67 8e4ce037017f4510b2892a66003877fa
|
||||
msgid "性能分析"
|
||||
msgstr "Profiler"
|
||||
|
||||
#: ../../source/index.rst:67 a36e02819ecd4b448a8cb4ebbecb6600
|
||||
#: ../../source/index.rst:75 a36e02819ecd4b448a8cb4ebbecb6600
|
||||
msgid "训练监控"
|
||||
msgstr "Monitor"
|
||||
|
||||
#: ../../source/index.rst:75 b912e292486f455c8b5cdd75962e8ac2
|
||||
#: ../../source/index.rst:83 b912e292486f455c8b5cdd75962e8ac2
|
||||
msgid "训练样例"
|
||||
msgstr "Example"
|
||||
|
||||
#: ../../source/index.rst:83 ea9e9281720941a1830e5df7a2badf7a
|
||||
#: ../../source/index.rst:91 ea9e9281720941a1830e5df7a2badf7a
|
||||
msgid "常见问题"
|
||||
msgstr "Q&A"
|
||||
|
||||
#: ../../source/index.rst:91 e08edc5aa1c74965b10084b393b88fae
|
||||
#: ../../source/index.rst:99 e08edc5aa1c74965b10084b393b88fae
|
||||
msgid "索引和表格"
|
||||
msgstr "Indices and tables"
|
||||
|
||||
#: ../../source/index.rst:93 f3fdca059caa49dcad09aa44be7f02d6
|
||||
#: ../../source/index.rst:101 f3fdca059caa49dcad09aa44be7f02d6
|
||||
msgid ":ref:`genindex`"
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/index.rst:94 b3791e811315435097bb507edc3f4b9b
|
||||
#: ../../source/index.rst:102 b3791e811315435097bb507edc3f4b9b
|
||||
msgid ":ref:`modindex`"
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/index.rst:95 a164b772960f4ab8b18c7e8820f69f55
|
||||
#: ../../source/index.rst:103 a164b772960f4ab8b18c7e8820f69f55
|
||||
msgid ":ref:`search`"
|
||||
msgstr ""
|
||||
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
# SOME DESCRIPTIVE TITLE.
|
||||
# Copyright (C) 2023, InternLM Team
|
||||
# This file is distributed under the same license as the InternLM package.
|
||||
# FIRST AUTHOR <EMAIL@ADDRESS>, 2023.
|
||||
#
|
||||
#, fuzzy
|
||||
msgid ""
|
||||
msgstr ""
|
||||
"Project-Id-Version: InternLM \n"
|
||||
"Report-Msgid-Bugs-To: \n"
|
||||
"POT-Creation-Date: 2023-09-26 17:04+0800\n"
|
||||
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
|
||||
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
|
||||
"Language: en\n"
|
||||
"Language-Team: en <LL@li.org>\n"
|
||||
"Plural-Forms: nplurals=2; plural=(n != 1);\n"
|
||||
"MIME-Version: 1.0\n"
|
||||
"Content-Type: text/plain; charset=utf-8\n"
|
||||
"Content-Transfer-Encoding: 8bit\n"
|
||||
"Generated-By: Babel 2.12.1\n"
|
||||
|
||||
#: ../../source/mixed_precision.rst:2
|
||||
msgid "混合精度"
|
||||
msgstr "Mixed Precision"
|
||||
|
||||
#: ../../source/mixed_precision.rst:3
|
||||
msgid ""
|
||||
"混合精度是指在模型训练的过程中同时使用16位和32位浮点数类型,是一种在最小化精度损失的前提下加速模型训练的方法。 "
|
||||
"混合精度通过让模型的某些部分使用32位浮点数以保持数值稳定性,并在其余部分利用半精度浮点数加速训练并可以减少内存使用,在评估指标(如准确率)方面仍可以获得同等的训练效果。"
|
||||
msgstr ""
|
||||
"Mixed precision refers to using both 16-bit and 32-bit floating-point "
|
||||
"types to train model, which can accelerate the model training while "
|
||||
"minimizing the accuracy loss. Mixed precision training uses 32-bit "
|
||||
"floating-point types in certain parts of the model to maintain numerical "
|
||||
"stability, and accelerate training and reduce memory usage by using "
|
||||
"16-bit floating-point types in other parts. Mixed precision can achieve "
|
||||
"the same training effect in evaluating indicators such as accuracy."
|
||||
|
||||
#: internlm.core.naive_amp.NaiveAMPModel:1 of
|
||||
msgid ""
|
||||
"This is a wrapper class for a model that automatically casts the model, "
|
||||
"its inputs, and outputs into fp16. It also provides options to cast the "
|
||||
"output back to fp32 and to synchronize buffers."
|
||||
msgstr ""
|
||||
|
||||
#: internlm.core.naive_amp.NaiveAMPModel of
|
||||
msgid "参数"
|
||||
msgstr ""
|
||||
|
||||
#: internlm.core.naive_amp.NaiveAMPModel:4 of
|
||||
msgid "The model to be wrapped and cast into fp16."
|
||||
msgstr ""
|
||||
|
||||
#: internlm.core.naive_amp.NaiveAMPModel:6 of
|
||||
msgid "If True, the output of this module is cast into fp32. Defaults to True."
|
||||
msgstr ""
|
||||
|
||||
#: internlm.core.naive_amp.NaiveAMPModel:8 of
|
||||
msgid ""
|
||||
"The parallel group mode used in this module. Defaults to "
|
||||
"``ParallelMode.DATA``."
|
||||
msgstr ""
|
||||
|
||||
#: internlm.core.naive_amp.NaiveAMPModel:11 of
|
||||
msgid "If True, the buffers are synchronized. Defaults to True."
|
||||
msgstr ""
|
||||
|
||||
#: ../../source/mixed_precision.rst:8
|
||||
msgid "InternLM默认将模型转换为16位浮点数类型进行训练(在配置文件中可以设置默认类型为其他数据类型)。在使用混合精度时,需要在构建模型时使用"
|
||||
msgstr ""
|
||||
"InternLM converts the model to 16-bit floating-point types for model "
|
||||
"training by default (the default type can be set to other data types in "
|
||||
"the configuration file). When using mixed precision, it is necessary to "
|
||||
"use "
|
||||
|
||||
#: ../../source/mixed_precision.rst:14
|
||||
msgid "将模型的某个子模块设置为32位浮点数类型进行训练,InternLM会在模型训练时自动将数据类型转换成需要的精度。"
|
||||
msgstr ""
|
||||
"to set a sub-module of the model to 16-bit floating-point types for "
|
||||
"training, and InternLM will automatically convert the data type to the "
|
||||
"required precision during model training."
|
||||
|
||||
#: ../../source/mixed_precision.rst:16
|
||||
msgid "例如:"
|
||||
msgstr "For example:"
|
|
@ -47,6 +47,14 @@ InternLM
|
|||
|
||||
parallel
|
||||
|
||||
混合精度
|
||||
-------------------
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
|
||||
mixed_precision
|
||||
|
||||
模型备份
|
||||
--------------------
|
||||
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
混合精度
|
||||
-----------------
|
||||
混合精度是指在模型训练的过程中同时使用16位和32位浮点数类型,是一种在最小化精度损失的前提下加速模型训练的方法。
|
||||
混合精度通过让模型的某些部分使用32位浮点数以保持数值稳定性,并在其余部分利用半精度浮点数加速训练并可以减少内存使用,在评估指标(如准确率)方面仍可以获得同等的训练效果。
|
||||
|
||||
.. autoclass:: internlm.core.naive_amp.NaiveAMPModel
|
||||
|
||||
InternLM默认将模型转换为16位浮点数类型进行训练(在配置文件中可以设置默认类型为其他数据类型)。在使用混合精度时,需要在构建模型时使用
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
set_fp32_attr_to_module(/*fp32 module*/)
|
||||
|
||||
将模型的某个子模块设置为32位浮点数类型进行训练,InternLM会在模型训练时自动将数据类型转换成需要的精度。
|
||||
|
||||
例如:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MlpModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(4, 1, bias=False)
|
||||
self.linear2 = nn.Linear(1, 4, bias=False)
|
||||
|
||||
model = MlpModel()
|
||||
# set model.linear2 as fp32 module
|
||||
set_fp32_attr_to_module(model.linear2)
|
||||
|
||||
# apply mixed precision
|
||||
model = NaiveAMPModel(
|
||||
model=model,
|
||||
output_to_fp32=True,
|
||||
dtype=torch.bfloat16(),
|
||||
sync_buffer=False,
|
||||
)
|
|
@ -133,7 +133,7 @@ ZeRO1.5 的实现使用了分层分片的概念,通过配置值 ``parallel.zer
|
|||
|
||||
hybrid_zero_optimizer = dict(
|
||||
# Enable low_level_optimzer overlap_communication
|
||||
overlap_sync_grad=True,
|
||||
overlap_sync_grad=True,
|
||||
overlap_sync_param=True,
|
||||
# bucket size for nccl communication params
|
||||
reduce_bucket_size=512 * 1024 * 1024,
|
||||
|
|
|
@ -3,7 +3,8 @@
|
|||
|
||||
# adopted from https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/amp
|
||||
|
||||
from typing import Any
|
||||
from functools import partial
|
||||
from typing import Any, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -15,6 +16,14 @@ from internlm.core.context import ParallelMode
|
|||
from internlm.core.context.parallel_context import global_context as gpc
|
||||
|
||||
|
||||
def set_fp32_attr_to_module(module: nn.Module):
|
||||
setattr(module, "is_fp32_module", True)
|
||||
|
||||
|
||||
def module_has_fp32_attr(module: nn.Module):
|
||||
return hasattr(module, "is_fp32_module") and getattr(module, "is_fp32_module")
|
||||
|
||||
|
||||
class NaiveAMPModel(nn.Module):
|
||||
"""
|
||||
This is a wrapper class for a model that automatically casts the model, its inputs, and outputs into fp16.
|
||||
|
@ -51,6 +60,9 @@ class NaiveAMPModel(nn.Module):
|
|||
self._sync_buf = False
|
||||
self._first_eval_run = False
|
||||
|
||||
# register hook for fp32 module
|
||||
self._register_fp32_parameters_hook()
|
||||
|
||||
@property
|
||||
def sync_buffer(self):
|
||||
"""Returns the current state of the buffer synchronization."""
|
||||
|
@ -134,3 +146,46 @@ class NaiveAMPModel(nn.Module):
|
|||
if self._output_to_fp32:
|
||||
out = self.convert_to_fp32(out)
|
||||
return out
|
||||
|
||||
def _register_fp32_parameters_hook(self) -> None:
|
||||
"""
|
||||
Set module to fp32 and register automatic conversion hook in the forward pass.
|
||||
The fp32 modules are marked by set_fp32_attr_to_module(.)
|
||||
"""
|
||||
dtype = torch.float32
|
||||
|
||||
def to_fp32(x, dtype=dtype):
|
||||
if isinstance(x, Tensor) and x.dtype != dtype:
|
||||
return x.to(dtype)
|
||||
return x
|
||||
|
||||
def _pre_forward_hook_for_fp32(model: nn.Module, inputs: tuple): # pylint: disable=W0613
|
||||
assert isinstance(inputs, tuple)
|
||||
return tuple(map(to_fp32, inputs))
|
||||
|
||||
def _post_forward_hook_for_fp32(
|
||||
model: nn.Module, inputs: tuple, outputs: Union[tuple, Tensor]
|
||||
): # pylint: disable=W0613
|
||||
assert isinstance(inputs, Union[tuple, Tensor])
|
||||
if isinstance(outputs, tuple):
|
||||
return tuple(map(to_fp32, outputs, self.dtype))
|
||||
else:
|
||||
return to_fp32(outputs, self.dtype)
|
||||
|
||||
# just want to share same for loop for ModuleList and Module
|
||||
if isinstance(self.model, nn.ModuleList):
|
||||
model = self.model
|
||||
else:
|
||||
model = [self.model]
|
||||
|
||||
modules = []
|
||||
# record the modules to transformer/embeding/head/norm block
|
||||
for _chunk in model:
|
||||
modules.extend([sub_module for _, sub_module in _chunk.named_modules()])
|
||||
|
||||
# register_forward_pre_hook for transformer/embeding/norm/xxx block
|
||||
for sub_module in modules:
|
||||
if module_has_fp32_attr(sub_module):
|
||||
sub_module.to(dtype)
|
||||
sub_module.register_forward_pre_hook(partial(_pre_forward_hook_for_fp32))
|
||||
sub_module.register_forward_hook(partial(_post_forward_hook_for_fp32))
|
||||
|
|
|
@ -31,6 +31,7 @@ from internlm.solver.beta2_scheduler import Beta2Scheduler
|
|||
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
|
||||
from internlm.solver.optimizer import HybridZeroOptimizer
|
||||
from internlm.solver.optimizer.utils import ParamBcastSyncHandler
|
||||
from internlm.train.utils import create_param_groups
|
||||
from internlm.utils.common import DummyProfile
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
|
@ -109,8 +110,9 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
|
|||
param_bcast_sync_handler = None
|
||||
|
||||
adam_cfg = gpc.config.adam
|
||||
params = create_param_groups(model, adam_cfg.weight_decay)
|
||||
naive_optimizer = torch.optim.AdamW(
|
||||
params=[{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}],
|
||||
params=params,
|
||||
lr=adam_cfg.lr,
|
||||
betas=(adam_cfg.adam_beta1, adam_cfg.adam_beta2),
|
||||
eps=adam_cfg.adam_eps,
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) -> Tuple[Dict]:
|
||||
"""Split parameters into different groups for optimizer
|
||||
Compatiable with muiltiple param groups, each should have a name
|
||||
|
||||
Args:
|
||||
param_groups (Tuple[Dict]): The list of parameter groups to split
|
||||
Input Example:
|
||||
>>> (
|
||||
>>> {'name': 'default', 'params': [tensor], 'weight_decay' :xxx},
|
||||
>>> )
|
||||
|
||||
Returns:
|
||||
Tuple[Dict]: list of params groups for optimizer
|
||||
Output Example:
|
||||
>>> (
|
||||
>>> {'name': 'default','params': [tensor],'weight_decay' :xxx},
|
||||
>>> {'name': 'default_fp32', 'params': [tensor],'weight_decay' :xxx},
|
||||
>>> ...,
|
||||
>>> )
|
||||
|
||||
Returns:
|
||||
Tuple[Dict]: list of fp16/fp32 groups for optimizer
|
||||
"""
|
||||
if isinstance(param_groups, tuple):
|
||||
param_groups = list(param_groups) # Tuple cannot be modified
|
||||
elif isinstance(param_groups, dict):
|
||||
param_groups = [param_groups]
|
||||
elif not isinstance(param_groups, list):
|
||||
raise ValueError(f"Unknown param group type of {type(param_groups)}")
|
||||
|
||||
fp32_group = {"name": "fp32", "params": []}
|
||||
for pgroup in param_groups:
|
||||
# copy attribute from origin group, we assume the input param_groups only
|
||||
# have one group, so the attribute will not be copyed multiple times.
|
||||
for ori_key in pgroup.keys():
|
||||
if ori_key not in ("name", "params"):
|
||||
fp32_group[ori_key] = pgroup[ori_key]
|
||||
# Assign param
|
||||
origin_params = []
|
||||
for param in pgroup["params"]:
|
||||
if param.dtype == torch.float32:
|
||||
fp32_group["params"].append(param)
|
||||
else:
|
||||
origin_params.append(param)
|
||||
# bf16 param group, the first group in the param_groups
|
||||
pgroup["params"] = origin_params
|
||||
|
||||
param_groups.append(fp32_group)
|
||||
|
||||
return tuple(param_groups)
|
||||
|
||||
|
||||
def create_param_groups(model, weight_decay):
|
||||
parameters = {"params": list(model.parameters()), "name": "default", "weight_decay": weight_decay}
|
||||
|
||||
return split_params_into_different_groups_for_optimizer(parameters)
|
|
@ -0,0 +1,128 @@
|
|||
import multiprocessing as mp
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from internlm.core.naive_amp import NaiveAMPModel, set_fp32_attr_to_module
|
||||
from internlm.model.modeling_internlm import PackedFlashBaseLayer1D
|
||||
from internlm.train.utils import create_param_groups
|
||||
from tests.test_model.test_model_internlm import build_environment, seed_all
|
||||
|
||||
|
||||
def _pre_forward_hook_for_check(model, inputs): # pylint: disable=W0613
|
||||
assert all(_.dtype == torch.float32 for _ in inputs)
|
||||
|
||||
|
||||
def _post_forward_hook_for_check(model, inputs, outputs): # pylint: disable=W0613
|
||||
if isinstance(outputs, tuple):
|
||||
assert all(_.dtype == torch.half for _ in outputs)
|
||||
else:
|
||||
assert outputs.dtype == torch.half
|
||||
|
||||
|
||||
def check_fused_precision(args):
|
||||
# init
|
||||
rank, world_size = args
|
||||
device = torch.device("cuda")
|
||||
build_environment(rank, world_size)
|
||||
|
||||
# fix seed
|
||||
seed_all(1024)
|
||||
# define model
|
||||
model = PackedFlashBaseLayer1D(
|
||||
hidden_size=16, # 768
|
||||
num_attention_heads=2, # 12
|
||||
mlp_ratio=2,
|
||||
attn_drop_rate=0.0,
|
||||
drop_rate=0.0,
|
||||
dtype=torch.bfloat16,
|
||||
layer_norm_epsilon=1e-5,
|
||||
checkpoint=False,
|
||||
layer_idx=0,
|
||||
residual_in_fp32=False,
|
||||
device=device,
|
||||
norm_type="rmsnorm",
|
||||
dropout_selective_checkpoint=True,
|
||||
use_scaled_init=True,
|
||||
use_swiglu=True,
|
||||
)
|
||||
model = model.to(device)
|
||||
set_fp32_attr_to_module(model.norm1)
|
||||
model = NaiveAMPModel(
|
||||
model=model,
|
||||
output_to_fp32=True,
|
||||
dtype=torch.half,
|
||||
sync_buffer=False,
|
||||
)
|
||||
model.model.norm1.register_forward_pre_hook(partial(_pre_forward_hook_for_check))
|
||||
model.model.norm1.register_forward_hook(partial(_post_forward_hook_for_check))
|
||||
|
||||
hidden_states = torch.rand(1, 1, 16).to(device).requires_grad_()
|
||||
|
||||
# forward
|
||||
model(hidden_states)
|
||||
|
||||
|
||||
class MlpModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(4, 1, bias=False).half()
|
||||
self.linear2 = nn.Linear(1, 4, bias=False).float()
|
||||
|
||||
|
||||
def check_split_fused_group(args):
|
||||
# init
|
||||
rank, world_size = args
|
||||
device = torch.device("cuda")
|
||||
build_environment(rank, world_size)
|
||||
rtol, atol = (1e-3, 5e-3)
|
||||
|
||||
# fix seed
|
||||
seed_all(1024)
|
||||
# define model
|
||||
model = MlpModel().to(device)
|
||||
groups = create_param_groups(model, weight_decay=0.05)
|
||||
|
||||
standard_group = (
|
||||
{
|
||||
"name": "default",
|
||||
"params": [torch.Tensor([[0.3088, 0.2935, -0.2900, 0.4280]]).to(torch.float16).to(device).requires_grad_()],
|
||||
"weight_decay": 0.05,
|
||||
},
|
||||
{
|
||||
"name": "fp32",
|
||||
"params": [torch.Tensor([[0.6273], [0.4844], [-0.0463], [-0.0090]]).to(device).requires_grad_()],
|
||||
"weight_decay": 0.05,
|
||||
},
|
||||
)
|
||||
|
||||
# check groups params
|
||||
for t1, t2 in zip(groups, standard_group):
|
||||
# assert t1["name"] == t2["name"]
|
||||
assert all(
|
||||
torch.allclose(p1, p2, rtol=rtol, atol=atol, equal_nan=True) for p1, p2 in zip(t1["params"], t2["params"])
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.fused_precision
|
||||
def test_fused_precision():
|
||||
ctx = mp.get_context("spawn")
|
||||
with ctx.Pool(processes=8) as pool:
|
||||
pool.map(check_fused_precision, [[rank, 8] for rank in range(8)])
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
|
||||
@pytest.mark.split_groups
|
||||
def test_split_fused_groups():
|
||||
ctx = mp.get_context("spawn")
|
||||
with ctx.Pool(processes=8) as pool:
|
||||
pool.map(check_split_fused_group, [[rank, 8] for rank in range(8)])
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main(["-s", "-q", "test_norm.py"])
|
Loading…
Reference in New Issue