mirror of https://github.com/hpcaitech/ColossalAI
[kernel] added jit warmup (#1792)
parent
76e64cb67c
commit
4268ae017b
|
@ -1,5 +1,11 @@
|
|||
import torch
|
||||
|
||||
from colossalai.nn.layer.colossalai_layer import Embedding, Linear
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .bias_dropout_add import bias_dropout_add_fused_train
|
||||
from .bias_gelu import bias_gelu_impl
|
||||
|
||||
JIT_OPTIONS_SET = False
|
||||
|
||||
|
||||
|
@ -30,3 +36,44 @@ def set_jit_fusion_options():
|
|||
torch._C._jit_override_can_fuse_on_gpu(True)
|
||||
|
||||
JIT_OPTIONS_SET = True
|
||||
|
||||
|
||||
def warmup_jit_fusion(batch_size: int,
|
||||
hidden_size: int,
|
||||
seq_length: int = 512,
|
||||
vocab_size: int = 32768,
|
||||
dtype: torch.dtype = torch.float32):
|
||||
""" Compilie JIT functions before the main training steps """
|
||||
|
||||
embed = Embedding(vocab_size, hidden_size).to(get_current_device())
|
||||
linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_current_device())
|
||||
linear_2 = Linear(hidden_size * 4, hidden_size, skip_bias_add=True).to(get_current_device())
|
||||
|
||||
x = torch.randint(vocab_size, (batch_size, seq_length), dtype=torch.long, device=get_current_device())
|
||||
x = embed(x)
|
||||
y, y_bias = linear_1(x)
|
||||
z, z_bias = linear_2(y)
|
||||
# Warmup JIT fusions with the input grad_enable state of both forward
|
||||
# prop and recomputation
|
||||
for bias_grad, input_grad in zip([True, True], [False, True]):
|
||||
for _ in range(10):
|
||||
bias = torch.rand_like(y_bias, dtype=dtype, device=get_current_device())
|
||||
input_ = torch.rand_like(y, dtype=dtype, device=get_current_device())
|
||||
bias.requires_grad, input_.requires_grad = bias_grad, input_grad
|
||||
bias_gelu_impl(input_, bias)
|
||||
|
||||
# Warmup fused bias+dropout+add
|
||||
dropout_rate = 0.1
|
||||
# Warmup JIT fusions with the input grad_enable state of both forward
|
||||
# prop and recomputation
|
||||
for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]):
|
||||
for _ in range(10):
|
||||
input_ = torch.rand_like(z, dtype=dtype, device=get_current_device())
|
||||
residual = torch.rand_like(x, dtype=dtype, device=get_current_device())
|
||||
bias = torch.rand_like(z_bias, dtype=dtype, device=get_current_device())
|
||||
input_.requires_grad = input_grad
|
||||
bias.requires_grad = bias_grad
|
||||
residual.requires_grad = residual_grad
|
||||
bias_dropout_add_fused_train(input_, bias, residual, dropout_rate)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from ._ops import *
|
||||
from .layer import *
|
||||
from .loss import *
|
||||
from .lr_scheduler import *
|
||||
from .metric import *
|
||||
from .optimizer import *
|
||||
from ._ops import *
|
||||
|
|
|
@ -7,6 +7,9 @@ from typing import Callable, Tuple
|
|||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from colossalai.communication import broadcast
|
||||
from colossalai.context import ParallelMode, seed
|
||||
from colossalai.core import global_context as gpc
|
||||
|
@ -14,18 +17,33 @@ from colossalai.global_variables import tensor_parallel_env as env
|
|||
from colossalai.kernel import LayerNorm
|
||||
from colossalai.nn import init as init
|
||||
from colossalai.registry import LAYERS
|
||||
from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict,
|
||||
partition_tensor_parallel_state_dict)
|
||||
from colossalai.utils.checkpointing import (
|
||||
broadcast_state_dict,
|
||||
gather_tensor_parallel_state_dict,
|
||||
partition_tensor_parallel_state_dict,
|
||||
)
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from torch import Tensor
|
||||
from torch.nn.parameter import Parameter
|
||||
from ..vanilla import VanillaPatchEmbedding, VanillaLayerNorm
|
||||
|
||||
from ..base_layer import ParallelLayer
|
||||
from ..colossalai_layer._utils import ColossalaiModule
|
||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition
|
||||
from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input,
|
||||
split_forward_gather_backward)
|
||||
from ..vanilla import VanillaLayerNorm, VanillaPatchEmbedding
|
||||
from ._operation import linear_with_async_comm
|
||||
from ._utils import (
|
||||
gather_forward_split_backward,
|
||||
get_parallel_input,
|
||||
reduce_grad,
|
||||
reduce_input,
|
||||
set_parallel_input,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
|
||||
Fast_LN = None
|
||||
try:
|
||||
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
|
||||
Fast_LN = FastLayerNorm
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
|
@ -102,19 +120,15 @@ class LayerNorm1D(ColossalaiModule):
|
|||
]
|
||||
|
||||
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None):
|
||||
from apex.normalization import FusedLayerNorm
|
||||
|
||||
fast_ln_installed = False
|
||||
try:
|
||||
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
|
||||
fast_ln_installed = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if fast_ln_installed and normalized_shape in self._fast_ln_supported_sizes:
|
||||
norm = FastLayerNorm(normalized_shape, eps=eps).to(dtype)
|
||||
if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes:
|
||||
norm = Fast_LN(normalized_shape, eps=eps).to(dtype)
|
||||
else:
|
||||
norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype)
|
||||
norm = None
|
||||
try:
|
||||
from apex.normalization import FusedLayerNorm
|
||||
norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype)
|
||||
except ImportError:
|
||||
norm = LayerNorm(normalized_shape, eps=eps).to(dtype)
|
||||
super().__init__(norm)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||
|
|
|
@ -5,7 +5,6 @@ import torch
|
|||
from torch import Tensor
|
||||
|
||||
from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.global_variables import tensor_parallel_env as env
|
||||
|
||||
|
|
Loading…
Reference in New Issue