refactor kernel (#142)

pull/143/head
ver217 2022-01-13 16:47:17 +08:00 committed by GitHub
parent 4a3d3446b0
commit f68eddfb3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 334 additions and 414 deletions

View File

@ -1,4 +1,3 @@
include *.txt README.md
recursive-include requirements *.txt
recursive-include colossalai *.cpp *.h *.cu *.tr *.cuh *.cc
recursive-include csrc *.cpp *.h *.cu *.tr *.cuh *.cc
recursive-include colossalai *.cpp *.h *.cu *.tr *.cuh *.cc

View File

@ -1,8 +1,5 @@
from .jit.bias_dropout_add import bias_dropout_add_fused_train, bias_dropout_add_fused_inference
from .jit.bias_gelu import bias_gelu_impl
from .cuda_native import LayerNorm, FusedScaleMaskSoftmax, MultiHeadAttention
__all__ = [
"bias_dropout_add_fused_train", "bias_dropout_add_fused_inference", "bias_gelu_impl",
"LayerNorm", "FusedScaleMaskSoftmax", "MultiHeadAttention"
]

View File

@ -1,17 +1,3 @@
from .builder import _build_cuda_native_kernel
CUDA_NATIVE_KERNEL_BUILD = False
def build_cuda_native_kernel():
global CUDA_NATIVE_KERNEL_BUILD
if CUDA_NATIVE_KERNEL_BUILD == False:
_build_cuda_native_kernel()
CUDA_NATIVE_KERNEL_BUILD = True
build_cuda_native_kernel()
from .layer_norm import MixedFusedLayerNorm as LayerNorm
from .scaled_softmax import FusedScaleMaskSoftmax
from .multihead_attention import MultiHeadAttention
from .multihead_attention import MultiHeadAttention

View File

@ -1,114 +0,0 @@
import os
import pathlib
import subprocess
from torch.utils import cpp_extension
# Setting this param to a list has a problem of generating different
# compilation commands (with diferent order of architectures) and
# leading to recompilation of fused kernels. Set it to empty string
# to avoid recompilation and assign arch flags explicity in
# extra_cuda_cflags below
os.environ["TORCH_CUDA_ARCH_LIST"] = ""
def _build_cuda_native_kernel():
# Check if cuda 11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')
# Build path
basepath = pathlib.Path(__file__).parent.absolute()
srcpath = basepath / 'csrc'
buildpath = basepath / 'build'
_create_build_dir(buildpath)
# Helper function to build the kernels.
def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
return cpp_extension.load(
name=name,
sources=sources,
build_directory=buildpath,
extra_cflags=[
'-O3',
],
extra_include_paths=[str(srcpath / 'kernels' / 'include')],
extra_cuda_cflags=['-O3', '-gencode', 'arch=compute_70,code=sm_70', '--use_fast_math'] +
extra_cuda_flags + cc_flag,
verbose=False)
# ==============
# Fused softmax.
# ==============
extra_cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda']
# Upper triangular softmax.
sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp',
srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu']
colossal_scaled_upper_triang_masked_softmax = _cpp_extention_load_helper(
"colossal_scaled_upper_triang_masked_softmax",
sources, extra_cuda_flags)
# Masked softmax.
sources=[srcpath / 'scaled_masked_softmax.cpp',
srcpath / 'scaled_masked_softmax_cuda.cu']
colossal_scaled_masked_softmax = _cpp_extention_load_helper(
"colossal_scaled_masked_softmax", sources, extra_cuda_flags)
# =================================
# Mixed precision fused layer norm.
# =================================
extra_cuda_flags = ['-maxrregcount=50']
sources = [srcpath / 'layer_norm_cuda.cpp', srcpath / 'layer_norm_cuda_kernel.cu']
colossal_layer_norm_cuda = _cpp_extention_load_helper("colossal_layer_norm_cuda", sources,
extra_cuda_flags)
# ==========================================
# Mixed precision Transformer Encoder Layer.
# ==========================================
extra_cuda_flags = ['-std=c++14',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'-U__CUDA_NO_HALF2_OPERATORS__',
'-DTHRUST_IGNORE_CUB_VERSION_CHECK']
sources = [srcpath / 'multihead_attention_1d.cpp']
kernel_sources = ["cublas_wrappers.cu",
"transform_kernels.cu",
"dropout_kernels.cu",
"normalize_kernels.cu",
"softmax_kernels.cu",
"general_kernels.cu",
"cuda_util.cu"]
sources += [(srcpath / 'kernels' / cu_file) for cu_file in kernel_sources]
colossal_multihead_attention = _cpp_extention_load_helper("colossal_multihead_attention", sources,
extra_cuda_flags)
def _get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
def _create_build_dir(buildpath):
try:
os.mkdir(buildpath)
except OSError:
if not os.path.isdir(buildpath):
print(f"Creation of the build directory {buildpath} failed")

View File

@ -1,7 +1,4 @@
/*This code from NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif

View File

@ -71,3 +71,202 @@
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch (TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Byte: \
{ \
using scalar_t_##LEVEL = uint8_t; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template <typename T>
__device__ __forceinline__ T reduce_block_into_lanes(T *x,
T val,
int lanes = 1,
bool share_result = false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64)
{
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1)
{
if (tid < i)
x[tid] = x[tid] + x[tid + i];
__syncthreads();
}
T final;
if (tid < 32)
{
if (blockSize >= 64)
final = x[tid] + x[tid + 32];
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
final = final + __shfl_down_sync(0xffffffff, final, i);
}
if (share_result)
{
if (tid < lanes)
x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
template <typename T>
__device__ __forceinline__ T reduce_block_into_lanes_max_op(T *x,
T val,
int lanes = 1,
bool share_result = false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64)
{
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1)
{
if (tid < i)
x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
__syncthreads();
}
T final;
if (tid < 32)
{
if (blockSize >= 64)
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
}
if (share_result)
{
if (tid < lanes)
x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}

View File

@ -34,10 +34,10 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias \
= colossal_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
weight_, bias_, ctx.eps)
= colossal_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
weight_, bias_, ctx.eps)
return grad_input, grad_weight, grad_bias, None, None
@ -48,7 +48,11 @@ class MixedFusedLayerNorm(torch.nn.Module):
super(MixedFusedLayerNorm, self).__init__()
global colossal_layer_norm_cuda
colossal_layer_norm_cuda = importlib.import_module("colossal_layer_norm_cuda")
if colossal_layer_norm_cuda is None:
try:
colossal_layer_norm_cuda = importlib.import_module("colossal_layer_norm_cuda")
except ImportError:
raise RuntimeError('MixedFusedLayerNorm requires cuda extensions')
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)

View File

@ -34,6 +34,7 @@ def calc_offset(sizes):
colossal_multihead_attention = None
@dataclass
class Config:
max_batch_tokens: int # max batch token numbers
@ -94,7 +95,7 @@ class MultiHeadAttention1DFunc(Function):
input_mask = input_mask.to(torch.half)
grad_input, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, \
grad_out_proj_bias, grad_norm_weight, grad_norm_bias = backward_func(
ctx.config.layer_id, grad_output, output, input, input_mask, in_proj_weight, \
ctx.config.layer_id, grad_output, output, input, input_mask, in_proj_weight,
in_proj_bias, out_proj_weight, out_proj_bias, norm_weight, norm_bias)
return (grad_input, None, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight,
@ -142,7 +143,10 @@ class MultiHeadAttention(nn.Module):
# Load cuda modules if needed
global colossal_multihead_attention
if colossal_multihead_attention is None:
colossal_multihead_attention = importlib.import_module("colossal_multihead_attention")
try:
colossal_multihead_attention = importlib.import_module("colossal_multihead_attention")
except ImportError:
raise RuntimeError('MultiHeadAttention requires cuda extensions')
# create the layer in cuda kernels.
cuda_module = colossal_multihead_attention
@ -210,14 +214,14 @@ class MultiHeadAttention(nn.Module):
with torch.no_grad():
self.in_proj_weight.copy_(
attn_qkvw_global.view(3, hs, hs)[:,
int(hs * rank_in_pg /
self.pg_size):int(hs * (rank_in_pg + 1) /
self.pg_size), :])
int(hs * rank_in_pg /
self.pg_size):int(hs * (rank_in_pg + 1) /
self.pg_size), :])
self.in_proj_bias.copy_(
attn_qkvb_global.view(3, hs)[:,
int(hs * rank_in_pg /
self.pg_size):int(hs * (rank_in_pg + 1) /
self.pg_size)])
int(hs * rank_in_pg /
self.pg_size):int(hs * (rank_in_pg + 1) /
self.pg_size)])
attn_ow_global = torch.empty(hs, hs)
nn.init.xavier_uniform_(attn_ow_global, 1.0)
@ -226,9 +230,9 @@ class MultiHeadAttention(nn.Module):
attn_ow_global = attn_ow_global.cpu()
with torch.no_grad():
self.out_proj_weight.copy_(attn_ow_global[:,
int(hs * rank_in_pg /
self.pg_size):int(hs * (rank_in_pg + 1) /
self.pg_size)])
int(hs * rank_in_pg /
self.pg_size):int(hs * (rank_in_pg + 1) /
self.pg_size)])
else:
attn_qkvw = self.in_proj_weight.view(-1, hs)

View File

@ -21,7 +21,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, scale):
import colossal_scaled_upper_triang_masked_softmax
try:
import colossal_scaled_upper_triang_masked_softmax
except ImportError:
raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions')
scale_t = torch.tensor([scale])
softmax_results = colossal_scaled_upper_triang_masked_softmax.forward(
@ -33,7 +36,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
@staticmethod
def backward(ctx, output_grads):
import colossal_scaled_upper_triang_masked_softmax
try:
import colossal_scaled_upper_triang_masked_softmax
except ImportError:
raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions')
softmax_results, scale_t = ctx.saved_tensors
input_grads = colossal_scaled_upper_triang_masked_softmax.backward(
@ -53,7 +59,10 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, mask, scale):
import colossal_scaled_masked_softmax
try:
import colossal_scaled_masked_softmax
except ImportError:
raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions')
scale_t = torch.tensor([scale])
@ -63,7 +72,10 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
@staticmethod
def backward(ctx, output_grads):
import colossal_scaled_masked_softmax
try:
import colossal_scaled_masked_softmax
except ImportError:
raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions')
softmax_results, scale_t = ctx.saved_tensors
@ -179,6 +191,9 @@ class FusedScaleMaskSoftmax(nn.Module):
@staticmethod
def get_batch_per_block(sq, sk, b, np):
import colossal_scaled_masked_softmax
try:
import colossal_scaled_masked_softmax
except ImportError:
raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions')
return colossal_scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)

View File

@ -1,3 +1,8 @@
from .option import _set_jit_fusion_options
from .bias_dropout_add import bias_dropout_add_fused_train, bias_dropout_add_fused_inference
from .bias_gelu import bias_gelu_impl
_set_jit_fusion_options()
_set_jit_fusion_options()
__all__ = [
"bias_dropout_add_fused_train", "bias_dropout_add_fused_inference", "bias_gelu_impl",
]

View File

@ -2,6 +2,7 @@ import torch
JIT_OPTIONS_SET = False
def _set_jit_fusion_options():
"""Set PyTorch JIT layer fusion options."""
global JIT_OPTIONS_SET

View File

@ -65,8 +65,7 @@ class FusedAdam(torch.optim.Optimizer):
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self.multi_tensor_adam = colossal_C.multi_tensor_adam
else:
raise RuntimeError(
'apex.optimizers.FusedAdam requires cuda extensions')
raise RuntimeError('FusedAdam requires cuda extensions')
def zero_grad(self):
if self.set_grad_none:

View File

@ -73,8 +73,7 @@ class FusedLAMB(torch.optim.Optimizer):
[0], dtype=torch.int, device=self.param_groups[0]["params"][0].device)
self.multi_tensor_lamb = colossal_C.multi_tensor_lamb
else:
raise RuntimeError(
'apex.optimizers.FusedLAMB requires cuda extensions')
raise RuntimeError('FusedLAMB requires cuda extensions')
self.adam_w_mode = 1 if adam_w_mode else 0
self.set_grad_none = set_grad_none

View File

@ -90,8 +90,7 @@ class FusedSGD(Optimizer):
[0], dtype=torch.int, device=self.param_groups[0]["params"][0].device)
self.multi_tensor_sgd = colossal_C.multi_tensor_sgd
else:
raise RuntimeError(
'apex.optimizers.FusedSGD requires cuda extensions')
raise RuntimeError('FusedSGD requires cuda extensions')
def __setstate__(self, state):
super(FusedSGD, self).__setstate__(state)

View File

@ -1,10 +0,0 @@
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif

View File

@ -1,202 +0,0 @@
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h
#include <ATen/ATen.h>
#include "compat.h"
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch (TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Byte: \
{ \
using scalar_t_##LEVEL = uint8_t; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template <typename T>
__device__ __forceinline__ T reduce_block_into_lanes(T *x,
T val,
int lanes = 1,
bool share_result = false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64)
{
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1)
{
if (tid < i)
x[tid] = x[tid] + x[tid + i];
__syncthreads();
}
T final;
if (tid < 32)
{
if (blockSize >= 64)
final = x[tid] + x[tid + 32];
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
final = final + __shfl_down_sync(0xffffffff, final, i);
}
if (share_result)
{
if (tid < lanes)
x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
template <typename T>
__device__ __forceinline__ T reduce_block_into_lanes_max_op(T *x,
T val,
int lanes = 1,
bool share_result = false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64)
{
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1)
{
if (tid < i)
x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
__syncthreads();
}
T final;
if (tid < 32)
{
if (blockSize >= 64)
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
}
if (share_result)
{
if (tid < lanes)
x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}

114
setup.py
View File

@ -11,8 +11,7 @@ this_dir = os.path.dirname(os.path.abspath(__file__))
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output(
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
@ -23,8 +22,7 @@ def get_cuda_bare_metal_version(cuda_dir):
def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(
cuda_dir)
raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
torch_binary_major = torch.version.cuda.split(".")[0]
torch_binary_minor = torch.version.cuda.split(".")[1]
@ -40,6 +38,13 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
"You can try commenting out this check (at your own risk).")
def append_nvcc_threads(nvcc_extra_args):
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
return nvcc_extra_args + ["--threads", "4"]
return nvcc_extra_args
def fetch_requirements(path):
with open(path, 'r') as fd:
return [r.strip() for r in fd.readlines()]
@ -67,8 +72,8 @@ print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if TORCH_MAJOR == 0 and TORCH_MINOR < 4:
raise RuntimeError("Colossal-AI requires Pytorch 0.4 or newer.\n" +
if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 8):
raise RuntimeError("Colossal-AI requires Pytorch 1.8 or newer.\n" +
"The latest stable release can be obtained from https://pytorch.org/")
cmdclass = {}
@ -79,22 +84,9 @@ ext_modules = []
# and
# https://github.com/NVIDIA/apex/issues/456
# https://github.com/pytorch/pytorch/commit/eb7b39e02f7d75c26d8a795ea8c7fd911334da7e#diff-4632522f237f1e4e728cb824300403ac
version_ge_1_1 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0):
version_ge_1_1 = ['-DVERSION_GE_1_1']
version_ge_1_3 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
version_ge_1_3 = ['-DVERSION_GE_1_3']
version_ge_1_5 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
version_ge_1_5 = ['-DVERSION_GE_1_5']
version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5
version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5']
if "--cuda_ext" in sys.argv:
if TORCH_MAJOR == 0:
raise RuntimeError("--cuda_ext requires Pytorch 1.0 or later, "
"found torch.__version__ = {}".format(torch.__version__))
sys.argv.remove("--cuda_ext")
if CUDA_HOME is None:
@ -103,19 +95,66 @@ if "--cuda_ext" in sys.argv:
else:
check_cuda_torch_binary_vs_bare_metal(CUDA_HOME)
ext_modules.append(
CUDAExtension(name='colossal_C',
sources=['csrc/colossal_C_frontend.cpp',
'csrc/multi_tensor_sgd_kernel.cu',
'csrc/multi_tensor_scale_kernel.cu',
'csrc/multi_tensor_adam.cu',
'csrc/multi_tensor_l2norm_kernel.cu',
'csrc/multi_tensor_lamb.cu'],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc': ['-lineinfo',
'-O3',
# '--resource-usage',
'--use_fast_math'] + version_dependent_macros}))
def cuda_ext_helper(name, sources, extra_cuda_flags):
return CUDAExtension(name=name,
sources=[os.path.join('colossalai/kernel/cuda_native/csrc', path) for path in sources],
include_dirs=[os.path.join(
this_dir, 'colossalai/kernel/cuda_native/csrc/kernels/include')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc': append_nvcc_threads(['-O3',
'--use_fast_math'] + version_dependent_macros + extra_cuda_flags)})
ext_modules.append(cuda_ext_helper('colossal_C',
['colossal_C_frontend.cpp',
'multi_tensor_sgd_kernel.cu',
'multi_tensor_scale_kernel.cu',
'multi_tensor_adam.cu',
'multi_tensor_l2norm_kernel.cu',
'multi_tensor_lamb.cu'],
['-lineinfo']))
cc_flag = ['-gencode', 'arch=compute_70,code=sm_70']
_, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')
extra_cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda']
ext_modules.append(cuda_ext_helper('colossal_scaled_upper_triang_masked_softmax',
['scaled_upper_triang_masked_softmax.cpp',
'scaled_upper_triang_masked_softmax_cuda.cu'],
extra_cuda_flags + cc_flag))
ext_modules.append(cuda_ext_helper('colossal_scaled_masked_softmax',
['scaled_masked_softmax.cpp', 'scaled_masked_softmax_cuda.cu'],
extra_cuda_flags + cc_flag))
extra_cuda_flags = ['-maxrregcount=50']
ext_modules.append(cuda_ext_helper('colossal_layer_norm_cuda',
['layer_norm_cuda.cpp', 'layer_norm_cuda_kernel.cu'],
extra_cuda_flags + cc_flag))
extra_cuda_flags = ['-std=c++14',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'-U__CUDA_NO_HALF2_OPERATORS__',
'-DTHRUST_IGNORE_CUB_VERSION_CHECK']
ext_modules.append(cuda_ext_helper('colossal_multihead_attention',
['multihead_attention_1d.cpp',
'kernels/cublas_wrappers.cu',
'kernels/transform_kernels.cu',
'kernels/dropout_kernels.cu',
'kernels/normalize_kernels.cu',
'kernels/softmax_kernels.cu',
'kernels/general_kernels.cu',
'kernels/cuda_util.cu'],
extra_cuda_flags + cc_flag))
install_requires = fetch_requirements('requirements/requirements.txt')
@ -123,14 +162,17 @@ install_requires = fetch_requirements('requirements/requirements.txt')
setup(
name='colossalai',
version='0.0.1-beta',
packages=find_packages(exclude=('csrc',
packages=find_packages(exclude=('benchmark',
'docker',
'tests',
'docs',
'examples',
'tests',
'scripts',
'requirements',
'*.egg-info',)),
description='An integrated large-scale model training system with efficient parallelization techniques',
ext_modules=ext_modules,
cmdclass={'build_ext': BuildExtension} if ext_modules else {},
package_data={'colossalai': ['kernel/cuda_native/csrc/*']},
install_requires=install_requires,
)
)