[kernel] added jit warmup (#1792)

pull/1828/head
アマデウス 2022-11-08 16:22:23 +08:00 committed by GitHub
parent 76e64cb67c
commit 4268ae017b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 81 additions and 21 deletions

View File

@ -1,5 +1,11 @@
import torch import torch
from colossalai.nn.layer.colossalai_layer import Embedding, Linear
from colossalai.utils import get_current_device
from .bias_dropout_add import bias_dropout_add_fused_train
from .bias_gelu import bias_gelu_impl
JIT_OPTIONS_SET = False JIT_OPTIONS_SET = False
@ -30,3 +36,44 @@ def set_jit_fusion_options():
torch._C._jit_override_can_fuse_on_gpu(True) torch._C._jit_override_can_fuse_on_gpu(True)
JIT_OPTIONS_SET = True JIT_OPTIONS_SET = True
def warmup_jit_fusion(batch_size: int,
hidden_size: int,
seq_length: int = 512,
vocab_size: int = 32768,
dtype: torch.dtype = torch.float32):
""" Compilie JIT functions before the main training steps """
embed = Embedding(vocab_size, hidden_size).to(get_current_device())
linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_current_device())
linear_2 = Linear(hidden_size * 4, hidden_size, skip_bias_add=True).to(get_current_device())
x = torch.randint(vocab_size, (batch_size, seq_length), dtype=torch.long, device=get_current_device())
x = embed(x)
y, y_bias = linear_1(x)
z, z_bias = linear_2(y)
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation
for bias_grad, input_grad in zip([True, True], [False, True]):
for _ in range(10):
bias = torch.rand_like(y_bias, dtype=dtype, device=get_current_device())
input_ = torch.rand_like(y, dtype=dtype, device=get_current_device())
bias.requires_grad, input_.requires_grad = bias_grad, input_grad
bias_gelu_impl(input_, bias)
# Warmup fused bias+dropout+add
dropout_rate = 0.1
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation
for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]):
for _ in range(10):
input_ = torch.rand_like(z, dtype=dtype, device=get_current_device())
residual = torch.rand_like(x, dtype=dtype, device=get_current_device())
bias = torch.rand_like(z_bias, dtype=dtype, device=get_current_device())
input_.requires_grad = input_grad
bias.requires_grad = bias_grad
residual.requires_grad = residual_grad
bias_dropout_add_fused_train(input_, bias, residual, dropout_rate)
torch.cuda.empty_cache()

View File

@ -1,6 +1,6 @@
from ._ops import *
from .layer import * from .layer import *
from .loss import * from .loss import *
from .lr_scheduler import * from .lr_scheduler import *
from .metric import * from .metric import *
from .optimizer import * from .optimizer import *
from ._ops import *

View File

@ -7,6 +7,9 @@ from typing import Callable, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor
from torch.nn.parameter import Parameter
from colossalai.communication import broadcast from colossalai.communication import broadcast
from colossalai.context import ParallelMode, seed from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
@ -14,18 +17,33 @@ from colossalai.global_variables import tensor_parallel_env as env
from colossalai.kernel import LayerNorm from colossalai.kernel import LayerNorm
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.registry import LAYERS from colossalai.registry import LAYERS
from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict, from colossalai.utils.checkpointing import (
partition_tensor_parallel_state_dict) broadcast_state_dict,
gather_tensor_parallel_state_dict,
partition_tensor_parallel_state_dict,
)
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from torch import Tensor
from torch.nn.parameter import Parameter
from ..vanilla import VanillaPatchEmbedding, VanillaLayerNorm
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ..colossalai_layer._utils import ColossalaiModule from ..colossalai_layer._utils import ColossalaiModule
from ..utils import divide, set_tensor_parallel_attribute_by_partition from ..utils import divide, set_tensor_parallel_attribute_by_partition
from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input, from ..vanilla import VanillaLayerNorm, VanillaPatchEmbedding
split_forward_gather_backward)
from ._operation import linear_with_async_comm from ._operation import linear_with_async_comm
from ._utils import (
gather_forward_split_backward,
get_parallel_input,
reduce_grad,
reduce_input,
set_parallel_input,
split_forward_gather_backward,
)
Fast_LN = None
try:
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
Fast_LN = FastLayerNorm
except ImportError:
pass
@LAYERS.register_module @LAYERS.register_module
@ -102,19 +120,15 @@ class LayerNorm1D(ColossalaiModule):
] ]
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None):
from apex.normalization import FusedLayerNorm if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes:
norm = Fast_LN(normalized_shape, eps=eps).to(dtype)
fast_ln_installed = False
try:
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
fast_ln_installed = True
except ImportError:
pass
if fast_ln_installed and normalized_shape in self._fast_ln_supported_sizes:
norm = FastLayerNorm(normalized_shape, eps=eps).to(dtype)
else: else:
norm = None
try:
from apex.normalization import FusedLayerNorm
norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype) norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype)
except ImportError:
norm = LayerNorm(normalized_shape, eps=eps).to(dtype)
super().__init__(norm) super().__init__(norm)
def _load_from_state_dict(self, state_dict, prefix, *args): def _load_from_state_dict(self, state_dict, prefix, *args):

View File

@ -5,7 +5,6 @@ import torch
from torch import Tensor from torch import Tensor
from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env from colossalai.global_variables import tensor_parallel_env as env