mirror of https://github.com/hpcaitech/ColossalAI
refactor kernel (#142)
parent
4a3d3446b0
commit
f68eddfb3d
|
@ -1,4 +1,3 @@
|
|||
include *.txt README.md
|
||||
recursive-include requirements *.txt
|
||||
recursive-include colossalai *.cpp *.h *.cu *.tr *.cuh *.cc
|
||||
recursive-include csrc *.cpp *.h *.cu *.tr *.cuh *.cc
|
||||
recursive-include colossalai *.cpp *.h *.cu *.tr *.cuh *.cc
|
|
@ -1,8 +1,5 @@
|
|||
from .jit.bias_dropout_add import bias_dropout_add_fused_train, bias_dropout_add_fused_inference
|
||||
from .jit.bias_gelu import bias_gelu_impl
|
||||
from .cuda_native import LayerNorm, FusedScaleMaskSoftmax, MultiHeadAttention
|
||||
|
||||
__all__ = [
|
||||
"bias_dropout_add_fused_train", "bias_dropout_add_fused_inference", "bias_gelu_impl",
|
||||
"LayerNorm", "FusedScaleMaskSoftmax", "MultiHeadAttention"
|
||||
]
|
||||
|
|
|
@ -1,17 +1,3 @@
|
|||
from .builder import _build_cuda_native_kernel
|
||||
|
||||
CUDA_NATIVE_KERNEL_BUILD = False
|
||||
|
||||
|
||||
def build_cuda_native_kernel():
|
||||
global CUDA_NATIVE_KERNEL_BUILD
|
||||
if CUDA_NATIVE_KERNEL_BUILD == False:
|
||||
_build_cuda_native_kernel()
|
||||
CUDA_NATIVE_KERNEL_BUILD = True
|
||||
|
||||
|
||||
build_cuda_native_kernel()
|
||||
|
||||
from .layer_norm import MixedFusedLayerNorm as LayerNorm
|
||||
from .scaled_softmax import FusedScaleMaskSoftmax
|
||||
from .multihead_attention import MultiHeadAttention
|
||||
from .multihead_attention import MultiHeadAttention
|
||||
|
|
|
@ -1,114 +0,0 @@
|
|||
import os
|
||||
import pathlib
|
||||
import subprocess
|
||||
|
||||
from torch.utils import cpp_extension
|
||||
|
||||
# Setting this param to a list has a problem of generating different
|
||||
# compilation commands (with diferent order of architectures) and
|
||||
# leading to recompilation of fused kernels. Set it to empty string
|
||||
# to avoid recompilation and assign arch flags explicity in
|
||||
# extra_cuda_cflags below
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = ""
|
||||
|
||||
|
||||
def _build_cuda_native_kernel():
|
||||
|
||||
# Check if cuda 11 is installed for compute capability 8.0
|
||||
cc_flag = []
|
||||
_, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
|
||||
if int(bare_metal_major) >= 11:
|
||||
cc_flag.append('-gencode')
|
||||
cc_flag.append('arch=compute_80,code=sm_80')
|
||||
|
||||
# Build path
|
||||
basepath = pathlib.Path(__file__).parent.absolute()
|
||||
srcpath = basepath / 'csrc'
|
||||
buildpath = basepath / 'build'
|
||||
_create_build_dir(buildpath)
|
||||
|
||||
# Helper function to build the kernels.
|
||||
def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
|
||||
return cpp_extension.load(
|
||||
name=name,
|
||||
sources=sources,
|
||||
build_directory=buildpath,
|
||||
extra_cflags=[
|
||||
'-O3',
|
||||
],
|
||||
extra_include_paths=[str(srcpath / 'kernels' / 'include')],
|
||||
extra_cuda_cflags=['-O3', '-gencode', 'arch=compute_70,code=sm_70', '--use_fast_math'] +
|
||||
extra_cuda_flags + cc_flag,
|
||||
verbose=False)
|
||||
|
||||
# ==============
|
||||
# Fused softmax.
|
||||
# ==============
|
||||
|
||||
extra_cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__',
|
||||
'-U__CUDA_NO_HALF_CONVERSIONS__',
|
||||
'--expt-relaxed-constexpr',
|
||||
'--expt-extended-lambda']
|
||||
|
||||
# Upper triangular softmax.
|
||||
sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp',
|
||||
srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu']
|
||||
colossal_scaled_upper_triang_masked_softmax = _cpp_extention_load_helper(
|
||||
"colossal_scaled_upper_triang_masked_softmax",
|
||||
sources, extra_cuda_flags)
|
||||
|
||||
# Masked softmax.
|
||||
sources=[srcpath / 'scaled_masked_softmax.cpp',
|
||||
srcpath / 'scaled_masked_softmax_cuda.cu']
|
||||
colossal_scaled_masked_softmax = _cpp_extention_load_helper(
|
||||
"colossal_scaled_masked_softmax", sources, extra_cuda_flags)
|
||||
|
||||
# =================================
|
||||
# Mixed precision fused layer norm.
|
||||
# =================================
|
||||
|
||||
extra_cuda_flags = ['-maxrregcount=50']
|
||||
sources = [srcpath / 'layer_norm_cuda.cpp', srcpath / 'layer_norm_cuda_kernel.cu']
|
||||
colossal_layer_norm_cuda = _cpp_extention_load_helper("colossal_layer_norm_cuda", sources,
|
||||
extra_cuda_flags)
|
||||
|
||||
# ==========================================
|
||||
# Mixed precision Transformer Encoder Layer.
|
||||
# ==========================================
|
||||
|
||||
extra_cuda_flags = ['-std=c++14',
|
||||
'-U__CUDA_NO_HALF_OPERATORS__',
|
||||
'-U__CUDA_NO_HALF_CONVERSIONS__',
|
||||
'-U__CUDA_NO_HALF2_OPERATORS__',
|
||||
'-DTHRUST_IGNORE_CUB_VERSION_CHECK']
|
||||
|
||||
sources = [srcpath / 'multihead_attention_1d.cpp']
|
||||
kernel_sources = ["cublas_wrappers.cu",
|
||||
"transform_kernels.cu",
|
||||
"dropout_kernels.cu",
|
||||
"normalize_kernels.cu",
|
||||
"softmax_kernels.cu",
|
||||
"general_kernels.cu",
|
||||
"cuda_util.cu"]
|
||||
sources += [(srcpath / 'kernels' / cu_file) for cu_file in kernel_sources]
|
||||
colossal_multihead_attention = _cpp_extention_load_helper("colossal_multihead_attention", sources,
|
||||
extra_cuda_flags)
|
||||
|
||||
|
||||
def _get_cuda_bare_metal_version(cuda_dir):
|
||||
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
|
||||
output = raw_output.split()
|
||||
release_idx = output.index("release") + 1
|
||||
release = output[release_idx].split(".")
|
||||
bare_metal_major = release[0]
|
||||
bare_metal_minor = release[1][0]
|
||||
|
||||
return raw_output, bare_metal_major, bare_metal_minor
|
||||
|
||||
|
||||
def _create_build_dir(buildpath):
|
||||
try:
|
||||
os.mkdir(buildpath)
|
||||
except OSError:
|
||||
if not os.path.isdir(buildpath):
|
||||
print(f"Creation of the build directory {buildpath} failed")
|
|
@ -1,7 +1,4 @@
|
|||
/*This code from NVIDIA apex:
|
||||
* https://github.com/NVIDIA/apex
|
||||
* with minor changes. */
|
||||
|
||||
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h
|
||||
#ifndef TORCH_CHECK
|
||||
#define TORCH_CHECK AT_CHECK
|
||||
#endif
|
||||
|
|
|
@ -71,3 +71,202 @@
|
|||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
|
||||
}
|
||||
|
||||
// Forward/backward compatiblity hack around
|
||||
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
|
||||
// pending more future-proof guidance from upstream.
|
||||
// struct TypeShim
|
||||
// {
|
||||
// const at::Type& payload;
|
||||
// TypeShim(const at::Type& type) : payload(type) {}
|
||||
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
|
||||
// operator const at::Type&(){ return payload; };
|
||||
// // Enable dispatch switch statements to take *this directly for post-3aeb78
|
||||
// //operator at::ScalarType(){ return payload.; };
|
||||
// };
|
||||
|
||||
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
|
||||
switch (TYPE) \
|
||||
{ \
|
||||
case at::ScalarType::Float: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
|
||||
switch (TYPE) \
|
||||
{ \
|
||||
case at::ScalarType::Float: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Byte: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = uint8_t; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
|
||||
switch (TYPE) \
|
||||
{ \
|
||||
case at::ScalarType::Double: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = double; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Float: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
|
||||
switch (TYPE) \
|
||||
{ \
|
||||
case at::ScalarType::Double: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = double; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Float: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T reduce_block_into_lanes(T *x,
|
||||
T val,
|
||||
int lanes = 1,
|
||||
bool share_result = false) // lanes is intended to be <= 32.
|
||||
{
|
||||
int tid = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
|
||||
|
||||
if (blockSize >= 64)
|
||||
{
|
||||
x[tid] = val;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = (blockSize >> 1); i >= 64; i >>= 1)
|
||||
{
|
||||
if (tid < i)
|
||||
x[tid] = x[tid] + x[tid + i];
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
T final;
|
||||
|
||||
if (tid < 32)
|
||||
{
|
||||
if (blockSize >= 64)
|
||||
final = x[tid] + x[tid + 32];
|
||||
else
|
||||
final = val;
|
||||
// __SYNCWARP();
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 16; i >= lanes; i >>= 1)
|
||||
final = final + __shfl_down_sync(0xffffffff, final, i);
|
||||
}
|
||||
|
||||
if (share_result)
|
||||
{
|
||||
if (tid < lanes)
|
||||
x[tid] = final; // EpilogueOp
|
||||
// Make sure the smem result is visible to all warps.
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
return final;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T reduce_block_into_lanes_max_op(T *x,
|
||||
T val,
|
||||
int lanes = 1,
|
||||
bool share_result = false) // lanes is intended to be <= 32.
|
||||
{
|
||||
int tid = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
|
||||
|
||||
if (blockSize >= 64)
|
||||
{
|
||||
x[tid] = val;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = (blockSize >> 1); i >= 64; i >>= 1)
|
||||
{
|
||||
if (tid < i)
|
||||
x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
T final;
|
||||
|
||||
if (tid < 32)
|
||||
{
|
||||
if (blockSize >= 64)
|
||||
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
|
||||
else
|
||||
final = val;
|
||||
// __SYNCWARP();
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 16; i >= lanes; i >>= 1)
|
||||
final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
|
||||
}
|
||||
|
||||
if (share_result)
|
||||
{
|
||||
if (tid < lanes)
|
||||
x[tid] = final; // EpilogueOp
|
||||
// Make sure the smem result is visible to all warps.
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
return final;
|
||||
}
|
|
@ -34,10 +34,10 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
|
|||
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
|
||||
grad_input = grad_weight = grad_bias = None
|
||||
grad_input, grad_weight, grad_bias \
|
||||
= colossal_layer_norm_cuda.backward_affine(
|
||||
grad_output.contiguous(), mean, invvar,
|
||||
input_, ctx.normalized_shape,
|
||||
weight_, bias_, ctx.eps)
|
||||
= colossal_layer_norm_cuda.backward_affine(
|
||||
grad_output.contiguous(), mean, invvar,
|
||||
input_, ctx.normalized_shape,
|
||||
weight_, bias_, ctx.eps)
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None
|
||||
|
||||
|
@ -48,7 +48,11 @@ class MixedFusedLayerNorm(torch.nn.Module):
|
|||
super(MixedFusedLayerNorm, self).__init__()
|
||||
|
||||
global colossal_layer_norm_cuda
|
||||
colossal_layer_norm_cuda = importlib.import_module("colossal_layer_norm_cuda")
|
||||
if colossal_layer_norm_cuda is None:
|
||||
try:
|
||||
colossal_layer_norm_cuda = importlib.import_module("colossal_layer_norm_cuda")
|
||||
except ImportError:
|
||||
raise RuntimeError('MixedFusedLayerNorm requires cuda extensions')
|
||||
|
||||
if isinstance(normalized_shape, numbers.Integral):
|
||||
normalized_shape = (normalized_shape,)
|
||||
|
|
|
@ -34,6 +34,7 @@ def calc_offset(sizes):
|
|||
|
||||
colossal_multihead_attention = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
max_batch_tokens: int # max batch token numbers
|
||||
|
@ -94,7 +95,7 @@ class MultiHeadAttention1DFunc(Function):
|
|||
input_mask = input_mask.to(torch.half)
|
||||
grad_input, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, \
|
||||
grad_out_proj_bias, grad_norm_weight, grad_norm_bias = backward_func(
|
||||
ctx.config.layer_id, grad_output, output, input, input_mask, in_proj_weight, \
|
||||
ctx.config.layer_id, grad_output, output, input, input_mask, in_proj_weight,
|
||||
in_proj_bias, out_proj_weight, out_proj_bias, norm_weight, norm_bias)
|
||||
|
||||
return (grad_input, None, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight,
|
||||
|
@ -142,7 +143,10 @@ class MultiHeadAttention(nn.Module):
|
|||
# Load cuda modules if needed
|
||||
global colossal_multihead_attention
|
||||
if colossal_multihead_attention is None:
|
||||
colossal_multihead_attention = importlib.import_module("colossal_multihead_attention")
|
||||
try:
|
||||
colossal_multihead_attention = importlib.import_module("colossal_multihead_attention")
|
||||
except ImportError:
|
||||
raise RuntimeError('MultiHeadAttention requires cuda extensions')
|
||||
|
||||
# create the layer in cuda kernels.
|
||||
cuda_module = colossal_multihead_attention
|
||||
|
@ -210,14 +214,14 @@ class MultiHeadAttention(nn.Module):
|
|||
with torch.no_grad():
|
||||
self.in_proj_weight.copy_(
|
||||
attn_qkvw_global.view(3, hs, hs)[:,
|
||||
int(hs * rank_in_pg /
|
||||
self.pg_size):int(hs * (rank_in_pg + 1) /
|
||||
self.pg_size), :])
|
||||
int(hs * rank_in_pg /
|
||||
self.pg_size):int(hs * (rank_in_pg + 1) /
|
||||
self.pg_size), :])
|
||||
self.in_proj_bias.copy_(
|
||||
attn_qkvb_global.view(3, hs)[:,
|
||||
int(hs * rank_in_pg /
|
||||
self.pg_size):int(hs * (rank_in_pg + 1) /
|
||||
self.pg_size)])
|
||||
int(hs * rank_in_pg /
|
||||
self.pg_size):int(hs * (rank_in_pg + 1) /
|
||||
self.pg_size)])
|
||||
|
||||
attn_ow_global = torch.empty(hs, hs)
|
||||
nn.init.xavier_uniform_(attn_ow_global, 1.0)
|
||||
|
@ -226,9 +230,9 @@ class MultiHeadAttention(nn.Module):
|
|||
attn_ow_global = attn_ow_global.cpu()
|
||||
with torch.no_grad():
|
||||
self.out_proj_weight.copy_(attn_ow_global[:,
|
||||
int(hs * rank_in_pg /
|
||||
self.pg_size):int(hs * (rank_in_pg + 1) /
|
||||
self.pg_size)])
|
||||
int(hs * rank_in_pg /
|
||||
self.pg_size):int(hs * (rank_in_pg + 1) /
|
||||
self.pg_size)])
|
||||
|
||||
else:
|
||||
attn_qkvw = self.in_proj_weight.view(-1, hs)
|
||||
|
|
|
@ -21,7 +21,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
|
|||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, scale):
|
||||
import colossal_scaled_upper_triang_masked_softmax
|
||||
try:
|
||||
import colossal_scaled_upper_triang_masked_softmax
|
||||
except ImportError:
|
||||
raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions')
|
||||
|
||||
scale_t = torch.tensor([scale])
|
||||
softmax_results = colossal_scaled_upper_triang_masked_softmax.forward(
|
||||
|
@ -33,7 +36,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
|
|||
|
||||
@staticmethod
|
||||
def backward(ctx, output_grads):
|
||||
import colossal_scaled_upper_triang_masked_softmax
|
||||
try:
|
||||
import colossal_scaled_upper_triang_masked_softmax
|
||||
except ImportError:
|
||||
raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions')
|
||||
|
||||
softmax_results, scale_t = ctx.saved_tensors
|
||||
input_grads = colossal_scaled_upper_triang_masked_softmax.backward(
|
||||
|
@ -53,7 +59,10 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
|
|||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, mask, scale):
|
||||
import colossal_scaled_masked_softmax
|
||||
try:
|
||||
import colossal_scaled_masked_softmax
|
||||
except ImportError:
|
||||
raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions')
|
||||
|
||||
scale_t = torch.tensor([scale])
|
||||
|
||||
|
@ -63,7 +72,10 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
|
|||
|
||||
@staticmethod
|
||||
def backward(ctx, output_grads):
|
||||
import colossal_scaled_masked_softmax
|
||||
try:
|
||||
import colossal_scaled_masked_softmax
|
||||
except ImportError:
|
||||
raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions')
|
||||
|
||||
softmax_results, scale_t = ctx.saved_tensors
|
||||
|
||||
|
@ -179,6 +191,9 @@ class FusedScaleMaskSoftmax(nn.Module):
|
|||
|
||||
@staticmethod
|
||||
def get_batch_per_block(sq, sk, b, np):
|
||||
import colossal_scaled_masked_softmax
|
||||
try:
|
||||
import colossal_scaled_masked_softmax
|
||||
except ImportError:
|
||||
raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions')
|
||||
|
||||
return colossal_scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)
|
||||
|
|
|
@ -1,3 +1,8 @@
|
|||
from .option import _set_jit_fusion_options
|
||||
from .bias_dropout_add import bias_dropout_add_fused_train, bias_dropout_add_fused_inference
|
||||
from .bias_gelu import bias_gelu_impl
|
||||
_set_jit_fusion_options()
|
||||
|
||||
_set_jit_fusion_options()
|
||||
__all__ = [
|
||||
"bias_dropout_add_fused_train", "bias_dropout_add_fused_inference", "bias_gelu_impl",
|
||||
]
|
||||
|
|
|
@ -2,6 +2,7 @@ import torch
|
|||
|
||||
JIT_OPTIONS_SET = False
|
||||
|
||||
|
||||
def _set_jit_fusion_options():
|
||||
"""Set PyTorch JIT layer fusion options."""
|
||||
global JIT_OPTIONS_SET
|
||||
|
|
|
@ -65,8 +65,7 @@ class FusedAdam(torch.optim.Optimizer):
|
|||
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
self.multi_tensor_adam = colossal_C.multi_tensor_adam
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'apex.optimizers.FusedAdam requires cuda extensions')
|
||||
raise RuntimeError('FusedAdam requires cuda extensions')
|
||||
|
||||
def zero_grad(self):
|
||||
if self.set_grad_none:
|
||||
|
|
|
@ -73,8 +73,7 @@ class FusedLAMB(torch.optim.Optimizer):
|
|||
[0], dtype=torch.int, device=self.param_groups[0]["params"][0].device)
|
||||
self.multi_tensor_lamb = colossal_C.multi_tensor_lamb
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'apex.optimizers.FusedLAMB requires cuda extensions')
|
||||
raise RuntimeError('FusedLAMB requires cuda extensions')
|
||||
|
||||
self.adam_w_mode = 1 if adam_w_mode else 0
|
||||
self.set_grad_none = set_grad_none
|
||||
|
|
|
@ -90,8 +90,7 @@ class FusedSGD(Optimizer):
|
|||
[0], dtype=torch.int, device=self.param_groups[0]["params"][0].device)
|
||||
self.multi_tensor_sgd = colossal_C.multi_tensor_sgd
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'apex.optimizers.FusedSGD requires cuda extensions')
|
||||
raise RuntimeError('FusedSGD requires cuda extensions')
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(FusedSGD, self).__setstate__(state)
|
||||
|
|
|
@ -1,10 +0,0 @@
|
|||
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h
|
||||
#ifndef TORCH_CHECK
|
||||
#define TORCH_CHECK AT_CHECK
|
||||
#endif
|
||||
|
||||
#ifdef VERSION_GE_1_3
|
||||
#define DATA_PTR data_ptr
|
||||
#else
|
||||
#define DATA_PTR data
|
||||
#endif
|
202
csrc/type_shim.h
202
csrc/type_shim.h
|
@ -1,202 +0,0 @@
|
|||
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h
|
||||
#include <ATen/ATen.h>
|
||||
#include "compat.h"
|
||||
|
||||
// Forward/backward compatiblity hack around
|
||||
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
|
||||
// pending more future-proof guidance from upstream.
|
||||
// struct TypeShim
|
||||
// {
|
||||
// const at::Type& payload;
|
||||
// TypeShim(const at::Type& type) : payload(type) {}
|
||||
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
|
||||
// operator const at::Type&(){ return payload; };
|
||||
// // Enable dispatch switch statements to take *this directly for post-3aeb78
|
||||
// //operator at::ScalarType(){ return payload.; };
|
||||
// };
|
||||
|
||||
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
|
||||
switch (TYPE) \
|
||||
{ \
|
||||
case at::ScalarType::Float: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
|
||||
switch (TYPE) \
|
||||
{ \
|
||||
case at::ScalarType::Float: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Byte: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = uint8_t; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
|
||||
switch (TYPE) \
|
||||
{ \
|
||||
case at::ScalarType::Double: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = double; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Float: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
|
||||
switch (TYPE) \
|
||||
{ \
|
||||
case at::ScalarType::Double: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = double; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Float: \
|
||||
{ \
|
||||
using scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T reduce_block_into_lanes(T *x,
|
||||
T val,
|
||||
int lanes = 1,
|
||||
bool share_result = false) // lanes is intended to be <= 32.
|
||||
{
|
||||
int tid = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
|
||||
|
||||
if (blockSize >= 64)
|
||||
{
|
||||
x[tid] = val;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = (blockSize >> 1); i >= 64; i >>= 1)
|
||||
{
|
||||
if (tid < i)
|
||||
x[tid] = x[tid] + x[tid + i];
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
T final;
|
||||
|
||||
if (tid < 32)
|
||||
{
|
||||
if (blockSize >= 64)
|
||||
final = x[tid] + x[tid + 32];
|
||||
else
|
||||
final = val;
|
||||
// __SYNCWARP();
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 16; i >= lanes; i >>= 1)
|
||||
final = final + __shfl_down_sync(0xffffffff, final, i);
|
||||
}
|
||||
|
||||
if (share_result)
|
||||
{
|
||||
if (tid < lanes)
|
||||
x[tid] = final; // EpilogueOp
|
||||
// Make sure the smem result is visible to all warps.
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
return final;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T reduce_block_into_lanes_max_op(T *x,
|
||||
T val,
|
||||
int lanes = 1,
|
||||
bool share_result = false) // lanes is intended to be <= 32.
|
||||
{
|
||||
int tid = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
|
||||
|
||||
if (blockSize >= 64)
|
||||
{
|
||||
x[tid] = val;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = (blockSize >> 1); i >= 64; i >>= 1)
|
||||
{
|
||||
if (tid < i)
|
||||
x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
T final;
|
||||
|
||||
if (tid < 32)
|
||||
{
|
||||
if (blockSize >= 64)
|
||||
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
|
||||
else
|
||||
final = val;
|
||||
// __SYNCWARP();
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 16; i >= lanes; i >>= 1)
|
||||
final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
|
||||
}
|
||||
|
||||
if (share_result)
|
||||
{
|
||||
if (tid < lanes)
|
||||
x[tid] = final; // EpilogueOp
|
||||
// Make sure the smem result is visible to all warps.
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
return final;
|
||||
}
|
114
setup.py
114
setup.py
|
@ -11,8 +11,7 @@ this_dir = os.path.dirname(os.path.abspath(__file__))
|
|||
|
||||
|
||||
def get_cuda_bare_metal_version(cuda_dir):
|
||||
raw_output = subprocess.check_output(
|
||||
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
|
||||
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
|
||||
output = raw_output.split()
|
||||
release_idx = output.index("release") + 1
|
||||
release = output[release_idx].split(".")
|
||||
|
@ -23,8 +22,7 @@ def get_cuda_bare_metal_version(cuda_dir):
|
|||
|
||||
|
||||
def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
|
||||
raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(
|
||||
cuda_dir)
|
||||
raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
|
||||
torch_binary_major = torch.version.cuda.split(".")[0]
|
||||
torch_binary_minor = torch.version.cuda.split(".")[1]
|
||||
|
||||
|
@ -40,6 +38,13 @@ def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
|
|||
"You can try commenting out this check (at your own risk).")
|
||||
|
||||
|
||||
def append_nvcc_threads(nvcc_extra_args):
|
||||
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
|
||||
return nvcc_extra_args + ["--threads", "4"]
|
||||
return nvcc_extra_args
|
||||
|
||||
|
||||
def fetch_requirements(path):
|
||||
with open(path, 'r') as fd:
|
||||
return [r.strip() for r in fd.readlines()]
|
||||
|
@ -67,8 +72,8 @@ print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
|
|||
TORCH_MAJOR = int(torch.__version__.split('.')[0])
|
||||
TORCH_MINOR = int(torch.__version__.split('.')[1])
|
||||
|
||||
if TORCH_MAJOR == 0 and TORCH_MINOR < 4:
|
||||
raise RuntimeError("Colossal-AI requires Pytorch 0.4 or newer.\n" +
|
||||
if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 8):
|
||||
raise RuntimeError("Colossal-AI requires Pytorch 1.8 or newer.\n" +
|
||||
"The latest stable release can be obtained from https://pytorch.org/")
|
||||
|
||||
cmdclass = {}
|
||||
|
@ -79,22 +84,9 @@ ext_modules = []
|
|||
# and
|
||||
# https://github.com/NVIDIA/apex/issues/456
|
||||
# https://github.com/pytorch/pytorch/commit/eb7b39e02f7d75c26d8a795ea8c7fd911334da7e#diff-4632522f237f1e4e728cb824300403ac
|
||||
version_ge_1_1 = []
|
||||
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0):
|
||||
version_ge_1_1 = ['-DVERSION_GE_1_1']
|
||||
version_ge_1_3 = []
|
||||
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
|
||||
version_ge_1_3 = ['-DVERSION_GE_1_3']
|
||||
version_ge_1_5 = []
|
||||
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
|
||||
version_ge_1_5 = ['-DVERSION_GE_1_5']
|
||||
version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5
|
||||
version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5']
|
||||
|
||||
if "--cuda_ext" in sys.argv:
|
||||
if TORCH_MAJOR == 0:
|
||||
raise RuntimeError("--cuda_ext requires Pytorch 1.0 or later, "
|
||||
"found torch.__version__ = {}".format(torch.__version__))
|
||||
|
||||
sys.argv.remove("--cuda_ext")
|
||||
|
||||
if CUDA_HOME is None:
|
||||
|
@ -103,19 +95,66 @@ if "--cuda_ext" in sys.argv:
|
|||
else:
|
||||
check_cuda_torch_binary_vs_bare_metal(CUDA_HOME)
|
||||
|
||||
ext_modules.append(
|
||||
CUDAExtension(name='colossal_C',
|
||||
sources=['csrc/colossal_C_frontend.cpp',
|
||||
'csrc/multi_tensor_sgd_kernel.cu',
|
||||
'csrc/multi_tensor_scale_kernel.cu',
|
||||
'csrc/multi_tensor_adam.cu',
|
||||
'csrc/multi_tensor_l2norm_kernel.cu',
|
||||
'csrc/multi_tensor_lamb.cu'],
|
||||
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
|
||||
'nvcc': ['-lineinfo',
|
||||
'-O3',
|
||||
# '--resource-usage',
|
||||
'--use_fast_math'] + version_dependent_macros}))
|
||||
def cuda_ext_helper(name, sources, extra_cuda_flags):
|
||||
return CUDAExtension(name=name,
|
||||
sources=[os.path.join('colossalai/kernel/cuda_native/csrc', path) for path in sources],
|
||||
include_dirs=[os.path.join(
|
||||
this_dir, 'colossalai/kernel/cuda_native/csrc/kernels/include')],
|
||||
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
|
||||
'nvcc': append_nvcc_threads(['-O3',
|
||||
'--use_fast_math'] + version_dependent_macros + extra_cuda_flags)})
|
||||
|
||||
ext_modules.append(cuda_ext_helper('colossal_C',
|
||||
['colossal_C_frontend.cpp',
|
||||
'multi_tensor_sgd_kernel.cu',
|
||||
'multi_tensor_scale_kernel.cu',
|
||||
'multi_tensor_adam.cu',
|
||||
'multi_tensor_l2norm_kernel.cu',
|
||||
'multi_tensor_lamb.cu'],
|
||||
['-lineinfo']))
|
||||
|
||||
cc_flag = ['-gencode', 'arch=compute_70,code=sm_70']
|
||||
_, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if int(bare_metal_major) >= 11:
|
||||
cc_flag.append('-gencode')
|
||||
cc_flag.append('arch=compute_80,code=sm_80')
|
||||
|
||||
extra_cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__',
|
||||
'-U__CUDA_NO_HALF_CONVERSIONS__',
|
||||
'--expt-relaxed-constexpr',
|
||||
'--expt-extended-lambda']
|
||||
|
||||
ext_modules.append(cuda_ext_helper('colossal_scaled_upper_triang_masked_softmax',
|
||||
['scaled_upper_triang_masked_softmax.cpp',
|
||||
'scaled_upper_triang_masked_softmax_cuda.cu'],
|
||||
extra_cuda_flags + cc_flag))
|
||||
|
||||
ext_modules.append(cuda_ext_helper('colossal_scaled_masked_softmax',
|
||||
['scaled_masked_softmax.cpp', 'scaled_masked_softmax_cuda.cu'],
|
||||
extra_cuda_flags + cc_flag))
|
||||
|
||||
extra_cuda_flags = ['-maxrregcount=50']
|
||||
|
||||
ext_modules.append(cuda_ext_helper('colossal_layer_norm_cuda',
|
||||
['layer_norm_cuda.cpp', 'layer_norm_cuda_kernel.cu'],
|
||||
extra_cuda_flags + cc_flag))
|
||||
|
||||
extra_cuda_flags = ['-std=c++14',
|
||||
'-U__CUDA_NO_HALF_OPERATORS__',
|
||||
'-U__CUDA_NO_HALF_CONVERSIONS__',
|
||||
'-U__CUDA_NO_HALF2_OPERATORS__',
|
||||
'-DTHRUST_IGNORE_CUB_VERSION_CHECK']
|
||||
|
||||
ext_modules.append(cuda_ext_helper('colossal_multihead_attention',
|
||||
['multihead_attention_1d.cpp',
|
||||
'kernels/cublas_wrappers.cu',
|
||||
'kernels/transform_kernels.cu',
|
||||
'kernels/dropout_kernels.cu',
|
||||
'kernels/normalize_kernels.cu',
|
||||
'kernels/softmax_kernels.cu',
|
||||
'kernels/general_kernels.cu',
|
||||
'kernels/cuda_util.cu'],
|
||||
extra_cuda_flags + cc_flag))
|
||||
|
||||
|
||||
install_requires = fetch_requirements('requirements/requirements.txt')
|
||||
|
@ -123,14 +162,17 @@ install_requires = fetch_requirements('requirements/requirements.txt')
|
|||
setup(
|
||||
name='colossalai',
|
||||
version='0.0.1-beta',
|
||||
packages=find_packages(exclude=('csrc',
|
||||
packages=find_packages(exclude=('benchmark',
|
||||
'docker',
|
||||
'tests',
|
||||
'docs',
|
||||
'examples',
|
||||
'tests',
|
||||
'scripts',
|
||||
'requirements',
|
||||
'*.egg-info',)),
|
||||
description='An integrated large-scale model training system with efficient parallelization techniques',
|
||||
ext_modules=ext_modules,
|
||||
cmdclass={'build_ext': BuildExtension} if ext_modules else {},
|
||||
package_data={'colossalai': ['kernel/cuda_native/csrc/*']},
|
||||
install_requires=install_requires,
|
||||
)
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue