mirror of https://github.com/InternLM/InternLM
remove fp32 hook for norm
parent
72bb3125a3
commit
c1fe37d125
|
@ -155,11 +155,13 @@ class NaiveAMPModel(nn.Module):
|
|||
return x.to(dtype)
|
||||
return x
|
||||
|
||||
def _pre_forward_hook(model: nn.Module, inputs: tuple): # pylint: disable=W0613
|
||||
def _pre_forward_hook_for_fp32(model: nn.Module, inputs: tuple): # pylint: disable=W0613
|
||||
assert isinstance(inputs, tuple)
|
||||
return tuple(map(to_fp32, inputs))
|
||||
|
||||
def _post_forward_hook(model: nn.Module, inputs: tuple, outputs: Union[tuple, Tensor]): # pylint: disable=W0613
|
||||
def _post_forward_hook_for_fp32(
|
||||
model: nn.Module, inputs: tuple, outputs: Union[tuple, Tensor]
|
||||
): # pylint: disable=W0613
|
||||
assert isinstance(inputs, Union[tuple, Tensor])
|
||||
if isinstance(outputs, tuple):
|
||||
return tuple(map(to_fp32, outputs, self.dtype))
|
||||
|
@ -175,22 +177,11 @@ class NaiveAMPModel(nn.Module):
|
|||
modules = []
|
||||
# record the modules to transformer/embeding/head/norm block
|
||||
for _chunk in model:
|
||||
if isinstance(_chunk, NaiveAMPModel):
|
||||
_chunk = _chunk.model
|
||||
|
||||
for child in _chunk.children():
|
||||
# should be the transformer block definaton in modeling_xxx.py
|
||||
if isinstance(child, nn.ModuleList):
|
||||
for _, block in enumerate(child):
|
||||
# TODO special case for MoE
|
||||
modules.extend(list(block.children()))
|
||||
else:
|
||||
# embedding, head, etc that out of the transformer block
|
||||
modules.append(child)
|
||||
modules.extend([sub_module for _, sub_module in _chunk.named_modules()])
|
||||
|
||||
# register_forward_pre_hook for transformer/embeding/norm/xxx block
|
||||
for sub_module in modules:
|
||||
if module_has_fp32_attr(sub_module):
|
||||
sub_module.to(dtype)
|
||||
sub_module.register_forward_pre_hook(partial(_pre_forward_hook))
|
||||
sub_module.register_forward_hook(partial(_post_forward_hook))
|
||||
sub_module.register_forward_pre_hook(partial(_pre_forward_hook_for_fp32))
|
||||
sub_module.register_forward_hook(partial(_post_forward_hook_for_fp32))
|
||||
|
|
|
@ -11,7 +11,6 @@ from torch import nn
|
|||
|
||||
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
||||
from internlm.core.context.parallel_context import global_context as gpc
|
||||
from internlm.core.naive_amp import set_fp32_attr_to_module
|
||||
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
|
||||
from internlm.model.embedding import Embedding1D
|
||||
from internlm.model.linear import (
|
||||
|
@ -102,8 +101,6 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
else:
|
||||
self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
set_fp32_attr_to_module(self.norm1)
|
||||
set_fp32_attr_to_module(self.norm2)
|
||||
|
||||
if use_swiglu:
|
||||
self.mlp = FeedForward(
|
||||
|
@ -337,7 +334,6 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
else:
|
||||
self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
set_fp32_attr_to_module(self.norm)
|
||||
self.head = head_cls(
|
||||
in_features=hidden_size,
|
||||
out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,
|
||||
|
|
Loading…
Reference in New Issue