add output embedding tf32 option (#523)

pull/530/head
jiaopenglong 2023-12-06 13:50:59 +08:00 committed by GitHub
parent c581cc4c02
commit 9fc252f40e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 0 deletions

View File

@ -24,6 +24,14 @@ def module_has_fp32_attr(module: nn.Module):
return hasattr(module, "is_fp32_module") and getattr(module, "is_fp32_module")
def set_output_attr_to_module(module: nn.Module):
setattr(module, "is_output", True)
def module_is_output(module: nn.Module):
return hasattr(module, "is_output") and getattr(module, "is_output")
class NaiveAMPModel(nn.Module):
"""
This is a wrapper class for a model that automatically casts the model, its inputs, and outputs into fp16.
@ -189,3 +197,8 @@ class NaiveAMPModel(nn.Module):
sub_module.to(fp32_dtype)
sub_module.register_forward_pre_hook(partial(_pre_forward_hook_for_fp32))
sub_module.register_forward_hook(partial(_post_forward_hook_for_fp32))
if gpc.config.get("output_tf32", False) and module_is_output(sub_module):
sub_module.to(fp32_dtype)
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
sub_module.register_forward_pre_hook(partial(_pre_forward_hook_for_fp32))

View File

@ -13,6 +13,7 @@ from torch import nn
from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context.parallel_context import global_context as gpc
from internlm.core.context.random import _SEED_MANAGER
from internlm.core.naive_amp import set_output_attr_to_module
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
from internlm.initialize.launch import GLOBAL_SEED
from internlm.model.embedding import Embedding1D
@ -368,6 +369,7 @@ class PackedFlashInternLm1D(nn.Module):
dtype=dtype,
weight_scale=embed_grad_scale,
)
set_output_attr_to_module(self.head)
for _, param in self.head.named_parameters():
normal_(std=0.0052)(param)
if gpc.get_world_size(ParallelMode.TENSOR) > 1: