remove fp32 hook for norm

pull/319/head
Wenwen Qu 2023-09-25 10:56:55 +08:00
parent 72bb3125a3
commit c1fe37d125
2 changed files with 7 additions and 20 deletions

View File

@ -155,11 +155,13 @@ class NaiveAMPModel(nn.Module):
return x.to(dtype)
return x
def _pre_forward_hook(model: nn.Module, inputs: tuple): # pylint: disable=W0613
def _pre_forward_hook_for_fp32(model: nn.Module, inputs: tuple): # pylint: disable=W0613
assert isinstance(inputs, tuple)
return tuple(map(to_fp32, inputs))
def _post_forward_hook(model: nn.Module, inputs: tuple, outputs: Union[tuple, Tensor]): # pylint: disable=W0613
def _post_forward_hook_for_fp32(
model: nn.Module, inputs: tuple, outputs: Union[tuple, Tensor]
): # pylint: disable=W0613
assert isinstance(inputs, Union[tuple, Tensor])
if isinstance(outputs, tuple):
return tuple(map(to_fp32, outputs, self.dtype))
@ -175,22 +177,11 @@ class NaiveAMPModel(nn.Module):
modules = []
# record the modules to transformer/embeding/head/norm block
for _chunk in model:
if isinstance(_chunk, NaiveAMPModel):
_chunk = _chunk.model
for child in _chunk.children():
# should be the transformer block definaton in modeling_xxx.py
if isinstance(child, nn.ModuleList):
for _, block in enumerate(child):
# TODO special case for MoE
modules.extend(list(block.children()))
else:
# embedding, head, etc that out of the transformer block
modules.append(child)
modules.extend([sub_module for _, sub_module in _chunk.named_modules()])
# register_forward_pre_hook for transformer/embeding/norm/xxx block
for sub_module in modules:
if module_has_fp32_attr(sub_module):
sub_module.to(dtype)
sub_module.register_forward_pre_hook(partial(_pre_forward_hook))
sub_module.register_forward_hook(partial(_post_forward_hook))
sub_module.register_forward_pre_hook(partial(_pre_forward_hook_for_fp32))
sub_module.register_forward_hook(partial(_post_forward_hook_for_fp32))

View File

@ -11,7 +11,6 @@ from torch import nn
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context.parallel_context import global_context as gpc
from internlm.core.naive_amp import set_fp32_attr_to_module
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
from internlm.model.embedding import Embedding1D
from internlm.model.linear import (
@ -102,8 +101,6 @@ class PackedFlashBaseLayer1D(nn.Module):
else:
self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
set_fp32_attr_to_module(self.norm1)
set_fp32_attr_to_module(self.norm2)
if use_swiglu:
self.mlp = FeedForward(
@ -337,7 +334,6 @@ class PackedFlashInternLm1D(nn.Module):
self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon)
else:
self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
set_fp32_attr_to_module(self.norm)
self.head = head_cls(
in_features=hidden_size,
out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,