From f68eddfb3d3e261df4182bfb42ba4ac872af325a Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 13 Jan 2022 16:47:17 +0800 Subject: [PATCH] refactor kernel (#142) --- MANIFEST.in | 3 +- colossalai/kernel/__init__.py | 3 - colossalai/kernel/cuda_native/__init__.py | 16 +- colossalai/kernel/cuda_native/builder.py | 114 ---------- .../cuda_native/csrc}/colossal_C_frontend.cpp | 0 colossalai/kernel/cuda_native/csrc/compat.h | 5 +- .../cuda_native/csrc}/multi_tensor_adam.cu | 0 .../cuda_native/csrc}/multi_tensor_apply.cuh | 0 .../csrc}/multi_tensor_l2norm_kernel.cu | 0 .../cuda_native/csrc}/multi_tensor_lamb.cu | 0 .../csrc}/multi_tensor_scale_kernel.cu | 0 .../csrc}/multi_tensor_sgd_kernel.cu | 0 .../kernel/cuda_native/csrc/type_shim.h | 199 +++++++++++++++++ colossalai/kernel/cuda_native/layer_norm.py | 14 +- .../kernel/cuda_native/multihead_attention.py | 26 ++- .../kernel/cuda_native/scaled_softmax.py | 25 ++- colossalai/kernel/jit/__init__.py | 7 +- colossalai/kernel/jit/option.py | 1 + colossalai/nn/optimizer/fused_adam.py | 3 +- colossalai/nn/optimizer/fused_lamb.py | 3 +- colossalai/nn/optimizer/fused_sgd.py | 3 +- csrc/compat.h | 10 - csrc/type_shim.h | 202 ------------------ setup.py | 114 ++++++---- 24 files changed, 334 insertions(+), 414 deletions(-) delete mode 100644 colossalai/kernel/cuda_native/builder.py rename {csrc => colossalai/kernel/cuda_native/csrc}/colossal_C_frontend.cpp (100%) rename {csrc => colossalai/kernel/cuda_native/csrc}/multi_tensor_adam.cu (100%) rename {csrc => colossalai/kernel/cuda_native/csrc}/multi_tensor_apply.cuh (100%) rename {csrc => colossalai/kernel/cuda_native/csrc}/multi_tensor_l2norm_kernel.cu (100%) rename {csrc => colossalai/kernel/cuda_native/csrc}/multi_tensor_lamb.cu (100%) rename {csrc => colossalai/kernel/cuda_native/csrc}/multi_tensor_scale_kernel.cu (100%) rename {csrc => colossalai/kernel/cuda_native/csrc}/multi_tensor_sgd_kernel.cu (100%) delete mode 100644 csrc/compat.h delete mode 100644 csrc/type_shim.h diff --git a/MANIFEST.in b/MANIFEST.in index a406adf97..48a44e0b4 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,3 @@ include *.txt README.md recursive-include requirements *.txt -recursive-include colossalai *.cpp *.h *.cu *.tr *.cuh *.cc -recursive-include csrc *.cpp *.h *.cu *.tr *.cuh *.cc \ No newline at end of file +recursive-include colossalai *.cpp *.h *.cu *.tr *.cuh *.cc \ No newline at end of file diff --git a/colossalai/kernel/__init__.py b/colossalai/kernel/__init__.py index 32bab15e5..d3d0be02b 100644 --- a/colossalai/kernel/__init__.py +++ b/colossalai/kernel/__init__.py @@ -1,8 +1,5 @@ -from .jit.bias_dropout_add import bias_dropout_add_fused_train, bias_dropout_add_fused_inference -from .jit.bias_gelu import bias_gelu_impl from .cuda_native import LayerNorm, FusedScaleMaskSoftmax, MultiHeadAttention __all__ = [ - "bias_dropout_add_fused_train", "bias_dropout_add_fused_inference", "bias_gelu_impl", "LayerNorm", "FusedScaleMaskSoftmax", "MultiHeadAttention" ] diff --git a/colossalai/kernel/cuda_native/__init__.py b/colossalai/kernel/cuda_native/__init__.py index 33394224f..a35158b72 100644 --- a/colossalai/kernel/cuda_native/__init__.py +++ b/colossalai/kernel/cuda_native/__init__.py @@ -1,17 +1,3 @@ -from .builder import _build_cuda_native_kernel - -CUDA_NATIVE_KERNEL_BUILD = False - - -def build_cuda_native_kernel(): - global CUDA_NATIVE_KERNEL_BUILD - if CUDA_NATIVE_KERNEL_BUILD == False: - _build_cuda_native_kernel() - CUDA_NATIVE_KERNEL_BUILD = True - - -build_cuda_native_kernel() - from .layer_norm import MixedFusedLayerNorm as LayerNorm from .scaled_softmax import FusedScaleMaskSoftmax -from .multihead_attention import MultiHeadAttention \ No newline at end of file +from .multihead_attention import MultiHeadAttention diff --git a/colossalai/kernel/cuda_native/builder.py b/colossalai/kernel/cuda_native/builder.py deleted file mode 100644 index 9f1d10e6e..000000000 --- a/colossalai/kernel/cuda_native/builder.py +++ /dev/null @@ -1,114 +0,0 @@ -import os -import pathlib -import subprocess - -from torch.utils import cpp_extension - -# Setting this param to a list has a problem of generating different -# compilation commands (with diferent order of architectures) and -# leading to recompilation of fused kernels. Set it to empty string -# to avoid recompilation and assign arch flags explicity in -# extra_cuda_cflags below -os.environ["TORCH_CUDA_ARCH_LIST"] = "" - - -def _build_cuda_native_kernel(): - - # Check if cuda 11 is installed for compute capability 8.0 - cc_flag = [] - _, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME) - if int(bare_metal_major) >= 11: - cc_flag.append('-gencode') - cc_flag.append('arch=compute_80,code=sm_80') - - # Build path - basepath = pathlib.Path(__file__).parent.absolute() - srcpath = basepath / 'csrc' - buildpath = basepath / 'build' - _create_build_dir(buildpath) - - # Helper function to build the kernels. - def _cpp_extention_load_helper(name, sources, extra_cuda_flags): - return cpp_extension.load( - name=name, - sources=sources, - build_directory=buildpath, - extra_cflags=[ - '-O3', - ], - extra_include_paths=[str(srcpath / 'kernels' / 'include')], - extra_cuda_cflags=['-O3', '-gencode', 'arch=compute_70,code=sm_70', '--use_fast_math'] + - extra_cuda_flags + cc_flag, - verbose=False) - - # ============== - # Fused softmax. - # ============== - - extra_cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__', - '-U__CUDA_NO_HALF_CONVERSIONS__', - '--expt-relaxed-constexpr', - '--expt-extended-lambda'] - - # Upper triangular softmax. - sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp', - srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu'] - colossal_scaled_upper_triang_masked_softmax = _cpp_extention_load_helper( - "colossal_scaled_upper_triang_masked_softmax", - sources, extra_cuda_flags) - - # Masked softmax. - sources=[srcpath / 'scaled_masked_softmax.cpp', - srcpath / 'scaled_masked_softmax_cuda.cu'] - colossal_scaled_masked_softmax = _cpp_extention_load_helper( - "colossal_scaled_masked_softmax", sources, extra_cuda_flags) - - # ================================= - # Mixed precision fused layer norm. - # ================================= - - extra_cuda_flags = ['-maxrregcount=50'] - sources = [srcpath / 'layer_norm_cuda.cpp', srcpath / 'layer_norm_cuda_kernel.cu'] - colossal_layer_norm_cuda = _cpp_extention_load_helper("colossal_layer_norm_cuda", sources, - extra_cuda_flags) - - # ========================================== - # Mixed precision Transformer Encoder Layer. - # ========================================== - - extra_cuda_flags = ['-std=c++14', - '-U__CUDA_NO_HALF_OPERATORS__', - '-U__CUDA_NO_HALF_CONVERSIONS__', - '-U__CUDA_NO_HALF2_OPERATORS__', - '-DTHRUST_IGNORE_CUB_VERSION_CHECK'] - - sources = [srcpath / 'multihead_attention_1d.cpp'] - kernel_sources = ["cublas_wrappers.cu", - "transform_kernels.cu", - "dropout_kernels.cu", - "normalize_kernels.cu", - "softmax_kernels.cu", - "general_kernels.cu", - "cuda_util.cu"] - sources += [(srcpath / 'kernels' / cu_file) for cu_file in kernel_sources] - colossal_multihead_attention = _cpp_extention_load_helper("colossal_multihead_attention", sources, - extra_cuda_flags) - - -def _get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) - output = raw_output.split() - release_idx = output.index("release") + 1 - release = output[release_idx].split(".") - bare_metal_major = release[0] - bare_metal_minor = release[1][0] - - return raw_output, bare_metal_major, bare_metal_minor - - -def _create_build_dir(buildpath): - try: - os.mkdir(buildpath) - except OSError: - if not os.path.isdir(buildpath): - print(f"Creation of the build directory {buildpath} failed") diff --git a/csrc/colossal_C_frontend.cpp b/colossalai/kernel/cuda_native/csrc/colossal_C_frontend.cpp similarity index 100% rename from csrc/colossal_C_frontend.cpp rename to colossalai/kernel/cuda_native/csrc/colossal_C_frontend.cpp diff --git a/colossalai/kernel/cuda_native/csrc/compat.h b/colossalai/kernel/cuda_native/csrc/compat.h index def1c2158..00066dc95 100644 --- a/colossalai/kernel/cuda_native/csrc/compat.h +++ b/colossalai/kernel/cuda_native/csrc/compat.h @@ -1,7 +1,4 @@ -/*This code from NVIDIA apex: - * https://github.com/NVIDIA/apex - * with minor changes. */ - +// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h #ifndef TORCH_CHECK #define TORCH_CHECK AT_CHECK #endif diff --git a/csrc/multi_tensor_adam.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu similarity index 100% rename from csrc/multi_tensor_adam.cu rename to colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu diff --git a/csrc/multi_tensor_apply.cuh b/colossalai/kernel/cuda_native/csrc/multi_tensor_apply.cuh similarity index 100% rename from csrc/multi_tensor_apply.cuh rename to colossalai/kernel/cuda_native/csrc/multi_tensor_apply.cuh diff --git a/csrc/multi_tensor_l2norm_kernel.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu similarity index 100% rename from csrc/multi_tensor_l2norm_kernel.cu rename to colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu diff --git a/csrc/multi_tensor_lamb.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu similarity index 100% rename from csrc/multi_tensor_lamb.cu rename to colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu diff --git a/csrc/multi_tensor_scale_kernel.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu similarity index 100% rename from csrc/multi_tensor_scale_kernel.cu rename to colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu diff --git a/csrc/multi_tensor_sgd_kernel.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu similarity index 100% rename from csrc/multi_tensor_sgd_kernel.cu rename to colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu diff --git a/colossalai/kernel/cuda_native/csrc/type_shim.h b/colossalai/kernel/cuda_native/csrc/type_shim.h index 845615feb..2cae99a2d 100644 --- a/colossalai/kernel/cuda_native/csrc/type_shim.h +++ b/colossalai/kernel/cuda_native/csrc/type_shim.h @@ -71,3 +71,202 @@ default: \ AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ } + +// Forward/backward compatiblity hack around +// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288 +// pending more future-proof guidance from upstream. +// struct TypeShim +// { +// const at::Type& payload; +// TypeShim(const at::Type& type) : payload(type) {} +// // Enable trivial conversion to a const at::Type& for pre-3aeb78 +// operator const at::Type&(){ return payload; }; +// // Enable dispatch switch statements to take *this directly for post-3aeb78 +// //operator at::ScalarType(){ return payload.; }; +// }; + +#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ + switch (TYPE) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + +#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \ + switch (TYPE) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Byte: \ + { \ + using scalar_t_##LEVEL = uint8_t; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + +#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ + switch (TYPE) \ + { \ + case at::ScalarType::Double: \ + { \ + using scalar_t_##LEVEL = double; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Float: \ + { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + +#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \ + switch (TYPE) \ + { \ + case at::ScalarType::Double: \ + { \ + using scalar_t_##LEVEL = double; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Float: \ + { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + +template +__device__ __forceinline__ T reduce_block_into_lanes(T *x, + T val, + int lanes = 1, + bool share_result = false) // lanes is intended to be <= 32. +{ + int tid = threadIdx.x + threadIdx.y * blockDim.x; + int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. + + if (blockSize >= 64) + { + x[tid] = val; + __syncthreads(); + } + +#pragma unroll + for (int i = (blockSize >> 1); i >= 64; i >>= 1) + { + if (tid < i) + x[tid] = x[tid] + x[tid + i]; + __syncthreads(); + } + + T final; + + if (tid < 32) + { + if (blockSize >= 64) + final = x[tid] + x[tid + 32]; + else + final = val; + // __SYNCWARP(); + +#pragma unroll + for (int i = 16; i >= lanes; i >>= 1) + final = final + __shfl_down_sync(0xffffffff, final, i); + } + + if (share_result) + { + if (tid < lanes) + x[tid] = final; // EpilogueOp + // Make sure the smem result is visible to all warps. + __syncthreads(); + } + + return final; +} + +template +__device__ __forceinline__ T reduce_block_into_lanes_max_op(T *x, + T val, + int lanes = 1, + bool share_result = false) // lanes is intended to be <= 32. +{ + int tid = threadIdx.x + threadIdx.y * blockDim.x; + int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. + + if (blockSize >= 64) + { + x[tid] = val; + __syncthreads(); + } + +#pragma unroll + for (int i = (blockSize >> 1); i >= 64; i >>= 1) + { + if (tid < i) + x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i])); + __syncthreads(); + } + + T final; + + if (tid < 32) + { + if (blockSize >= 64) + final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32])); + else + final = val; + // __SYNCWARP(); + +#pragma unroll + for (int i = 16; i >= lanes; i >>= 1) + final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i))); + } + + if (share_result) + { + if (tid < lanes) + x[tid] = final; // EpilogueOp + // Make sure the smem result is visible to all warps. + __syncthreads(); + } + + return final; +} \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/layer_norm.py b/colossalai/kernel/cuda_native/layer_norm.py index 22b6efa01..4e1d486d9 100644 --- a/colossalai/kernel/cuda_native/layer_norm.py +++ b/colossalai/kernel/cuda_native/layer_norm.py @@ -34,10 +34,10 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): input_, weight_, bias_, mean, invvar = ctx.saved_tensors grad_input = grad_weight = grad_bias = None grad_input, grad_weight, grad_bias \ - = colossal_layer_norm_cuda.backward_affine( - grad_output.contiguous(), mean, invvar, - input_, ctx.normalized_shape, - weight_, bias_, ctx.eps) + = colossal_layer_norm_cuda.backward_affine( + grad_output.contiguous(), mean, invvar, + input_, ctx.normalized_shape, + weight_, bias_, ctx.eps) return grad_input, grad_weight, grad_bias, None, None @@ -48,7 +48,11 @@ class MixedFusedLayerNorm(torch.nn.Module): super(MixedFusedLayerNorm, self).__init__() global colossal_layer_norm_cuda - colossal_layer_norm_cuda = importlib.import_module("colossal_layer_norm_cuda") + if colossal_layer_norm_cuda is None: + try: + colossal_layer_norm_cuda = importlib.import_module("colossal_layer_norm_cuda") + except ImportError: + raise RuntimeError('MixedFusedLayerNorm requires cuda extensions') if isinstance(normalized_shape, numbers.Integral): normalized_shape = (normalized_shape,) diff --git a/colossalai/kernel/cuda_native/multihead_attention.py b/colossalai/kernel/cuda_native/multihead_attention.py index 52e1c0bcf..aab76cc5f 100644 --- a/colossalai/kernel/cuda_native/multihead_attention.py +++ b/colossalai/kernel/cuda_native/multihead_attention.py @@ -34,6 +34,7 @@ def calc_offset(sizes): colossal_multihead_attention = None + @dataclass class Config: max_batch_tokens: int # max batch token numbers @@ -94,7 +95,7 @@ class MultiHeadAttention1DFunc(Function): input_mask = input_mask.to(torch.half) grad_input, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, \ grad_out_proj_bias, grad_norm_weight, grad_norm_bias = backward_func( - ctx.config.layer_id, grad_output, output, input, input_mask, in_proj_weight, \ + ctx.config.layer_id, grad_output, output, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, out_proj_bias, norm_weight, norm_bias) return (grad_input, None, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, @@ -142,7 +143,10 @@ class MultiHeadAttention(nn.Module): # Load cuda modules if needed global colossal_multihead_attention if colossal_multihead_attention is None: - colossal_multihead_attention = importlib.import_module("colossal_multihead_attention") + try: + colossal_multihead_attention = importlib.import_module("colossal_multihead_attention") + except ImportError: + raise RuntimeError('MultiHeadAttention requires cuda extensions') # create the layer in cuda kernels. cuda_module = colossal_multihead_attention @@ -210,14 +214,14 @@ class MultiHeadAttention(nn.Module): with torch.no_grad(): self.in_proj_weight.copy_( attn_qkvw_global.view(3, hs, hs)[:, - int(hs * rank_in_pg / - self.pg_size):int(hs * (rank_in_pg + 1) / - self.pg_size), :]) + int(hs * rank_in_pg / + self.pg_size):int(hs * (rank_in_pg + 1) / + self.pg_size), :]) self.in_proj_bias.copy_( attn_qkvb_global.view(3, hs)[:, - int(hs * rank_in_pg / - self.pg_size):int(hs * (rank_in_pg + 1) / - self.pg_size)]) + int(hs * rank_in_pg / + self.pg_size):int(hs * (rank_in_pg + 1) / + self.pg_size)]) attn_ow_global = torch.empty(hs, hs) nn.init.xavier_uniform_(attn_ow_global, 1.0) @@ -226,9 +230,9 @@ class MultiHeadAttention(nn.Module): attn_ow_global = attn_ow_global.cpu() with torch.no_grad(): self.out_proj_weight.copy_(attn_ow_global[:, - int(hs * rank_in_pg / - self.pg_size):int(hs * (rank_in_pg + 1) / - self.pg_size)]) + int(hs * rank_in_pg / + self.pg_size):int(hs * (rank_in_pg + 1) / + self.pg_size)]) else: attn_qkvw = self.in_proj_weight.view(-1, hs) diff --git a/colossalai/kernel/cuda_native/scaled_softmax.py b/colossalai/kernel/cuda_native/scaled_softmax.py index c1388e299..e2584e83b 100644 --- a/colossalai/kernel/cuda_native/scaled_softmax.py +++ b/colossalai/kernel/cuda_native/scaled_softmax.py @@ -21,7 +21,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): @staticmethod def forward(ctx, inputs, scale): - import colossal_scaled_upper_triang_masked_softmax + try: + import colossal_scaled_upper_triang_masked_softmax + except ImportError: + raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions') scale_t = torch.tensor([scale]) softmax_results = colossal_scaled_upper_triang_masked_softmax.forward( @@ -33,7 +36,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): @staticmethod def backward(ctx, output_grads): - import colossal_scaled_upper_triang_masked_softmax + try: + import colossal_scaled_upper_triang_masked_softmax + except ImportError: + raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions') softmax_results, scale_t = ctx.saved_tensors input_grads = colossal_scaled_upper_triang_masked_softmax.backward( @@ -53,7 +59,10 @@ class ScaledMaskedSoftmax(torch.autograd.Function): @staticmethod def forward(ctx, inputs, mask, scale): - import colossal_scaled_masked_softmax + try: + import colossal_scaled_masked_softmax + except ImportError: + raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions') scale_t = torch.tensor([scale]) @@ -63,7 +72,10 @@ class ScaledMaskedSoftmax(torch.autograd.Function): @staticmethod def backward(ctx, output_grads): - import colossal_scaled_masked_softmax + try: + import colossal_scaled_masked_softmax + except ImportError: + raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions') softmax_results, scale_t = ctx.saved_tensors @@ -179,6 +191,9 @@ class FusedScaleMaskSoftmax(nn.Module): @staticmethod def get_batch_per_block(sq, sk, b, np): - import colossal_scaled_masked_softmax + try: + import colossal_scaled_masked_softmax + except ImportError: + raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions') return colossal_scaled_masked_softmax.get_batch_per_block(sq, sk, b, np) diff --git a/colossalai/kernel/jit/__init__.py b/colossalai/kernel/jit/__init__.py index 3f0e888bf..374e43537 100644 --- a/colossalai/kernel/jit/__init__.py +++ b/colossalai/kernel/jit/__init__.py @@ -1,3 +1,8 @@ from .option import _set_jit_fusion_options +from .bias_dropout_add import bias_dropout_add_fused_train, bias_dropout_add_fused_inference +from .bias_gelu import bias_gelu_impl +_set_jit_fusion_options() -_set_jit_fusion_options() \ No newline at end of file +__all__ = [ + "bias_dropout_add_fused_train", "bias_dropout_add_fused_inference", "bias_gelu_impl", +] diff --git a/colossalai/kernel/jit/option.py b/colossalai/kernel/jit/option.py index 06823ad3e..73c0b7b57 100644 --- a/colossalai/kernel/jit/option.py +++ b/colossalai/kernel/jit/option.py @@ -2,6 +2,7 @@ import torch JIT_OPTIONS_SET = False + def _set_jit_fusion_options(): """Set PyTorch JIT layer fusion options.""" global JIT_OPTIONS_SET diff --git a/colossalai/nn/optimizer/fused_adam.py b/colossalai/nn/optimizer/fused_adam.py index 8bcd3841a..cb75d073b 100644 --- a/colossalai/nn/optimizer/fused_adam.py +++ b/colossalai/nn/optimizer/fused_adam.py @@ -65,8 +65,7 @@ class FusedAdam(torch.optim.Optimizer): self._dummy_overflow_buf = torch.cuda.IntTensor([0]) self.multi_tensor_adam = colossal_C.multi_tensor_adam else: - raise RuntimeError( - 'apex.optimizers.FusedAdam requires cuda extensions') + raise RuntimeError('FusedAdam requires cuda extensions') def zero_grad(self): if self.set_grad_none: diff --git a/colossalai/nn/optimizer/fused_lamb.py b/colossalai/nn/optimizer/fused_lamb.py index 8a340a9f3..952cccd50 100644 --- a/colossalai/nn/optimizer/fused_lamb.py +++ b/colossalai/nn/optimizer/fused_lamb.py @@ -73,8 +73,7 @@ class FusedLAMB(torch.optim.Optimizer): [0], dtype=torch.int, device=self.param_groups[0]["params"][0].device) self.multi_tensor_lamb = colossal_C.multi_tensor_lamb else: - raise RuntimeError( - 'apex.optimizers.FusedLAMB requires cuda extensions') + raise RuntimeError('FusedLAMB requires cuda extensions') self.adam_w_mode = 1 if adam_w_mode else 0 self.set_grad_none = set_grad_none diff --git a/colossalai/nn/optimizer/fused_sgd.py b/colossalai/nn/optimizer/fused_sgd.py index 4986aa5f5..9e29f67f7 100644 --- a/colossalai/nn/optimizer/fused_sgd.py +++ b/colossalai/nn/optimizer/fused_sgd.py @@ -90,8 +90,7 @@ class FusedSGD(Optimizer): [0], dtype=torch.int, device=self.param_groups[0]["params"][0].device) self.multi_tensor_sgd = colossal_C.multi_tensor_sgd else: - raise RuntimeError( - 'apex.optimizers.FusedSGD requires cuda extensions') + raise RuntimeError('FusedSGD requires cuda extensions') def __setstate__(self, state): super(FusedSGD, self).__setstate__(state) diff --git a/csrc/compat.h b/csrc/compat.h deleted file mode 100644 index 00066dc95..000000000 --- a/csrc/compat.h +++ /dev/null @@ -1,10 +0,0 @@ -// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h -#ifndef TORCH_CHECK -#define TORCH_CHECK AT_CHECK -#endif - -#ifdef VERSION_GE_1_3 -#define DATA_PTR data_ptr -#else -#define DATA_PTR data -#endif \ No newline at end of file diff --git a/csrc/type_shim.h b/csrc/type_shim.h deleted file mode 100644 index e9696dea6..000000000 --- a/csrc/type_shim.h +++ /dev/null @@ -1,202 +0,0 @@ -// modified from https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h -#include -#include "compat.h" - -// Forward/backward compatiblity hack around -// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288 -// pending more future-proof guidance from upstream. -// struct TypeShim -// { -// const at::Type& payload; -// TypeShim(const at::Type& type) : payload(type) {} -// // Enable trivial conversion to a const at::Type& for pre-3aeb78 -// operator const at::Type&(){ return payload; }; -// // Enable dispatch switch statements to take *this directly for post-3aeb78 -// //operator at::ScalarType(){ return payload.; }; -// }; - -#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ - switch (TYPE) \ - { \ - case at::ScalarType::Float: \ - { \ - using scalar_t_##LEVEL = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: \ - { \ - using scalar_t_##LEVEL = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } - -#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \ - switch (TYPE) \ - { \ - case at::ScalarType::Float: \ - { \ - using scalar_t_##LEVEL = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: \ - { \ - using scalar_t_##LEVEL = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Byte: \ - { \ - using scalar_t_##LEVEL = uint8_t; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } - -#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ - switch (TYPE) \ - { \ - case at::ScalarType::Double: \ - { \ - using scalar_t_##LEVEL = double; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Float: \ - { \ - using scalar_t_##LEVEL = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: \ - { \ - using scalar_t_##LEVEL = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } - -#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \ - switch (TYPE) \ - { \ - case at::ScalarType::Double: \ - { \ - using scalar_t_##LEVEL = double; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Float: \ - { \ - using scalar_t_##LEVEL = float; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } - -template -__device__ __forceinline__ T reduce_block_into_lanes(T *x, - T val, - int lanes = 1, - bool share_result = false) // lanes is intended to be <= 32. -{ - int tid = threadIdx.x + threadIdx.y * blockDim.x; - int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. - - if (blockSize >= 64) - { - x[tid] = val; - __syncthreads(); - } - -#pragma unroll - for (int i = (blockSize >> 1); i >= 64; i >>= 1) - { - if (tid < i) - x[tid] = x[tid] + x[tid + i]; - __syncthreads(); - } - - T final; - - if (tid < 32) - { - if (blockSize >= 64) - final = x[tid] + x[tid + 32]; - else - final = val; - // __SYNCWARP(); - -#pragma unroll - for (int i = 16; i >= lanes; i >>= 1) - final = final + __shfl_down_sync(0xffffffff, final, i); - } - - if (share_result) - { - if (tid < lanes) - x[tid] = final; // EpilogueOp - // Make sure the smem result is visible to all warps. - __syncthreads(); - } - - return final; -} - -template -__device__ __forceinline__ T reduce_block_into_lanes_max_op(T *x, - T val, - int lanes = 1, - bool share_result = false) // lanes is intended to be <= 32. -{ - int tid = threadIdx.x + threadIdx.y * blockDim.x; - int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. - - if (blockSize >= 64) - { - x[tid] = val; - __syncthreads(); - } - -#pragma unroll - for (int i = (blockSize >> 1); i >= 64; i >>= 1) - { - if (tid < i) - x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i])); - __syncthreads(); - } - - T final; - - if (tid < 32) - { - if (blockSize >= 64) - final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32])); - else - final = val; - // __SYNCWARP(); - -#pragma unroll - for (int i = 16; i >= lanes; i >>= 1) - final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i))); - } - - if (share_result) - { - if (tid < lanes) - x[tid] = final; // EpilogueOp - // Make sure the smem result is visible to all warps. - __syncthreads(); - } - - return final; -} \ No newline at end of file diff --git a/setup.py b/setup.py index 20ddf3477..41f1ffc8b 100644 --- a/setup.py +++ b/setup.py @@ -11,8 +11,7 @@ this_dir = os.path.dirname(os.path.abspath(__file__)) def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output( - [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) output = raw_output.split() release_idx = output.index("release") + 1 release = output[release_idx].split(".") @@ -23,8 +22,7 @@ def get_cuda_bare_metal_version(cuda_dir): def check_cuda_torch_binary_vs_bare_metal(cuda_dir): - raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version( - cuda_dir) + raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir) torch_binary_major = torch.version.cuda.split(".")[0] torch_binary_minor = torch.version.cuda.split(".")[1] @@ -40,6 +38,13 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir): "You can try commenting out this check (at your own risk).") +def append_nvcc_threads(nvcc_extra_args): + _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) + if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: + return nvcc_extra_args + ["--threads", "4"] + return nvcc_extra_args + + def fetch_requirements(path): with open(path, 'r') as fd: return [r.strip() for r in fd.readlines()] @@ -67,8 +72,8 @@ print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MINOR = int(torch.__version__.split('.')[1]) -if TORCH_MAJOR == 0 and TORCH_MINOR < 4: - raise RuntimeError("Colossal-AI requires Pytorch 0.4 or newer.\n" + +if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 8): + raise RuntimeError("Colossal-AI requires Pytorch 1.8 or newer.\n" + "The latest stable release can be obtained from https://pytorch.org/") cmdclass = {} @@ -79,22 +84,9 @@ ext_modules = [] # and # https://github.com/NVIDIA/apex/issues/456 # https://github.com/pytorch/pytorch/commit/eb7b39e02f7d75c26d8a795ea8c7fd911334da7e#diff-4632522f237f1e4e728cb824300403ac -version_ge_1_1 = [] -if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0): - version_ge_1_1 = ['-DVERSION_GE_1_1'] -version_ge_1_3 = [] -if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2): - version_ge_1_3 = ['-DVERSION_GE_1_3'] -version_ge_1_5 = [] -if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4): - version_ge_1_5 = ['-DVERSION_GE_1_5'] -version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5 +version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5'] if "--cuda_ext" in sys.argv: - if TORCH_MAJOR == 0: - raise RuntimeError("--cuda_ext requires Pytorch 1.0 or later, " - "found torch.__version__ = {}".format(torch.__version__)) - sys.argv.remove("--cuda_ext") if CUDA_HOME is None: @@ -103,19 +95,66 @@ if "--cuda_ext" in sys.argv: else: check_cuda_torch_binary_vs_bare_metal(CUDA_HOME) - ext_modules.append( - CUDAExtension(name='colossal_C', - sources=['csrc/colossal_C_frontend.cpp', - 'csrc/multi_tensor_sgd_kernel.cu', - 'csrc/multi_tensor_scale_kernel.cu', - 'csrc/multi_tensor_adam.cu', - 'csrc/multi_tensor_l2norm_kernel.cu', - 'csrc/multi_tensor_lamb.cu'], - extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, - 'nvcc': ['-lineinfo', - '-O3', - # '--resource-usage', - '--use_fast_math'] + version_dependent_macros})) + def cuda_ext_helper(name, sources, extra_cuda_flags): + return CUDAExtension(name=name, + sources=[os.path.join('colossalai/kernel/cuda_native/csrc', path) for path in sources], + include_dirs=[os.path.join( + this_dir, 'colossalai/kernel/cuda_native/csrc/kernels/include')], + extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, + 'nvcc': append_nvcc_threads(['-O3', + '--use_fast_math'] + version_dependent_macros + extra_cuda_flags)}) + + ext_modules.append(cuda_ext_helper('colossal_C', + ['colossal_C_frontend.cpp', + 'multi_tensor_sgd_kernel.cu', + 'multi_tensor_scale_kernel.cu', + 'multi_tensor_adam.cu', + 'multi_tensor_l2norm_kernel.cu', + 'multi_tensor_lamb.cu'], + ['-lineinfo'])) + + cc_flag = ['-gencode', 'arch=compute_70,code=sm_70'] + _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME) + if int(bare_metal_major) >= 11: + cc_flag.append('-gencode') + cc_flag.append('arch=compute_80,code=sm_80') + + extra_cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__', + '-U__CUDA_NO_HALF_CONVERSIONS__', + '--expt-relaxed-constexpr', + '--expt-extended-lambda'] + + ext_modules.append(cuda_ext_helper('colossal_scaled_upper_triang_masked_softmax', + ['scaled_upper_triang_masked_softmax.cpp', + 'scaled_upper_triang_masked_softmax_cuda.cu'], + extra_cuda_flags + cc_flag)) + + ext_modules.append(cuda_ext_helper('colossal_scaled_masked_softmax', + ['scaled_masked_softmax.cpp', 'scaled_masked_softmax_cuda.cu'], + extra_cuda_flags + cc_flag)) + + extra_cuda_flags = ['-maxrregcount=50'] + + ext_modules.append(cuda_ext_helper('colossal_layer_norm_cuda', + ['layer_norm_cuda.cpp', 'layer_norm_cuda_kernel.cu'], + extra_cuda_flags + cc_flag)) + + extra_cuda_flags = ['-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', + '-U__CUDA_NO_HALF_CONVERSIONS__', + '-U__CUDA_NO_HALF2_OPERATORS__', + '-DTHRUST_IGNORE_CUB_VERSION_CHECK'] + + ext_modules.append(cuda_ext_helper('colossal_multihead_attention', + ['multihead_attention_1d.cpp', + 'kernels/cublas_wrappers.cu', + 'kernels/transform_kernels.cu', + 'kernels/dropout_kernels.cu', + 'kernels/normalize_kernels.cu', + 'kernels/softmax_kernels.cu', + 'kernels/general_kernels.cu', + 'kernels/cuda_util.cu'], + extra_cuda_flags + cc_flag)) install_requires = fetch_requirements('requirements/requirements.txt') @@ -123,14 +162,17 @@ install_requires = fetch_requirements('requirements/requirements.txt') setup( name='colossalai', version='0.0.1-beta', - packages=find_packages(exclude=('csrc', + packages=find_packages(exclude=('benchmark', + 'docker', 'tests', 'docs', + 'examples', 'tests', + 'scripts', + 'requirements', '*.egg-info',)), description='An integrated large-scale model training system with efficient parallelization techniques', ext_modules=ext_modules, cmdclass={'build_ext': BuildExtension} if ext_modules else {}, - package_data={'colossalai': ['kernel/cuda_native/csrc/*']}, install_requires=install_requires, -) \ No newline at end of file +)