fix(modeling): norm weight should be fp32

pull/580/head
877825076@qq.com 2024-01-11 13:43:04 +08:00
parent 91480c5b63
commit def738a5c8
1 changed files with 4 additions and 1 deletions

View File

@ -13,7 +13,7 @@ from torch import nn
from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context.parallel_context import global_context as gpc
from internlm.core.context.random import _SEED_MANAGER
from internlm.core.naive_amp import set_output_attr_to_module
from internlm.core.naive_amp import set_fp32_attr_to_module, set_output_attr_to_module
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
from internlm.initialize.launch import GLOBAL_SEED
from internlm.model.embedding import Embedding1D
@ -113,6 +113,8 @@ class PackedFlashBaseLayer1D(nn.Module):
else:
self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
set_fp32_attr_to_module(self.norm1)
set_fp32_attr_to_module(self.norm2)
if use_swiglu:
self.mlp = FeedForward(
@ -360,6 +362,7 @@ class PackedFlashInternLm1D(nn.Module):
self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon)
else:
self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
set_fp32_attr_to_module(self.norm)
self.head = head_cls(
in_features=hidden_size,
out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,