mirror of https://github.com/InternLM/InternLM
remove fp32 hook for norm
parent
72bb3125a3
commit
c1fe37d125
|
@ -155,11 +155,13 @@ class NaiveAMPModel(nn.Module):
|
||||||
return x.to(dtype)
|
return x.to(dtype)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def _pre_forward_hook(model: nn.Module, inputs: tuple): # pylint: disable=W0613
|
def _pre_forward_hook_for_fp32(model: nn.Module, inputs: tuple): # pylint: disable=W0613
|
||||||
assert isinstance(inputs, tuple)
|
assert isinstance(inputs, tuple)
|
||||||
return tuple(map(to_fp32, inputs))
|
return tuple(map(to_fp32, inputs))
|
||||||
|
|
||||||
def _post_forward_hook(model: nn.Module, inputs: tuple, outputs: Union[tuple, Tensor]): # pylint: disable=W0613
|
def _post_forward_hook_for_fp32(
|
||||||
|
model: nn.Module, inputs: tuple, outputs: Union[tuple, Tensor]
|
||||||
|
): # pylint: disable=W0613
|
||||||
assert isinstance(inputs, Union[tuple, Tensor])
|
assert isinstance(inputs, Union[tuple, Tensor])
|
||||||
if isinstance(outputs, tuple):
|
if isinstance(outputs, tuple):
|
||||||
return tuple(map(to_fp32, outputs, self.dtype))
|
return tuple(map(to_fp32, outputs, self.dtype))
|
||||||
|
@ -175,22 +177,11 @@ class NaiveAMPModel(nn.Module):
|
||||||
modules = []
|
modules = []
|
||||||
# record the modules to transformer/embeding/head/norm block
|
# record the modules to transformer/embeding/head/norm block
|
||||||
for _chunk in model:
|
for _chunk in model:
|
||||||
if isinstance(_chunk, NaiveAMPModel):
|
modules.extend([sub_module for _, sub_module in _chunk.named_modules()])
|
||||||
_chunk = _chunk.model
|
|
||||||
|
|
||||||
for child in _chunk.children():
|
|
||||||
# should be the transformer block definaton in modeling_xxx.py
|
|
||||||
if isinstance(child, nn.ModuleList):
|
|
||||||
for _, block in enumerate(child):
|
|
||||||
# TODO special case for MoE
|
|
||||||
modules.extend(list(block.children()))
|
|
||||||
else:
|
|
||||||
# embedding, head, etc that out of the transformer block
|
|
||||||
modules.append(child)
|
|
||||||
|
|
||||||
# register_forward_pre_hook for transformer/embeding/norm/xxx block
|
# register_forward_pre_hook for transformer/embeding/norm/xxx block
|
||||||
for sub_module in modules:
|
for sub_module in modules:
|
||||||
if module_has_fp32_attr(sub_module):
|
if module_has_fp32_attr(sub_module):
|
||||||
sub_module.to(dtype)
|
sub_module.to(dtype)
|
||||||
sub_module.register_forward_pre_hook(partial(_pre_forward_hook))
|
sub_module.register_forward_pre_hook(partial(_pre_forward_hook_for_fp32))
|
||||||
sub_module.register_forward_hook(partial(_post_forward_hook))
|
sub_module.register_forward_hook(partial(_post_forward_hook_for_fp32))
|
||||||
|
|
|
@ -11,7 +11,6 @@ from torch import nn
|
||||||
|
|
||||||
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
||||||
from internlm.core.context.parallel_context import global_context as gpc
|
from internlm.core.context.parallel_context import global_context as gpc
|
||||||
from internlm.core.naive_amp import set_fp32_attr_to_module
|
|
||||||
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
|
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
|
||||||
from internlm.model.embedding import Embedding1D
|
from internlm.model.embedding import Embedding1D
|
||||||
from internlm.model.linear import (
|
from internlm.model.linear import (
|
||||||
|
@ -102,8 +101,6 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
else:
|
else:
|
||||||
self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||||
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||||
set_fp32_attr_to_module(self.norm1)
|
|
||||||
set_fp32_attr_to_module(self.norm2)
|
|
||||||
|
|
||||||
if use_swiglu:
|
if use_swiglu:
|
||||||
self.mlp = FeedForward(
|
self.mlp = FeedForward(
|
||||||
|
@ -337,7 +334,6 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon)
|
self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon)
|
||||||
else:
|
else:
|
||||||
self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||||
set_fp32_attr_to_module(self.norm)
|
|
||||||
self.head = head_cls(
|
self.head = head_cls(
|
||||||
in_features=hidden_size,
|
in_features=hidden_size,
|
||||||
out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,
|
out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,
|
||||||
|
|
Loading…
Reference in New Issue