diff --git a/internlm/core/naive_amp.py b/internlm/core/naive_amp.py index 9bead52..62970d9 100644 --- a/internlm/core/naive_amp.py +++ b/internlm/core/naive_amp.py @@ -24,6 +24,14 @@ def module_has_fp32_attr(module: nn.Module): return hasattr(module, "is_fp32_module") and getattr(module, "is_fp32_module") +def set_output_attr_to_module(module: nn.Module): + setattr(module, "is_output", True) + + +def module_is_output(module: nn.Module): + return hasattr(module, "is_output") and getattr(module, "is_output") + + class NaiveAMPModel(nn.Module): """ This is a wrapper class for a model that automatically casts the model, its inputs, and outputs into fp16. @@ -189,3 +197,8 @@ class NaiveAMPModel(nn.Module): sub_module.to(fp32_dtype) sub_module.register_forward_pre_hook(partial(_pre_forward_hook_for_fp32)) sub_module.register_forward_hook(partial(_post_forward_hook_for_fp32)) + if gpc.config.get("output_tf32", False) and module_is_output(sub_module): + sub_module.to(fp32_dtype) + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + sub_module.register_forward_pre_hook(partial(_pre_forward_hook_for_fp32)) diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 204f71f..a47a5cd 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -13,6 +13,7 @@ from torch import nn from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context.parallel_context import global_context as gpc from internlm.core.context.random import _SEED_MANAGER +from internlm.core.naive_amp import set_output_attr_to_module from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal from internlm.initialize.launch import GLOBAL_SEED from internlm.model.embedding import Embedding1D @@ -368,6 +369,7 @@ class PackedFlashInternLm1D(nn.Module): dtype=dtype, weight_scale=embed_grad_scale, ) + set_output_attr_to_module(self.head) for _, param in self.head.named_parameters(): normal_(std=0.0052)(param) if gpc.get_world_size(ParallelMode.TENSOR) > 1: