diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index a47a5cd..faf0d7b 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -13,7 +13,7 @@ from torch import nn from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context.parallel_context import global_context as gpc from internlm.core.context.random import _SEED_MANAGER -from internlm.core.naive_amp import set_output_attr_to_module +from internlm.core.naive_amp import set_fp32_attr_to_module, set_output_attr_to_module from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal from internlm.initialize.launch import GLOBAL_SEED from internlm.model.embedding import Embedding1D @@ -113,6 +113,8 @@ class PackedFlashBaseLayer1D(nn.Module): else: self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) + set_fp32_attr_to_module(self.norm1) + set_fp32_attr_to_module(self.norm2) if use_swiglu: self.mlp = FeedForward( @@ -360,6 +362,7 @@ class PackedFlashInternLm1D(nn.Module): self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) else: self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) + set_fp32_attr_to_module(self.norm) self.head = head_cls( in_features=hidden_size, out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,