diff --git a/colossalai/kernel/jit/option.py b/colossalai/kernel/jit/option.py index d95905897..aa41f5767 100644 --- a/colossalai/kernel/jit/option.py +++ b/colossalai/kernel/jit/option.py @@ -1,5 +1,11 @@ import torch +from colossalai.nn.layer.colossalai_layer import Embedding, Linear +from colossalai.utils import get_current_device + +from .bias_dropout_add import bias_dropout_add_fused_train +from .bias_gelu import bias_gelu_impl + JIT_OPTIONS_SET = False @@ -30,3 +36,44 @@ def set_jit_fusion_options(): torch._C._jit_override_can_fuse_on_gpu(True) JIT_OPTIONS_SET = True + + +def warmup_jit_fusion(batch_size: int, + hidden_size: int, + seq_length: int = 512, + vocab_size: int = 32768, + dtype: torch.dtype = torch.float32): + """ Compilie JIT functions before the main training steps """ + + embed = Embedding(vocab_size, hidden_size).to(get_current_device()) + linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_current_device()) + linear_2 = Linear(hidden_size * 4, hidden_size, skip_bias_add=True).to(get_current_device()) + + x = torch.randint(vocab_size, (batch_size, seq_length), dtype=torch.long, device=get_current_device()) + x = embed(x) + y, y_bias = linear_1(x) + z, z_bias = linear_2(y) + # Warmup JIT fusions with the input grad_enable state of both forward + # prop and recomputation + for bias_grad, input_grad in zip([True, True], [False, True]): + for _ in range(10): + bias = torch.rand_like(y_bias, dtype=dtype, device=get_current_device()) + input_ = torch.rand_like(y, dtype=dtype, device=get_current_device()) + bias.requires_grad, input_.requires_grad = bias_grad, input_grad + bias_gelu_impl(input_, bias) + + # Warmup fused bias+dropout+add + dropout_rate = 0.1 + # Warmup JIT fusions with the input grad_enable state of both forward + # prop and recomputation + for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]): + for _ in range(10): + input_ = torch.rand_like(z, dtype=dtype, device=get_current_device()) + residual = torch.rand_like(x, dtype=dtype, device=get_current_device()) + bias = torch.rand_like(z_bias, dtype=dtype, device=get_current_device()) + input_.requires_grad = input_grad + bias.requires_grad = bias_grad + residual.requires_grad = residual_grad + bias_dropout_add_fused_train(input_, bias, residual, dropout_rate) + + torch.cuda.empty_cache() diff --git a/colossalai/nn/__init__.py b/colossalai/nn/__init__.py index 91fc0da55..910ad2031 100644 --- a/colossalai/nn/__init__.py +++ b/colossalai/nn/__init__.py @@ -1,6 +1,6 @@ +from ._ops import * from .layer import * from .loss import * from .lr_scheduler import * from .metric import * from .optimizer import * -from ._ops import * diff --git a/colossalai/nn/layer/parallel_1d/layers.py b/colossalai/nn/layer/parallel_1d/layers.py index 0edc5e37b..88ecdf691 100644 --- a/colossalai/nn/layer/parallel_1d/layers.py +++ b/colossalai/nn/layer/parallel_1d/layers.py @@ -7,6 +7,9 @@ from typing import Callable, Tuple import torch import torch.nn.functional as F +from torch import Tensor +from torch.nn.parameter import Parameter + from colossalai.communication import broadcast from colossalai.context import ParallelMode, seed from colossalai.core import global_context as gpc @@ -14,18 +17,33 @@ from colossalai.global_variables import tensor_parallel_env as env from colossalai.kernel import LayerNorm from colossalai.nn import init as init from colossalai.registry import LAYERS -from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict, - partition_tensor_parallel_state_dict) +from colossalai.utils.checkpointing import ( + broadcast_state_dict, + gather_tensor_parallel_state_dict, + partition_tensor_parallel_state_dict, +) from colossalai.utils.cuda import get_current_device -from torch import Tensor -from torch.nn.parameter import Parameter -from ..vanilla import VanillaPatchEmbedding, VanillaLayerNorm + from ..base_layer import ParallelLayer from ..colossalai_layer._utils import ColossalaiModule from ..utils import divide, set_tensor_parallel_attribute_by_partition -from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input, - split_forward_gather_backward) +from ..vanilla import VanillaLayerNorm, VanillaPatchEmbedding from ._operation import linear_with_async_comm +from ._utils import ( + gather_forward_split_backward, + get_parallel_input, + reduce_grad, + reduce_input, + set_parallel_input, + split_forward_gather_backward, +) + +Fast_LN = None +try: + from apex.contrib.layer_norm.layer_norm import FastLayerNorm + Fast_LN = FastLayerNorm +except ImportError: + pass @LAYERS.register_module @@ -102,19 +120,15 @@ class LayerNorm1D(ColossalaiModule): ] def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None): - from apex.normalization import FusedLayerNorm - - fast_ln_installed = False - try: - from apex.contrib.layer_norm.layer_norm import FastLayerNorm - fast_ln_installed = True - except ImportError: - pass - - if fast_ln_installed and normalized_shape in self._fast_ln_supported_sizes: - norm = FastLayerNorm(normalized_shape, eps=eps).to(dtype) + if Fast_LN is not None and normalized_shape in self._fast_ln_supported_sizes: + norm = Fast_LN(normalized_shape, eps=eps).to(dtype) else: - norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype) + norm = None + try: + from apex.normalization import FusedLayerNorm + norm = FusedLayerNorm(normalized_shape, eps=eps).to(dtype) + except ImportError: + norm = LayerNorm(normalized_shape, eps=eps).to(dtype) super().__init__(norm) def _load_from_state_dict(self, state_dict, prefix, *args): diff --git a/colossalai/nn/layer/parallel_3d/_utils.py b/colossalai/nn/layer/parallel_3d/_utils.py index 759810f5e..364191a79 100644 --- a/colossalai/nn/layer/parallel_3d/_utils.py +++ b/colossalai/nn/layer/parallel_3d/_utils.py @@ -5,7 +5,6 @@ import torch from torch import Tensor from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D -from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.global_variables import tensor_parallel_env as env