mirror of https://github.com/InternLM/InternLM
add output embedding tf32 option (#523)
parent
c581cc4c02
commit
9fc252f40e
|
@ -24,6 +24,14 @@ def module_has_fp32_attr(module: nn.Module):
|
|||
return hasattr(module, "is_fp32_module") and getattr(module, "is_fp32_module")
|
||||
|
||||
|
||||
def set_output_attr_to_module(module: nn.Module):
|
||||
setattr(module, "is_output", True)
|
||||
|
||||
|
||||
def module_is_output(module: nn.Module):
|
||||
return hasattr(module, "is_output") and getattr(module, "is_output")
|
||||
|
||||
|
||||
class NaiveAMPModel(nn.Module):
|
||||
"""
|
||||
This is a wrapper class for a model that automatically casts the model, its inputs, and outputs into fp16.
|
||||
|
@ -189,3 +197,8 @@ class NaiveAMPModel(nn.Module):
|
|||
sub_module.to(fp32_dtype)
|
||||
sub_module.register_forward_pre_hook(partial(_pre_forward_hook_for_fp32))
|
||||
sub_module.register_forward_hook(partial(_post_forward_hook_for_fp32))
|
||||
if gpc.config.get("output_tf32", False) and module_is_output(sub_module):
|
||||
sub_module.to(fp32_dtype)
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
sub_module.register_forward_pre_hook(partial(_pre_forward_hook_for_fp32))
|
||||
|
|
|
@ -13,6 +13,7 @@ from torch import nn
|
|||
from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode
|
||||
from internlm.core.context.parallel_context import global_context as gpc
|
||||
from internlm.core.context.random import _SEED_MANAGER
|
||||
from internlm.core.naive_amp import set_output_attr_to_module
|
||||
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
|
||||
from internlm.initialize.launch import GLOBAL_SEED
|
||||
from internlm.model.embedding import Embedding1D
|
||||
|
@ -368,6 +369,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
dtype=dtype,
|
||||
weight_scale=embed_grad_scale,
|
||||
)
|
||||
set_output_attr_to_module(self.head)
|
||||
for _, param in self.head.named_parameters():
|
||||
normal_(std=0.0052)(param)
|
||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||
|
|
Loading…
Reference in New Issue