From c1fe37d12599ef45c84c0f88612157281a499b62 Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Mon, 25 Sep 2023 10:56:55 +0800 Subject: [PATCH] remove fp32 hook for norm --- internlm/core/naive_amp.py | 23 +++++++---------------- internlm/model/modeling_internlm.py | 4 ---- 2 files changed, 7 insertions(+), 20 deletions(-) diff --git a/internlm/core/naive_amp.py b/internlm/core/naive_amp.py index 2f4d832..6fa84ec 100644 --- a/internlm/core/naive_amp.py +++ b/internlm/core/naive_amp.py @@ -155,11 +155,13 @@ class NaiveAMPModel(nn.Module): return x.to(dtype) return x - def _pre_forward_hook(model: nn.Module, inputs: tuple): # pylint: disable=W0613 + def _pre_forward_hook_for_fp32(model: nn.Module, inputs: tuple): # pylint: disable=W0613 assert isinstance(inputs, tuple) return tuple(map(to_fp32, inputs)) - def _post_forward_hook(model: nn.Module, inputs: tuple, outputs: Union[tuple, Tensor]): # pylint: disable=W0613 + def _post_forward_hook_for_fp32( + model: nn.Module, inputs: tuple, outputs: Union[tuple, Tensor] + ): # pylint: disable=W0613 assert isinstance(inputs, Union[tuple, Tensor]) if isinstance(outputs, tuple): return tuple(map(to_fp32, outputs, self.dtype)) @@ -175,22 +177,11 @@ class NaiveAMPModel(nn.Module): modules = [] # record the modules to transformer/embeding/head/norm block for _chunk in model: - if isinstance(_chunk, NaiveAMPModel): - _chunk = _chunk.model - - for child in _chunk.children(): - # should be the transformer block definaton in modeling_xxx.py - if isinstance(child, nn.ModuleList): - for _, block in enumerate(child): - # TODO special case for MoE - modules.extend(list(block.children())) - else: - # embedding, head, etc that out of the transformer block - modules.append(child) + modules.extend([sub_module for _, sub_module in _chunk.named_modules()]) # register_forward_pre_hook for transformer/embeding/norm/xxx block for sub_module in modules: if module_has_fp32_attr(sub_module): sub_module.to(dtype) - sub_module.register_forward_pre_hook(partial(_pre_forward_hook)) - sub_module.register_forward_hook(partial(_post_forward_hook)) + sub_module.register_forward_pre_hook(partial(_pre_forward_hook_for_fp32)) + sub_module.register_forward_hook(partial(_post_forward_hook_for_fp32)) diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 858b6f0..64ff4de 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -11,7 +11,6 @@ from torch import nn from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context.parallel_context import global_context as gpc -from internlm.core.naive_amp import set_fp32_attr_to_module from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal from internlm.model.embedding import Embedding1D from internlm.model.linear import ( @@ -102,8 +101,6 @@ class PackedFlashBaseLayer1D(nn.Module): else: self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) - set_fp32_attr_to_module(self.norm1) - set_fp32_attr_to_module(self.norm2) if use_swiglu: self.mlp = FeedForward( @@ -337,7 +334,6 @@ class PackedFlashInternLm1D(nn.Module): self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) else: self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) - set_fp32_attr_to_module(self.norm) self.head = head_cls( in_features=hidden_size, out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,