mirror of https://github.com/InternLM/InternLM
add output embedding tf32 option (#523)
parent
c581cc4c02
commit
9fc252f40e
|
@ -24,6 +24,14 @@ def module_has_fp32_attr(module: nn.Module):
|
||||||
return hasattr(module, "is_fp32_module") and getattr(module, "is_fp32_module")
|
return hasattr(module, "is_fp32_module") and getattr(module, "is_fp32_module")
|
||||||
|
|
||||||
|
|
||||||
|
def set_output_attr_to_module(module: nn.Module):
|
||||||
|
setattr(module, "is_output", True)
|
||||||
|
|
||||||
|
|
||||||
|
def module_is_output(module: nn.Module):
|
||||||
|
return hasattr(module, "is_output") and getattr(module, "is_output")
|
||||||
|
|
||||||
|
|
||||||
class NaiveAMPModel(nn.Module):
|
class NaiveAMPModel(nn.Module):
|
||||||
"""
|
"""
|
||||||
This is a wrapper class for a model that automatically casts the model, its inputs, and outputs into fp16.
|
This is a wrapper class for a model that automatically casts the model, its inputs, and outputs into fp16.
|
||||||
|
@ -189,3 +197,8 @@ class NaiveAMPModel(nn.Module):
|
||||||
sub_module.to(fp32_dtype)
|
sub_module.to(fp32_dtype)
|
||||||
sub_module.register_forward_pre_hook(partial(_pre_forward_hook_for_fp32))
|
sub_module.register_forward_pre_hook(partial(_pre_forward_hook_for_fp32))
|
||||||
sub_module.register_forward_hook(partial(_post_forward_hook_for_fp32))
|
sub_module.register_forward_hook(partial(_post_forward_hook_for_fp32))
|
||||||
|
if gpc.config.get("output_tf32", False) and module_is_output(sub_module):
|
||||||
|
sub_module.to(fp32_dtype)
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
sub_module.register_forward_pre_hook(partial(_pre_forward_hook_for_fp32))
|
||||||
|
|
|
@ -13,6 +13,7 @@ from torch import nn
|
||||||
from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode
|
from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode
|
||||||
from internlm.core.context.parallel_context import global_context as gpc
|
from internlm.core.context.parallel_context import global_context as gpc
|
||||||
from internlm.core.context.random import _SEED_MANAGER
|
from internlm.core.context.random import _SEED_MANAGER
|
||||||
|
from internlm.core.naive_amp import set_output_attr_to_module
|
||||||
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
|
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
|
||||||
from internlm.initialize.launch import GLOBAL_SEED
|
from internlm.initialize.launch import GLOBAL_SEED
|
||||||
from internlm.model.embedding import Embedding1D
|
from internlm.model.embedding import Embedding1D
|
||||||
|
@ -368,6 +369,7 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
weight_scale=embed_grad_scale,
|
weight_scale=embed_grad_scale,
|
||||||
)
|
)
|
||||||
|
set_output_attr_to_module(self.head)
|
||||||
for _, param in self.head.named_parameters():
|
for _, param in self.head.named_parameters():
|
||||||
normal_(std=0.0052)(param)
|
normal_(std=0.0052)(param)
|
||||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||||
|
|
Loading…
Reference in New Issue