From 655e9dae40ba20df20db61bfb89cdae742f1a5f8 Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Tue, 26 Sep 2023 20:39:55 +0800 Subject: [PATCH] Feat(norm)/support fused precision (#319) * add fused precision support for norm * refactor code * refactor code * change the granularity of hook * fix bugs if self.model is ModuleList * add dtype condition for post hook * refactor code for split group * refactor code for pre/post hook * refactor code for split group * remove fp32 hook for norm * unit tests for fused precision * add doc for fused precision * add doc for En. version * reformat docs * Update mixed_precision.rst * Update mixed_precision.po * update mixed_precision.po --- doc/code-docs/locales/en/LC_MESSAGES/index.po | 23 ++-- .../locales/en/LC_MESSAGES/mixed_precision.po | 85 ++++++++++++ doc/code-docs/source/index.rst | 8 ++ doc/code-docs/source/mixed_precision.rst | 36 +++++ doc/code-docs/source/parallel.rst | 2 +- internlm/core/naive_amp.py | 57 +++++++- internlm/train/training_internlm.py | 4 +- internlm/train/utils.py | 61 +++++++++ .../test_fused_precision.py | 128 ++++++++++++++++++ 9 files changed, 391 insertions(+), 13 deletions(-) create mode 100644 doc/code-docs/locales/en/LC_MESSAGES/mixed_precision.po create mode 100644 doc/code-docs/source/mixed_precision.rst create mode 100644 internlm/train/utils.py create mode 100644 tests/test_model/test_fused_precision/test_fused_precision.py diff --git a/doc/code-docs/locales/en/LC_MESSAGES/index.po b/doc/code-docs/locales/en/LC_MESSAGES/index.po index 25645c6..7d0c4ec 100644 --- a/doc/code-docs/locales/en/LC_MESSAGES/index.po +++ b/doc/code-docs/locales/en/LC_MESSAGES/index.po @@ -43,39 +43,42 @@ msgstr "Training API" msgid "并行训练" msgstr "Parallel Training" -#: ../../source/index.rst:51 9234725f3c464731993d73607608c874 +#: ../../source/index.rst:51 +msgid "混合精度" +msgstr "Mixed Precision" + +#: ../../source/index.rst:59 9234725f3c464731993d73607608c874 msgid "模型备份" msgstr "Model Checkpointing" -#: ../../source/index.rst:59 8e4ce037017f4510b2892a66003877fa +#: ../../source/index.rst:67 8e4ce037017f4510b2892a66003877fa msgid "性能分析" msgstr "Profiler" -#: ../../source/index.rst:67 a36e02819ecd4b448a8cb4ebbecb6600 +#: ../../source/index.rst:75 a36e02819ecd4b448a8cb4ebbecb6600 msgid "训练监控" msgstr "Monitor" -#: ../../source/index.rst:75 b912e292486f455c8b5cdd75962e8ac2 +#: ../../source/index.rst:83 b912e292486f455c8b5cdd75962e8ac2 msgid "训练样例" msgstr "Example" -#: ../../source/index.rst:83 ea9e9281720941a1830e5df7a2badf7a +#: ../../source/index.rst:91 ea9e9281720941a1830e5df7a2badf7a msgid "常见问题" msgstr "Q&A" -#: ../../source/index.rst:91 e08edc5aa1c74965b10084b393b88fae +#: ../../source/index.rst:99 e08edc5aa1c74965b10084b393b88fae msgid "索引和表格" msgstr "Indices and tables" -#: ../../source/index.rst:93 f3fdca059caa49dcad09aa44be7f02d6 +#: ../../source/index.rst:101 f3fdca059caa49dcad09aa44be7f02d6 msgid ":ref:`genindex`" msgstr "" -#: ../../source/index.rst:94 b3791e811315435097bb507edc3f4b9b +#: ../../source/index.rst:102 b3791e811315435097bb507edc3f4b9b msgid ":ref:`modindex`" msgstr "" -#: ../../source/index.rst:95 a164b772960f4ab8b18c7e8820f69f55 +#: ../../source/index.rst:103 a164b772960f4ab8b18c7e8820f69f55 msgid ":ref:`search`" msgstr "" - diff --git a/doc/code-docs/locales/en/LC_MESSAGES/mixed_precision.po b/doc/code-docs/locales/en/LC_MESSAGES/mixed_precision.po new file mode 100644 index 0000000..2520d1c --- /dev/null +++ b/doc/code-docs/locales/en/LC_MESSAGES/mixed_precision.po @@ -0,0 +1,85 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) 2023, InternLM Team +# This file is distributed under the same license as the InternLM package. +# FIRST AUTHOR , 2023. +# +#, fuzzy +msgid "" +msgstr "" +"Project-Id-Version: InternLM \n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2023-09-26 17:04+0800\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language: en\n" +"Language-Team: en \n" +"Plural-Forms: nplurals=2; plural=(n != 1);\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=utf-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Generated-By: Babel 2.12.1\n" + +#: ../../source/mixed_precision.rst:2 +msgid "混合精度" +msgstr "Mixed Precision" + +#: ../../source/mixed_precision.rst:3 +msgid "" +"混合精度是指在模型训练的过程中同时使用16位和32位浮点数类型,是一种在最小化精度损失的前提下加速模型训练的方法。 " +"混合精度通过让模型的某些部分使用32位浮点数以保持数值稳定性,并在其余部分利用半精度浮点数加速训练并可以减少内存使用,在评估指标(如准确率)方面仍可以获得同等的训练效果。" +msgstr "" +"Mixed precision refers to using both 16-bit and 32-bit floating-point " +"types to train model, which can accelerate the model training while " +"minimizing the accuracy loss. Mixed precision training uses 32-bit " +"floating-point types in certain parts of the model to maintain numerical " +"stability, and accelerate training and reduce memory usage by using " +"16-bit floating-point types in other parts. Mixed precision can achieve " +"the same training effect in evaluating indicators such as accuracy." + +#: internlm.core.naive_amp.NaiveAMPModel:1 of +msgid "" +"This is a wrapper class for a model that automatically casts the model, " +"its inputs, and outputs into fp16. It also provides options to cast the " +"output back to fp32 and to synchronize buffers." +msgstr "" + +#: internlm.core.naive_amp.NaiveAMPModel of +msgid "参数" +msgstr "" + +#: internlm.core.naive_amp.NaiveAMPModel:4 of +msgid "The model to be wrapped and cast into fp16." +msgstr "" + +#: internlm.core.naive_amp.NaiveAMPModel:6 of +msgid "If True, the output of this module is cast into fp32. Defaults to True." +msgstr "" + +#: internlm.core.naive_amp.NaiveAMPModel:8 of +msgid "" +"The parallel group mode used in this module. Defaults to " +"``ParallelMode.DATA``." +msgstr "" + +#: internlm.core.naive_amp.NaiveAMPModel:11 of +msgid "If True, the buffers are synchronized. Defaults to True." +msgstr "" + +#: ../../source/mixed_precision.rst:8 +msgid "InternLM默认将模型转换为16位浮点数类型进行训练(在配置文件中可以设置默认类型为其他数据类型)。在使用混合精度时,需要在构建模型时使用" +msgstr "" +"InternLM converts the model to 16-bit floating-point types for model " +"training by default (the default type can be set to other data types in " +"the configuration file). When using mixed precision, it is necessary to " +"use " + +#: ../../source/mixed_precision.rst:14 +msgid "将模型的某个子模块设置为32位浮点数类型进行训练,InternLM会在模型训练时自动将数据类型转换成需要的精度。" +msgstr "" +"to set a sub-module of the model to 16-bit floating-point types for " +"training, and InternLM will automatically convert the data type to the " +"required precision during model training." + +#: ../../source/mixed_precision.rst:16 +msgid "例如:" +msgstr "For example:" diff --git a/doc/code-docs/source/index.rst b/doc/code-docs/source/index.rst index c01ac54..8811af2 100644 --- a/doc/code-docs/source/index.rst +++ b/doc/code-docs/source/index.rst @@ -47,6 +47,14 @@ InternLM parallel +混合精度 +------------------- + +.. toctree:: + :maxdepth: 2 + + mixed_precision + 模型备份 -------------------- diff --git a/doc/code-docs/source/mixed_precision.rst b/doc/code-docs/source/mixed_precision.rst new file mode 100644 index 0000000..59955e0 --- /dev/null +++ b/doc/code-docs/source/mixed_precision.rst @@ -0,0 +1,36 @@ +混合精度 +----------------- +混合精度是指在模型训练的过程中同时使用16位和32位浮点数类型,是一种在最小化精度损失的前提下加速模型训练的方法。 +混合精度通过让模型的某些部分使用32位浮点数以保持数值稳定性,并在其余部分利用半精度浮点数加速训练并可以减少内存使用,在评估指标(如准确率)方面仍可以获得同等的训练效果。 + +.. autoclass:: internlm.core.naive_amp.NaiveAMPModel + +InternLM默认将模型转换为16位浮点数类型进行训练(在配置文件中可以设置默认类型为其他数据类型)。在使用混合精度时,需要在构建模型时使用 + +.. code-block:: python + + set_fp32_attr_to_module(/*fp32 module*/) + +将模型的某个子模块设置为32位浮点数类型进行训练,InternLM会在模型训练时自动将数据类型转换成需要的精度。 + +例如: + +.. code-block:: python + + class MlpModel(nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(4, 1, bias=False) + self.linear2 = nn.Linear(1, 4, bias=False) + + model = MlpModel() + # set model.linear2 as fp32 module + set_fp32_attr_to_module(model.linear2) + + # apply mixed precision + model = NaiveAMPModel( + model=model, + output_to_fp32=True, + dtype=torch.bfloat16(), + sync_buffer=False, + ) diff --git a/doc/code-docs/source/parallel.rst b/doc/code-docs/source/parallel.rst index 5f593c0..6de9545 100644 --- a/doc/code-docs/source/parallel.rst +++ b/doc/code-docs/source/parallel.rst @@ -133,7 +133,7 @@ ZeRO1.5 的实现使用了分层分片的概念,通过配置值 ``parallel.zer hybrid_zero_optimizer = dict( # Enable low_level_optimzer overlap_communication - overlap_sync_grad=True, + overlap_sync_grad=True, overlap_sync_param=True, # bucket size for nccl communication params reduce_bucket_size=512 * 1024 * 1024, diff --git a/internlm/core/naive_amp.py b/internlm/core/naive_amp.py index 7470659..e02d151 100644 --- a/internlm/core/naive_amp.py +++ b/internlm/core/naive_amp.py @@ -3,7 +3,8 @@ # adopted from https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/amp -from typing import Any +from functools import partial +from typing import Any, Union import torch import torch.distributed as dist @@ -15,6 +16,14 @@ from internlm.core.context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc +def set_fp32_attr_to_module(module: nn.Module): + setattr(module, "is_fp32_module", True) + + +def module_has_fp32_attr(module: nn.Module): + return hasattr(module, "is_fp32_module") and getattr(module, "is_fp32_module") + + class NaiveAMPModel(nn.Module): """ This is a wrapper class for a model that automatically casts the model, its inputs, and outputs into fp16. @@ -51,6 +60,9 @@ class NaiveAMPModel(nn.Module): self._sync_buf = False self._first_eval_run = False + # register hook for fp32 module + self._register_fp32_parameters_hook() + @property def sync_buffer(self): """Returns the current state of the buffer synchronization.""" @@ -134,3 +146,46 @@ class NaiveAMPModel(nn.Module): if self._output_to_fp32: out = self.convert_to_fp32(out) return out + + def _register_fp32_parameters_hook(self) -> None: + """ + Set module to fp32 and register automatic conversion hook in the forward pass. + The fp32 modules are marked by set_fp32_attr_to_module(.) + """ + dtype = torch.float32 + + def to_fp32(x, dtype=dtype): + if isinstance(x, Tensor) and x.dtype != dtype: + return x.to(dtype) + return x + + def _pre_forward_hook_for_fp32(model: nn.Module, inputs: tuple): # pylint: disable=W0613 + assert isinstance(inputs, tuple) + return tuple(map(to_fp32, inputs)) + + def _post_forward_hook_for_fp32( + model: nn.Module, inputs: tuple, outputs: Union[tuple, Tensor] + ): # pylint: disable=W0613 + assert isinstance(inputs, Union[tuple, Tensor]) + if isinstance(outputs, tuple): + return tuple(map(to_fp32, outputs, self.dtype)) + else: + return to_fp32(outputs, self.dtype) + + # just want to share same for loop for ModuleList and Module + if isinstance(self.model, nn.ModuleList): + model = self.model + else: + model = [self.model] + + modules = [] + # record the modules to transformer/embeding/head/norm block + for _chunk in model: + modules.extend([sub_module for _, sub_module in _chunk.named_modules()]) + + # register_forward_pre_hook for transformer/embeding/norm/xxx block + for sub_module in modules: + if module_has_fp32_attr(sub_module): + sub_module.to(dtype) + sub_module.register_forward_pre_hook(partial(_pre_forward_hook_for_fp32)) + sub_module.register_forward_hook(partial(_post_forward_hook_for_fp32)) diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index e08d4ec..883129d 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -31,6 +31,7 @@ from internlm.solver.beta2_scheduler import Beta2Scheduler from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR from internlm.solver.optimizer import HybridZeroOptimizer from internlm.solver.optimizer.utils import ParamBcastSyncHandler +from internlm.train.utils import create_param_groups from internlm.utils.common import DummyProfile from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer @@ -109,8 +110,9 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]): param_bcast_sync_handler = None adam_cfg = gpc.config.adam + params = create_param_groups(model, adam_cfg.weight_decay) naive_optimizer = torch.optim.AdamW( - params=[{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}], + params=params, lr=adam_cfg.lr, betas=(adam_cfg.adam_beta1, adam_cfg.adam_beta2), eps=adam_cfg.adam_eps, diff --git a/internlm/train/utils.py b/internlm/train/utils.py new file mode 100644 index 0000000..211cb53 --- /dev/null +++ b/internlm/train/utils.py @@ -0,0 +1,61 @@ +from typing import Dict, Tuple + +import torch + + +def split_params_into_different_groups_for_optimizer(param_groups: Tuple[Dict]) -> Tuple[Dict]: + """Split parameters into different groups for optimizer + Compatiable with muiltiple param groups, each should have a name + + Args: + param_groups (Tuple[Dict]): The list of parameter groups to split + Input Example: + >>> ( + >>> {'name': 'default', 'params': [tensor], 'weight_decay' :xxx}, + >>> ) + + Returns: + Tuple[Dict]: list of params groups for optimizer + Output Example: + >>> ( + >>> {'name': 'default','params': [tensor],'weight_decay' :xxx}, + >>> {'name': 'default_fp32', 'params': [tensor],'weight_decay' :xxx}, + >>> ..., + >>> ) + + Returns: + Tuple[Dict]: list of fp16/fp32 groups for optimizer + """ + if isinstance(param_groups, tuple): + param_groups = list(param_groups) # Tuple cannot be modified + elif isinstance(param_groups, dict): + param_groups = [param_groups] + elif not isinstance(param_groups, list): + raise ValueError(f"Unknown param group type of {type(param_groups)}") + + fp32_group = {"name": "fp32", "params": []} + for pgroup in param_groups: + # copy attribute from origin group, we assume the input param_groups only + # have one group, so the attribute will not be copyed multiple times. + for ori_key in pgroup.keys(): + if ori_key not in ("name", "params"): + fp32_group[ori_key] = pgroup[ori_key] + # Assign param + origin_params = [] + for param in pgroup["params"]: + if param.dtype == torch.float32: + fp32_group["params"].append(param) + else: + origin_params.append(param) + # bf16 param group, the first group in the param_groups + pgroup["params"] = origin_params + + param_groups.append(fp32_group) + + return tuple(param_groups) + + +def create_param_groups(model, weight_decay): + parameters = {"params": list(model.parameters()), "name": "default", "weight_decay": weight_decay} + + return split_params_into_different_groups_for_optimizer(parameters) diff --git a/tests/test_model/test_fused_precision/test_fused_precision.py b/tests/test_model/test_fused_precision/test_fused_precision.py new file mode 100644 index 0000000..e368813 --- /dev/null +++ b/tests/test_model/test_fused_precision/test_fused_precision.py @@ -0,0 +1,128 @@ +import multiprocessing as mp +from functools import partial + +import pytest +import torch +from torch import nn + +from internlm.core.naive_amp import NaiveAMPModel, set_fp32_attr_to_module +from internlm.model.modeling_internlm import PackedFlashBaseLayer1D +from internlm.train.utils import create_param_groups +from tests.test_model.test_model_internlm import build_environment, seed_all + + +def _pre_forward_hook_for_check(model, inputs): # pylint: disable=W0613 + assert all(_.dtype == torch.float32 for _ in inputs) + + +def _post_forward_hook_for_check(model, inputs, outputs): # pylint: disable=W0613 + if isinstance(outputs, tuple): + assert all(_.dtype == torch.half for _ in outputs) + else: + assert outputs.dtype == torch.half + + +def check_fused_precision(args): + # init + rank, world_size = args + device = torch.device("cuda") + build_environment(rank, world_size) + + # fix seed + seed_all(1024) + # define model + model = PackedFlashBaseLayer1D( + hidden_size=16, # 768 + num_attention_heads=2, # 12 + mlp_ratio=2, + attn_drop_rate=0.0, + drop_rate=0.0, + dtype=torch.bfloat16, + layer_norm_epsilon=1e-5, + checkpoint=False, + layer_idx=0, + residual_in_fp32=False, + device=device, + norm_type="rmsnorm", + dropout_selective_checkpoint=True, + use_scaled_init=True, + use_swiglu=True, + ) + model = model.to(device) + set_fp32_attr_to_module(model.norm1) + model = NaiveAMPModel( + model=model, + output_to_fp32=True, + dtype=torch.half, + sync_buffer=False, + ) + model.model.norm1.register_forward_pre_hook(partial(_pre_forward_hook_for_check)) + model.model.norm1.register_forward_hook(partial(_post_forward_hook_for_check)) + + hidden_states = torch.rand(1, 1, 16).to(device).requires_grad_() + + # forward + model(hidden_states) + + +class MlpModel(nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(4, 1, bias=False).half() + self.linear2 = nn.Linear(1, 4, bias=False).float() + + +def check_split_fused_group(args): + # init + rank, world_size = args + device = torch.device("cuda") + build_environment(rank, world_size) + rtol, atol = (1e-3, 5e-3) + + # fix seed + seed_all(1024) + # define model + model = MlpModel().to(device) + groups = create_param_groups(model, weight_decay=0.05) + + standard_group = ( + { + "name": "default", + "params": [torch.Tensor([[0.3088, 0.2935, -0.2900, 0.4280]]).to(torch.float16).to(device).requires_grad_()], + "weight_decay": 0.05, + }, + { + "name": "fp32", + "params": [torch.Tensor([[0.6273], [0.4844], [-0.0463], [-0.0090]]).to(device).requires_grad_()], + "weight_decay": 0.05, + }, + ) + + # check groups params + for t1, t2 in zip(groups, standard_group): + # assert t1["name"] == t2["name"] + assert all( + torch.allclose(p1, p2, rtol=rtol, atol=atol, equal_nan=True) for p1, p2 in zip(t1["params"], t2["params"]) + ) + + +@pytest.mark.fused_precision +def test_fused_precision(): + ctx = mp.get_context("spawn") + with ctx.Pool(processes=8) as pool: + pool.map(check_fused_precision, [[rank, 8] for rank in range(8)]) + pool.close() + pool.join() + + +@pytest.mark.split_groups +def test_split_fused_groups(): + ctx = mp.get_context("spawn") + with ctx.Pool(processes=8) as pool: + pool.map(check_split_fused_group, [[rank, 8] for rank in range(8)]) + pool.close() + pool.join() + + +if __name__ == "__main__": + pytest.main(["-s", "-q", "test_norm.py"])