mirror of https://github.com/InternLM/InternLM
fix(modeling): norm weight should be fp32
parent
91480c5b63
commit
def738a5c8
|
@ -13,7 +13,7 @@ from torch import nn
|
|||
from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode
|
||||
from internlm.core.context.parallel_context import global_context as gpc
|
||||
from internlm.core.context.random import _SEED_MANAGER
|
||||
from internlm.core.naive_amp import set_output_attr_to_module
|
||||
from internlm.core.naive_amp import set_fp32_attr_to_module, set_output_attr_to_module
|
||||
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
|
||||
from internlm.initialize.launch import GLOBAL_SEED
|
||||
from internlm.model.embedding import Embedding1D
|
||||
|
@ -113,6 +113,8 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
else:
|
||||
self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
set_fp32_attr_to_module(self.norm1)
|
||||
set_fp32_attr_to_module(self.norm2)
|
||||
|
||||
if use_swiglu:
|
||||
self.mlp = FeedForward(
|
||||
|
@ -360,6 +362,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
else:
|
||||
self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
set_fp32_attr_to_module(self.norm)
|
||||
self.head = head_cls(
|
||||
in_features=hidden_size,
|
||||
out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,
|
||||
|
|
Loading…
Reference in New Issue