diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml index a34a60669..2cad504f3 100644 --- a/.github/workflows/build_on_pr.yml +++ b/.github/workflows/build_on_pr.yml @@ -140,7 +140,7 @@ jobs: - name: Install Colossal-AI run: | - CUDA_EXT=1 pip install -v -e . + BUILD_EXT=1 pip install -v -e . pip install -r requirements/requirements-test.txt - name: Store Colossal-AI Cache diff --git a/.github/workflows/build_on_schedule.yml b/.github/workflows/build_on_schedule.yml index 03f9c53f1..8c03fee53 100644 --- a/.github/workflows/build_on_schedule.yml +++ b/.github/workflows/build_on_schedule.yml @@ -55,7 +55,7 @@ jobs: if: steps.check-avai.outputs.avai == 'true' run: | [ ! -z "$(ls -A /github/home/cuda_ext_cache/)" ] && cp -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/ - CUDA_EXT=1 pip install -v -e . + BUILD_EXT=1 pip install -v -e . cp -r /__w/ColossalAI/ColossalAI/build /github/home/cuda_ext_cache/ pip install -r requirements/requirements-test.txt diff --git a/MANIFEST.in b/MANIFEST.in index ad26b634a..f0a5611ef 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,4 @@ include *.txt README.md recursive-include requirements *.txt recursive-include colossalai *.cpp *.h *.cu *.tr *.cuh *.cc *.pyi -recursive-include op_builder *.py +recursive-include extensions *.py *.cpp *.h *.cu *.tr *.cuh *.cc *.pyi diff --git a/colossalai/accelerator/base_accelerator.py b/colossalai/accelerator/base_accelerator.py index a550cd7a2..33c113999 100644 --- a/colossalai/accelerator/base_accelerator.py +++ b/colossalai/accelerator/base_accelerator.py @@ -48,6 +48,12 @@ class BaseAccelerator(ABC): # ======================= # device APIs # ======================= + @abstractmethod + def get_version(self) -> str: + """ + Return the version of the accelerator which torch is built against. + """ + @abstractmethod def get_current_device(self) -> torch.device: """ @@ -66,6 +72,7 @@ class BaseAccelerator(ABC): Bind the current process to a device. """ + @abstractmethod def get_device_name(self, device: Union[torch.device, int]) -> str: """ Return the name of the device. diff --git a/colossalai/accelerator/cpu_accelerator.py b/colossalai/accelerator/cpu_accelerator.py index c1f01b4f7..080aa61e8 100644 --- a/colossalai/accelerator/cpu_accelerator.py +++ b/colossalai/accelerator/cpu_accelerator.py @@ -24,6 +24,12 @@ class CpuAccelerator(BaseAccelerator): # ======================= # device APIs # ======================= + def get_version(self) -> str: + """ + Return the version of the accelerator which torch is built against. + """ + return "" + def get_current_device(self) -> torch.device: """ Return the current device. diff --git a/colossalai/accelerator/cuda_accelerator.py b/colossalai/accelerator/cuda_accelerator.py index bdaf53bd5..f1ab487d4 100644 --- a/colossalai/accelerator/cuda_accelerator.py +++ b/colossalai/accelerator/cuda_accelerator.py @@ -21,6 +21,12 @@ class CudaAccelerator(BaseAccelerator): # ======================= # device APIs # ======================= + def get_version(self) -> str: + """ + Return the version of the accelerator which torch is built against. + """ + return torch.version.cuda + def get_current_device(self) -> torch.device: """ Return the current device. diff --git a/colossalai/accelerator/npu_accelerator.py b/colossalai/accelerator/npu_accelerator.py index ba3ddc552..1a86f84cb 100644 --- a/colossalai/accelerator/npu_accelerator.py +++ b/colossalai/accelerator/npu_accelerator.py @@ -27,6 +27,12 @@ class NpuAccelerator(BaseAccelerator): # ======================= # device APIs # ======================= + def get_version(self) -> str: + """ + Return the version of the accelerator which torch is built against. + """ + return torch.version.npu + def get_current_device(self) -> torch.device: """ Return the current device. diff --git a/colossalai/kernel/__init__.py b/colossalai/kernel/__init__.py index 5356fbf48..e69de29bb 100644 --- a/colossalai/kernel/__init__.py +++ b/colossalai/kernel/__init__.py @@ -1,14 +0,0 @@ -from .cpu_adam_loader import CPUAdamLoader -from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention -from .extensions.flash_attention import AttnMaskType -from .flash_attention_loader import ColoAttention, FlashAttentionLoader - -__all__ = [ - "LayerNorm", - "FusedScaleMaskSoftmax", - "MultiHeadAttention", - "CPUAdamLoader", - "FlashAttentionLoader", - "ColoAttention", - "AttnMaskType", -] diff --git a/colossalai/kernel/base_kernel_loader.py b/colossalai/kernel/base_kernel_loader.py deleted file mode 100644 index ff7a43261..000000000 --- a/colossalai/kernel/base_kernel_loader.py +++ /dev/null @@ -1,28 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Dict, List - -from .extensions.base_extension import BaseExtension - - -class BaseKernelLoader(ABC): - """ - Usage: - kernel_loader = KernelLoader() - kernel = kernel_loader.load() - """ - - def __init__(self, extension_map: Dict[str, BaseExtension], supported_device: List[str]): - self._extension_map = extension_map - self._supported_device = supported_device - - def run_checks(self): - # run supported device check and other possible checks - pass - - @abstractmethod - def fetch_kernel(self): - pass - - def load(self): - self.run_checks() - return self.fetch_kernel() diff --git a/colossalai/kernel/cpu_adam_loader.py b/colossalai/kernel/cpu_adam_loader.py deleted file mode 100644 index 0df6bd49b..000000000 --- a/colossalai/kernel/cpu_adam_loader.py +++ /dev/null @@ -1,64 +0,0 @@ -import platform -from collections import OrderedDict - -from .base_kernel_loader import BaseKernelLoader -from .extensions.cpu_adam import ArmCPUAdamExtension, X86CPUAdamExtension - - -class CPUAdamLoader(BaseKernelLoader): - """ - CPU Adam Loader - - Usage: - # init - cpu_adam = CPUAdamLoader().load() - cpu_adam_op = cpu_adam.CPUAdamOptimizer( - alpha, beta1, beta2, epsilon, weight_decay, adamw_mode, - ) - ... - # optim step - cpu_adam_op.step( - step, lr, beta1, beta2, epsilon, weight_decay, bias_correction, - params, grads, exp_avg, exp_avg_sq, loss_scale, - ) - - Args: - func CPUAdamOptimizer: - alpha (float): learning rate. Default to 1e-3. - beta1 (float): coefficients used for computing running averages of gradient. Default to 0.9. - beta2 (float): coefficients used for computing running averages of its square. Default to 0.99. - epsilon (float): term added to the denominator to improve numerical stability. Default to 1e-8. - weight_decay (float): weight decay (L2 penalty). Default to 0. - adamw_mode (bool): whether to use the adamw. Default to True. - func step: - step (int): current step. - lr (float): learning rate. - beta1 (float): coefficients used for computing running averages of gradient. - beta2 (float): coefficients used for computing running averages of its square. - epsilon (float): term added to the denominator to improve numerical stability. - weight_decay (float): weight decay (L2 penalty). - bias_correction (bool): whether to use bias correction. - params (torch.Tensor): parameter. - grads (torch.Tensor): gradient. - exp_avg (torch.Tensor): exp average. - exp_avg_sq (torch.Tensor): exp average square. - loss_scale (float): loss scale value. - """ - - def __init__(self): - super().__init__( - extension_map=OrderedDict( - arm=ArmCPUAdamExtension, - x86=X86CPUAdamExtension, - ), - supported_device=["cpu"], - ) - - def fetch_kernel(self): - if platform.machine() == "x86_64": - kernel = self._extension_map["x86"]().fetch() - elif platform.machine() in ["aarch64", "aarch64_be", "armv8b", "armv8l"]: - kernel = self._extension_map["arm"]().fetch() - else: - raise Exception("not supported") - return kernel diff --git a/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cu b/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cu deleted file mode 100644 index 2b1b366b1..000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cu +++ /dev/null @@ -1,63 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#include "column_remap.cuh" -#include "util.cuh" - -const int SHUF_BLOCKSIZE_X = 256; -const int SHUF_BLOCKSIZE_Y = 16; - -__global__ void column_remap_kernel -( - const half* __restrict__ x, - half* __restrict__ x_new, - const int x_width, - const int x_height, - const uint32_t* x_map -) -{ - int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x; - int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y; - if (x_column >= x_width) return; - //if (x_row >= x_height) return; - - int x_stride = x_width; - int x_idx = x_row * x_stride + x_column; - - int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height); - int x_idx_end = x_row_end * x_stride + x_column; - - int s_column = x_map[x_column]; - int s_idx = x_row * x_stride + s_column; - - while (x_idx < x_idx_end) - { - x_new[x_idx] = x[s_idx]; - x_idx += x_stride; - s_idx += x_stride; - } -} - -// Remap columns in x to correspond to sequential group index before matmul -// -// perform x -> seq_x such that seq_x @ seq_w == x @ w - -void column_remap_cuda -( - const half* x, - half* x_new, - const int x_height, - const int x_width, - const uint32_t* x_map -) -{ - dim3 threads(SHUF_BLOCKSIZE_X, 1, 1); - - dim3 blocks - ( - (x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X, - (x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y, - 1 - ); - - column_remap_kernel<<>>(x, x_new, x_width, x_height, x_map); -} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cuh b/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cuh deleted file mode 100644 index 0364e38c4..000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/column_remap.cuh +++ /dev/null @@ -1,19 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _column_remap_cuh -#define _column_remap_cuh - -#include -#include -#include - -void column_remap_cuda -( - const half* x, - half* x_new, - const int x_height, - const int x_width, - const uint32_t* x_map -); - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/cu_compat.cuh b/colossalai/kernel/cuda_native/csrc/gptq/cu_compat.cuh deleted file mode 100644 index c5258813e..000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/cu_compat.cuh +++ /dev/null @@ -1,58 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _cuda_compat_cuh -#define _cuda_compat_cuh - -// atomicAdd for half types, to support CC < 7.x - -__device__ __forceinline__ void atomicAdd_half(half* address, half val) -{ - unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); - unsigned int old = *address_as_ui; - unsigned int assumed; - - do - { - assumed = old; - __half_raw hsum; - hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); - half tmpres = __hadd(hsum, val); - hsum = __half_raw(tmpres); - old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; - old = atomicCAS(address_as_ui, assumed, old); - } - while (assumed != old); -} - -// atomicAdd for half2 types - -__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) -{ - unsigned int* address_as_ui = (unsigned int*)address; - unsigned int old = *address_as_ui; - unsigned int assumed; - do - { - assumed = old; - half2 old_val = *((half2*)&old); - half2 new_val = __hadd2(old_val, val); - old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); - } - while (assumed != old); -} - -// - -#if defined(__CUDA_ARCH__) || defined(USE_ROCM) -#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) - -__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } - -#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) -__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } -#endif - -#endif -#endif - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cu b/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cu deleted file mode 100644 index 4416027c8..000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cu +++ /dev/null @@ -1,75 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#define _cuda_buffers_cu -#include "cuda_buffers.cuh" - -CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL}; -// __constant__ half2 q4_table[16][256]; -// half2 q4_table_host[16][256]; -// bool q4_table_init = false; - -CudaBuffers::CudaBuffers -( - int _device, - int _temp_state_size, - half* _temp_state, - half* _temp_dq -) : - device(_device), - temp_state_size(_temp_state_size), - temp_state(_temp_state), - temp_dq(_temp_dq) -{ - cudaSetDevice(_device); - - cudaStreamCreate(&alt_stream_1); - cudaStreamCreate(&alt_stream_2); - cudaStreamCreate(&alt_stream_3); - cudaEventCreate(&alt_stream_1_done); - cudaEventCreate(&alt_stream_2_done); - cudaEventCreate(&alt_stream_3_done); -} - -CudaBuffers::~CudaBuffers() -{ - cudaStreamDestroy(alt_stream_1); - cudaStreamDestroy(alt_stream_2); - cudaStreamDestroy(alt_stream_3); - cudaEventDestroy(alt_stream_1_done); - cudaEventDestroy(alt_stream_2_done); - cudaEventDestroy(alt_stream_3_done); -} - -CudaBuffers* get_buffers(const int device_index) -{ - return g_buffers[device_index]; -} - -void prepare_buffers_cuda -( - int _device, - int _temp_state_size, - half* _temp_state, - half* _temp_dq -) -{ - CudaBuffers* buffers = new CudaBuffers - ( - _device, - _temp_state_size, - _temp_state, - _temp_dq - ); - - g_buffers[_device] = buffers; -} - -void cleanup_buffers_cuda() -{ - for (int i = 0; i < CUDA_MAX_DEVICES; i++) - { - if (!g_buffers[i]) continue; - delete g_buffers[i]; - g_buffers[i] = NULL; - } -} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cuh b/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cuh deleted file mode 100644 index 0bf2057c6..000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/cuda_buffers.cuh +++ /dev/null @@ -1,55 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _cuda_buffers_cuh -#define _cuda_buffers_cuh - -#include -#include -#include -#include - -const int CUDA_MAX_DEVICES = 16; - -// #ifndef _cuda_buffers_cu -// extern __constant__ half2 q4_table[16][256]; -// #endif - -class CudaBuffers -{ -public: - int device; - - half* temp_state; // [max_hidden_rows * intermediate_size] - int temp_state_size; - half* temp_dq; // size of largest quant tensor * 8 - - cudaStream_t alt_stream_1; - cudaStream_t alt_stream_2; - cudaStream_t alt_stream_3; - cudaEvent_t alt_stream_1_done; - cudaEvent_t alt_stream_2_done; - cudaEvent_t alt_stream_3_done; - - CudaBuffers - ( - int _device, - int _temp_state_size, - half* _temp_state, - half* _temp_dq - ); - ~CudaBuffers(); -}; - -CudaBuffers* get_buffers(const int device_index); - -void prepare_buffers_cuda -( - int _device, - int _temp_state_size, - half* _temp_state, - half* _temp_dq -); - -void cleanup_buffers_cuda(); - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/hip_compat.cuh b/colossalai/kernel/cuda_native/csrc/gptq/hip_compat.cuh deleted file mode 100644 index 5cd2e8553..000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/hip_compat.cuh +++ /dev/null @@ -1,49 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _hip_compat_cuh -#define _hip_compat_cuh - -// Workaround for a bug in hipamd, backported from upstream. -__device__ __forceinline__ __half __compat_hrcp(__half x) { - return __half_raw{ - static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))}; -} - -__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) { - return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)), - static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))}; -} - -#define hrcp __compat_hrcp -#define h2rcp __compat_h2rcp - -// Workaround for hipify_python using rocblas instead of hipblas. -__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle, - hipblasOperation_t transA, - hipblasOperation_t transB, - int m, - int n, - int k, - const half* alpha, - const half* AP, - int lda, - const half* BP, - int ldb, - const half* beta, - half* CP, - int ldc) { - return hipblasHgemm(handle, transA, transB, m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(AP), lda, - reinterpret_cast(BP), ldb, - reinterpret_cast(beta), - reinterpret_cast(CP), ldc); -} - -#define rocblas_handle hipblasHandle_t -#define rocblas_operation_none HIPBLAS_OP_N -#define rocblas_get_stream hipblasGetStream -#define rocblas_set_stream hipblasSetStream -#define rocblas_hgemm __compat_hipblasHgemm - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp b/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp deleted file mode 100644 index bcc0e4390..000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/linear_gptq.cpp +++ /dev/null @@ -1,254 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#include -#include -#include -#include -#include -#include -#include -#include "util.cuh" -#include "tuning.h" -#include "cuda_buffers.cuh" -#include "q4_matrix.cuh" -#include "q4_matmul.cuh" -#include "column_remap.cuh" - -// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a -// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of -// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console. - -void check_cuda(cudaError_t ret) -{ - switch (ret) - { - case cudaSuccess: - break; - - case cudaUnspecified: - printf(" **** Unspecified error\n"); - TORCH_CHECK(false, "CUDA error"); - break; - - default: - printf(" **** CUDA error\n"); \ - printf(" **** %s\n", cudaGetErrorString(ret)); \ - TORCH_CHECK(false, "CUDA error"); \ - break; - } -} - -// Some decluttering macros - -#define STRINGIFY_(__x) #__x -#define STRINGIFY(__x) STRINGIFY_(__x) -#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) -#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) -#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") -#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") -#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod)) -#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small") - -#define TORCH_CHECK_DEVICE_INDEX(__index) \ -do { \ - TORCH_CHECK(__index >= 0, "no device index"); \ - TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \ -} while(0) - -#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \ -do { \ - TORCH_CHECK_DTYPE(__w, kInt); \ - TORCH_CHECK_DTYPE(__w_scales, kHalf); \ - TORCH_CHECK_DTYPE(__w_zeros, kInt); \ - TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \ - TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \ - TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \ - TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \ -} while(0) - -int get_groupsize(torch::Tensor w, torch::Tensor w_zeros) -{ - int groupsize = w.size(0) * 8 / w_zeros.size(0); - TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]") - return groupsize; -} - - -// Tuning parameters - -ExLlamaTuning tuningParams; - -void set_tuning_params -( - int matmul_recons_thd, - bool matmul_fused_remap, - bool matmul_no_half2 -) -{ - tuningParams.matmul_recons_thd = matmul_recons_thd; - tuningParams.matmul_fused_remap = matmul_fused_remap; - tuningParams.matmul_no_half2 = matmul_no_half2; -} - - -// Release all unmanaged objects allocated by the extension - -void cleanup() -{ - cleanup_buffers_cuda(); - g_q4_free_matrices(); -} - - -// Prepare buffers for forward pass - -void prepare_buffers -( - torch::Device device, - torch::Tensor temp_state, - torch::Tensor temp_dq -) -{ - int device_index = device.index(); - TORCH_CHECK_DEVICE_INDEX(device_index); - const at::cuda::OptionalCUDAGuard device_guard(device); - - prepare_buffers_cuda - ( - device_index, - // buffer size used for sanity checks - temp_state.numel(), - (half*) temp_state.data_ptr(), - (half*) temp_dq.data_ptr() - ); -} - - -// Create Q4Matrix, return handle - -uintptr_t make_q4 -( - torch::Tensor qweight, - torch::Tensor qzeros, - torch::Tensor scales, - torch::Tensor g_idx, - int device -) -{ - TORCH_CHECK_DTYPE(qweight, kInt); - TORCH_CHECK_DTYPE(qzeros, kInt); - TORCH_CHECK_DTYPE(scales, kHalf); - TORCH_CHECK_DTYPE_OPT(g_idx, kInt); - TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8); - TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1); - TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1); - - int width = qweight.size(1); - int height = qweight.size(0) * 8; - int groups = qzeros.size(0); - - Q4Matrix* m = new Q4Matrix - ( - height, - width, - groups, - - (uint32_t*) qweight.data_ptr(), - (uint32_t*) qzeros.data_ptr(), - (half*) scales.data_ptr(), - g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(), - - device - ); - - g_q4_keep_matrix(m); - return reinterpret_cast (m); -} - - -// Matmul half @ quant -> half - -void q4_matmul -( - torch::Tensor x, - uintptr_t w, - torch::Tensor out -) -{ - Q4Matrix* wm = reinterpret_cast (w); - - TORCH_CHECK_DTYPE(x, kHalf); - TORCH_CHECK_DTYPE(out, kHalf); - TORCH_CHECK_SHAPES(x, 0, out, 0, 1); - TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes") - - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - - int x_height = x.size(0); - - if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd) - { - q4_matmul_cuda - ( - &tuningParams, - (half*) x.data_ptr(), - x_height, - wm, - (half*) out.data_ptr() - ); - } - else - { - q4_matmul_recons_cuda - ( - &tuningParams, - (half*) x.data_ptr(), - x_height, - wm, - (half*) out.data_ptr(), - at::cuda::getCurrentCUDABlasHandle() - ); - } -} - - -// Remap columns in half tensor - -void column_remap -( - torch::Tensor x, - torch::Tensor x_new, - torch::Tensor x_map -) -{ - TORCH_CHECK_DTYPE(x, kHalf); - TORCH_CHECK_DTYPE(x_new, kHalf); - TORCH_CHECK_DTYPE(x_map, kInt); - TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1); - - int height = x.size(0); - int width = x.size(1); - - TORCH_CHECK_BUFFER_SIZE(x_new, height * width); - - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - - column_remap_cuda - ( - (half*) x.data_ptr(), - (half*) x_new.data_ptr(), - height, - width, - (uint32_t*) x_map.data_ptr() - ); -} - - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - m.def("set_tuning_params", &set_tuning_params, "set_tuning_params"); - m.def("prepare_buffers", &prepare_buffers, "prepare_buffers"); - m.def("cleanup", &cleanup, "cleanup"); - m.def("make_q4", &make_q4, "make_q4"); - m.def("q4_matmul", &q4_matmul, "q4_matmul"); -} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/matrix.cuh b/colossalai/kernel/cuda_native/csrc/gptq/matrix.cuh deleted file mode 100644 index 2fd5ab0b3..000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/matrix.cuh +++ /dev/null @@ -1,294 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _matrix_cuh -#define _matrix_cuh - -#include -#include - -class MatrixView_half -{ -public: - const half* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } - __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } - __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } - __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; } -}; - -class MatrixView_half_rw -{ -public: - half* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } - __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } - __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; } - __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; } - __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } -}; - -class MatrixView_q4_row -{ -public: - const uint32_t* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ int item(int row, int column) const - { - int shift = (column & 0x07) * 4; - return (data[row * width / 8 + column / 8] >> shift) & 0x0f; - } -}; - -class MatrixView_q4_column -{ -public: - const uint32_t* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ int item(int row, int column) const - { - int shift = (row & 0x07) * 4; - return (data[row / 8 * width + column] >> shift) & 0x0f; - } - - __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; } - __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; } -}; - -// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu - -// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale - -__device__ __forceinline__ half2 dot_product_8 -( - const half2 acc, - MatrixView_half& h_, - const int h_row, - const int h_column, // divisible by 8 - MatrixView_q4_column& v_, - const int v_row, // divisible by 8 - const int v_column, - const half2 v_scale_2, - const uint32_t v_zero, // + 1 (!!) - const int count -) -{ - const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column); - const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); - half2 result = acc; - - for (int i = 0; i < count; i++) - { - uint32_t v_read = *v_ptr; v_ptr += v_.width; - - half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); - half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); - half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); - half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); - half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); - half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); - half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); - half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); - - half2 v_01 = __halves2half2(v_0, v_1); - half2 v_23 = __halves2half2(v_2, v_3); - half2 v_45 = __halves2half2(v_4, v_5); - half2 v_67 = __halves2half2(v_6, v_7); - -// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently) -// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff]; -// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff]; -// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ]; - - half2 tmp = __hmul2(*h_ptr++, v_01); - tmp = __hfma2(*h_ptr++, v_23, tmp); - tmp = __hfma2(*h_ptr++, v_45, tmp); - tmp = __hfma2(*h_ptr++, v_67, tmp); - result = __hfma2(v_scale_2, tmp, result); - } - - return result; -} - -__device__ __forceinline__ half dot_product_8_h -( - const half acc, - MatrixView_half& h_, - const int h_row, - const int h_column, // divisible by 8 - MatrixView_q4_column& v_, - const int v_row, // divisible by 8 - const int v_column, - const half v_scale, - const uint32_t v_zero, // + 1 (!!) - const int count -) -{ - const half* h_ptr = h_.item_ptr(h_row, h_column); - const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); - half result = acc; - - for (int i = 0; i < count; i++) - { - uint32_t v_read = *v_ptr; v_ptr += v_.width; - - half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); - half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); - half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); - half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); - half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); - half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); - half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); - half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); - - half tmp = __hmul(*h_ptr++, v_0); - tmp = __hfma(*h_ptr++, v_1, tmp); - tmp = __hfma(*h_ptr++, v_2, tmp); - tmp = __hfma(*h_ptr++, v_3, tmp); - tmp = __hfma(*h_ptr++, v_4, tmp); - tmp = __hfma(*h_ptr++, v_5, tmp); - tmp = __hfma(*h_ptr++, v_6, tmp); - tmp = __hfma(*h_ptr++, v_7, tmp); - result = __hfma(v_scale, tmp, result); - } - - return result; -} - -// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map - -__device__ __forceinline__ half2 dot_product_8_x_map -( - const half2 acc, - MatrixView_half& h_, - const int h_row, - const int h_column, // divisible by 8 - MatrixView_q4_column& v_, - const int v_row, // divisible by 8 - const int v_column, - const half2 v_scale_2, - const uint32_t v_zero, // + 1 (!!) - const int count, - const uint32_t* x_map -) -{ - const half* h_ptr = h_.item_ptr(h_row, 0); - const uint32_t* x_map_ptr = x_map + h_column; - const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); - half2 result = acc; - - for (int i = 0; i < count; i++) - { - uint32_t v_read = *v_ptr; v_ptr += v_.width; - - half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); - half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); - half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); - half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); - half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); - half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); - half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); - half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); - - half2 v_01 = __halves2half2(v_0, v_1); - half2 v_23 = __halves2half2(v_2, v_3); - half2 v_45 = __halves2half2(v_4, v_5); - half2 v_67 = __halves2half2(v_6, v_7); - - half h_0 = h_ptr[*x_map_ptr++]; - half h_1 = h_ptr[*x_map_ptr++]; - half h_2 = h_ptr[*x_map_ptr++]; - half h_3 = h_ptr[*x_map_ptr++]; - half h_4 = h_ptr[*x_map_ptr++]; - half h_5 = h_ptr[*x_map_ptr++]; - half h_6 = h_ptr[*x_map_ptr++]; - half h_7 = h_ptr[*x_map_ptr++]; - - half2 h_01 = __halves2half2(h_0, h_1); - half2 h_23 = __halves2half2(h_2, h_3); - half2 h_45 = __halves2half2(h_4, h_5); - half2 h_67 = __halves2half2(h_6, h_7); - - half2 tmp = __hmul2(h_01, v_01); - tmp = __hfma2(h_23, v_23, tmp); - tmp = __hfma2(h_45, v_45, tmp); - tmp = __hfma2(h_67, v_67, tmp); - result = __hfma2(v_scale_2, tmp, result); - } - - return result; -} - -__device__ __forceinline__ half dot_product_8_x_map_h -( - const half acc, - MatrixView_half& h_, - const int h_row, - const int h_column, // divisible by 8 - MatrixView_q4_column& v_, - const int v_row, // divisible by 8 - const int v_column, - const half v_scale, - const uint32_t v_zero, // + 1 (!!) - const int count, - const uint32_t* x_map -) -{ - const half* h_ptr = h_.item_ptr(h_row, 0); - const uint32_t* x_map_ptr = x_map + h_column; - const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); - half result = acc; - - for (int i = 0; i < count; i++) - { - uint32_t v_read = *v_ptr; v_ptr += v_.width; - - half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); - half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); - half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); - half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); - half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); - half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); - half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); - half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); - - half tmp = __hmul(h_ptr[*x_map_ptr++], v_0); - tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp); - result = __hfma(v_scale, tmp, result); - } - - return result; -} - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cu b/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cu deleted file mode 100644 index f47daeb0e..000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cu +++ /dev/null @@ -1,260 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#include "q4_matmul.cuh" -#include "column_remap.cuh" -#include "util.cuh" -#include "matrix.cuh" -#include "cu_compat.cuh" -#include "cuda_buffers.cuh" -#if defined(USE_ROCM) -#include "hip_compat.cuh" -#endif - -const int THREADS_X = 32; // Block size and thread count along columns in w and out -const int THREADS_Y = 1; // Block size and thread count along rows in x and out - -typedef void (*fp_q4_matmul_kernel) -( - const half*, - const uint32_t*, - half*, - const half*, - const uint32_t*, - const int, - const int, - const int, - const int, - const int, - const uint32_t*, - bool -); - -template -__global__ void q4_matmul_kernel -( - const half* __restrict__ x, - const uint32_t* __restrict__ w, - half* __restrict__ out, - const half* __restrict__ w_scales, - const uint32_t* __restrict__ w_zeros, - const int height, - const int dim, - const int width, - const int groupsize, - const int block_size_z, - const uint32_t* __restrict__ x_map, - bool no_zero -) -{ - // Start of block - - int x_column = block_size_z * blockIdx.z; - int x_column_end = min(dim, block_size_z * (blockIdx.z + 1)); - - int w_column = THREADS_X * blockIdx.x + threadIdx.x; - int x_row = THREADS_Y * blockIdx.y + threadIdx.y; - - int iterations = (x_column_end - x_column) / 8; - - // Views - - MatrixView_half x_(x, height, dim); - MatrixView_half w_scales_(w_scales, dim / groupsize, width); - MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width); - MatrixView_q4_column w_(w, dim, width); - MatrixView_half_rw out_(out, height, width); - - // Zero output - - if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0) - { - *((uint32_t*) out_.item_ptr(x_row, w_column)) = 0; - __syncthreads(); - } - - // Loop over part of x row (and w column) - - half2 acc = {}; - half acc_h = {}; - - if constexpr (use_groupsize) - { - // For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this - // could be slightly faster - - for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize) - { - if constexpr (use_half2) - { - half2 w_scale = w_scales_.item_half2half2(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; - - if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); - else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); - } - else - { - half w_scale = w_scales_.item(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; - - if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); - else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); - } - } - } - else - { - // Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache - - for (int k = x_column; k < x_column + iterations * 8; k += 8) - { - if constexpr (use_half2) - { - int group = k / groupsize; - half2 w_scale = w_scales_.item_half2half2(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; - - if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); - else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); - } - else - { - int group = k / groupsize; - half w_scale = w_scales_.item(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; - - if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); - else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); - } - } - } - - // Add to block result - - if constexpr (use_half2) - { - half result = __hadd(__low2half(acc), __high2half(acc)); - atomicAdd(out_.item_ptr(x_row, w_column), result); - } - else - { - atomicAdd(out_.item_ptr(x_row, w_column), acc_h); - } -} - -fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map) -{ - // - if (tuningParams->matmul_no_half2) { - if (block_size_z % groupsize == 0) { - if (x_map) return q4_matmul_kernel; - else return q4_matmul_kernel; - } else { - if (x_map) return q4_matmul_kernel; - else return q4_matmul_kernel; - } - } else { - if (block_size_z % groupsize == 0) - { - if (x_map) return q4_matmul_kernel; - else return q4_matmul_kernel; - } else { - if (x_map) return q4_matmul_kernel; - else return q4_matmul_kernel; - } - } -}; - -// Compute y = x @ w - -void q4_matmul_cuda -( - ExLlamaTuning* tuningParams, - const half* x, - const int x_height, - const Q4Matrix* w, - half* out, - bool no_zero, - cudaStream_t alt_stream -) -{ - int height = x_height; - int dim = w->height; - int width = w->width; - - cudaSetDevice(w->device); - - uint32_t* x_map = w->cuda_x_map; - const half* x_mapped = x; - if (x_map && !tuningParams->matmul_fused_remap && !alt_stream) - { - CudaBuffers* buffers = get_buffers(w->device); - column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); - x_mapped = buffers->temp_state; - x_map = NULL; - } - - int block_size_z; - if (w->width == 4096) block_size_z = 384; // 7B - else if (w->width == 11008) block_size_z = 256; - else if (w->width == 5120) block_size_z = 384; // 13B - else if (w->width == 13824) block_size_z = 256; - else if (w->width == 6656) block_size_z = 256; // 33B - else if (w->width == 17920) block_size_z = 128; - else block_size_z = 256; - - //if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half)); - - dim3 threads(THREADS_X, THREADS_Y, 1); - - dim3 blocks - ( - (width + threads.x - 1) / threads.x, - (height + threads.y - 1) / threads.y, - (dim + block_size_z - 1) / block_size_z - ); - - fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map); - - kernel<<>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero); -} - -void q4_matmul_recons_cuda -( - ExLlamaTuning* tuningParams, - const half* x, - const int x_height, - Q4Matrix* w, - half* out, - const cublasHandle_t handle, - bool no_zero -) -{ - int height = x_height; - int dim = w->height; - int width = w->width; - - cudaSetDevice(w->device); - CudaBuffers* buffers = get_buffers(w->device); - - const half* x_mapped = x; - if (w->cuda_x_map) - { - TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "temp_state buffer is too small"); - column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); - x_mapped = buffers->temp_state; - } - - w->reconstruct(buffers->temp_dq); - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700 - const float alpha = 1.0f; - const float beta = no_zero ? 1.0f : 0.0f; - cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width, - x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width); -#else - const half alpha = __float2half(1.0f); - const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f); - cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width); -#endif -} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cuh b/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cuh deleted file mode 100644 index 09f3e1a63..000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/q4_matmul.cuh +++ /dev/null @@ -1,43 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _q4_matmul_cuh -#define _q4_matmul_cuh - -#include -#include -#include -#include -#include - -#include "q4_matrix.cuh" -#include "tuning.h" - -// Workaround for hipify_python using rocblas instead of hipblas. -#if defined(USE_ROCM) -#include -#define rocblas_handle hipblasHandle_t -#endif - -void q4_matmul_cuda -( - ExLlamaTuning* tuningParams, - const half* x, - const int x_height, - const Q4Matrix* w, - half* out, - bool no_zero = false, - cudaStream_t alt_stream = NULL -); - -void q4_matmul_recons_cuda -( - ExLlamaTuning* tuningParams, - const half* x, - const int x_height, - Q4Matrix* w, - half* out, - const cublasHandle_t handle, - bool no_zero = false -); - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu deleted file mode 100644 index 9c61143f5..000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cu +++ /dev/null @@ -1,225 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#include "q4_matrix.cuh" -#include -#include "util.cuh" -#include "matrix.cuh" - -using namespace std; - -const int UNSHUF_BLOCKSIZE_X = 64; - -const int RECONS_THREADS_X = 64; // Block size and thread count along columns in out, each thread converts 1 column -const int RECONS_THREADS_Y = 1; // Block size and thread count along rows in x and out, each thread converts 8 rows - -vector g_q4_matrices; - -void g_q4_keep_matrix(Q4Matrix* m) -{ - g_q4_matrices.push_back(m); -} - -void g_q4_free_matrices() -{ - for (const auto& m : g_q4_matrices) delete m; - g_q4_matrices.clear(); -} - -Q4Matrix::Q4Matrix -( - const int _height, - const int _width, - const int _groups, - - uint32_t* _qweight, - uint32_t* _qzeros, - half* _scales, - uint32_t* _g_idx, - - const int _device -) : - height(_height), - width(_width), - groups(_groups), - device(_device) -{ - cudaSetDevice(device); - - cuda_qweight = _qweight; - cuda_qzeros = _qzeros; - cuda_scales = _scales; - - groupsize = height / groups; - - if (_g_idx) make_sequential(_g_idx); -} - -Q4Matrix::~Q4Matrix() -{ -} - -// Make sequential - -__global__ void make_sequential_kernel -( - const uint32_t* __restrict__ w, - uint32_t* __restrict__ w_new, - const uint32_t* __restrict__ x_map, - const int w_height, - const int w_width -) -{ - const uint64_t* w2 = (uint64_t*) w; - uint64_t* w_new2 = (uint64_t*) w_new; - int w2_stride = w_width >> 1; - - int w2_column = UNSHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x; - if (w2_column >= w2_stride) return; - - int w_new2_row = blockIdx.y; - - int x_map_idx = w_new2_row << 3; - - uint64_t dst = 0; - - #pragma unroll - for (int i = 0; i < 8; i++) - { - int source_row = x_map[x_map_idx++]; - - int w2_row = source_row >> 3; - int w2_subrow = source_row & 0x07; - int w2_row_shift = w2_subrow << 2; - int wnew2_row_shift = i << 2; - - uint64_t src = w2[w2_row * w2_stride + w2_column]; - src >>= w2_row_shift; - src &= 0x0000000f0000000f; - src <<= wnew2_row_shift; - dst |= src; - } - - w_new2[w_new2_row * w2_stride + w2_column] = dst; -} - -void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx) -{ - uint32_t* cuda_new_qweight = NULL; - cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t)); - cudaMalloc(&cuda_x_map, height * sizeof(uint32_t)); // TODO: Should probably be allocated in PyTorch - - uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t)); - uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t)); - uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t)); - - // Group histogram - - for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++; - - // Group map - - for (int i = 0, acc = 0; i < groups; i++) - { - short tmp = cpu_g_idx_map[i]; - cpu_g_idx_map[i] = acc; - acc += tmp; - } - - // X map (inverse) - - for (int row = 0; row < height; row++) - { - uint32_t target_group = cpu_g_idx[row]; - uint32_t target_row = cpu_g_idx_map[target_group]; - cpu_g_idx_map[target_group]++; - cpu_x_map_inv[row] = target_row; - } - - // X map - - for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row; - - // Move to CUDA - - cudaMemcpyAsync(cuda_x_map, cpu_x_map, height * sizeof(uint32_t), cudaMemcpyHostToDevice); - - // Rearrange rows in w - - dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1); - dim3 blocks - ( - (width + UNSHUF_BLOCKSIZE_X * 2 - 1) / (UNSHUF_BLOCKSIZE_X * 2), - height / 8, - 1 - ); - - make_sequential_kernel<<>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width); - - // Replace qweights - - cudaMemcpyAsync(cuda_qweight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice); - - // Cleanup - - cudaDeviceSynchronize(); - cudaFree(cuda_new_qweight); - free(cpu_g_idx_map); - free(cpu_x_map); - free(cpu_x_map_inv); -} - -__global__ void reconstruct_kernel -( - const uint32_t* __restrict__ w, - half* __restrict__ out, // (y) - const half* __restrict__ w_scales, - const uint32_t* __restrict__ w_zeros, - const int height, - const int width, - const int groupsize -) -{ - // Start of block - - int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x; - int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8; - if (column >= width) return; - - // Views - - MatrixView_q4_column w_(w, height, width); - MatrixView_half_rw out_(out, height, width); - MatrixView_half w_scales_(w_scales, height / groupsize, width); - MatrixView_q4_row w_zeros_(w_zeros, height / groupsize, width); - - // Groupsize version - - int group = row / groupsize; - - half w_scale = w_scales_.item(group, column); - uint32_t w_zero = w_zeros_.item(group, column) + 1; - - uint32_t w_read = w_.item_uint32_t(row, column); - half* out_ptr = out_.item_ptr(row, column); - - #pragma unroll - for (int s = 0; s < 32; s += 4) - { - half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale); - *out_ptr = w_item; out_ptr += out_.width; - } -} - -void Q4Matrix::reconstruct(half* out) -{ - dim3 threads(RECONS_THREADS_X, RECONS_THREADS_Y, 1); - - dim3 blocks - ( - (width + threads.x - 1) / threads.x, - (height / 8 + threads.y - 1) / threads.y, - 1 - ); - - reconstruct_kernel<<>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize); -} diff --git a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh b/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh deleted file mode 100644 index 50cb72a41..000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/q4_matrix.cuh +++ /dev/null @@ -1,53 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _q4_matrix_cuh -#define _q4_matrix_cuh - -#include -#include -#include - -class Q4Matrix -{ -public: - - int device; - - int height; - int width; - int groups; - int groupsize; - - uint32_t* cuda_qweight = NULL; - uint32_t* cuda_qzeros = NULL; - half* cuda_scales = NULL; - uint32_t* cuda_x_map = NULL; - - Q4Matrix - ( - const int _height, - const int _width, - const int _groups, - - uint32_t* _qweight, - uint32_t* _qzeros, - half* _scales, - uint32_t* _g_idx, - - const int _device - ); - - ~Q4Matrix(); - - void reconstruct(half* out); - -private: - - void make_sequential(const uint32_t* cpu_g_idx); - -}; - -void g_q4_keep_matrix(Q4Matrix* m); -void g_q4_free_matrices(); - -#endif \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/gptq/tuning.h b/colossalai/kernel/cuda_native/csrc/gptq/tuning.h deleted file mode 100644 index e413b8a96..000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/tuning.h +++ /dev/null @@ -1,12 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _tuning_h -#define _tuning_h - -struct ExLlamaTuning { - int matmul_recons_thd; - bool matmul_fused_remap; - bool matmul_no_half2; -}; - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/gptq/util.cuh b/colossalai/kernel/cuda_native/csrc/gptq/util.cuh deleted file mode 100644 index 7b3975732..000000000 --- a/colossalai/kernel/cuda_native/csrc/gptq/util.cuh +++ /dev/null @@ -1,33 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _util_cuh -#define _util_cuh - -#include -#include -#include -#include - -#if defined(USE_ROCM) -#define cudaUnspecified hipErrorUnknown -#else -#define cudaUnspecified cudaErrorApiFailureBase -#endif - -// React to failure on return code != cudaSuccess - -#define _cuda_check(fn) \ -do { \ - {_cuda_err = fn;} \ - if (_cuda_err != cudaSuccess) goto _cuda_fail; \ -} while(false) - -// React to failure on return code == 0 - -#define _alloc_check(fn) \ -do { \ - if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \ - else _cuda_err = cudaSuccess; \ -} while(false) - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu b/colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu deleted file mode 100644 index 58d26235a..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu +++ /dev/null @@ -1,191 +0,0 @@ -#include "block_reduce.h" -#include "cuda_util.h" -#include "kernels.h" -#include "ls_cub.cuh" - -ls::cub::CachingDeviceAllocator g_allocator(true); - -template -__global__ void ls_cross_entropy_fw_kernel( - const T *__restrict__ inputs, const int *__restrict__ targets, - float *__restrict__ outputs, float *__restrict__ nll_loss_outputs, - const int padding_idx, const float epsilon, const int vocab_size) { - /* step1: compute each thread's max_logit and sum_exp_logit, store in - * max_input, sum_exp_logit */ - const int block_start = blockIdx.x * vocab_size; - const int left_idx = block_start + threadIdx.x; - const int right_idx = (blockIdx.x + 1) * vocab_size; - float max_input[1] = {REDUCE_FLOAT_INF_NEG}; - float sum_logits[2] = {0.f, 0.f}; // logit and logit exp - int target_tid = targets[blockIdx.x]; - - if (target_tid == padding_idx) { - if (threadIdx.x == 0) { - nll_loss_outputs[blockIdx.x] = 0.f; - outputs[blockIdx.x] = 0.f; - } - return; - } - - for (int i = left_idx; i < right_idx; i += blockDim.x) { - max_input[0] = fmaxf(max_input[0], static_cast(inputs[i])); - } - blockReduce(max_input); - __shared__ float s_max_input; - if (threadIdx.x == 0) { - s_max_input = max_input[0]; - } - __syncthreads(); - - for (int i = left_idx; i < right_idx; i += blockDim.x) { - float logit = static_cast(inputs[i]) - s_max_input; - sum_logits[0] += logit; - sum_logits[1] += expf(logit); - } - - blockReduce(sum_logits); - __shared__ float s_sum_logit; - __shared__ float s_sum_exp; - if (threadIdx.x == 0) { - s_sum_logit = sum_logits[0]; - s_sum_exp = sum_logits[1]; - } - __syncthreads(); - - float eps_i = epsilon / (vocab_size - 1); - if (threadIdx.x == 0) { - // neg_log_prob = log(sum(exp(x - x_max))) - (x - x_max) - float nll_loss = logf(s_sum_exp) - - static_cast(inputs[block_start + target_tid]) + - s_max_input; - nll_loss_outputs[blockIdx.x] = nll_loss; - float sum_nll_loss = vocab_size * logf(s_sum_exp) - s_sum_logit; - outputs[blockIdx.x] = - (1.f - epsilon - eps_i) * nll_loss + eps_i * sum_nll_loss; - } -} - -template -__global__ void ls_cross_entropy_bw_kernel( - const float *__restrict__ grad_outputs, const T *__restrict__ inputs, - const int *__restrict__ targets, T *__restrict__ grad_inputs, - const int padding_idx, const float epsilon, const int vocab_size) { - /* step1: compute each thread's max_logit and sum_exp_logit, store in - * max_input, sum_exp_logit */ - const int block_start = blockIdx.x * vocab_size; - const int left_idx = block_start + threadIdx.x; - const int right_idx = (blockIdx.x + 1) * vocab_size; - float max_input[1] = {REDUCE_FLOAT_INF_NEG}; - float sum_logits[1] = {0.f}; - const float grad_out = static_cast(grad_outputs[0]); - int target_tid = targets[blockIdx.x]; - - if (target_tid == padding_idx) { - for (int i = left_idx; i < right_idx; i += blockDim.x) { - grad_inputs[i] = 0.f; - } - return; - } - - for (int i = left_idx; i < right_idx; i += blockDim.x) { - max_input[0] = fmaxf(max_input[0], static_cast(inputs[i])); - } - blockReduce(max_input); - __shared__ float s_max_input; - if (threadIdx.x == 0) { - s_max_input = max_input[0]; - } - __syncthreads(); - - for (int i = left_idx; i < right_idx; i += blockDim.x) { - float logit = static_cast(inputs[i]) - s_max_input; - sum_logits[0] += expf(logit); - } - - blockReduce(sum_logits); - __shared__ float s_sum_exp; - if (threadIdx.x == 0) { - s_sum_exp = sum_logits[0]; - } - __syncthreads(); - - float eps_i = epsilon / (vocab_size - 1); - float nll_weight = 1.0 - epsilon - eps_i; - - for (int i = left_idx; i < right_idx; i += blockDim.x) { - float prob = expf(static_cast(inputs[i]) - s_max_input) / s_sum_exp; - float grad = 0; - grad += (vocab_size * prob - 1) * eps_i; - grad += prob * nll_weight; - if ((i - block_start) == target_tid) { - grad -= nll_weight; - } - grad_inputs[i] = grad_out * grad; - } -} - -template -void launch_cross_entropy_fw(const T *inputs_ptr, const int *targets_ptr, - float *outputs_ptr, float *nll_loss_ptr, - float *loss_buffer, const int padding_idx, - const float epsilon, const int batch_size, - const int seq_len, const int vocab_size, - cudaStream_t stream) { - int grid_dim = batch_size * seq_len; - float *nll_loss_buffer = loss_buffer + grid_dim; - ls_cross_entropy_fw_kernel<<>>( - inputs_ptr, targets_ptr, loss_buffer, nll_loss_buffer, padding_idx, - epsilon, vocab_size); - - int num_items = grid_dim; - void *d_temp_storage = NULL; - size_t temp_storage_bytes = 0; - CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, - loss_buffer, outputs_ptr, - num_items, stream)); - CHECK_GPU_ERROR( - g_allocator.DeviceAllocate(&d_temp_storage, temp_storage_bytes)); - CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, - loss_buffer, outputs_ptr, - num_items, stream)); - CHECK_GPU_ERROR(ls::cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, - nll_loss_buffer, nll_loss_ptr, - num_items, stream)); - CHECK_GPU_ERROR(g_allocator.DeviceFree(d_temp_storage)); -} - -template void launch_cross_entropy_fw( - const float *inputs_ptr, const int *targets_ptr, float *outputs_ptr, - float *nll_loss_ptr, float *loss_buffer, const int padding_idx, - const float epsilon, const int batch_size, const int seq_len, - const int vocab_size, cudaStream_t stream); - -template void launch_cross_entropy_fw<__half>( - const __half *inputs_ptr, const int *targets_ptr, float *outputs_ptr, - float *nll_loss_ptr, float *loss_buffer, const int padding_idx, - const float epsilon, const int batch_size, const int seq_len, - const int vocab_size, cudaStream_t stream); - -template -void launch_cross_entropy_bw(const float *grad_outputs_ptr, const T *inputs_ptr, - const int *targets_ptr, T *grad_inputs_ptr, - const int padding_idx, const float epsilon, - const int batch_size, const int seq_len, - const int vocab_size, cudaStream_t stream) { - int grid_dim = batch_size * seq_len; - ls_cross_entropy_bw_kernel<<>>( - grad_outputs_ptr, inputs_ptr, targets_ptr, grad_inputs_ptr, padding_idx, - epsilon, vocab_size); -} - -template void launch_cross_entropy_bw( - const float *grad_outputs_ptr, const float *inputs_ptr, - const int *targets_ptr, float *grad_inputs_ptr, const int padding_idx, - const float epsilon, const int batch_size, const int seq_len, - const int vocab_size, cudaStream_t stream); - -template void launch_cross_entropy_bw<__half>( - const float *grad_outputs_ptr, const __half *inputs_ptr, - const int *targets_ptr, __half *grad_inputs_ptr, const int padding_idx, - const float epsilon, const int batch_size, const int seq_len, - const int vocab_size, cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/cublas_wrappers.cu b/colossalai/kernel/cuda_native/csrc/kernels/cublas_wrappers.cu deleted file mode 100644 index 09f34763f..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/cublas_wrappers.cu +++ /dev/null @@ -1,88 +0,0 @@ -/* Copyright 2021 The LightSeq Team - Copyright Microsoft DeepSpeed - This file is adapted from Microsoft DeepSpeed - Licensed under the MIT License. -*/ -#include "cublas_wrappers.h" - -int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const float *alpha, const float *beta, const float *A, - const float *B, float *C, cublasGemmAlgo_t algo) { - cublasStatus_t status = - cublasGemmEx(handle, transa, transb, m, n, k, (const void *)alpha, - (const void *)A, CUDA_R_32F, (transa == CUBLAS_OP_N) ? m : k, - (const void *)B, CUDA_R_32F, (transb == CUBLAS_OP_N) ? k : n, - (const void *)beta, C, CUDA_R_32F, m, CUDA_R_32F, algo); - - if (status != CUBLAS_STATUS_SUCCESS) { - fprintf(stderr, - "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", - m, n, k, (int)status); - return EXIT_FAILURE; - } - return 0; -} - -int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const float *alpha, const float *beta, const __half *A, - const __half *B, __half *C, cublasGemmAlgo_t algo) { - cublasStatus_t status = cublasGemmEx( - handle, transa, transb, m, n, k, (const void *)alpha, (const void *)A, - CUDA_R_16F, (transa == CUBLAS_OP_N) ? m : k, (const void *)B, CUDA_R_16F, - (transb == CUBLAS_OP_N) ? k : n, (const void *)beta, (void *)C, - CUDA_R_16F, m, CUDA_R_32F, algo); - - if (status != CUBLAS_STATUS_SUCCESS) { - fprintf(stderr, - "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", - m, n, k, (int)status); - return EXIT_FAILURE; - } - return 0; -} - -int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k, - const float *alpha, const float *beta, - const float *A, const float *B, float *C, - cublasOperation_t op_A, cublasOperation_t op_B, - int stride_A, int stride_B, int stride_C, - int batch, cublasGemmAlgo_t algo) { - cublasStatus_t status = cublasGemmStridedBatchedEx( - handle, op_A, op_B, m, n, k, alpha, A, CUDA_R_32F, - (op_A == CUBLAS_OP_N) ? m : k, stride_A, B, CUDA_R_32F, - (op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, CUDA_R_32F, m, stride_C, - batch, CUDA_R_32F, algo); - - if (status != CUBLAS_STATUS_SUCCESS) { - fprintf(stderr, - "!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, " - "error: %d) \n", - batch, m, n, k, (int)status); - return EXIT_FAILURE; - } - return 0; -} - -int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k, - const float *alpha, const float *beta, - const __half *A, const __half *B, __half *C, - cublasOperation_t op_A, cublasOperation_t op_B, - int stride_A, int stride_B, int stride_C, - int batch, cublasGemmAlgo_t algo) { - cublasStatus_t status = cublasGemmStridedBatchedEx( - handle, op_A, op_B, m, n, k, alpha, A, CUDA_R_16F, - (op_A == CUBLAS_OP_N) ? m : k, stride_A, B, CUDA_R_16F, - (op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, CUDA_R_16F, m, stride_C, - batch, CUDA_R_32F, algo); - - if (status != CUBLAS_STATUS_SUCCESS) { - fprintf(stderr, - "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", - m, n, k, (int)status); - return EXIT_FAILURE; - } - - return 0; -} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu b/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu deleted file mode 100644 index e5ac17308..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu +++ /dev/null @@ -1,169 +0,0 @@ -#include -#include -#include -#include "cuda_util.h" - -/* GPU function guard */ -std::string _cudaGetErrorString(cudaError_t error) { - return cudaGetErrorString(error); -} - -std::string _cudaGetErrorString(cublasStatus_t error) { - switch (error) { - case CUBLAS_STATUS_SUCCESS: - return "CUBLAS_STATUS_SUCCESS"; - - case CUBLAS_STATUS_NOT_INITIALIZED: - return "CUBLAS_STATUS_NOT_INITIALIZED"; - - case CUBLAS_STATUS_ALLOC_FAILED: - return "CUBLAS_STATUS_ALLOC_FAILED"; - - case CUBLAS_STATUS_INVALID_VALUE: - return "CUBLAS_STATUS_INVALID_VALUE"; - - case CUBLAS_STATUS_ARCH_MISMATCH: - return "CUBLAS_STATUS_ARCH_MISMATCH"; - - case CUBLAS_STATUS_MAPPING_ERROR: - return "CUBLAS_STATUS_MAPPING_ERROR"; - - case CUBLAS_STATUS_EXECUTION_FAILED: - return "CUBLAS_STATUS_EXECUTION_FAILED"; - - case CUBLAS_STATUS_INTERNAL_ERROR: - return "CUBLAS_STATUS_INTERNAL_ERROR"; - - case CUBLAS_STATUS_NOT_SUPPORTED: - return "CUBLAS_STATUS_NOT_SUPPORTED"; - - case CUBLAS_STATUS_LICENSE_ERROR: - return "CUBLAS_STATUS_LICENSE_ERROR"; - } - return "CUBLAS_UNKNOW"; -} - -template -void check_gpu_error(T result, char const *const func, const char *const file, - int const line) { - if (result) { - throw std::runtime_error(std::string("[CUDA][ERROR] ") + +file + "(" + - std::to_string(line) + - "): " + (_cudaGetErrorString(result)) + "\n"); - } -} - -template void check_gpu_error(cudaError_t result, - char const *const func, - const char *const file, - int const line); -template void check_gpu_error(cublasStatus_t result, - char const *const func, - const char *const file, - int const line); - -template -void print_vec(const T *outv, std::string outn, int num_output_ele) { - std::cout << outn << ": "; - std::vector hout(num_output_ele, (T)0); - cudaMemcpy(hout.data(), outv, num_output_ele * sizeof(T), - cudaMemcpyDeviceToHost); - for (int i = 0; i < num_output_ele; i++) { - std::cout << hout[i] << ", "; - } - std::cout << std::endl; -} - -template <> -void print_vec<__half>(const __half *outv, std::string outn, - int num_output_ele) { - std::cout << outn << ": "; - std::vector<__half> hout(num_output_ele, (__half)0.f); - cudaMemcpy(hout.data(), outv, num_output_ele * sizeof(__half), - cudaMemcpyDeviceToHost); - for (int i = 0; i < num_output_ele; i++) { - std::cout << __half2float(hout[i]) << ", "; - } - std::cout << std::endl; -} - -template void print_vec(const float *outv, std::string outn, - int num_output_ele); - -template void print_vec(const int *outv, std::string outn, - int num_output_ele); - -template void print_vec<__half>(const __half *outv, std::string outn, - int num_output_ele); - -template -T *cuda_malloc(size_t ele_num) { - size_t byte_size = ele_num * sizeof(T); - T *pdata = nullptr; - CHECK_GPU_ERROR(cudaMalloc((void **)&pdata, byte_size)); - return pdata; -} - -template float *cuda_malloc(size_t ele_num); - -template __half *cuda_malloc<__half>(size_t ele_num); - -template uint8_t *cuda_malloc(size_t ele_num); - -void cuda_free(void *pdata) { - if (pdata != nullptr) { - cudaFree(pdata); - } -} - -template -struct _isnan { - __device__ bool operator()(T a) const { return isnan(a); } -}; - -template <> -struct _isnan<__half> { - __device__ bool operator()(const __half a) const { return __hisnan(a); } -}; - -template -struct _isinf { - __device__ bool operator()(T a) const { return isinf(a); } -}; - -template <> -struct _isinf<__half> { - __device__ bool operator()(const __half a) const { return __hisinf(a); } -}; - -template -void check_nan_inf(const T *data_ptr, int dsize, bool check_nan_inf, - std::string file, int line, cudaStream_t stream) { - // check_nan_inf = 0 for checking nan - // check_nan_inf = 1 for checking inf - bool res = false; - std::string msg = file + "(" + std::to_string(line) + "): "; - if (check_nan_inf) { - msg += "nan."; - res = thrust::transform_reduce(thrust::cuda::par.on(stream), data_ptr, - data_ptr + dsize, _isnan(), false, - thrust::logical_or()); - } else { - msg += "inf."; - res = thrust::transform_reduce(thrust::cuda::par.on(stream), data_ptr, - data_ptr + dsize, _isinf(), false, - thrust::logical_or()); - } - if (res) { - throw std::runtime_error(msg); - } - std::cout << msg << " [check pass]." << std::endl; -} - -template void check_nan_inf(const float *data_ptr, int dsize, - bool check_nan_inf, std::string file, - int line, cudaStream_t stream); - -template void check_nan_inf<__half>(const __half *data_ptr, int dsize, - bool check_nan_inf, std::string file, - int line, cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu deleted file mode 100644 index ce0b017f1..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu +++ /dev/null @@ -1,1002 +0,0 @@ -#include -#include - -#include "kernels.h" - -#include - - -namespace cg = cooperative_groups; - -curandStatePhilox4_32_10_t *curandstate; - -/** - * @brief element-wise activation function on device, like Relu, Gelu - * - * @tparam enum class ActivationType, kRelu, kGelu - * @tparam input type - * @param any shape of float and __half2 - * @return same shape and type with input - */ -template -__forceinline__ __device__ T activation_kernel(T x); - -template <> -__device__ float activation_kernel(float x) { - float cdf = - 0.5f * - (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); - return x * cdf; -} - -template <> -__device__ __half2 -activation_kernel(__half2 val) { - __half2 val_pow3 = __hmul2(val, __hmul2(val, val)); - float2 tmp_pow = __half22float2(val_pow3); - float2 tmp = __half22float2(val); - - tmp.x = - 0.5f * - (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x)))); - tmp.y = - 0.5f * - (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y)))); - return __hmul2(val, __float22half2_rn(tmp)); -} - -template <> -__device__ float activation_kernel(float x) { - return fmaxf(x, 0); -} - -template <> -__device__ __half2 -activation_kernel(__half2 x) { - return __floats2half2_rn(fmaxf(0.f, __half2float(x.x)), - fmaxf(0.f, __half2float(x.y))); -} - -/** - * @brief element-wise activation backward function on device - * - * @tparam enum class ActivationType - * @tparam input type - * @param any shape of float and __half2 - * @return same shape of input - */ -template -__forceinline__ __device__ T activation_bwd_kernel(T grad, T x); - -template <> -__device__ float activation_bwd_kernel(float grad, - float x) { - const float sqrt_param = 0.79788456080286535587989211986876f; - const float mul_param = 0.044715; - - float x2mul = x * x * mul_param; - float tan_h = tanhf(sqrt_param * (x + x * x2mul)); - float dg1 = 0.5f * (1.0f + tan_h); - float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); - float dg3 = dg2 * 3 * x2mul; - return grad * (dg1 + dg2 + dg3); -} - -template <> -__device__ __half activation_bwd_kernel( - __half grad, __half x_half) { - float x = __half2float(x_half); - const float sqrt_param = 0.79788456080286535587989211986876f; - const float mul_param = 0.044715; - - float x2mul = x * x * mul_param; - float tan_h = tanhf(sqrt_param * (x + x * x2mul)); - float dg1 = 0.5f * (1.0f + tan_h); - float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); - float dg3 = dg2 * 3 * x2mul; - return grad * __float2half(dg1 + dg2 + dg3); -} - -template <> -__device__ float activation_bwd_kernel(float grad, - float x) { - return x > 0.f ? grad : 0.f; -} - -template <> -__device__ __half -activation_bwd_kernel(__half grad, __half x) { - const __half half_zero = __float2half(0.f); - return x > half_zero ? grad : half_zero; -} - -template <> -__device__ __half2 activation_bwd_kernel( - __half2 grad2, __half2 x_half2) { - const __half half_zero = __float2half(0.f); - return __floats2half2_rn(x_half2.x > half_zero ? grad2.x : half_zero, - x_half2.y > half_zero ? grad2.y : half_zero); -} - -/** - * @brief init curand states in global memory - * - * @thread grid_dim * block*dim to suuport any size of states - * @param state persistant curand states - * @param seed seed to init states - * @return void - */ -__global__ void curand_init_kernel(curandStatePhilox4_32_10_t *state, - int seed) { - /* Each thread gets same seed, a different sequence - number, no offset */ - int id = threadIdx.x + blockIdx.x * blockDim.x; - curand_init(seed, id, 0, &state[id]); -} - -void launch_curand_init(int total_count, int dim, cudaStream_t stream) { - cudaMalloc(&curandstate, total_count * sizeof(curandStatePhilox4_32_10_t)); - int grid_dim = total_count >> 9; - curand_init_kernel<<>>( - curandstate, std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count()); -} - -/** - * @brief element-wise dropout, store dropped position in mask, it's not - * in-place - * - * @thread - * gridDim.x = total_count / 1024 - * blockDim.x = 1024 - * - * @param total_count total elements - * @param ratio drop ratio - * @param out any size of float and __half - * @param in same with out - * @param mask uint8 type, same size with out - * @param seed seed to curand - * @return void - */ -__global__ void ls_dropout_kernel(const int total_count, const float ratio, - float *__restrict__ out, - const float *__restrict__ in, - uint8_t *__restrict__ mask, const int seed) { - const float scale = 1.f / (1.f - ratio); - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 4 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - uint8_t m[4]; - - float4 *out4 = reinterpret_cast(out); - const float4 *data4 = reinterpret_cast(in); - uint32_t *mask4 = reinterpret_cast(mask); - float4 rand = curand_uniform4(&state); - - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - - uint32_t *m4 = reinterpret_cast(m); - mask4[i] = m4[0]; - - float4 input4 = data4[i]; - float4 res4; - res4.x = input4.x * scale * m[0]; - res4.y = input4.y * scale * m[1]; - res4.z = input4.z * scale * m[2]; - res4.w = input4.w * scale * m[3]; - out4[i] = res4; -} - -__global__ void ls_dropout_kernel(const int total_count, const float ratio, - __half *__restrict__ out, - const __half *__restrict__ in, - uint8_t *__restrict__ mask, const int seed) { - const float scale = 1.f / (1.f - ratio); - - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 8 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - - const float4 *vals_float4 = reinterpret_cast(in); - float4 *outs_float4 = reinterpret_cast(out); - uint64_t *mask8 = reinterpret_cast(mask); - - uint8_t m[8]; - float4 rand = curand_uniform4(&state); - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - rand = curand_uniform4(&state); - m[4] = (uint8_t)(rand.x > ratio); - m[5] = (uint8_t)(rand.y > ratio); - m[6] = (uint8_t)(rand.z > ratio); - m[7] = (uint8_t)(rand.w > ratio); - uint64_t *m8 = reinterpret_cast(m); - mask8[i] = *m8; - - float4 val_float4 = vals_float4[i]; - float4 out_float4; - __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); - __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); - __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]); - __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]); - __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]); - __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]); - out_half2[0] = __hmul2(val_half2[0], scale_mask_1); - out_half2[1] = __hmul2(val_half2[1], scale_mask_2); - out_half2[2] = __hmul2(val_half2[2], scale_mask_3); - out_half2[3] = __hmul2(val_half2[3], scale_mask_4); - outs_float4[i] = out_float4; -} - -/** - * @brief element-wise dropout backward with dropout mask, it's - * not in-place - * - * @thread - * gridDim.x = total_count / 1024 - * blockDim.x = 1024 - * - * @param total_count total elements - * @param ratio drop ratio - * @param in any size of float and __half - * @param mask uint8 type, same size with in - * @return void - */ -__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, - float *out, const float *in, - const uint8_t *__restrict__ mask) { - const float scale = 1.f / (1.f - ratio); - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 4 >= total_count) return; - - uint8_t m[4]; - - float4 *out4 = reinterpret_cast(out); - const float4 *in4 = reinterpret_cast(in); - const uint32_t *mask4 = reinterpret_cast(mask); - - uint32_t *m4 = reinterpret_cast(m); - m4[0] = mask4[i]; - - float4 input4 = in4[i]; - float4 res4; - res4.x = input4.x * scale * static_cast(m[0]); - res4.y = input4.y * scale * static_cast(m[1]); - res4.z = input4.z * scale * static_cast(m[2]); - res4.w = input4.w * scale * static_cast(m[3]); - out4[i] = res4; -} - -__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, - __half *out, const __half *in, - const uint8_t *__restrict__ mask) { - const __half scale = 1.f / (1.f - ratio); - - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 8 >= total_count) return; - - float4 *out4 = reinterpret_cast(out); - const float4 *vals_float4 = reinterpret_cast(in); - const uint64_t *mask8 = reinterpret_cast(mask); - - uint8_t m[8]; - uint64_t *m8 = reinterpret_cast(m); - m8[0] = mask8[i]; - - float4 val_float4 = vals_float4[i]; - float4 out_float4; - __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); - __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); - __half2 scale_mask_1 = - __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1])); - __half2 scale_mask_2 = - __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3])); - __half2 scale_mask_3 = - __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5])); - __half2 scale_mask_4 = - __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7])); - out_half2[0] = __hmul2(val_half2[0], scale_mask_1); - out_half2[1] = __hmul2(val_half2[1], scale_mask_2); - out_half2[2] = __hmul2(val_half2[2], scale_mask_3); - out_half2[3] = __hmul2(val_half2[3], scale_mask_4); - out4[i] = out_float4; -} - -template <> -void launch_ls_dropout(float *out, const float *vals, uint8_t *mask, - int total_count, float ratio, cudaStream_t stream, - bool backward) { - int grid_dim = total_count >> 12; - if (!backward) { - ls_dropout_kernel<<>>( - total_count, ratio, out, vals, mask, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count()); - } else { - ls_dropout_bwd_kernel<<>>(total_count, ratio, - out, vals, mask); - } -} - -template <> -void launch_ls_dropout<__half>(__half *out, const __half *vals, uint8_t *mask, - int total_count, float ratio, - cudaStream_t stream, bool backward) { - int grid_dim = total_count >> 13; - if (!backward) { - ls_dropout_kernel<<>>( - total_count, ratio, out, vals, mask, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count()); - } else { - ls_dropout_bwd_kernel<<>>(total_count, ratio, - out, vals, mask); - } -} - -/** - * @brief fused bias, dropout, and residual at the end of Attention and FFN, - * store dropped position in mask, it's not in-place - * - * @thread - * gridDim.x = total_count / 1024 - * blockDim.x = 1024 - * - * @param total_count total elements - * @param ratio drop ratio - * @param out [batch_size, seq_len, hidden_size], float and __half - * @param in [batch_size, seq_len, hidden_size], float and __half - * @param mask [batch_size, seq_len, hidden_size], uint8 type - * @param bias [hidden_size], ffn bias - * @param residual [batch_size, seq_len, hidden_size], float and __half - * @param seed seed to curand - * @param hidden_size hidden size - * @return void - */ -__global__ void ls_dropout_res_bias_kernel( - const int total_count, const float ratio, float *__restrict__ out, - const float *__restrict__ in, uint8_t *__restrict__ mask, - const float *__restrict__ bias, const float *__restrict__ residual, - const int seed, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 4 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - uint8_t m[4]; - - float4 *out4 = reinterpret_cast(out); - const float4 *data4 = reinterpret_cast(in); - const float4 *residual4 = reinterpret_cast(residual); - const float4 *bias4 = reinterpret_cast(bias); - uint32_t *mask4 = reinterpret_cast(mask); - float4 rand = curand_uniform4(&state); - - m[0] = static_cast(rand.x > ratio); - m[1] = static_cast(rand.y > ratio); - m[2] = static_cast(rand.z > ratio); - m[3] = static_cast(rand.w > ratio); - - int bias_i = i % (hidden_size >> 2); - uint32_t *m4 = reinterpret_cast(m); - mask4[i] = m4[0]; - const float4 input4 = data4[i]; - const float4 b4 = __ldg(&bias4[bias_i]); - const float4 res4 = residual4[i]; - float4 output4; - - output4.x = (input4.x + b4.x) * scale * m[0] + res4.x; - output4.y = (input4.y + b4.y) * scale * m[1] + res4.y; - output4.z = (input4.z + b4.z) * scale * m[2] + res4.z; - output4.w = (input4.w + b4.w) * scale * m[3] + res4.w; - - out4[i] = output4; -} - -__global__ void ls_dropout_res_bias_kernel( - const int total_count, const float ratio, __half *__restrict__ out, - const __half *__restrict__ in, uint8_t *__restrict__ mask, - const __half *__restrict__ bias, const __half *__restrict__ residual, - const int seed, const int hidden_size) { - const __half scale = 1. / (1. - ratio); - - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 8 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - - const float4 *vals_float4 = reinterpret_cast(in); - float4 *outs_float4 = reinterpret_cast(out); - const float4 *residual4 = reinterpret_cast(residual); - const float4 *bias4 = reinterpret_cast(bias); - uint64_t *mask8 = reinterpret_cast(mask); - - uint8_t m[8]; - float4 rand = curand_uniform4(&state); - m[0] = static_cast(rand.x > ratio); - m[1] = static_cast(rand.y > ratio); - m[2] = static_cast(rand.z > ratio); - m[3] = static_cast(rand.w > ratio); - rand = curand_uniform4(&state); - m[4] = static_cast(rand.x > ratio); - m[5] = static_cast(rand.y > ratio); - m[6] = static_cast(rand.z > ratio); - m[7] = static_cast(rand.w > ratio); - uint64_t *m8 = reinterpret_cast(m); - mask8[i] = m8[0]; - - int bias_i = i % (hidden_size >> 3); - float4 val_float4 = vals_float4[i]; - const float4 b4 = __ldg(&bias4[bias_i]); - const float4 res4 = residual4[i]; - float4 out_float4; - - __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); - __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); - const __half2 *b_half2 = reinterpret_cast(&b4); - const __half2 *res_half2 = reinterpret_cast(&res4); - __half2 scale_mask_1 = - __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1])); - __half2 scale_mask_2 = - __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3])); - __half2 scale_mask_3 = - __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5])); - __half2 scale_mask_4 = - __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7])); - out_half2[0] = - __hfma2(__hadd2(val_half2[0], b_half2[0]), scale_mask_1, res_half2[0]); - out_half2[1] = - __hfma2(__hadd2(val_half2[1], b_half2[1]), scale_mask_2, res_half2[1]); - out_half2[2] = - __hfma2(__hadd2(val_half2[2], b_half2[2]), scale_mask_3, res_half2[2]); - out_half2[3] = - __hfma2(__hadd2(val_half2[3], b_half2[3]), scale_mask_4, res_half2[3]); - outs_float4[i] = out_float4; -} - -template <> -void launch_ls_dropout_res_bias(float *out, const float *vals, - uint8_t *mask, const float *bias, - const float *residual, int total_count, - int dim, float ratio, - cudaStream_t stream) { - int grid_dim = total_count >> 12; - ls_dropout_res_bias_kernel<<>>( - total_count, ratio, out, vals, mask, bias, residual, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -template <> -void launch_ls_dropout_res_bias<__half>(__half *out, const __half *vals, - uint8_t *mask, const __half *bias, - const __half *residual, int total_count, - int dim, float ratio, - cudaStream_t stream) { - int grid_dim = total_count >> 13; - ls_dropout_res_bias_kernel<<>>( - total_count, ratio, out, vals, mask, bias, residual, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -/** - * @brief fused bias and dropout backward at the end of Attention and FFN - * - * @thread - * gridDim.x = hidden_size / 8 - * blockDim.x = 8 - * blockDim.y = 1024 / 8 = 128 - * - * @param row_size batch_size * seq_len - * @param ratio dropout ratio - * @param in_grad [batch_size, seq_len, hidden_size], input grad - * @param bias_grad [hidden_size], bias grad - * @param out_grad [batch_size, seq_len, hidden_size], output grad - * @param mask [batch_size, seq_len, hidden_size], dropout mask - * @param hidden_size - * @return void - */ -__global__ void ls_dropout_bias_bwd_kernel( - const int row_size, const float ratio, float *__restrict__ in_grad, - float *__restrict__ bias_grad, const float *__restrict__ out_grad, - const uint8_t *__restrict__ mask, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - // every block generate 8 bias result - __shared__ float tile[8][129]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8); - int stride = hidden_size * 128; - float local_sum = 0; - - int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); - for (int r = threadIdx.y; r < row_size; r += 128) { - float val = out_grad[idx]; - val *= scale * static_cast(mask[idx]); - local_sum += val; - in_grad[idx] = val; - idx += stride; - } - - tile[threadIdx.x][threadIdx.y] = local_sum; - __syncthreads(); - - float sum = 0; - int tid = threadIdx.y * blockDim.x + threadIdx.x; - int x = tid >> 7; - int y = tid & (127); - if (y < 32) { -#pragma unroll - for (int i = 0; i < 4; i++) { - sum += tile[x][y + i * 32]; - } - } - __syncthreads(); - - for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i); - - if (y == 0) tile[0][x] = sum; - __syncthreads(); - - if (threadIdx.x < 8) { - int pos = flat_2dim(blockIdx.x, threadIdx.x, 8); - bias_grad[pos] = tile[0][threadIdx.x]; - } -} - -__global__ void ls_dropout_bias_bwd_kernel( - const int row_size, const float ratio, __half *__restrict__ in_grad, - __half *__restrict__ bias_grad, const __half *__restrict__ out_grad, - const uint8_t *__restrict__ mask, const int hidden_size) { - const __half2 scale = __float2half2_rn(1.f / (1.f - ratio)); - __shared__ __half2 tile[8][129]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad); - const __half2 *out_grad2 = reinterpret_cast(out_grad); - __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad); - - int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8); - int stride = hidden_size * 128; - __half2 local_sum = __float2half2_rn(0.f); - - int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); - for (int r = threadIdx.y; r < row_size; r += 128) { - __half2 val = out_grad2[idx]; - __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]); - val *= scale * m2; - local_sum += val; - in_grad2[idx] = val; - idx += stride; - } - - tile[threadIdx.x][threadIdx.y] = local_sum; - __syncthreads(); - - __half2 sum = __float2half2_rn(0.f); - int tid = threadIdx.y * blockDim.x + threadIdx.x; - int x = tid >> 7; - int y = tid & (127); - if (y < 32) { -#pragma unroll - for (int i = 0; i < 4; i++) { - sum += tile[x][y + i * 32]; - } - } - __syncthreads(); - - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); - - if (y == 0) tile[0][x] = sum; - __syncthreads(); - - if (threadIdx.x < 8) { - int pos = flat_2dim(blockIdx.x, threadIdx.x, 8); - bias_grad2[pos] = tile[0][threadIdx.x]; - } -} - -template -void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad, - const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream) { - dim3 grid_dim((dim - 1) / 8 + 1); - dim3 block_dim(8, 128); - ls_dropout_bias_bwd_kernel<<>>( - row_size, ratio, in_grad, bias_grad, out_grad, mask, dim); -} - -template <> -void launch_ls_dropout_bias_bwd(__half *in_grad, __half *bias_grad, - const __half *out_grad, const uint8_t *mask, - int row_size, int dim, float ratio, - cudaStream_t stream) { - dim >>= 1; - dim3 grid_dim((dim - 1) / 8 + 1); - dim3 block_dim(8, 128); - ls_dropout_bias_bwd_kernel<<>>( - row_size, ratio, in_grad, bias_grad, out_grad, mask, dim); -} - -template void launch_ls_dropout_bias_bwd(float *in_grad, float *bias_grad, - const float *out_grad, - const uint8_t *mask, int row_size, - int dim, float ratio, - cudaStream_t stream); - -/** - * @brief fused bias, activation, and dropout at the end of first ffn - * - * @thread - * gridDim.x = hidden_size / 8 - * blockDim.x = 8 - * blockDim.y = 1024 / 8 = 128 - * - * @tparam act_type activation function, like kRelu, kGelu - * @param total_count total elements - * @param ratio drop ratio - * @param out [batch_size, seq_len, hidden_size], float and __half - * @param in [batch_size, seq_len, hidden_size], float and __half - * @param mask [batch_size, seq_len, hidden_size], uint8 type - * @param bias [hidden_size], ffn bias - * @param seed seed to curand - * @param hidden_size - * @return void - */ -template -__global__ void ls_dropout_act_bias_kernel( - const int total_count, const float ratio, float *__restrict__ out, - const float *__restrict__ in, uint8_t *__restrict__ mask, - const float *__restrict__ bias, const int seed, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 4 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - uint8_t m[4]; - - float4 *out4 = reinterpret_cast(out); - const float4 *data4 = reinterpret_cast(in); - const float4 *bias4 = reinterpret_cast(bias); - uint32_t *mask4 = reinterpret_cast(mask); - float4 rand = curand_uniform4(&state); - - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - - int bias_i = i % (hidden_size >> 2); - uint32_t *m4 = reinterpret_cast(m); - mask4[i] = m4[0]; - const float4 input4 = data4[i]; - const float4 b4 = __ldg(&bias4[bias_i]); - float4 output4; - - output4.x = - activation_kernel(input4.x + b4.x) * scale * m[0]; - output4.y = - activation_kernel(input4.y + b4.y) * scale * m[1]; - output4.z = - activation_kernel(input4.z + b4.z) * scale * m[2]; - output4.w = - activation_kernel(input4.w + b4.w) * scale * m[3]; - - out4[i] = output4; -} - -template -__global__ void ls_dropout_act_bias_kernel( - const int total_count, const float ratio, __half *__restrict__ out, - const __half *__restrict__ in, uint8_t *__restrict__ mask, - const __half *__restrict__ bias, const int seed, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 8 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - - const float4 *vals_float4 = reinterpret_cast(in); - float4 *outs_float4 = reinterpret_cast(out); - const float4 *bias4 = reinterpret_cast(bias); - uint64_t *mask8 = reinterpret_cast(mask); - - uint8_t m[8]; - float4 rand = curand_uniform4(&state); - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - rand = curand_uniform4(&state); - m[4] = (uint8_t)(rand.x > ratio); - m[5] = (uint8_t)(rand.y > ratio); - m[6] = (uint8_t)(rand.z > ratio); - m[7] = (uint8_t)(rand.w > ratio); - uint64_t *m8 = reinterpret_cast(m); - mask8[i] = *m8; - - int bias_i = i % (hidden_size >> 3); - float4 val_float4 = vals_float4[i]; - const float4 b4 = __ldg(&bias4[bias_i]); - float4 out_float4; - - __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); - __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); - const __half2 *b_half2 = reinterpret_cast(&b4); - - __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]); - __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]); - __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]); - __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]); - out_half2[0] = __hmul2( - activation_kernel(__hadd2(val_half2[0], b_half2[0])), - scale_mask_1); - out_half2[1] = __hmul2( - activation_kernel(__hadd2(val_half2[1], b_half2[1])), - scale_mask_2); - out_half2[2] = __hmul2( - activation_kernel(__hadd2(val_half2[2], b_half2[2])), - scale_mask_3); - out_half2[3] = __hmul2( - activation_kernel(__hadd2(val_half2[3], b_half2[3])), - scale_mask_4); - outs_float4[i] = out_float4; -} - -template <> -void launch_ls_dropout_act_bias( - float *out, const float *vals, uint8_t *mask, const float *bias, - int total_count, int dim, float ratio, cudaStream_t stream) { - int grid_dim = total_count >> 10; - ls_dropout_act_bias_kernel - <<>>( - total_count, ratio, out, vals, mask, bias, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -template <> -void launch_ls_dropout_act_bias( - __half *out, const __half *vals, uint8_t *mask, const __half *bias, - int total_count, int dim, float ratio, cudaStream_t stream) { - int grid_dim = total_count >> 11; - ls_dropout_act_bias_kernel - <<>>( - total_count, ratio, out, vals, mask, bias, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -template <> -void launch_ls_dropout_act_bias( - float *out, const float *vals, uint8_t *mask, const float *bias, - int total_count, int dim, float ratio, cudaStream_t stream) { - int grid_dim = total_count >> 10; - ls_dropout_act_bias_kernel - <<>>( - total_count, ratio, out, vals, mask, bias, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -template <> -void launch_ls_dropout_act_bias( - __half *out, const __half *vals, uint8_t *mask, const __half *bias, - int total_count, int dim, float ratio, cudaStream_t stream) { - int grid_dim = total_count >> 11; - ls_dropout_act_bias_kernel - <<>>( - total_count, ratio, out, vals, mask, bias, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -/** - * @brief fused bias, activation, and dropout backward - * - * @thread - * gridDim.x = total_count / 1024 - * blockDim.x = 1024 - * - * @tparam act_type kRelu - * @param row_size batch_size * seq_len - * @param ratio dropout ratio - * @param in_grad [batch_size, seq_len, hidden_size], input grad - * @param bias_grad [hidden_size], bias grad - * @param out_grad [batch_size, seq_len, hidden_size], output grad - * @param mask [batch_size, seq_len, hidden_size], dropout mask - * @param hidden_size - * @return void - */ -template -__global__ void ls_dropout_act_bias_bwd_kernel( - const int row_size, const float ratio, T *in_grad, - T *__restrict__ bias_grad, const T *__restrict__ input, - const T *__restrict__ bias, const T *out_grad, - const uint8_t *__restrict__ mask, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - __shared__ float tile[WARP_SIZE][WARP_SIZE + 1]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); - - int stride = hidden_size * WARP_SIZE; - float local_sum = 0; - - int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); - if (col_idx < hidden_size) { - for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) { - float val = out_grad[idx]; - float in = input[idx]; - float b = bias[idx % hidden_size]; - val = activation_bwd_kernel( - val * scale * static_cast(mask[idx]), in + b); - local_sum += val; - in_grad[idx] = val; - idx += stride; - } - } - - tile[threadIdx.x][threadIdx.y] = local_sum; - __syncthreads(); - float sum = tile[threadIdx.y][threadIdx.x]; - __syncthreads(); - - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); - - if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; - __syncthreads(); - - if (threadIdx.y == 0) { - int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); - bias_grad[pos] = tile[0][threadIdx.x]; - } -} - -// @brief fused bias, activation, and dropout backward -// It is deprecated for precision reason. Keep it for future optimization. -// -// template -// __global__ void ls_dropout_act_bias_bwd_kernel( -// const int row_size, const float ratio, __half * in_grad, -// __half *__restrict__ bias_grad, const __half *__restrict__ input, const -// __half *__restrict__ bias, const __half * out_grad, const uint8_t -// *__restrict__ mask, const int hidden_size) { -// const __half2 scale = __float2half2_rn(1.f / (1.f - ratio)); -// __shared__ __half2 tile[WARP_SIZE][WARP_SIZE + 1]; - -// cg::thread_block b = cg::this_thread_block(); -// cg::thread_block_tile g = cg::tiled_partition(b); - -// __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad); -// __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad); -// const __half2 *out_grad2 = reinterpret_cast(out_grad); -// const __half2 *input2 = reinterpret_cast(input); -// const __half2 *bias2 = reinterpret_cast(bias); - -// int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); - -// int stride = hidden_size * WARP_SIZE; -// __half2 local_sum = __float2half2_rn(0.f); - -// int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); -// if (col_idx < hidden_size) { -// for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) { -// __half2 val = out_grad2[idx]; -// __half2 in2 = input2[idx]; -// __half2 b2 = bias2[idx % hidden_size ]; -// __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]); -// val = activation_bwd_kernel(val * scale -// * -// m2, -// in2+b2); -// local_sum += val; -// in_grad2[idx] = val; -// idx += stride; -// } -// } - -// tile[threadIdx.x][threadIdx.y] = local_sum; -// __syncthreads(); -// __half2 sum = tile[threadIdx.y][threadIdx.x]; -// __syncthreads(); - -// for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); - -// if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; -// __syncthreads(); - -// if (threadIdx.y == 0) { -// int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); -// bias_grad2[pos] = tile[0][threadIdx.x]; -// } -// } - -template -void launch_ls_dropout_act_bias_bwd(T *in_grad, T *bias_grad, const T *input, - const T *bias, const T *out_grad, - const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream) { - dim3 grid_dim((dim - 1) / WARP_SIZE + 1); - dim3 block_dim(WARP_SIZE, WARP_SIZE); - ls_dropout_act_bias_bwd_kernel<<>>( - row_size, ratio, in_grad, bias_grad, input, bias, out_grad, mask, dim); -} - -// template <> -// void launch_ls_dropout_act_bias_bwd( -// __half *in_grad, __half *bias_grad,const __half *input, const __half -// *bias, const __half *out_grad, const uint8_t *mask, int row_size, int -// dim, float ratio, cudaStream_t stream) { -// dim >>= 1; -// dim3 grid_dim((dim - 1) / WARP_SIZE + 1); -// dim3 block_dim(WARP_SIZE, WARP_SIZE); -// ls_dropout_act_bias_bwd_kernel -// <<>>(row_size, ratio, in_grad, -// bias_grad, -// input, bias,out_grad, mask, dim); -// } - -template void launch_ls_dropout_act_bias_bwd( - float *in_grad, float *bias_grad, const float *input, const float *bias, - const float *out_grad, const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); - -template void launch_ls_dropout_act_bias_bwd( - __half *in_grad, __half *bias_grad, const __half *input, const __half *bias, - const __half *out_grad, const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); - -template void launch_ls_dropout_act_bias_bwd( - float *in_grad, float *bias_grad, const float *input, const float *bias, - const float *out_grad, const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); - -template void launch_ls_dropout_act_bias_bwd( - __half *in_grad, __half *bias_grad, const __half *input, const __half *bias, - const __half *out_grad, const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu deleted file mode 100644 index 625b02cd2..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu +++ /dev/null @@ -1,232 +0,0 @@ -#include - -#include "kernels.h" - -namespace cg = cooperative_groups; - -/** -@brief: fuse_transpose_bias -Calculate the sum of elements in each column of the matrix. - -@thread -gridDim.x = ceil(cols / WARP_SIZE) -blockDim.x = WARP_SIZE -blockDim.y = WARP_SIZE - -@param -inp: [rows, cols] -out: [cols] -rows: the number of rows in the matrix -cols: the number of cols in the matrix -*/ -template -__global__ void column_sum_reduce(const T *__restrict__ inp, - T *__restrict__ out, int rows, int cols) { - __shared__ float tile[WARP_SIZE][WARP_SIZE]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); - int y_stride = cols * WARP_SIZE; - float localSum = 0; - - // Loop across matrix row - // TODO: optimize to log complexity - if (idx < cols) { - int offset = flat_2dim(threadIdx.y, idx, cols); - for (int r = threadIdx.y; r < rows; r += WARP_SIZE) { - localSum += (float)inp[offset]; - offset += y_stride; - } - } - - // The sum of a row in tile is equal to the sum of a col in original matrix - tile[threadIdx.x][threadIdx.y] = localSum; - - __syncthreads(); - - // Sum the shared buffer. - // The change of threadIdx.x is continuous - float sum = tile[threadIdx.y][threadIdx.x]; - - __syncthreads(); - - // Calculate the sum of a row in tile - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); - - if (threadIdx.x == 0) { - int pos = flat_2dim(blockIdx.x, threadIdx.y, WARP_SIZE); - if (pos < cols) out[pos] = sum; - } -} - -// [r, c] -> [c] -template <> -void launch_fuse_transpose_bias_kernel(const float *inp, float *out, - int rows, int cols, - cudaStream_t stream) { - dim3 grid_dim((cols - 1) / WARP_SIZE + 1); - dim3 block_dim(WARP_SIZE, WARP_SIZE); - - column_sum_reduce - <<>>(inp, out, rows, cols); -} - -template <> -void launch_fuse_transpose_bias_kernel<__half>(const __half *inp, __half *out, - int rows, int cols, - cudaStream_t stream) { - dim3 grid_dim((cols - 1) / WARP_SIZE + 1); - dim3 block_dim(WARP_SIZE, WARP_SIZE); - - column_sum_reduce<__half> - <<>>(inp, out, rows, cols); -} - -/** -@brief: fused_add2 -Add two matrix inp1 and inp2 to out. - -@thread -gridDim.x = batch_size * seq_len -blockDim.x = min(hidden_dim, MAX_THREADS) - -@param -inp1: [batch_size, seq_len, hidden_dim] -inp2: [batch_size, seq_len, hidden_dim] -out: [batch_size, seq_len, hidden_dim] -batch_size: the size of the current batch -seq_len: the sequence length of the current batch -hidden_dim: dim of the hidden tensor -*/ -template -__global__ void fused_add2_kernel(T *out, const T *inp1, const T *inp2, - int hidden_dim); - -template <> -__global__ void fused_add2_kernel(float *out, const float *inp1, - const float *inp2, int hidden_dim) { - int row_id = blockIdx.x; - int offset = flat_2dim(row_id, 0, hidden_dim); - - const float4 *inp1_4 = reinterpret_cast(inp1); - const float4 *inp2_4 = reinterpret_cast(inp2); - float4 *out_4 = reinterpret_cast(out); - float4 vinp1; - float4 vinp2; - float4 val; - - for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - vinp1 = inp1_4[offset + i]; - vinp2 = inp2_4[offset + i]; - val.x = vinp1.x + vinp2.x; - val.y = vinp1.y + vinp2.y; - val.z = vinp1.z + vinp2.z; - val.w = vinp1.w + vinp2.w; - out_4[offset + i] = val; - } -} - -template <> -__global__ void fused_add2_kernel<__half>(__half *out, const __half *inp1, - const __half *inp2, int hidden_dim) { - int row_id = blockIdx.x; - int offset = flat_2dim(row_id, 0, hidden_dim); - - const float4 *inp1_4 = reinterpret_cast(inp1); - const float4 *inp2_4 = reinterpret_cast(inp2); - float4 *out_4 = reinterpret_cast(out); - float4 vinp1; - float4 vinp2; - float4 val; - __half2 *h2_inp1 = reinterpret_cast<__half2 *>(&vinp1); - __half2 *h2_inp2 = reinterpret_cast<__half2 *>(&vinp2); - __half2 *h2_val = reinterpret_cast<__half2 *>(&val); - - for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - vinp1 = inp1_4[offset + i]; - vinp2 = inp2_4[offset + i]; - h2_val[0] = __hadd2(h2_inp1[0], h2_inp2[0]); - h2_val[1] = __hadd2(h2_inp1[1], h2_inp2[1]); - h2_val[2] = __hadd2(h2_inp1[2], h2_inp2[2]); - h2_val[3] = __hadd2(h2_inp1[3], h2_inp2[3]); - out_4[offset + i] = val; - } -} - -//[b, s, h] -> [b, s, h] -template <> -void launch_fused_add2(float *out, const float *inp1, const float *inp2, - int batch_size, int seq_len, int hidden_dim, - cudaStream_t &stream) { - hidden_dim >>= 2; - - dim3 grid_dim(batch_size * seq_len); - dim3 block_dim(min(hidden_dim, MAX_THREADS)); - - fused_add2_kernel<<>>(out, inp1, inp2, - hidden_dim); -} - -template <> -void launch_fused_add2<__half>(__half *out, const __half *inp1, - const __half *inp2, int batch_size, int seq_len, - int hidden_dim, cudaStream_t &stream) { - hidden_dim >>= 3; - - dim3 grid_dim(batch_size * seq_len); - dim3 block_dim(min(hidden_dim, MAX_THREADS)); - - fused_add2_kernel<<>>(out, inp1, inp2, - hidden_dim); -} - -template -__global__ void kernel_concat3_dim1(const T *inp1, const T *inp2, T *output, - int sz0, int sz2, int sz1_1, int sz1_2) { - int nele = sz0 * sz2 * (sz1_1 + sz1_2); - int idx = flat_2dim(blockIdx.x, threadIdx.x, blockDim.x); - if (idx >= nele) { - return; - } - float4 *dst_ptr = (float4 *)output + idx; - int idx2 = idx % sz2; - idx = idx / sz2; - int idx1 = idx % (sz1_1 + sz1_2); - int idx0 = idx / (sz1_1 + sz1_2); - float4 *src_ptr = nullptr; - int sz1 = 0; - if (idx1 < sz1_1) { - sz1 = sz1_1; - src_ptr = (float4 *)inp1; - } else { - idx1 -= sz1_1; - sz1 = sz1_2; - src_ptr = (float4 *)inp2; - } - src_ptr += flat_3dim(idx0, idx1, idx2, sz1, sz2); - dst_ptr[0] = src_ptr[0]; -} - -template <> -void launch_concat3_dim1(const float *inp1, const float *inp2, - float *output, int sz0, int sz2, int sz1_1, - int sz1_2, cudaStream_t stream) { - sz2 >>= 2; - int nele = sz0 * sz2 * (sz1_1 + sz1_2); - int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; - kernel_concat3_dim1<<>>( - inp1, inp2, output, sz0, sz2, sz1_1, sz1_2); -} - -template <> -void launch_concat3_dim1<__half>(const __half *inp1, const __half *inp2, - __half *output, int sz0, int sz2, int sz1_1, - int sz1_2, cudaStream_t stream) { - sz2 >>= 3; - int nele = sz0 * sz2 * (sz1_1 + sz1_2); - int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; - kernel_concat3_dim1<<>>( - inp1, inp2, output, sz0, sz2, sz1_1, sz1_2); -} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/context.h b/colossalai/kernel/cuda_native/csrc/kernels/include/context.h deleted file mode 100644 index f7d75f38c..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/context.h +++ /dev/null @@ -1,36 +0,0 @@ -#pragma once - -#include -#include - -#include -#include - -#include "cuda_util.h" - -class Context { - public: - Context() : _stream(nullptr) { - CHECK_GPU_ERROR(cublasCreate(&_cublasHandle)); - } - - virtual ~Context() {} - - static Context &Instance() { - static Context _ctx; - return _ctx; - } - - void set_stream(cudaStream_t stream) { - _stream = stream; - CHECK_GPU_ERROR(cublasSetStream(_cublasHandle, _stream)); - } - - cudaStream_t get_stream() { return _stream; } - - cublasHandle_t get_cublashandle() { return _cublasHandle; } - - private: - cudaStream_t _stream; - cublasHandle_t _cublasHandle; -}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h b/colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h deleted file mode 100644 index f4e9befc6..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h +++ /dev/null @@ -1,46 +0,0 @@ -#pragma once - -#include -#include -#include - -#include - -#include "cuda_util.h" - -template -class CrossEntropyLayer { - public: - CrossEntropyLayer(float epsilon, int padding_idx, int max_batch_tokens); - - virtual ~CrossEntropyLayer(); - - void Forward(const T *inputs_ptr, const int *targets_ptr, float *outputs_ptr, - float *nll_loss_ptr); - - void Backward(const float *grad_outputs_ptr, const T *inputs_ptr, - const int *targets_ptr, T *grad_inputs_ptr); - - void set_cur_batch_shape(int batch_size, int seq_len, int vocab_size); - - private: - void allocate_mem_buffer() { - // allocate local gpu memory - _loss_buffer = cuda_malloc(_max_batch_tokens * 2); - } - - void free_mem_buffer() { - // free local gpu memory - cuda_free(_loss_buffer); - } - - const int _padding_idx; - const float _epsilon; - const int _max_batch_tokens; - - size_t _batch_size; - size_t _seq_len; - size_t _vocab_size; - - float *_loss_buffer; -}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/cublas_wrappers.h b/colossalai/kernel/cuda_native/csrc/kernels/include/cublas_wrappers.h deleted file mode 100644 index 90255152b..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/cublas_wrappers.h +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright 2021 The LightSeq Team - Copyright Microsoft DeepSpeed - This file is adapted from Microsoft DeepSpeed - Licensed under the MIT License. -*/ -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const float *alpha, const float *beta, const float *A, - const float *B, float *C, - cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT); - -int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const float *alpha, const float *beta, const __half *A, - const __half *B, __half *C, - cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP); - -int cublas_strided_batched_gemm(cublasHandle_t handle, int m, int n, int k, - const float *alpha, const float *beta, - const float *A, const float *B, float *C, - cublasOperation_t op_A, cublasOperation_t op_B, - int stride_A, int stride_B, int stride_C, - int batch, - cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT); - -int cublas_strided_batched_gemm( - cublasHandle_t handle, int m, int n, int k, const float *alpha, - const float *beta, const __half *A, const __half *B, __half *C, - cublasOperation_t op_A, cublasOperation_t op_B, int stride_A, int stride_B, - int stride_C, int batch, - cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h b/colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h deleted file mode 100644 index 1595257be..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h +++ /dev/null @@ -1,34 +0,0 @@ -#pragma once - -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -template -void check_gpu_error(T result, char const *const func, const char *const file, - int const line); - -#define CHECK_GPU_ERROR(val) check_gpu_error((val), #val, __FILE__, __LINE__) - -template -void print_vec(const T *outv, std::string outn, int num_output_ele); - -template -T *cuda_malloc(size_t ele_num); - -void cuda_free(void *pdata); - -template -void check_nan_inf(const T *data_ptr, int dsize, bool check_nan_inf, - std::string file, int line, cudaStream_t stream); - -#define CHECK_NAN_INF(ptr, size, stream) \ - check_nan_inf((ptr), (size), true, __FILE__, __LINE__, (stream)); \ - check_nan_inf((ptr), (size), false, __FILE__, __LINE__, (stream)) diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h b/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h deleted file mode 100644 index 025fbf3f8..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h +++ /dev/null @@ -1,96 +0,0 @@ -#pragma once - -#include -#include -#include - -#include - -#include "kernels.h" - -template -class Dropout { - public: - struct Config { - float ratio; - bool training; - - Config(float r) : ratio(r), training(true) {} - float RATIO() const { return training ? ratio : 0.0; } - }; - - Dropout(const Config &config, size_t max_ele_num) - : _config(config), _mask(nullptr) { - _mask = cuda_malloc(max_ele_num); - } - - virtual ~Dropout() { cuda_free(_mask); } - - // after attention softmax - void dropout(T *output, const T *input, int count, cudaStream_t stream, - bool bwd = false) { - launch_ls_dropout(output, input, _mask, count, _config.RATIO(), stream, - bwd); - } - - void d_dropout(T *d_inp_out, int count, cudaStream_t stream) { - launch_ls_dropout(d_inp_out, d_inp_out, _mask, count, _config.RATIO(), - stream, true); - } - - // transformer layer's postprocessing dropout, after attn or ffn module, - // before residual add. - void bias_dropout_residual(T *output, const T *input, const T *residual, - const T *bias, int rows, int cols, - cudaStream_t stream) { - launch_ls_dropout_res_bias(output, input, _mask, bias, residual, - rows * cols, cols, _config.RATIO(), stream); - } - - void d_bias_dropout_residual(T *d_input, T *d_bias, const T *d_output, - int rows, int cols, cudaStream_t stream) { - launch_ls_dropout_bias_bwd(d_input, d_bias, d_output, _mask, rows, cols, - _config.RATIO(), stream); - } - - // dropout inside ffn. - void bias_act_dropout(T *output, const T *input, const T *bias, int rows, - int cols, std::string activation_fn, - cudaStream_t stream) { - if (activation_fn == "relu") { - launch_ls_dropout_act_bias( - output, input, _mask, bias, rows * cols, cols, _config.RATIO(), - stream); - } else if (activation_fn == "gelu") { - launch_ls_dropout_act_bias( - output, input, _mask, bias, rows * cols, cols, _config.RATIO(), - stream); - } else { - throw std::runtime_error("not supported activation: " + activation_fn); - } - } - - void d_bias_act_dropout(T *d_inp_out, T *d_bias_out, const T *input, - const T *bias, int rows, int cols, - std::string activation_fn, cudaStream_t stream) { - if (activation_fn == "relu") { - launch_ls_dropout_act_bias_bwd( - d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols, - _config.RATIO(), stream); - } else if (activation_fn == "gelu") { - launch_ls_dropout_act_bias_bwd( - d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols, - _config.RATIO(), stream); - } else { - throw std::runtime_error("not supported activation: " + activation_fn); - } - } - - bool HasDropout() const { return _config.RATIO() > 0.0; } - - void SetTrainingMode(bool training) { _config.training = training; } - - private: - uint8_t *_mask; - Config _config; -}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h b/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h deleted file mode 100644 index 8186da1ee..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h +++ /dev/null @@ -1,69 +0,0 @@ -#pragma once - -/* Copyright 2021 The LightSeq Team - Copyright Microsoft DeepSpeed - This file is adapted from Microsoft DeepSpeed - Licensed under the MIT License. -*/ -#include -#include -#include - -#include - -#include "cublas_wrappers.h" -#include "kernels.h" - -template -class FeedForward { - public: - struct Config { - int outputSize; - int inputSize; - std::array gemm_algos; - Config(int outputs, int inputs) - : outputSize(outputs), - inputSize(inputs), - gemm_algos(std::array({99, 99, 99})) {} - }; - - FeedForward(Config config) : config_(config) {} - - ~FeedForward() {} - - void Forward(int bsz, const T *input_ptr, const T *weights, T *out, - cublasHandle_t &_cublasHandle) { - float alpha = T(1.); - float beta = T(0.); - - cublas_gemm_ex(_cublasHandle, CUBLAS_OP_T, CUBLAS_OP_N, config_.outputSize, - bsz, config_.inputSize, &alpha, &beta, weights, input_ptr, - out, cublasGemmAlgo_t(config_.gemm_algos[0])); - } - void Backward(int bsz, const T *out_grad, const T *input_ptr, - const T *weights, T *weights_grad, T *bias_grad, - cublasHandle_t &_cublasHandle, cudaStream_t &stream, - T *inp_grad_out = nullptr, T *out_grad_trans_out = nullptr, - bool compute_bias = true) { - float alpha = (T)1.0, beta = (T)0.0; - cublas_gemm_ex(_cublasHandle, CUBLAS_OP_N, CUBLAS_OP_T, config_.inputSize, - config_.outputSize, bsz, &alpha, &beta, input_ptr, out_grad, - weights_grad, cublasGemmAlgo_t(config_.gemm_algos[1])); - - cublas_gemm_ex(_cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N, config_.inputSize, - bsz, config_.outputSize, &alpha, &beta, weights, out_grad, - inp_grad_out, cublasGemmAlgo_t(config_.gemm_algos[2])); - if (compute_bias) { - launch_fuse_transpose_bias_kernel(out_grad, bias_grad, bsz, - config_.outputSize, stream); - } - } - - void reset_size(int outputSize, int inputSize) { - config_.outputSize = outputSize; - config_.inputSize = inputSize; - } - - private: - Config config_; -}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h b/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h deleted file mode 100644 index 735e1363c..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h +++ /dev/null @@ -1,275 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include - -#include - -#define MAX_THREADS 1024 -#define WARP_SIZE 32 - -enum class ActivationType { kRelu, kGelu }; - -void launch_curand_init(int total_count, int dim, cudaStream_t stream); - -template -void launch_layer_norm(T *ln_res, T *vars, T *means, const T *inp, - const T *scale, const T *bias, int batch_size, - int hidden_dim, cudaStream_t stream); - -template -void launch_ln_bw(T *gamma_grad, T *betta_grad, T *inp_grad, const T *out_grad, - const T *residual_grad, const T *inp_or_out, const T *gamma, - const T *betta, const T *vars, const T *means, int batch, - int hidden_dim, cudaStream_t stream[2]); - -template -void launch_attn_softmax(T *vals, const T *attn_mask, int batch_size, int heads, - int from_len, int to_len, bool mask_future, - cudaStream_t stream); - -template -void launch_attn_softmax_bw(T *out_grad, const T *soft_inp, int rows, - int softmax_len, cudaStream_t stream); - -// [b, s, h] -> [b, nh, s, ad] -template -void launch_transform_0213(T *output, const T *vals, int batch_size, - int seq_length, int hidden_dim, int nhead, - cudaStream_t stream); - -// [b, s, 3, h] -> [3, b, nh, s, ad] -template -void launch_bias_add_transform_20314(T *output, const T *input, const T *bias, - int dim_0, int dim_1, int dim_2, int dim_3, - int dim_4, cudaStream_t stream); - -// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad] -template -void launch_transform4d_0213(T *output, const T *vals, int batch_size, - int seq_len, int hidden_dim, int nhead, - int trans_count, cudaStream_t stream); - -template -void launch_ls_dropout(T *out, const T *vals, uint8_t *mask, int total_count, - float ratio, cudaStream_t stream, bool backward = false); - -template -void launch_ls_dropout_res_bias(T *out, const T *vals, uint8_t *mask, - const T *bias, const T *residual, - int total_count, int dim, float ratio, - cudaStream_t stream); - -template -void launch_ls_dropout_act_bias(T *out, const T *vals, uint8_t *mask, - const T *bias, int total_count, int dim, - float ratio, cudaStream_t stream); - -template -void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad, - const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); - -template -void launch_ls_dropout_act_bias_bwd(T *in_grad, T *bias_grad, const T *input, - const T *bias, const T *out_grad, - const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); - -template -void launch_fuse_transpose_bias_kernel(const T *inp, T *out, int rows, int cols, - cudaStream_t stream); - -void launch_param_update(const float *input, __half *output, int size, - cudaStream_t stream); - -template -void launch_concat3_dim1(const T *inp1, const T *inp2, T *output, int sz0, - int sz2, int sz1_1, int sz1_2, cudaStream_t stream); - -template -void launch_fused_add2(T *out, const T *inp1, const T *inp2, int batch_size, - int seq_len, int hidden_size, cudaStream_t &stream); - -template -void launch_cross_entropy_fw(const T *inputs_ptr, const int *targets_ptr, - float *outputs_ptr, float *nll_loss_ptr, - float *loss_buffer, const int padding_idx, - const float epsilon, const int batch_size, - const int seq_len, const int vocab_size, - cudaStream_t stream); - -template -void launch_cross_entropy_bw(const float *grad_outputs_ptr, const T *inputs_ptr, - const int *targets_ptr, T *grad_inputs_ptr, - const int padding_idx, const float epsilon, - const int batch_size, const int seq_len, - const int vocab_size, cudaStream_t stream); - -template -void launch_lookup_scale_pos_dropout( - T *output, const int *input, const T *embeddings, const T *pos_embeddings, - uint8_t *dropout_mask, int batch_size, int seq_len, int embedding_dim, - int padding_idx, float dropout_ratio, int step, cudaStream_t &stream); - -template -void launch_d_lookup_scale_pos_dropout( - T *grad_embeddings, const T *grad_output, const int *input, - const uint8_t *dropout_mask, int batch_size, int seq_len, int embedding_dim, - int vocab_size, int padding_idx, float dropout_ratio, cudaStream_t &stream); - -/* Convert 2-dim tensor index into vector index */ -__forceinline__ __host__ __device__ int flat_2dim(int id1, int id2, int dim2) { - return id1 * dim2 + id2; -} - -/* Convert 3-dim tensor index into vector index */ -__forceinline__ __host__ __device__ int flat_3dim(int id1, int id2, int id3, - int dim2, int dim3) { - return id1 * dim2 * dim3 + id2 * dim3 + id3; -} - -/* Convert 4-dim tensor index into vector index */ -__forceinline__ __host__ __device__ int flat_4dim(int id1, int id2, int id3, - int id4, int dim2, int dim3, - int dim4) { - // return id1*(dim2*dim3*dim4) + id2*(dim3*dim4) + id3*dim4 + id4; - int res = id4; - - int ld = dim4; - res += id3 * ld; - - ld *= dim3; - res += id2 * ld; - - ld *= dim2; - res += id1 * ld; - - return res; -} - -/* Convert 5-dim tensor index into vector index */ -__forceinline__ __host__ __device__ int flat_5dim(int id1, int id2, int id3, - int id4, int id5, int dim2, - int dim3, int dim4, - int dim5) { - // return id1*(dim2*dim3*dim4*dim5) + id2*(dim3*dim4*dim5) + id3*(dim4*dim5) + - // id4*dim5 + dim5; - int res = id5; - - int ld = dim5; - res += id4 * ld; - - ld *= dim4; - res += id3 * ld; - - ld *= dim3; - res += id2 * ld; - - ld *= dim2; - res += id1 * ld; - - return res; -} - -/* Convert 6-dim tensor index into vector index */ -__forceinline__ __host__ __device__ int flat_6dim(int id1, int id2, int id3, - int id4, int id5, int id6, - int dim2, int dim3, int dim4, - int dim5, int dim6) { - // return id1*(dim2*dim3*dim4*dim5*dim6) + id2*(dim3*dim4*dim5*dim6) + - // id3*(dim4*dim5*dim6) + id4*(dim5*dim6) + id5*dim6 + id6; - int res = id6; - - int ld = dim6; - res += id5 * ld; - - ld *= dim5; - res += id4 * ld; - - ld *= dim4; - res += id3 * ld; - - ld *= dim3; - res += id2 * ld; - - ld *= dim2; - res += id1 * ld; - - return res; -} - -/* Convert vector index to 6-dim tensor index */ -__forceinline__ __host__ __device__ void decompose_6dim( - int src, int dim1, int dim2, int dim3, int dim4, int dim5, int *id0, - int *id1, int *id2, int *id3, int *id4, int *id5) { - *id5 = src % dim5; - src /= dim5; - - *id4 = src % dim4; - src /= dim4; - - *id3 = src % dim3; - src /= dim3; - - *id2 = src % dim2; - src /= dim2; - - *id1 = src % dim1; - *id0 = src / dim1; -} - -/* Convert vector index to 5-dim tensor index */ -__forceinline__ __host__ __device__ void decompose_5dim(int src, int dim1, - int dim2, int dim3, - int dim4, int *id0, - int *id1, int *id2, - int *id3, int *id4) { - *id4 = src % dim4; - src /= dim4; - - *id3 = src % dim3; - src /= dim3; - - *id2 = src % dim2; - src /= dim2; - - *id1 = src % dim1; - *id0 = src / dim1; -} - -/* Convert vector index to 4-dim tensor index */ -__forceinline__ __host__ __device__ void decompose_4dim(int src, int dim1, - int dim2, int dim3, - int *id0, int *id1, - int *id2, int *id3) { - *id3 = src % dim3; - src /= dim3; - - *id2 = src % dim2; - src /= dim2; - - *id1 = src % dim1; - *id0 = src / dim1; -} - -/* Convert vector index to 3-dim tensor index */ -__forceinline__ __host__ __device__ void decompose_3dim(int src, int dim1, - int dim2, int *id0, - int *id1, int *id2) { - *id2 = src % dim2; - src /= dim2; - - *id1 = src % dim1; - *id0 = src / dim1; -} - -/* Convert vector index to 2-dim tensor index */ -__forceinline__ __host__ __device__ void decompose_2dim(int src, int dim1, - int *id0, int *id1) { - *id1 = src % dim1; - *id0 = src / dim1; -} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/ls_cub.cuh b/colossalai/kernel/cuda_native/csrc/kernels/include/ls_cub.cuh deleted file mode 100644 index 4f65e7b54..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/ls_cub.cuh +++ /dev/null @@ -1,12 +0,0 @@ -// copied from https://github.com/dmlc/dgl/pull/2758 -#ifndef DGL_ARRAY_CUDA_DGL_CUB_CUH_ -#define DGL_ARRAY_CUDA_DGL_CUB_CUH_ - -#define CUB_NS_PREFIX namespace ls { -#define CUB_NS_POSTFIX } -#include "cub/cub.cuh" -#include "cub/util_allocator.cuh" -#undef CUB_NS_POSTFIX -#undef CUB_NS_PREFIX - -#endif diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h b/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h deleted file mode 100644 index a7767e187..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h +++ /dev/null @@ -1,65 +0,0 @@ -#pragma once - -#include -#include -#include - -#include - -#include "kernels.h" - -using namespace std; - -template -class Normalize_Layer { - public: - struct Config { - uint32_t hidden_dim; - bool use_mean; - Config(uint32_t hidden_dim, bool use_mean = false) - : hidden_dim(hidden_dim), use_mean(use_mean) {} - }; - - Normalize_Layer(Config config, size_t max_rows) - : config_(config), vars_(nullptr), means_(nullptr) { - vars_ = cuda_malloc(max_rows); - if (config_.use_mean) { - means_ = cuda_malloc(max_rows); - } - } - - ~Normalize_Layer() { - cuda_free(vars_); - cuda_free(means_); - } - - void Forward(T *ln_res, const T *inp, const T *gamma, const T *betta, - int batch_size, cudaStream_t stream) { - launch_layer_norm(ln_res, vars_, means_, inp, gamma, betta, batch_size, - config_.hidden_dim, stream); - } - - /* - residual_grad, inp_or_out, betta should be treated carefully. - inp_or_out = input if use_mean else output - residual_grad, betta can be nullptr. - residual_grad will be added to dinp if it is not nullptr - which is useful in transformer layer when pre-ln - betta are only used to compute xhat, - (use_mean == false) ^ (betta == nullptr) should be true - */ - void Backward(T *gamma_grad, T *betta_grad, T *inp_grad, const T *out_grad, - const T *residual_grad, const T *inp_or_out, const T *gamma, - const T *betta, int batch_size, cudaStream_t stream[2]) { - launch_ln_bw(gamma_grad, betta_grad, inp_grad, out_grad, residual_grad, - inp_or_out, gamma, betta, vars_, means_, batch_size, - config_.hidden_dim, stream); - } - - inline bool use_mean() const { return config_.use_mean; } - - private: - Config config_; - T *vars_; - T *means_; -}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h b/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h deleted file mode 100644 index b917abaf0..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h +++ /dev/null @@ -1,42 +0,0 @@ -#pragma once - -#include -#include -#include - -#include - -#include "kernels.h" - -using namespace std; - -template -class Softmax { - public: - struct Config { - size_t nhead; - Config(size_t nhead) : nhead(nhead) {} - }; - - Softmax(Config config) : config_(config) {} - - ~Softmax() {} - - void Forward(T *vals, const T *attn_mask, int batch_size, int from_len, - int to_len, cudaStream_t &stream, bool mask_future = true) { - launch_attn_softmax(vals, attn_mask, batch_size, config_.nhead, from_len, - to_len, mask_future, stream); - } - - void Backward(T *out_grad, const T *soft_out, int batch_size, int from_len, - int to_len, cudaStream_t stream) { - launch_attn_softmax_bw(out_grad, soft_out, - batch_size * config_.nhead * from_len, to_len, - stream); - } - - void reset_size(size_t nhead) { config_.nhead = nhead; } - - private: - Config config_; -}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/strided_batch_gemm.h b/colossalai/kernel/cuda_native/csrc/kernels/include/strided_batch_gemm.h deleted file mode 100644 index d386650e8..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/strided_batch_gemm.h +++ /dev/null @@ -1,100 +0,0 @@ -/* Copyright 2021 The LightSeq Team - Copyright Microsoft DeepSpeed - This file is adapted from Microsoft DeepSpeed - Licensed under the MIT License. -*/ -#pragma once - -#include -#include -#include - -#include - -#include "cublas_wrappers.h" - -template -class StridedBatchGemm { - public: - struct Config { - int m; - int n; - int k; - float alpha; - float beta; - cublasOperation_t op_A; - cublasOperation_t op_B; - std::array gemm_algos; - - Config(float param_alpha, float param_beta, cublasOperation_t opA, - cublasOperation_t opB) - : alpha(param_alpha), - beta(param_beta), - op_A(opA), - op_B(opB), - gemm_algos(std::array({99, 99, 99})) {} - void SetConfig(int mm, int nn, int kk) { - m = mm; - n = nn; - k = kk; - } - }; - - StridedBatchGemm(const Config &config) : _config(config) {} - - virtual ~StridedBatchGemm() {} - - void Forward(int bsz, T *output, const T *_buffer_a, const T *_buffer_b, - cublasHandle_t handle) { - int stride_a = _config.m * _config.k; - int stride_b = _config.n * _config.k; - int stride_c = _config.m * _config.n; - - cublas_strided_batched_gemm( - handle, _config.m, _config.n, _config.k, &_config.alpha, &_config.beta, - _buffer_a, _buffer_b, output, _config.op_A, _config.op_B, stride_a, - stride_b, stride_c, bsz, cublasGemmAlgo_t(_config.gemm_algos[0])); - } - - void Backward(int bsz, const T *d_output, const T *_buffer_a, - const T *_buffer_b, cublasHandle_t handle, - T *inpGradA = nullptr, T *inpGradB = nullptr) { - int mb = (_config.op_A == CUBLAS_OP_T ? _config.k : _config.m); - int kb = (_config.op_A == CUBLAS_OP_T ? _config.m : _config.k); - - int stride_a = mb * _config.n; - int stride_b = _config.n * kb; - int stride_c = _config.m * _config.k; - - // B need to transpose. - cublasOperation_t op_b = - (_config.op_B == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); - - // Calculate d_A. - cublas_strided_batched_gemm( - handle, mb, kb, _config.n, &_config.alpha, &_config.beta, - (_config.op_A == CUBLAS_OP_T ? _buffer_b : d_output), - (_config.op_A == CUBLAS_OP_T ? d_output : _buffer_b), inpGradA, - CUBLAS_OP_N, op_b, stride_a, stride_b, stride_c, bsz, - cublasGemmAlgo_t(_config.gemm_algos[1])); - - // A need to transpose. - cublasOperation_t op_a = - (_config.op_A == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); - - stride_a = _config.m * _config.k; - stride_b = _config.m * _config.n; - stride_c = _config.n * _config.k; - - // Calculate d_B. - cublas_strided_batched_gemm( - handle, _config.k, _config.n, _config.m, &_config.alpha, &_config.beta, - _buffer_a, d_output, inpGradB, op_a, CUBLAS_OP_N, stride_a, stride_b, - stride_c, bsz, cublasGemmAlgo_t(_config.gemm_algos[2])); - } - - inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); } - - private: - Config _config; -}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu deleted file mode 100644 index e2f1869b1..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu +++ /dev/null @@ -1,1172 +0,0 @@ -#include - -#include "block_reduce.h" -#include "kernels.h" - -namespace cg = cooperative_groups; -const float LN_EPSILON = 1e-8f; -#define TILE_DIM 32 - -template -__forceinline__ __device__ T add_eps(T x) { - return fabsf(x) > LN_EPSILON ? x : (x < 0 ? -LN_EPSILON : LN_EPSILON); -} - -/** -@brief: ker_layer_norm -Standard layer normalization. -It will not only output the layer norm result, - but also outputs variance. - may also output means, depends on whether - the means argument is nullptr - -@thread -gridDim.x = batch_size * seq_len -blockDim.x = hidden_size - -@param -ln_res: [batch_size* seq_len, hidden_size], ln result. -vars: [batch_size* seq_len], variance per token -means: [batch_size* seq_len], means per token, can be nullput -inp: [batch_size * seq_len, hidden_size], ln input. -scale: [hidden_size], ln scale -bias: [hidden_size], ln bias -*/ -template -__global__ void ker_layer_norm(T *ln_res, T *vars, T *means, const T *inp, - const T *scale, const T *bias, int hidden_size) { - // step 0. compute local sum - float l_sum = 0; - float l_square_sum = 0; - const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size; - for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float4 val = inp_f4[idx]; - l_sum += val.x + val.y + val.z + val.w; - l_square_sum += - val.x * val.x + val.y * val.y + val.z * val.z + val.w * val.w; - } - - // step 1. compute reduce sum - float mean_dim = float(hidden_size) * 4.f; - float reduce_val[2] = {l_sum, l_square_sum}; - blockReduce(reduce_val); - __shared__ float s_mean, s_var; - if (threadIdx.x == 0) { - s_mean = reduce_val[0] / mean_dim; - if (means != nullptr) { - means[blockIdx.x] = s_mean; - } - s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; - vars[blockIdx.x] = s_var; - s_var = rsqrtf(s_var); - } - __syncthreads(); - - // step 2. layer norm result - float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size; - for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float4 vscale = __ldg((const float4 *)scale + idx); - float4 vbias = __ldg((const float4 *)bias + idx); - float4 val = inp_f4[idx]; - val.x = (val.x - s_mean) * s_var * vscale.x + vbias.x; - val.y = (val.y - s_mean) * s_var * vscale.y + vbias.y; - val.z = (val.z - s_mean) * s_var * vscale.z + vbias.z; - val.w = (val.w - s_mean) * s_var * vscale.w + vbias.w; - output_f4[idx] = val; - } -} - -template <> -__global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars, - __half *means, const __half *inp, - const __half *scale, const __half *bias, - int hidden_size) { - // step 0. compute local sum - float l_sum = 0; - float l_square_sum = 0; - const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size; - for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float4 val_f4 = inp_f4[idx]; - __half2 *val_h2 = (__half2 *)(&val_f4); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 val_f2 = __half22float2(val_h2[i]); - l_sum += val_f2.x + val_f2.y; - l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y; - } - } - - // step 1. compute reduce sum - float mean_dim = float(hidden_size) * 8.f; - float reduce_val[2] = {l_sum, l_square_sum}; - blockReduce(reduce_val); - __shared__ float s_mean, s_var; - if (threadIdx.x == 0) { - s_mean = reduce_val[0] / mean_dim; - if (means != nullptr) { - means[blockIdx.x] = s_mean; - } - s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; - vars[blockIdx.x] = s_var; - s_var = rsqrtf(s_var); - } - __syncthreads(); - - // step 2. layer norm result - float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size; - for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - // load scale, bias, input - float4 scale_f4 = __ldg((const float4 *)scale + idx); - __half2 *scale_h2 = (__half2 *)(&scale_f4); - float4 bias_f4 = __ldg((const float4 *)bias + idx); - __half2 *bias_h2 = (__half2 *)(&bias_f4); - float4 val_f4 = inp_f4[idx]; - __half2 *val_h2 = (__half2 *)(&val_f4); - -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 scale_f2 = __half22float2(scale_h2[i]); - float2 bias_f2 = __half22float2(bias_h2[i]); - float2 val_f2 = __half22float2(val_h2[i]); - val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; - val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; - val_h2[i] = __float22half2_rn(val_f2); - } - output_f4[idx] = val_f4; - } -} - -// __global__ void ker_layer_norm_x2(__half *ln_res, __half *vars, -// __half *means, const __half *inp, -// const __half *scale, const __half -// *bias, int hidden_size) { -// // step 0. compute local sum -// float l_sum = 0; -// float l_square_sum = 0; -// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * 2 * hidden_size; -// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * -// 2) { -// float4 val_f4 = inp_f4[idx]; -// float4 val_f4_1 = inp_f4[idx+1]; -// __half2 *val_h2 = (__half2 *)(&val_f4); -// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); -// #pragma unroll -// for (int i = 0; i < 4; i++) { -// float2 val_f2 = __half22float2(val_h2[i]); -// float2 val_f2_1 = __half22float2(val_h2_1[i]); -// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y; -// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y + val_f2_1.x -// * val_f2_1.x + val_f2_1.y * val_f2_1.y; -// } -// } - -// // step 1. compute reduce sum -// float mean_dim = float(hidden_size) * 8.f * 2; -// float reduce_val[2] = {l_sum, l_square_sum}; -// blockReduce(reduce_val); -// __shared__ float s_mean, s_var; -// if (threadIdx.x == 0) { -// s_mean = reduce_val[0] / mean_dim; -// if (means != nullptr) { -// means[blockIdx.x] = s_mean; -// } -// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; -// vars[blockIdx.x] = s_var; -// s_var = rsqrtf(s_var); -// } -// __syncthreads(); - -// // step 2. layer norm result -// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 2; -// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * -// 2) { -// // load scale, bias, input -// float4 scale_f4 = __ldg((const float4 *)scale + idx); -// __half2 *scale_h2 = (__half2 *)(&scale_f4); -// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1); -// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1); -// float4 bias_f4 = __ldg((const float4 *)bias + idx); -// __half2 *bias_h2 = (__half2 *)(&bias_f4); -// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1); -// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1); -// float4 val_f4 = inp_f4[idx]; -// __half2 *val_h2 = (__half2 *)(&val_f4); -// float4 val_f4_1 = inp_f4[idx+1]; -// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); - -// #pragma unroll -// for (int i = 0; i < 4; i++) { -// float2 scale_f2 = __half22float2(scale_h2[i]); -// float2 scale_f2_1 = __half22float2(scale_h2_1[i]); -// float2 bias_f2 = __half22float2(bias_h2[i]); -// float2 bias_f2_1 = __half22float2(bias_h2_1[i]); -// float2 val_f2 = __half22float2(val_h2[i]); -// float2 val_f2_1 = __half22float2(val_h2_1[i]); -// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; -// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; -// val_h2[i] = __float22half2_rn(val_f2); -// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + -// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y -// + bias_f2_1.y; val_h2_1[i] = __float22half2_rn(val_f2_1); -// } -// output_f4[idx] = val_f4; -// output_f4[idx+1] = val_f4_1; -// } -// } - -// __global__ void ker_layer_norm_x4(__half *ln_res, __half *vars, -// __half *means, const __half *inp, -// const __half *scale, const __half -// *bias, int hidden_size) { -// // step 0. compute local sum -// float l_sum = 0; -// float l_square_sum = 0; -// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size * 4; -// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * -// 4) { -// float4 val_f4 = inp_f4[idx]; -// float4 val_f4_1 = inp_f4[idx+1]; -// float4 val_f4_2 = inp_f4[idx+2]; -// float4 val_f4_3 = inp_f4[idx+3]; -// __half2 *val_h2 = (__half2 *)(&val_f4); -// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); -// __half2 *val_h2_2 = (__half2 *)(&val_f4_2); -// __half2 *val_h2_3 = (__half2 *)(&val_f4_3); -// #pragma unroll -// for (int i = 0; i < 4; i++) { -// float2 val_f2 = __half22float2(val_h2[i]); -// float2 val_f2_1 = __half22float2(val_h2_1[i]); -// float2 val_f2_2 = __half22float2(val_h2_2[i]); -// float2 val_f2_3 = __half22float2(val_h2_3[i]); -// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y + val_f2_2.x + -// val_f2_2.y + val_f2_3.x + val_f2_3.y; l_square_sum += val_f2.x * -// val_f2.x + val_f2.y * val_f2.y; l_square_sum += val_f2_1.x * val_f2_1.x -// + val_f2_1.y * val_f2_1.y; l_square_sum += val_f2_2.x * val_f2_2.x + -// val_f2_2.y * val_f2_2.y; l_square_sum += val_f2_3.x * val_f2_3.x + -// val_f2_3.y * val_f2_3.y; -// } -// } - -// // step 1. compute reduce sum -// float mean_dim = float(hidden_size) * 8.f * 4; -// float reduce_val[2] = {l_sum, l_square_sum}; -// blockReduce(reduce_val); -// __shared__ float s_mean, s_var; -// if (threadIdx.x == 0) { -// s_mean = reduce_val[0] / mean_dim; -// if (means != nullptr) { -// means[blockIdx.x] = s_mean; -// } -// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; -// vars[blockIdx.x] = s_var; -// s_var = rsqrtf(s_var); -// } -// __syncthreads(); - -// // step 2. layer norm result -// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 4; -// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * -// 4) { -// // load scale, bias, input -// float4 scale_f4 = __ldg((const float4 *)scale + idx); -// __half2 *scale_h2 = (__half2 *)(&scale_f4); -// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1); -// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1); -// float4 scale_f4_2 = __ldg((const float4 *)scale + idx + 2); -// __half2 *scale_h2_2 = (__half2 *)(&scale_f4_2); -// float4 scale_f4_3 = __ldg((const float4 *)scale + idx + 3); -// __half2 *scale_h2_3 = (__half2 *)(&scale_f4_3); -// float4 bias_f4 = __ldg((const float4 *)bias + idx); -// __half2 *bias_h2 = (__half2 *)(&bias_f4); -// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1); -// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1); -// float4 bias_f4_2 = __ldg((const float4 *)bias + idx + 2); -// __half2 *bias_h2_2 = (__half2 *)(&bias_f4_2); -// float4 bias_f4_3 = __ldg((const float4 *)bias + idx + 3); -// __half2 *bias_h2_3 = (__half2 *)(&bias_f4_3); -// float4 val_f4 = inp_f4[idx]; -// __half2 *val_h2 = (__half2 *)(&val_f4); -// float4 val_f4_1 = inp_f4[idx+1]; -// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); -// float4 val_f4_2 = inp_f4[idx+2]; -// __half2 *val_h2_2 = (__half2 *)(&val_f4_2); -// float4 val_f4_3 = inp_f4[idx+3]; -// __half2 *val_h2_3 = (__half2 *)(&val_f4_3); - -// #pragma unroll -// for (int i = 0; i < 4; i++) { -// float2 scale_f2 = __half22float2(scale_h2[i]); -// float2 scale_f2_1 = __half22float2(scale_h2_1[i]); -// float2 scale_f2_2 = __half22float2(scale_h2_2[i]); -// float2 scale_f2_3 = __half22float2(scale_h2_3[i]); -// float2 bias_f2 = __half22float2(bias_h2[i]); -// float2 bias_f2_1 = __half22float2(bias_h2_1[i]); -// float2 bias_f2_2 = __half22float2(bias_h2_2[i]); -// float2 bias_f2_3 = __half22float2(bias_h2_3[i]); -// float2 val_f2 = __half22float2(val_h2[i]); -// float2 val_f2_1 = __half22float2(val_h2_1[i]); -// float2 val_f2_2 = __half22float2(val_h2_2[i]); -// float2 val_f2_3 = __half22float2(val_h2_3[i]); -// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; -// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; -// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + -// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y -// + bias_f2_1.y; val_f2_2.x = (val_f2_2.x - s_mean) * s_var * -// scale_f2_2.x + bias_f2_2.x; val_f2_2.y = (val_f2_2.y - s_mean) * s_var -// * scale_f2_2.y + bias_f2_2.y; val_f2_3.x = (val_f2_3.x - s_mean) * -// s_var * scale_f2_3.x + bias_f2_3.x; val_f2_3.y = (val_f2_3.y - s_mean) -// * s_var * scale_f2_3.y + bias_f2_3.y; val_h2[i] = -// __float22half2_rn(val_f2); val_h2_1[i] = __float22half2_rn(val_f2_1); -// val_h2_2[i] = __float22half2_rn(val_f2_2); -// val_h2_3[i] = __float22half2_rn(val_f2_3); -// } -// output_f4[idx] = val_f4; -// output_f4[idx+1] = val_f4_1; -// output_f4[idx+2] = val_f4_2; -// output_f4[idx+3] = val_f4_3; -// } -// } - -template <> -void launch_layer_norm(float *ln_res, float *vars, float *means, - const float *inp, const float *scale, - const float *bias, int batch_size, int hidden_dim, - cudaStream_t stream) { - if (hidden_dim % 4 != 0) { - throw std::runtime_error("violate hidden_dim % 4 = 0"); - } - hidden_dim >>= 2; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - dim3 grid_dim(batch_size); - dim3 block_dim(nthread); - - ker_layer_norm<<>>( - ln_res, vars, means, inp, scale, bias, hidden_dim); -} - -template <> -void launch_layer_norm<__half>(__half *ln_res, __half *vars, __half *means, - const __half *inp, const __half *scale, - const __half *bias, int batch_size, - int hidden_dim, cudaStream_t stream) { - if (hidden_dim % 8 != 0) { - throw std::runtime_error("violate hidden_dim % 8 = 0"); - } - hidden_dim >>= 3; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - dim3 grid_dim(batch_size); - dim3 block_dim(nthread); - - ker_layer_norm<__half><<>>( - ln_res, vars, means, inp, scale, bias, hidden_dim); - // if (hidden_dim % 8 != 0) { - // throw std::runtime_error("violate hidden_dim % 8 = 0"); - // } - // hidden_dim >>= 3; - - // if (hidden_dim * 8 < 8192) { - // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - // dim3 grid_dim(batch_size); - // dim3 block_dim(nthread); - // ker_layer_norm<__half><<>>( - // ln_res, vars, means, inp, scale, bias, hidden_dim); - // } else if (hidden_dim * 8 >= 8192 && hidden_dim * 8 <= 8192 * 2) { - // hidden_dim >>= 1; - // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - // dim3 grid_dim(batch_size); - // dim3 block_dim(nthread); - // ker_layer_norm_x2<<>>( - // ln_res, vars, means, inp, scale, bias, hidden_dim); - // } else if (hidden_dim * 8 > 8192 * 2 && hidden_dim * 8 <= 8192 * 4) { - // hidden_dim >>= 2; - // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - // dim3 grid_dim(batch_size); - // dim3 block_dim(nthread); - // ker_layer_norm_x4<<>>( - // ln_res, vars, means, inp, scale, bias, hidden_dim); - // } else { - // throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768"); - // } -} - -/** -@brief: ker_ln_bw_dgamma_dbetta -Layer norm backword kernel, compute the gradient of gamma and betta. -dbetta = sum(dout, dim=0) -dgamma = sum(xhat * dout, dim=0) -xhat = (input - mean) * rsqrt(var) or - (output - betta) / gamma - - -@thread -gridDim.x = hidden_size / 32 -blockDim.x = 32 -blockDim.y = 32 - -@param -gamma_grad: [hidden_size], gradient of gamma -betta_grad: [hidden_size], gradient of betta -out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output -inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr - ln input if means is not nullptr -gamma: [hidden_size], gamma of ln, - used to compute xhat, maybe nullptr -betta: [hidden_size], betta of ln, - used to compute xhat, maybe nullptr -vars: [batch_size * seq_len], variance of ln forward, - used to compute xhat, maybe nullptr -means: [batch_size * seq_len], mean of ln forward, - used to compute xhat, maybe nullptr -(gamma && betta) ^ (vars && means) should be true -*/ -template -__global__ void ker_ln_bw_dgamma_dbetta(T *gamma_grad, T *betta_grad, - const T *out_grad, const T *inp_or_out, - const T *gamma, const T *betta, - const T *vars, const T *means, int rows, - int width) { - __shared__ float betta_buffer[TILE_DIM][TILE_DIM]; - __shared__ float gamma_buffer[TILE_DIM][TILE_DIM]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int idx = blockDim.x * blockIdx.x + threadIdx.x; - int offset = threadIdx.y * width + idx; - int y_stride = width * TILE_DIM; - - // Loop across inp height - float dbetta = 0; - float dgamma = 0; - float dout, val; - if (idx < width) { - if (means == nullptr) { - float vbetta = (float)betta[idx]; - float vgamma = (float)gamma[idx]; - for (int r = threadIdx.y; r < rows; r += TILE_DIM) { - dout = (float)out_grad[offset]; - // inp_or_out is output - val = (float)inp_or_out[offset]; - dbetta += dout; - dgamma += ((val - vbetta) / add_eps(vgamma) * dout); - offset += y_stride; - } - } else { - for (int r = threadIdx.y; r < rows; r += TILE_DIM) { - dout = (float)out_grad[offset]; - // inp_or_out is input - val = (float)inp_or_out[offset]; - dbetta += dout; - dgamma += ((val - (float)means[r]) * - rsqrtf((float)vars[r] + LN_EPSILON) * dout); - offset += y_stride; - } - } - } - - // Sum the shared buffer. - betta_buffer[threadIdx.x][threadIdx.y] = dbetta; - gamma_buffer[threadIdx.x][threadIdx.y] = dgamma; - __syncthreads(); - float s1 = betta_buffer[threadIdx.y][threadIdx.x]; - float s2 = gamma_buffer[threadIdx.y][threadIdx.x]; - __syncthreads(); - - for (int i = 1; i < TILE_DIM; i <<= 1) { - s1 += g.shfl_down(s1, i); - s2 += g.shfl_down(s2, i); - } - - int pos = blockIdx.x * TILE_DIM + threadIdx.y; - if (threadIdx.x == 0 && idx < width) { - betta_grad[pos] = s1; - gamma_grad[pos] = s2; - } -} - -/** -@brief: ker_ln_bw_dinp -Layer norm backword kernel, compute the gradient of input. -dinp = (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / hidden_dim) - * rsqrt(var) -xhat = (input - mean) * rsqrt(var) if mean is not nullptr - (output - betta) / gamma if mean is nullptr -dxhat = dout * gamma - - -@thread -gridDim.x = batch_size * seq_len -blockDim.x = hidden_size - -@param -inp_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output -out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output -residual_grad: [batch_size * seq_len, hidden_size], gradient of residual input, - usually appear in pre-layer-norm for transformer layer, maybe nullptr -inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr - ln input if means is not nullptr -gamma: [hidden_size], gamma of ln, - used to compute xhat and dxhat -betta: [hidden_size], betta of ln, - used to compute xhat, maybe nullptr -vars: [batch_size * seq_len], variance of ln forward, - used to compute xhat and dinp -means: [batch_size * seq_len], mean of ln forward, - used to compute xhat, maybe nullptr -*/ -template -__global__ void ker_ln_bw_dinp(T *inp_grad, const T *out_grad, - const T *residual_grad, const T *inp_or_out, - const T *gamma, const T *betta, const T *vars, - const T *means, int hidden_dim) { - int offset = blockIdx.x * hidden_dim + threadIdx.x; - float4 dxhat, xhat; - float var_rsqrt; - - if (threadIdx.x < hidden_dim) { - // step 0. dxhat = dout * gamma - dxhat = ((const float4 *)out_grad)[offset]; - float4 vgamma = ((const float4 *)gamma)[threadIdx.x]; - dxhat.x *= vgamma.x; - dxhat.y *= vgamma.y; - dxhat.z *= vgamma.z; - dxhat.w *= vgamma.w; - - /* - step 1. xhat = (output - betta) / gamma or - (input - mean) * rsqrtf(var) - */ - xhat = ((const float4 *)inp_or_out)[offset]; - var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); - if (means == nullptr) { - // inp_or_out is output, xhat = (output - betta) / gamma - float4 vbetta = ((const float4 *)betta)[threadIdx.x]; - xhat.x = (xhat.x - vbetta.x) / add_eps(vgamma.x); - xhat.y = (xhat.y - vbetta.y) / add_eps(vgamma.y); - xhat.z = (xhat.z - vbetta.z) / add_eps(vgamma.z); - xhat.w = (xhat.w - vbetta.w) / add_eps(vgamma.w); - } else { - // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) - float fmean = (float)means[blockIdx.x]; - xhat.x = (xhat.x - fmean) * var_rsqrt; - xhat.y = (xhat.y - fmean) * var_rsqrt; - xhat.z = (xhat.z - fmean) * var_rsqrt; - xhat.w = (xhat.w - fmean) * var_rsqrt; - } - } - - /* step2. block reduce sum for dxhat and dxhat*xhat */ - float reduce_val[2] = {0.f, 0.f}; - if (threadIdx.x < hidden_dim) { - reduce_val[0] = dxhat.x + dxhat.y + dxhat.z + dxhat.w; - reduce_val[1] = dxhat.x * xhat.x + dxhat.y * xhat.y + dxhat.z * xhat.z + - dxhat.w * xhat.w; - } - blockReduce(reduce_val); - __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; - if (threadIdx.x == 0) { - float mean_dim = hidden_dim * 4; - s_sum_dxhat = reduce_val[0] / mean_dim; - s_sum_dxhat_xhat = reduce_val[1] / mean_dim; - } - __syncthreads(); - - /* - step3. compute input gradient - (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) - */ - if (threadIdx.x >= hidden_dim) { - return; - } - dxhat.x = (dxhat.x - s_sum_dxhat - xhat.x * s_sum_dxhat_xhat) * var_rsqrt; - dxhat.y = (dxhat.y - s_sum_dxhat - xhat.y * s_sum_dxhat_xhat) * var_rsqrt; - dxhat.z = (dxhat.z - s_sum_dxhat - xhat.z * s_sum_dxhat_xhat) * var_rsqrt; - dxhat.w = (dxhat.w - s_sum_dxhat - xhat.w * s_sum_dxhat_xhat) * var_rsqrt; - if (residual_grad) { - // Add the residual grad, - // usually in pre-layer-norm for transformer layer - float4 dresidual = ((const float4 *)residual_grad)[offset]; - dxhat.x += dresidual.x; - dxhat.y += dresidual.y; - dxhat.z += dresidual.z; - dxhat.w += dresidual.w; - } - ((float4 *)inp_grad)[offset] = dxhat; -} - -template <> -__global__ void ker_ln_bw_dinp<__half>(__half *inp_grad, const __half *out_grad, - const __half *residual_grad, - const __half *inp_or_out, - const __half *gamma, const __half *betta, - const __half *vars, const __half *means, - int hidden_dim) { - int offset = blockIdx.x * hidden_dim + threadIdx.x; - - float2 dxhat[4], xhat[4]; - float var_rsqrt; - float4 vtmp; - __half2 *tmp_h2; - float reduce_val[2] = {0.f, 0.f}; - - if (threadIdx.x < hidden_dim) { - // step 0. dxhat = dout * gamma - vtmp = ((const float4 *)out_grad)[offset]; - tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); - float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x]; - __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vdout = __half22float2(tmp_h2[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - dxhat[i].x = vdout.x * vgamma.x; - dxhat[i].y = vdout.y * vgamma.y; - reduce_val[0] += dxhat[i].x + dxhat[i].y; - } - - /* - step 1. xhat = (output - betta) / gamma or - (input - mean) * rsqrtf(var) - */ - vtmp = ((const float4 *)inp_or_out)[offset]; - var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); - if (means == nullptr) { - // inp_or_out is output, xhat = (output - betta) / gamma - float4 vbetta = ((const float4 *)betta)[threadIdx.x]; - __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vout = __half22float2(tmp_h2[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vbetta = __half22float2(betta_h2[i]); - xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); - xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - } - } else { - // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) - float fmean = (float)means[blockIdx.x]; -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vinp = __half22float2(tmp_h2[i]); - xhat[i].x = (vinp.x - fmean) * var_rsqrt; - xhat[i].y = (vinp.y - fmean) * var_rsqrt; - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - } - } - } - - /* step2. block reduce sum for dxhat and dxhat*xhat */ - blockReduce(reduce_val); - __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; - if (threadIdx.x == 0) { - float mean_dim = hidden_dim * 8; - s_sum_dxhat = reduce_val[0] / mean_dim; - s_sum_dxhat_xhat = reduce_val[1] / mean_dim; - } - __syncthreads(); - - /* - step3. compute input gradient - (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) - */ - if (threadIdx.x >= hidden_dim) { - return; - } - if (residual_grad) { - // Add the residual grad, - // usually in pre-layer-norm for transformer layer - float4 dresidual = ((const float4 *)residual_grad)[offset]; - __half *hdres = reinterpret_cast<__half *>(&dresidual); -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i])); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i + 1])); - } - } else { -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - } - } - ((float4 *)inp_grad)[offset] = vtmp; -} - -__global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad, - const __half *residual_grad, - const __half *inp_or_out, const __half *gamma, - const __half *betta, const __half *vars, - const __half *means, int hidden_dim) { - int offset = blockIdx.x * hidden_dim * 2 + threadIdx.x * 2; - - float2 dxhat[4], xhat[4]; - float2 dxhat_1[4], xhat_1[4]; - float var_rsqrt; - float4 vtmp, vtmp_1; - __half2 *tmp_h2; - __half2 *tmp_h2_1; - float reduce_val[2] = {0.f, 0.f}; - - if (threadIdx.x < hidden_dim) { - // step 0. dxhat = dout * gamma - vtmp = ((const float4 *)out_grad)[offset]; - vtmp_1 = ((const float4 *)out_grad)[offset + 1]; - tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); - tmp_h2_1 = reinterpret_cast<__half2 *>(&vtmp_1); - float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x * 2]; - float4 gamma_f4_1 = ((const float4 *)gamma)[threadIdx.x * 2 + 1]; - __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); - __half2 *gamma_h2_1 = reinterpret_cast<__half2 *>(&gamma_f4_1); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vdout = __half22float2(tmp_h2[i]); - float2 vdout_1 = __half22float2(tmp_h2_1[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vgamma_1 = __half22float2(gamma_h2_1[i]); - dxhat[i].x = vdout.x * vgamma.x; - dxhat[i].y = vdout.y * vgamma.y; - dxhat_1[i].x = vdout_1.x * vgamma_1.x; - dxhat_1[i].y = vdout_1.y * vgamma_1.y; - reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y; - } - - /* - step 1. xhat = (output - betta) / gamma or - (input - mean) * rsqrtf(var) - */ - vtmp = ((const float4 *)inp_or_out)[offset]; - vtmp_1 = ((const float4 *)inp_or_out)[offset + 1]; - var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); - if (means == nullptr) { - // inp_or_out is output, xhat = (output - betta) / gamma - float4 vbetta = ((const float4 *)betta)[2 * threadIdx.x]; - float4 vbetta_1 = ((const float4 *)betta)[2 * threadIdx.x + 1]; - __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); - __half2 *betta_h2_1 = reinterpret_cast<__half2 *>(&vbetta_1); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vout = __half22float2(tmp_h2[i]); - float2 vout_1 = __half22float2(tmp_h2_1[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vgamma_1 = __half22float2(gamma_h2_1[i]); - float2 vbetta = __half22float2(betta_h2[i]); - float2 vbetta_1 = __half22float2(betta_h2_1[i]); - xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); - xhat_1[i].x = (vout_1.x - vbetta_1.x) / add_eps(vgamma_1.x); - xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); - xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y); - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += - xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; - } - } else { - // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) - float fmean = (float)means[blockIdx.x]; -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vinp = __half22float2(tmp_h2[i]); - float2 vinp_1 = __half22float2(tmp_h2_1[i]); - xhat[i].x = (vinp.x - fmean) * var_rsqrt; - xhat_1[i].x = (vinp_1.x - fmean) * var_rsqrt; - xhat[i].y = (vinp.y - fmean) * var_rsqrt; - xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt; - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += - xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; - } - } - } - - /* step2. block reduce sum for dxhat and dxhat*xhat */ - blockReduce(reduce_val); - __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; - if (threadIdx.x == 0) { - float mean_dim = hidden_dim * 8 * 2; - s_sum_dxhat = reduce_val[0] / mean_dim; - s_sum_dxhat_xhat = reduce_val[1] / mean_dim; - } - __syncthreads(); - - /* - step3. compute input gradient - (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) - */ - if (threadIdx.x >= hidden_dim) { - return; - } - if (residual_grad) { - // Add the residual grad, - // usually in pre-layer-norm for transformer layer - float4 dresidual = ((const float4 *)residual_grad)[offset]; - float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1]; - __half *hdres = reinterpret_cast<__half *>(&dresidual); - __half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1); -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i])); - tmp_h2_1[i].x = __float2half( - (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i])); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i + 1])); - tmp_h2_1[i].y = __float2half( - (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i + 1])); - } - } else { -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_1[i].x = __float2half( - (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_1[i].y = __float2half( - (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - } - } - ((float4 *)inp_grad)[offset] = vtmp; - ((float4 *)inp_grad)[offset + 1] = vtmp_1; -} - -__global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad, - const __half *residual_grad, - const __half *inp_or_out, const __half *gamma, - const __half *betta, const __half *vars, - const __half *means, int hidden_dim) { - int offset = blockIdx.x * hidden_dim * 4 + threadIdx.x * 4; - - float2 dxhat[4], xhat[4]; - float2 dxhat_1[4], xhat_1[4]; - float2 dxhat_2[4], xhat_2[4]; - float2 dxhat_3[4], xhat_3[4]; - float var_rsqrt; - float4 vtmp, vtmp_1, vtmp_2, vtmp_3; - __half2 *tmp_h2; - __half2 *tmp_h2_1; - __half2 *tmp_h2_2; - __half2 *tmp_h2_3; - float reduce_val[2] = {0.f, 0.f}; - - if (threadIdx.x < hidden_dim) { - // step 0. dxhat = dout * gamma - vtmp = ((const float4 *)out_grad)[offset]; - vtmp_1 = ((const float4 *)out_grad)[offset + 1]; - vtmp_2 = ((const float4 *)out_grad)[offset + 2]; - vtmp_3 = ((const float4 *)out_grad)[offset + 3]; - tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); - tmp_h2_1 = reinterpret_cast<__half2 *>(&vtmp_1); - tmp_h2_2 = reinterpret_cast<__half2 *>(&vtmp_2); - tmp_h2_3 = reinterpret_cast<__half2 *>(&vtmp_3); - float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x * 4]; - float4 gamma_f4_1 = ((const float4 *)gamma)[threadIdx.x * 4 + 1]; - float4 gamma_f4_2 = ((const float4 *)gamma)[threadIdx.x * 4 + 2]; - float4 gamma_f4_3 = ((const float4 *)gamma)[threadIdx.x * 4 + 3]; - __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); - __half2 *gamma_h2_1 = reinterpret_cast<__half2 *>(&gamma_f4_1); - __half2 *gamma_h2_2 = reinterpret_cast<__half2 *>(&gamma_f4_2); - __half2 *gamma_h2_3 = reinterpret_cast<__half2 *>(&gamma_f4_3); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vdout = __half22float2(tmp_h2[i]); - float2 vdout_1 = __half22float2(tmp_h2_1[i]); - float2 vdout_2 = __half22float2(tmp_h2_2[i]); - float2 vdout_3 = __half22float2(tmp_h2_3[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vgamma_1 = __half22float2(gamma_h2_1[i]); - float2 vgamma_2 = __half22float2(gamma_h2_2[i]); - float2 vgamma_3 = __half22float2(gamma_h2_3[i]); - dxhat[i].x = vdout.x * vgamma.x; - dxhat[i].y = vdout.y * vgamma.y; - dxhat_1[i].x = vdout_1.x * vgamma_1.x; - dxhat_1[i].y = vdout_1.y * vgamma_1.y; - dxhat_2[i].x = vdout_2.x * vgamma_2.x; - dxhat_2[i].y = vdout_2.y * vgamma_2.y; - dxhat_3[i].x = vdout_3.x * vgamma_3.x; - dxhat_3[i].y = vdout_3.y * vgamma_3.y; - reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y + - dxhat_2[i].x + dxhat_2[i].y + dxhat_3[i].x + - dxhat_3[i].y; - } - - /* - step 1. xhat = (output - betta) / gamma or - (input - mean) * rsqrtf(var) - */ - vtmp = ((const float4 *)inp_or_out)[offset]; - vtmp_1 = ((const float4 *)inp_or_out)[offset + 1]; - vtmp_2 = ((const float4 *)inp_or_out)[offset + 2]; - vtmp_3 = ((const float4 *)inp_or_out)[offset + 3]; - var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); - if (means == nullptr) { - // inp_or_out is output, xhat = (output - betta) / gamma - float4 vbetta = ((const float4 *)betta)[4 * threadIdx.x]; - float4 vbetta_1 = ((const float4 *)betta)[4 * threadIdx.x + 1]; - float4 vbetta_2 = ((const float4 *)betta)[4 * threadIdx.x + 2]; - float4 vbetta_3 = ((const float4 *)betta)[4 * threadIdx.x + 3]; - __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); - __half2 *betta_h2_1 = reinterpret_cast<__half2 *>(&vbetta_1); - __half2 *betta_h2_2 = reinterpret_cast<__half2 *>(&vbetta_2); - __half2 *betta_h2_3 = reinterpret_cast<__half2 *>(&vbetta_3); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vout = __half22float2(tmp_h2[i]); - float2 vout_1 = __half22float2(tmp_h2_1[i]); - float2 vout_2 = __half22float2(tmp_h2_2[i]); - float2 vout_3 = __half22float2(tmp_h2_3[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vgamma_1 = __half22float2(gamma_h2_1[i]); - float2 vgamma_2 = __half22float2(gamma_h2_2[i]); - float2 vgamma_3 = __half22float2(gamma_h2_3[i]); - float2 vbetta = __half22float2(betta_h2[i]); - float2 vbetta_1 = __half22float2(betta_h2_1[i]); - float2 vbetta_2 = __half22float2(betta_h2_2[i]); - float2 vbetta_3 = __half22float2(betta_h2_3[i]); - xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); - xhat_1[i].x = (vout_1.x - vbetta_1.x) / add_eps(vgamma_1.x); - xhat_2[i].x = (vout_2.x - vbetta_2.x) / add_eps(vgamma_2.x); - xhat_3[i].x = (vout_3.x - vbetta_3.x) / add_eps(vgamma_3.x); - xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); - xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y); - xhat_2[i].y = (vout_2.y - vbetta_2.y) / add_eps(vgamma_2.y); - xhat_3[i].y = (vout_3.y - vbetta_3.y) / add_eps(vgamma_3.y); - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += - xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; - reduce_val[1] += - xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; - reduce_val[1] += - xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; - } - } else { - // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) - float fmean = (float)means[blockIdx.x]; -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vinp = __half22float2(tmp_h2[i]); - float2 vinp_1 = __half22float2(tmp_h2_1[i]); - float2 vinp_2 = __half22float2(tmp_h2_2[i]); - float2 vinp_3 = __half22float2(tmp_h2_3[i]); - xhat[i].x = (vinp.x - fmean) * var_rsqrt; - xhat_1[i].x = (vinp_1.x - fmean) * var_rsqrt; - xhat_2[i].x = (vinp_2.x - fmean) * var_rsqrt; - xhat_3[i].x = (vinp_3.x - fmean) * var_rsqrt; - xhat[i].y = (vinp.y - fmean) * var_rsqrt; - xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt; - xhat_2[i].y = (vinp_2.y - fmean) * var_rsqrt; - xhat_3[i].y = (vinp_3.y - fmean) * var_rsqrt; - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += - xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; - reduce_val[1] += - xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; - reduce_val[1] += - xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; - } - } - } - - /* step2. block reduce sum for dxhat and dxhat*xhat */ - blockReduce(reduce_val); - __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; - if (threadIdx.x == 0) { - float mean_dim = hidden_dim * 8 * 4; - s_sum_dxhat = reduce_val[0] / mean_dim; - s_sum_dxhat_xhat = reduce_val[1] / mean_dim; - } - __syncthreads(); - - /* - step3. compute input gradient - (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) - */ - if (threadIdx.x >= hidden_dim) { - return; - } - if (residual_grad) { - // Add the residual grad, - // usually in pre-layer-norm for transformer layer - float4 dresidual = ((const float4 *)residual_grad)[offset]; - float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1]; - float4 dresidual_2 = ((const float4 *)residual_grad)[offset + 2]; - float4 dresidual_3 = ((const float4 *)residual_grad)[offset + 3]; - __half *hdres = reinterpret_cast<__half *>(&dresidual); - __half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1); - __half *hdres_2 = reinterpret_cast<__half *>(&dresidual_2); - __half *hdres_3 = reinterpret_cast<__half *>(&dresidual_3); -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i])); - tmp_h2_1[i].x = __float2half( - (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i])); - tmp_h2_2[i].x = __float2half( - (dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_2[2 * i])); - tmp_h2_3[i].x = __float2half( - (dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_3[2 * i])); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i + 1])); - tmp_h2_1[i].y = __float2half( - (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i + 1])); - tmp_h2_2[i].y = __float2half( - (dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i + 1])); - tmp_h2_3[i].y = __float2half( - (dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i + 1])); - } - } else { -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_1[i].x = __float2half( - (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_2[i].x = __float2half( - (dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_3[i].x = __float2half( - (dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_1[i].y = __float2half( - (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_2[i].y = __float2half( - (dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_3[i].y = __float2half( - (dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - } - } - ((float4 *)inp_grad)[offset] = vtmp; - ((float4 *)inp_grad)[offset + 1] = vtmp_1; - ((float4 *)inp_grad)[offset + 2] = vtmp_2; - ((float4 *)inp_grad)[offset + 3] = vtmp_3; -} - -/** -Layer norm backword, - compute the gradient of gamma, betta and input. -dbetta = sum(dout, dim=0) -xhat = (input - mean) * rsqrt(var) if mean is not nullptr - (output - betta) / gamma if mean is nullptr -dgamma = sum(xhat * dout, dim=0) -dxhat = dout * gamma -dinp = (dxhat - (sum(dxhat, 1) + xhat * sum(dxhat * xhat, 1)) / hidden_dim) - * rsqrt(var) - -residual_grad, means, betta can be nullptr. -residual_grad will be added to dinp if it is not nullptr - which is useful in transformer layer when pre-ln -means and betta are only used to compute xhat, - (means == nullptr) ^ (betta == nullptr) should be true -*/ -template <> -void launch_ln_bw(float *gamma_grad, float *betta_grad, float *inp_grad, - const float *out_grad, const float *residual_grad, - const float *inp_or_out, const float *gamma, - const float *betta, const float *vars, - const float *means, int batch, int hidden_dim, - cudaStream_t stream[2]) { - // compute grad of gamma and betta - dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM); - dim3 block_dim(TILE_DIM, TILE_DIM); - ker_ln_bw_dgamma_dbetta<<>>( - gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means, - batch, hidden_dim); - - // compute grad of input - if (hidden_dim % 4 != 0 || hidden_dim > 4096) { - throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 4096"); - } - hidden_dim >>= 2; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - ker_ln_bw_dinp<<>>( - inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means, - hidden_dim); -} - -template <> -void launch_ln_bw<__half>(__half *gamma_grad, __half *betta_grad, - __half *inp_grad, const __half *out_grad, - const __half *residual_grad, const __half *inp_or_out, - const __half *gamma, const __half *betta, - const __half *vars, const __half *means, int batch, - int hidden_dim, cudaStream_t stream[2]) { - // compute grad of gamma and betta - dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM); - dim3 block_dim(TILE_DIM, TILE_DIM); - ker_ln_bw_dgamma_dbetta<__half><<>>( - gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means, - batch, hidden_dim); - - // compute grad of input - if (hidden_dim % 8 != 0) { - throw std::runtime_error("hidden_dim % 8 != 0"); - } - hidden_dim >>= 3; - - if (hidden_dim * 8 <= 8192) { - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - ker_ln_bw_dinp<<>>( - inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, - means, hidden_dim); - } else if (hidden_dim * 8 > 8192 && hidden_dim * 8 <= 8192 * 2) { - hidden_dim >>= 1; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - ker_ln_bw_dinp_x2<<>>( - inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, - means, hidden_dim); - } else if (hidden_dim * 8 > 2 * 8192 && hidden_dim * 8 <= 8192 * 4) { - hidden_dim >>= 2; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - ker_ln_bw_dinp_x4<<>>( - inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, - means, hidden_dim); - } else { - throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768"); - } -} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu deleted file mode 100644 index 3862a699d..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu +++ /dev/null @@ -1,365 +0,0 @@ -#include -#include - -#include -#include - -#include "block_reduce.h" -#include "kernels.h" - -namespace cg = cooperative_groups; -const float EPSILON = 1e-8f; - -/** -@brief: softmax_kernel -Softmax forward kernel for - enc-self-attn, dec-self-attn, encdec-attn - -@thread -gridDim.x = dynamic -gridDim.y = batch_size -gridDim.z = nhead -blockDim.x = from_len - -@param -inp: [batch_size, nhead, from_len, to_len], softmax input. -attn_mask: [batch_size, to_len], padding tokens are -inf, - non padding tokens are 0. - attn_mask!=nullptr for enc-self-attn and enc-dec-attn - attn_mask=nullptr and mask_future=ture for dec-self-attn training - attn_mask=nullptr and mask_future=false for dec-self-attn infer -*/ -template -__global__ void ker_attn_softmax(T *inp, const T *attn_mask, int from_len, - int to_len, bool mask_future) { - int batch_id = blockIdx.y; - int head_id = blockIdx.z; - const int nhead = gridDim.z; - const int token_per_reduce = 1; - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - T mval[ele_per_thread]; - if (attn_mask) { - attn_mask += batch_id * to_len; - BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG); - } - - inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len); - for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len; - token_id += gridDim.x * token_per_reduce) { - T inp_val[token_per_reduce][ele_per_thread]; - for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { - BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len, - REDUCE_FLOAT_INF_NEG); - } - - /* step 1. compute max */ - // thread local max - float val[token_per_reduce][ele_per_thread]; - float l_max[token_per_reduce]; - for (int i = 0; i < token_per_reduce; i++) { - l_max[i] = REDUCE_FLOAT_INF_NEG; - for (int j = 0; j < ele_per_thread; j++) { - if (attn_mask) { - val[i][j] = (float)inp_val[i][j] + (float)mval[j]; - } else { - if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) { - val[i][j] = REDUCE_FLOAT_INF_NEG; - } else { - val[i][j] = (float)inp_val[i][j]; - } - } - l_max[i] = fmaxf(l_max[i], val[i][j]); - } - } - // block reduce max - blockReduce(l_max); - // write shared - __shared__ float s_max[token_per_reduce]; - if (threadIdx.x == 0) { - for (int i = 0; i < token_per_reduce; i++) { - s_max[i] = l_max[i]; - } - } - __syncthreads(); - - /* step 2. compute sum */ - // thread local sum - float l_sum[token_per_reduce]; - for (int i = 0; i < token_per_reduce; i++) { - l_sum[i] = 0.f; - for (int j = 0; j < ele_per_thread; j++) { - val[i][j] = __expf(val[i][j] - s_max[i]); - l_sum[i] += val[i][j]; - } - } - // block reduce sum - blockReduce(l_sum); - // write shared - __shared__ float s_sum[token_per_reduce]; - if (threadIdx.x == 0) { - for (int i = 0; i < token_per_reduce; i++) { - s_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON); - } - } - __syncthreads(); - - /* step 3. compute final result */ - for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { - for (int j = 0; j < ele_per_thread; j++) { - inp_val[i][j] = (T)(val[i][j] * s_sum[i]); - } - BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], - to_len); - } - } // blockIdx.x -} - -template -__global__ void ker_attn_softmax_lt32(T *inp, const T *attn_mask, int from_len, - int to_len, bool mask_future) { - int batch_id = blockIdx.y; - int head_id = blockIdx.z; - const int nhead = gridDim.z; - const int token_per_reduce = 1; - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - T mval[ele_per_thread]; - if (attn_mask) { - attn_mask += batch_id * to_len; - BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG); - } - - inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len); - for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len; - token_id += gridDim.x * token_per_reduce) { - T inp_val[token_per_reduce][ele_per_thread]; - for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { - BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len, - REDUCE_FLOAT_INF_NEG); - } - - /* step 1. compute max */ - // thread local max - float val[token_per_reduce][ele_per_thread]; - float l_max[token_per_reduce]; - for (int i = 0; i < token_per_reduce; i++) { - l_max[i] = REDUCE_FLOAT_INF_NEG; - for (int j = 0; j < ele_per_thread; j++) { - if (attn_mask) { - val[i][j] = (float)inp_val[i][j] + (float)mval[j]; - } else { - if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) { - val[i][j] = REDUCE_FLOAT_INF_NEG; - } else { - val[i][j] = (float)inp_val[i][j]; - } - } - l_max[i] = fmaxf(l_max[i], val[i][j]); - } - } - // warp reduce max - warpReduce(l_max); - - /* step 2. compute sum */ - // thread local sum - float l_sum[token_per_reduce]; - for (int i = 0; i < token_per_reduce; i++) { - l_sum[i] = 0.f; - for (int j = 0; j < ele_per_thread; j++) { - val[i][j] = __expf(val[i][j] - l_max[i]); - l_sum[i] += val[i][j]; - } - } - // warp reduce sum - warpReduce(l_sum); - - /* step 3. compute final result */ - for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { - l_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON); - for (int j = 0; j < ele_per_thread; j++) { - inp_val[i][j] = (T)(val[i][j] * l_sum[i]); - } - BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], - to_len); - } - } // blockIdx.x -} - -/* - attn_mask!=nullptr for enc-self-attn and enc-dec-attn - attn_mask=nullptr and mask_future=ture for dec-self-attn training - attn_mask=nullptr and mask_future=false for dec-self-attn infer -*/ -template <> -void launch_attn_softmax(float *inp, const float *attn_mask, - int batch_size, int nhead, int from_len, - int to_len, bool mask_future, - cudaStream_t stream) { - dim3 grid_dim(1, batch_size, nhead); - if (to_len <= 32) { - ker_attn_softmax_lt32<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 64) { - ker_attn_softmax_lt32<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 128) { - grid_dim.x = 16; - ker_attn_softmax<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 256) { - grid_dim.x = 32; - ker_attn_softmax<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 512) { - grid_dim.x = 64; - ker_attn_softmax<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else { - throw std::runtime_error( - "Sequence length greater than 512 is currently not supported"); - } -} - -template <> -void launch_attn_softmax<__half>(__half *inp, const __half *attn_mask, - int batch_size, int nhead, int from_len, - int to_len, bool mask_future, - cudaStream_t stream) { - dim3 grid_dim(1, batch_size, nhead); - if (to_len <= 32) { - ker_attn_softmax_lt32<__half, 32, 1><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 64) { - ker_attn_softmax_lt32<__half, 32, 2><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 128) { - grid_dim.x = 8; - ker_attn_softmax<__half, 64, 2><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 256) { - grid_dim.x = 16; - ker_attn_softmax<__half, 128, 2><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 512) { - grid_dim.x = 32; - ker_attn_softmax<__half, 256, 2><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else { - throw std::runtime_error( - "Sequence length greater than 512 is currently not supported"); - } -} - -/** -@brief: ker_attn_softmax_bw -Softmax backward in self attention. - -@thread -gridDim.x = batch_size * nhead * seq_len / warps_per_block -blockDim.x = WARP_SIZE -blockDim.y = warps_per_block - -@param -grad: [batch_size, nhead, seq_len, seq_len], output grad. -output: [batch_size, nhead, seq_len, seq_len], output of softmax forward. -*/ -template -__global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) { - int batch_idx = blockIdx.x * blockDim.y + threadIdx.y; - int offset = batch_idx * softmax_length + threadIdx.x; - - grad += offset; - inp += offset; - - T grad_reg[ITERATIONS]; - T inp_reg[ITERATIONS]; - float sum = 0.0; - -#pragma unroll - for (int i = 0; i < ITERATIONS; ++i) { - int curr_idx = threadIdx.x + i * WARP_SIZE; - if (curr_idx < softmax_length) { - grad_reg[i] = grad[i * WARP_SIZE]; - inp_reg[i] = inp[i * WARP_SIZE]; - sum += (float)grad_reg[i] * (float)inp_reg[i]; - } - } - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i); - -#pragma unroll - for (int i = 0; i < ITERATIONS; ++i) { - int curr_idx = threadIdx.x + i * WARP_SIZE; - if (curr_idx < softmax_length) - grad[i * WARP_SIZE] = (T)((float)inp_reg[i] * ((float)grad_reg[i] - sum)); - } -} - -template -void launch_attn_softmax_bw(T *out_grad, const T *soft_inp, int rows, - int softmax_len, cudaStream_t stream) { - const int warps_per_block = 4; - // rows = batch_size * nhead * from_len - dim3 grid_dim(rows / warps_per_block); - dim3 block_dim(WARP_SIZE, warps_per_block); - - if (softmax_len <= 32) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 64) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 128) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 256) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 384) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 512) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 768) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 1024) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 2048) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else - throw std::runtime_error( - std::string( - "Special sequence length found in softmax backward, seq_len: ") + - std::to_string(softmax_len)); -} - -template void launch_attn_softmax_bw<__half>(__half *out_grad, - const __half *soft_inp, int rows, - int softmax_len, - cudaStream_t stream); -template void launch_attn_softmax_bw(float *out_grad, - const float *soft_inp, int rows, - int softmax_len, - cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu deleted file mode 100644 index 04de3c092..000000000 --- a/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu +++ /dev/null @@ -1,314 +0,0 @@ -#include -#include -#include - -#include "kernels.h" - -using namespace cub; - -/** -@brief: transform_0213 -Split the attention heads and reshape input -during backward progress of encoder self-attention - -@thread -gridDim.x = batch_size -gridDim.y = seq_len -blockDim.x = min(hidden_dim, MAX_THREADS) - -@param -input: [batch_size, seq_len, hidden_dim] -output: [batch_size, nhead, seq_len, head_dim] -batch_size: the size of the current batch -seq_len: the sequence length of the current batch -hidden_dim: dim of the hidden tensor -nhead: number of attention heads -*/ - -template -__global__ void transform_0213(T *output, const T *input, int hidden_dim, - int head_dim); - -template <> -__global__ void transform_0213(float *output, const float *input, - int hidden_dim, int head_dim) { - int batch_id = blockIdx.x; - int token_id = blockIdx.y; - int seq_len = gridDim.y; - int nhead = hidden_dim / head_dim; - - // [b, s, h] - int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim); - // [b, nh, s, ad] - int trg_offset = - flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim); - - const float4 *input4 = reinterpret_cast(input); - float4 *res4 = reinterpret_cast(output); - float4 vinput4; - - for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - vinput4 = input4[src_offset + i]; - - int head_id = i / head_dim; - int dim_id = i % head_dim; - int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim); - res4[trg_offset + cur_trg_offset] = vinput4; - } -} - -template <> -__global__ void transform_0213<__half>(__half *output, const __half *input, - int hidden_dim, int head_dim) { - int batch_id = blockIdx.x; - int token_id = blockIdx.y; - int seq_len = gridDim.y; - int nhead = hidden_dim / head_dim; - - // [b, s, h] - int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim); - // [b, nh, s, ad] - int trg_offset = - flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim); - - const float4 *input4 = reinterpret_cast(input); - float4 *res4 = reinterpret_cast(output); - float4 vinput4; - - for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - vinput4 = input4[src_offset + i]; - - int head_id = i / head_dim; - int dim_id = i % head_dim; - int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim); - res4[trg_offset + cur_trg_offset] = vinput4; - } -} - -// [b, s, h] -> [b, nh, s, ad] -template <> -void launch_transform_0213(float *output, const float *input, - int batch_size, int seq_len, int hidden_dim, - int nhead, cudaStream_t stream) { - hidden_dim >>= 2; - int head_dim = hidden_dim / nhead; - - dim3 grid_dim(batch_size, seq_len); - dim3 block_dim(min(hidden_dim, MAX_THREADS)); - - transform_0213 - <<>>(output, input, hidden_dim, head_dim); -} - -template <> -void launch_transform_0213<__half>(__half *output, const __half *input, - int batch_size, int seq_len, int hidden_dim, - int nhead, cudaStream_t stream) { - hidden_dim >>= 3; - int head_dim = hidden_dim / nhead; - - dim3 grid_dim(batch_size, seq_len); - dim3 block_dim(min(hidden_dim, MAX_THREADS)); - - transform_0213<__half> - <<>>(output, input, hidden_dim, head_dim); -} - -/** -@brief: bias_add_transform_20314 -Add bias to input, transform from -[0, 1, 2, 3, 4] to [2, 0, 3, 1, 4] - -@thread -gridDim.x = dim_0 -gridDim.y = dim_1 -gridDim.z = dim_2 -blockDim.x = min(dim_3 * dim_4, MAX_THREADS) - -@param -input: [dim_0, dim_1, dim_2, dim_3, dim_4] -bias: [dim_2, dim_3, dim_4] -output: [dim_2, dim_0, dim_3, dim_1, dim_4] -*/ -template -__global__ void bias_add_transform_20314(T *output, const T *input, - const T *bias, int dim_3, int dim_4); - -template <> -__global__ void bias_add_transform_20314(float *output, - const float *input, - const float *bias, int dim_3, - int dim_4) { - int id0 = blockIdx.x; - int id1 = blockIdx.y; - int id2 = blockIdx.z; - int dim_0 = gridDim.x; - int dim_1 = gridDim.y; - int dim_2 = gridDim.z; - int dim_34 = dim_3 * dim_4; - - int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34); - int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4); - int bias_offset = flat_2dim(id2, 0, dim_34); - - const float4 *qkv4 = reinterpret_cast(input); - const float4 *bias4 = reinterpret_cast(bias); - float4 *res4 = reinterpret_cast(output); - float4 vqkv4; - float4 vbias4; - float4 vres4; - - for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) { - vqkv4 = qkv4[src_offset + i]; - vbias4 = bias4[bias_offset + i]; - vres4.x = vqkv4.x + vbias4.x; - vres4.y = vqkv4.y + vbias4.y; - vres4.z = vqkv4.z + vbias4.z; - vres4.w = vqkv4.w + vbias4.w; - - int id3 = i / dim_4; - int id4 = i % dim_4; - int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4); - res4[trg_offset + cur_trg_offset] = vres4; - } -} - -template <> -__global__ void bias_add_transform_20314<__half>(__half *output, - const __half *input, - const __half *bias, int dim_3, - int dim_4) { - int id0 = blockIdx.x; - int id1 = blockIdx.y; - int id2 = blockIdx.z; - int dim_0 = gridDim.x; - int dim_1 = gridDim.y; - int dim_2 = gridDim.z; - int dim_34 = dim_3 * dim_4; - - int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34); - int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4); - int bias_offset = flat_2dim(id2, 0, dim_34); - - const float4 *qkv4 = reinterpret_cast(input); - const float4 *bias4 = reinterpret_cast(bias); - float4 *res4 = reinterpret_cast(output); - float4 vqkv4; - float4 vbias4; - float4 vres4; - __half2 *h2_qkv = reinterpret_cast<__half2 *>(&vqkv4); - __half2 *h2_bias = reinterpret_cast<__half2 *>(&vbias4); - __half2 *h2_res = reinterpret_cast<__half2 *>(&vres4); - - for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) { - vqkv4 = qkv4[src_offset + i]; - vbias4 = bias4[bias_offset + i]; - h2_res[0] = __hadd2(h2_qkv[0], h2_bias[0]); - h2_res[1] = __hadd2(h2_qkv[1], h2_bias[1]); - h2_res[2] = __hadd2(h2_qkv[2], h2_bias[2]); - h2_res[3] = __hadd2(h2_qkv[3], h2_bias[3]); - - int id3 = i / dim_4; - int id4 = i % dim_4; - int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4); - res4[trg_offset + cur_trg_offset] = vres4; - } -} - -// [b, s, 3, h] -> [3, b, nh, s, ad] -template <> -void launch_bias_add_transform_20314(float *output, const float *input, - const float *bias, int dim_0, - int dim_1, int dim_2, int dim_3, - int dim_4, cudaStream_t stream) { - dim_4 >>= 2; - - dim3 grid_dim(dim_0, dim_1, dim_2); - dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS)); - - bias_add_transform_20314 - <<>>(output, input, bias, dim_3, dim_4); -} - -template <> -void launch_bias_add_transform_20314<__half>(__half *output, - const __half *input, - const __half *bias, int dim_0, - int dim_1, int dim_2, int dim_3, - int dim_4, cudaStream_t stream) { - dim_4 >>= 3; - - dim3 grid_dim(dim_0, dim_1, dim_2); - dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS)); - - bias_add_transform_20314<__half> - <<>>(output, input, bias, dim_3, dim_4); -} - -/** -@brief: transform4d_0213 -Reshape the input matrix to merge the heads - -@thread -gridDim.x = (num_all + max_block_thread - 1) / max_block_thread -blockDim.x = max_block_thread - -@param -input: [trans_count, batch_size, nhead, seq_len, head_dim] -output: [batch_size, seq_len, trans_count, nhead, head_dim] -batch_size: the size of the current batch -seq_len: the sequence length of the current batch -hidden_dim: dim of the hidden tensor -nhead: number of attention heads -trans_count: 1 or 3, the count of matrice need to be transformed -*/ -template -__global__ void transform4d_0213(T *output, const T *input, int batch_size, - int seq_len, int trans_count, int nhead, - int head_dim, int num_all) { - int offset = blockIdx.x * blockDim.x + threadIdx.x; - if (offset >= num_all) { - return; - } - int trans_id, batch_id, head_id, token_id, dim_id; - decompose_5dim(offset, batch_size, nhead, seq_len, head_dim, &trans_id, - &batch_id, &head_id, &token_id, &dim_id); - // [b, s, tc, nh, ad] - int trg_offset = flat_5dim(batch_id, token_id, trans_id, head_id, dim_id, - seq_len, trans_count, nhead, head_dim); - - const float4 *input4 = reinterpret_cast(input); - float4 *res4 = reinterpret_cast(output); - res4[trg_offset] = input4[offset]; -} - -// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad] -template <> -void launch_transform4d_0213(float *output, const float *input, - int batch_size, int seq_len, int hidden_dim, - int nhead, int trans_count, - cudaStream_t stream) { - hidden_dim >>= 2; - int head_dim = hidden_dim / nhead; - int num_all = batch_size * seq_len * trans_count * hidden_dim; - int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS; - - transform4d_0213<<>>( - output, input, batch_size, seq_len, trans_count, nhead, head_dim, - num_all); -} - -template <> -void launch_transform4d_0213<__half>(__half *output, const __half *input, - int batch_size, int seq_len, - int hidden_dim, int nhead, int trans_count, - cudaStream_t stream) { - hidden_dim >>= 3; - int head_dim = hidden_dim / nhead; - int num_all = batch_size * seq_len * trans_count * hidden_dim; - int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS; - - transform4d_0213<__half><<>>( - output, input, batch_size, seq_len, trans_count, nhead, head_dim, - num_all); -} diff --git a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp deleted file mode 100644 index d08f3dbc7..000000000 --- a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp +++ /dev/null @@ -1,406 +0,0 @@ -#include "multihead_attention_1d.h" - -#include -#include -#include - -#if TORCH_VERSION_MAJOR > 1 || \ - (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13) -#include -#else -#include -#endif -#include - -#include "context.h" -#include "kernels.h" - -template -MultiHeadAttention::MultiHeadAttention(int layer_id, int max_batch_tokens, - int max_seq_len, int hidden_size, - int num_heads, - float attn_prob_dropout_ratio, - float hidden_output_dropout_ratio, - bool pre_or_postLayerNorm) - : _layer_id(layer_id), - _max_batch_tokens(max_batch_tokens), - _max_seq_len(max_seq_len), - _hidden_size(hidden_size), - _heads(num_heads), - _training(true), - _pre_or_postLayerNorm(pre_or_postLayerNorm), - _qkv_linear( - typename FeedForward::Config(3 * hidden_size, hidden_size)), - _attn_out_linear( - typename FeedForward::Config(hidden_size, hidden_size)), - _attn_ln(typename Normalize_Layer::Config(hidden_size, false), - _max_batch_tokens), - _softmax(typename Softmax::Config(num_heads)), - _attn_prob_dropout(typename Dropout::Config(attn_prob_dropout_ratio), - _max_batch_tokens * _heads * _max_seq_len), - _attn_dropout(typename Dropout::Config(hidden_output_dropout_ratio), - _max_batch_tokens * _hidden_size), - _attn_scores(typename StridedBatchGemm::Config( - (T(1.0) / T(sqrt(_hidden_size / _heads))), T(0.0), CUBLAS_OP_T, - CUBLAS_OP_N)), - _attn_context(typename StridedBatchGemm::Config( - T(1.0), T(0.0), CUBLAS_OP_N, CUBLAS_OP_N)) { - assert(_hidden_size % _heads == 0); -} - -template -MultiHeadAttention::~MultiHeadAttention() { - free_mem_buffer(); -} - -template -void MultiHeadAttention::attn_layer_fw(const T *input_ptr, - const T *input_mask_ptr, - T *output_ptr, T *buffer) { - T *q_tf_ptr = _qkv_ptr; - T *k_tf_ptr = q_tf_ptr + _batch_dim / pg_size; - T *v_tf_ptr = k_tf_ptr + _batch_dim / pg_size; - - if (_pre_or_postLayerNorm) { - _attn_ln.Forward(_gemmQKV_inp_ptr, input_ptr, _attn_nw_ptr, _attn_nb_ptr, - _batch_tokens, _stream); - } - const T *gemmQKV_inp_ptr = - _pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr; - _qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size); - _qkv_linear.Forward(_batch_tokens, gemmQKV_inp_ptr, _attn_qkvw_ptr, buffer, - _cublasHandle); - - launch_bias_add_transform_20314(q_tf_ptr, buffer, _attn_qkvb_ptr, - _batch_size, _seq_len, 3, _heads / pg_size, - _hidden_size / _heads, _stream); - - // attention scores, q*k - _attn_scores.Forward(_batch_heads, _soft_out_ptr, k_tf_ptr, q_tf_ptr, - _cublasHandle); - - // Softmax + Mask - _softmax.reset_size(_heads / pg_size); - _softmax.Forward(_soft_out_ptr, input_mask_ptr, _batch_size, _seq_len, - _seq_len, _stream, true); - - // attn prob dropout. - _attn_prob_dropout.dropout(_ctx_bufB_ptr, _soft_out_ptr, - _batch_heads * _seq_len * _seq_len, _stream); - - // attention context, score * v - _attn_context.Forward(_batch_heads, buffer, v_tf_ptr, _ctx_bufB_ptr, - _cublasHandle); - - // [b, nh, s, ad] -> [b, s, nh, ad] - launch_transform4d_0213(_attn_o_inp_ptr, buffer, _batch_size, _seq_len, - _hidden_size / pg_size, _heads / pg_size, 1, - _stream); - - _attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size); - _attn_out_linear.Forward(_batch_tokens, _attn_o_inp_ptr, _attn_ow_ptr, - output_ptr, _cublasHandle); - - // allreduce - if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) { - } else { - auto data_type = torch::kFloat; - if (typeid(T) != typeid(float)) { - data_type = torch::kHalf; - } - auto output_tensor = torch::from_blob( - output_ptr, {int(_batch_size), int(_seq_len), int(_hidden_size)}, - torch::TensorOptions(torch::kCUDA).dtype(data_type)); - std::vector allreduce_tensors = {output_tensor}; - auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions()); - work->wait(); - } - - _attn_dropout.bias_dropout_residual(output_ptr, output_ptr, input_ptr, - _attn_ob_ptr, _batch_tokens, _hidden_size, - _stream); - if (!_pre_or_postLayerNorm) { - // in-place ln since ln-input will not be used in post-ln mode - _attn_ln.Forward(output_ptr, output_ptr, _attn_nw_ptr, _attn_nb_ptr, - _batch_tokens, _stream); - } -} - -template -void MultiHeadAttention::Forward(const T *input_ptr, const T *input_mask_ptr, - T *out_ptr) { - _stream = Context::Instance().get_stream(); - _cublasHandle = Context::Instance().get_cublashandle(); - T *attn_buffer = _shared_mem_ptr; // 3 * _batch_dim - - attn_layer_fw(input_ptr, input_mask_ptr, out_ptr, attn_buffer); -} - -template -void MultiHeadAttention::attn_layer_bw(const T *input_ptr, - const T *input_mask_ptr, - const T *output_ptr, - const T *grad_output_ptr, - T *grad_input_ptr, T *buffer) { - cudaStream_t streams[2] = {_stream, _stream}; - - const T *q_tf_ptr = _qkv_ptr; - const T *k_tf_ptr = q_tf_ptr + _batch_dim / pg_size; - const T *v_tf_ptr = k_tf_ptr + _batch_dim / pg_size; - // batch_dim = batch_size * seq_len * hidden_size - // buffer size: batch_dim * 3 + max(batch_dim * 3, - // batch_size * head_num * seq_len * seq_len) - T *grad_residual_ptr = buffer; - buffer += _batch_dim; - - T *grad_input_buf_ptr = buffer; // batch_dim - T *grad_qkv_5d_ptr = buffer; // batch_dim * 3 - buffer += 3 * _batch_dim / pg_size; - - T *grad_qkv_4d_ptr = buffer; // batch_dim * 3 - T *grad_softmax_ptr = buffer; // batch_size * head_num * seq_len * seq_len - // buffer += max(3 * _batch_dim, - // batch_size * head_num * seq_len * seq_len); - - if (_pre_or_postLayerNorm) { - _attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr, - grad_output_ptr, _batch_tokens, - _hidden_size, _stream); - } else { - _attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_residual_ptr, - grad_output_ptr, nullptr, output_ptr, _attn_nw_ptr, - _attn_nb_ptr, _batch_tokens, streams); - _attn_dropout.d_bias_dropout_residual(grad_input_ptr, _grad_attn_ob_ptr, - grad_residual_ptr, _batch_tokens, - _hidden_size, _stream); - } - - // bw of output project - _attn_out_linear.reset_size(_hidden_size, _hidden_size / pg_size); - _attn_out_linear.Backward(_batch_tokens, grad_input_ptr, _attn_o_inp_ptr, - _attn_ow_ptr, _grad_attn_ow_ptr, _grad_attn_ob_ptr, - _cublasHandle, _stream, grad_input_buf_ptr, nullptr, - false); - launch_transform_0213(grad_input_ptr, grad_input_buf_ptr, _batch_size, - _seq_len, _hidden_size / pg_size, _heads / pg_size, - _stream); - - // bw of score * v - _attn_context.Backward( - _batch_heads, grad_input_ptr, v_tf_ptr, _ctx_bufB_ptr, _cublasHandle, - grad_qkv_5d_ptr + 2 * _batch_dim / pg_size, grad_softmax_ptr); - - _attn_prob_dropout.d_dropout(grad_softmax_ptr, - _batch_heads * _seq_len * _seq_len, _stream); - - _softmax.reset_size(_heads / pg_size); - _softmax.Backward(grad_softmax_ptr, _soft_out_ptr, _batch_size, _seq_len, - _seq_len, _stream); - - // bw of q * k - _attn_scores.Backward(_batch_heads, grad_softmax_ptr, k_tf_ptr, q_tf_ptr, - _cublasHandle, grad_qkv_5d_ptr + _batch_dim / pg_size, - grad_qkv_5d_ptr); - - // [3, b, nh, s, ad] -> [b, s, 3, h] - launch_transform4d_0213(grad_qkv_4d_ptr, grad_qkv_5d_ptr, _batch_size, - _seq_len, _hidden_size / pg_size, _heads / pg_size, - 3, _stream); - - const T *gemmQKV_inp_ptr = - _pre_or_postLayerNorm ? _gemmQKV_inp_ptr : input_ptr; - _qkv_linear.reset_size(3 * _hidden_size / pg_size, _hidden_size); - _qkv_linear.Backward(_batch_tokens, grad_qkv_4d_ptr, gemmQKV_inp_ptr, - _attn_qkvw_ptr, _grad_attn_qkvw_ptr, _grad_attn_qkvb_ptr, - _cublasHandle, _stream, grad_input_buf_ptr, nullptr, - true); - - // allreduce - if (pg == c10::detail::UniqueVoidPtr() || pg->getSize() == 1) { - } else { - auto data_type = torch::kFloat; - if (typeid(T) != typeid(float)) { - data_type = torch::kHalf; - } - auto grad_input_tensor = - torch::from_blob(grad_input_buf_ptr, - {int(_batch_size), int(_seq_len), int(_hidden_size)}, - torch::TensorOptions(torch::kCUDA).dtype(data_type)); - std::vector allreduce_tensors = {grad_input_tensor}; - auto work = pg->allreduce(allreduce_tensors, c10d::AllreduceOptions()); - work->wait(); - } - - if (_pre_or_postLayerNorm) { - _attn_ln.Backward(_grad_attn_nw_ptr, _grad_attn_nb_ptr, grad_input_ptr, - grad_input_buf_ptr, grad_output_ptr, gemmQKV_inp_ptr, - _attn_nw_ptr, _attn_nb_ptr, _batch_tokens, streams); - } else { - // FIXME later - launch_fused_add2(grad_input_ptr, grad_input_buf_ptr, grad_residual_ptr, - _batch_size, _seq_len, _hidden_size, _stream); - } -} - -template -void MultiHeadAttention::Backward(const T *grad_output_ptr, - const T *input_ptr, const T *output_ptr, - const T *input_mask_ptr, - T *grad_input_ptr) { - _stream = Context::Instance().get_stream(); - _cublasHandle = Context::Instance().get_cublashandle(); - T *buffer = _shared_mem_ptr; - - /* - buffer size needed by attn bw: - 4 * _batch_dim + max(3 * _batch_dim, - _batch_size * _head_num * _seq_len * _seq_len); - */ - attn_layer_bw(input_ptr, input_mask_ptr, output_ptr, grad_output_ptr, - grad_input_ptr, buffer); -} - -template -void MultiHeadAttention::SetTrainingMode(bool training) { - // Dropout will be skipped when not in training model. - _attn_prob_dropout.SetTrainingMode(training); - _attn_dropout.SetTrainingMode(training); -} - -template -T *MultiHeadAttention::_shared_mem_ptr = nullptr; - -template class MultiHeadAttention; -template class MultiHeadAttention<__half>; - -// x is torch::Tensor -#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) \ - AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ - CHECK_CONTIGUOUS(x) - -static std::unordered_map> s_multihead_attention; - -template -int create_multihead_attention(int layer_id, int max_batch_tokens, - int max_seq_len, int hidden_dim, int num_heads, - float attn_prob_dropout_ratio, - float hidden_dropout_ratio, - bool pre_or_postLayerNorm, - c10::intrusive_ptr pg_) { - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - Context::Instance().set_stream(stream); - auto layer = std::make_shared>( - layer_id, max_batch_tokens, max_seq_len, hidden_dim, num_heads, - attn_prob_dropout_ratio, hidden_dropout_ratio, pre_or_postLayerNorm); - - layer->SetPG(pg_); - - s_multihead_attention[layer_id] = layer; - - std::string dtype = (std::is_same::value) ? "half" : "float"; - - return 0; -} - -template -std::vector multihead_attention_fw( - int layer_id, const torch::Tensor &input, const torch::Tensor &input_mask, - const torch::Tensor &in_proj_weight, const torch::Tensor &in_proj_bias, - const torch::Tensor &out_proj_weight, const torch::Tensor &out_proj_bias, - const torch::Tensor &norm_weight, const torch::Tensor &norm_bias, - bool training_mode, bool prelayernorm) { - CHECK_INPUT(input); - CHECK_INPUT(input_mask); - - const T *input_ptr = (const T *)input.data_ptr(); - const T *input_mask_ptr = (const T *)input_mask.data_ptr(); - - auto output = torch::empty_like(input); - T *out_ptr = (T *)output.data_ptr(); - - std::shared_ptr> layer = - std::static_pointer_cast>( - s_multihead_attention[layer_id]); - layer->set_cur_batch_shape(input.size(0), input.size(1)); - layer->SetTrainingMode(training_mode); - - layer->_attn_qkvw_ptr = (const T *)in_proj_weight.data_ptr(); - layer->_attn_qkvb_ptr = (const T *)in_proj_bias.data_ptr(); - layer->_attn_ow_ptr = (const T *)out_proj_weight.data_ptr(); - layer->_attn_ob_ptr = (const T *)out_proj_bias.data_ptr(); - layer->_attn_nw_ptr = (const T *)norm_weight.data_ptr(); - layer->_attn_nb_ptr = (const T *)norm_bias.data_ptr(); - - layer->Forward(input_ptr, input_mask_ptr, out_ptr); - - return {output}; -} - -template -std::vector multihead_attention_bw( - int layer_id, const torch::Tensor &grad_dec_output, - const torch::Tensor &output, const torch::Tensor &input, - const torch::Tensor &input_mask, const torch::Tensor &in_proj_weight, - const torch::Tensor &in_proj_bias, const torch::Tensor &out_proj_weight, - const torch::Tensor &out_proj_bias, const torch::Tensor &norm_weight, - const torch::Tensor &norm_bias) { - auto g_output = grad_dec_output.contiguous(); - CHECK_INPUT(g_output); - CHECK_INPUT(output); - CHECK_INPUT(input); - CHECK_INPUT(input_mask); - - auto grad_input = torch::empty_like(input); - auto grad_in_proj_weight = torch::empty_like(in_proj_weight); - auto grad_in_proj_bias = torch::empty_like(in_proj_bias); - auto grad_out_proj_weight = torch::empty_like(out_proj_weight); - auto grad_out_proj_bias = torch::empty_like(out_proj_bias); - auto grad_norm_weight = torch::empty_like(norm_weight); - auto grad_norm_bias = torch::empty_like(norm_bias); - - // inputs. - const T *grad_dec_output_ptr = (const T *)g_output.data_ptr(); - const T *input_ptr = (const T *)input.data_ptr(); - const T *output_ptr = (const T *)output.data_ptr(); - const T *input_mask_ptr = (const T *)input_mask.data_ptr(); - - // outputs. - T *grad_input_ptr = (T *)grad_input.data_ptr(); - - std::shared_ptr> layer = - std::static_pointer_cast>( - s_multihead_attention[layer_id]); - layer->set_cur_batch_shape(g_output.size(0), g_output.size(1)); - - layer->_grad_attn_qkvw_ptr = (T *)grad_in_proj_weight.data_ptr(); - layer->_grad_attn_qkvb_ptr = (T *)grad_in_proj_bias.data_ptr(); - layer->_grad_attn_ow_ptr = (T *)grad_out_proj_weight.data_ptr(); - layer->_grad_attn_ob_ptr = (T *)grad_out_proj_bias.data_ptr(); - layer->_grad_attn_nw_ptr = (T *)grad_norm_weight.data_ptr(); - layer->_grad_attn_nb_ptr = (T *)grad_norm_bias.data_ptr(); - - layer->Backward(grad_dec_output_ptr, input_ptr, output_ptr, input_mask_ptr, - grad_input_ptr); - - return {grad_input, grad_in_proj_weight, grad_in_proj_bias, - grad_out_proj_weight, grad_out_proj_bias, grad_norm_weight, - grad_norm_bias}; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("multihead_attention_fw_fp32", &multihead_attention_fw, - "Multi-head Attention forward with fp32 (CUDA)"); - m.def("multihead_attention_fw_fp16", &multihead_attention_fw<__half>, - "Multi-head Attention forward with fp16 (CUDA)"); - m.def("multihead_attention_bw_fp32", &multihead_attention_bw, - "Multi-head Attention backward with fp32 (CUDA)"); - m.def("multihead_attention_bw_fp16", &multihead_attention_bw<__half>, - "Multi-head Attention backward with fp16 (CUDA)"); - m.def("create_multihead_attention_fp32", &create_multihead_attention, - "Create Multi-head Attention with fp32 (CUDA)"); - m.def("create_multihead_attention_fp16", &create_multihead_attention<__half>, - "Create Multi-head Attention with fp16 (CUDA)"); -} diff --git a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h b/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h deleted file mode 100644 index 6505eb31f..000000000 --- a/colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h +++ /dev/null @@ -1,167 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include - -#if TORCH_VERSION_MAJOR > 1 || \ - (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13) -#include -#else -#include -#endif - -#include -#include - -#include "cuda_util.h" -#include "dropout.h" -#include "feed_forward.h" -#include "normalize_layer.h" -#include "softmax.h" -#include "strided_batch_gemm.h" - -template -class MultiHeadAttention { - public: - MultiHeadAttention(int layer_id, int max_batch_tokens, int _max_seq_len, - int hidden_size, int num_heads, float attn_dropout_ratio, - float hidden_output_dropout_ratio, - bool pre_or_postLayerNorm); - - virtual ~MultiHeadAttention(); - - void Forward(const T *input_ptr, const T *input_mask_ptr, T *out_ptr); - - void Backward(const T *grad_output_ptr, const T *input_ptr, - const T *output_ptr, const T *input_mask_ptr, - T *grad_input_ptr); - - void attn_layer_fw(const T *input_ptr, const T *input_mask_ptr, T *output_ptr, - T *buffer); - - void attn_layer_bw(const T *input_ptr, const T *input_mask_ptr, - const T *output_ptr, const T *grad_output_ptr, - T *grad_input_attn_layer_bwptr, T *buffer); - - void set_cur_batch_shape(int batch_size, int seq_len) { - _batch_size = batch_size; - _seq_len = seq_len; - _batch_tokens = batch_size * seq_len; - _batch_heads = batch_size * _heads / pg_size; - _batch_dim = _batch_tokens * _hidden_size; - _attn_scores.SetConfig(_seq_len, _seq_len, _hidden_size / _heads); - _attn_context.SetConfig(_hidden_size / _heads, _seq_len, _seq_len); - } - - void SetTrainingMode(bool training); - inline bool IsTrainingMode() const { return _training; } - - void SetPG(c10::intrusive_ptr pg_) { - pg = pg_; - pg_size = 1; - if (pg != c10::detail::UniqueVoidPtr()) { - pg_size = pg->getSize(); - } - allocate_mem_buffer(); - } - - // weights ptr - const T *_attn_qkvw_ptr; - const T *_attn_qkvb_ptr; - const T *_attn_ow_ptr; - const T *_attn_ob_ptr; - const T *_attn_nw_ptr; - const T *_attn_nb_ptr; - - // grads ptr - T *_grad_attn_qkvw_ptr; - T *_grad_attn_qkvb_ptr; - T *_grad_attn_ow_ptr; - T *_grad_attn_ob_ptr; - T *_grad_attn_nw_ptr; - T *_grad_attn_nb_ptr; - - private: - void allocate_mem_buffer() { - // allocate local gpu memory - if (_pre_or_postLayerNorm) { - _gemmQKV_inp_ptr = cuda_malloc(_max_batch_tokens * _hidden_size); - } else { - _gemmQKV_inp_ptr = nullptr; - } - - _qkv_ptr = cuda_malloc(_max_batch_tokens * _hidden_size * 3); - _soft_out_ptr = - cuda_malloc(_max_batch_tokens * _heads / pg_size * _max_seq_len); - _ctx_bufB_ptr = - cuda_malloc(_max_batch_tokens * _heads / pg_size * _max_seq_len); - _attn_o_inp_ptr = cuda_malloc(_max_batch_tokens * _hidden_size); - - // buffer size needed by attn bw - size_t smem_size = - 4 * _max_batch_tokens * _hidden_size / pg_size + - std::max(3 * _max_batch_tokens * _hidden_size / pg_size, - _max_batch_tokens * _heads / pg_size * _max_seq_len); - - if (!_shared_mem_ptr) { - cuda_free(_shared_mem_ptr); - _shared_mem_ptr = cuda_malloc(smem_size); - } - } - - void free_mem_buffer() { - // free local gpu memory - cuda_free(_gemmQKV_inp_ptr); - cuda_free(_qkv_ptr); - cuda_free(_soft_out_ptr); - cuda_free(_ctx_bufB_ptr); - cuda_free(_attn_o_inp_ptr); - - // free shared gpu memory between layers - cuda_free(_shared_mem_ptr); - _shared_mem_ptr = nullptr; - } - - // const parameter between batch - const size_t _layer_id; - const size_t _hidden_size; - const size_t _heads; - const size_t _max_batch_tokens; - const size_t _max_seq_len; - const bool _pre_or_postLayerNorm; - // dynamic parameter between batch - size_t _batch_size; - size_t _seq_len; - size_t _batch_tokens; - size_t _batch_heads; - size_t _batch_dim; - bool _training; - - cublasHandle_t _cublasHandle; - cudaStream_t _stream; - - // layers - FeedForward _qkv_linear; - FeedForward _attn_out_linear; - Normalize_Layer _attn_ln; - Softmax _softmax; - Dropout _attn_prob_dropout; - Dropout _attn_dropout; - StridedBatchGemm _attn_scores; - StridedBatchGemm _attn_context; - - // local GPU memory - T *_gemmQKV_inp_ptr; - T *_qkv_ptr; - T *_soft_out_ptr; - T *_ctx_bufB_ptr; - T *_attn_o_inp_ptr; - // shared GPU memory between layer - static T *_shared_mem_ptr; - - c10::intrusive_ptr pg; - int pg_size; -}; diff --git a/colossalai/kernel/cuda_native/csrc/smoothquant/binding.cpp b/colossalai/kernel/cuda_native/csrc/smoothquant/binding.cpp deleted file mode 100644 index 844427294..000000000 --- a/colossalai/kernel/cuda_native/csrc/smoothquant/binding.cpp +++ /dev/null @@ -1,8 +0,0 @@ -#include - -#include "linear.h" - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("linear_silu_a8_w8_bfp32_ofp32", &linear_silu_a8_w8_bfp32_ofp32, - "Linear SiLU (INT8)"); -} diff --git a/colossalai/kernel/cuda_native/csrc/smoothquant/linear.cu b/colossalai/kernel/cuda_native/csrc/smoothquant/linear.cu deleted file mode 100644 index a30d02a4c..000000000 --- a/colossalai/kernel/cuda_native/csrc/smoothquant/linear.cu +++ /dev/null @@ -1,162 +0,0 @@ -// modified from https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/kernels/linear.cu - -#include "linear.h" -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8 - torch::Tensor weight, // INT8 - torch::Tensor bias, // FP32 - float alpha, // FP32 - float beta // FP32 -) { - auto M = input.size(0); - auto N = weight.size(0); - auto K = input.size(1); - - using ElementOutput = float; - using ElementAccumulator = int32_t; - using ElementComputeEpilogue = float; - using ElementInputA = int8_t; // <- data type of elements in input matrix A - using ElementInputB = int8_t; // <- data type of elements in input matrix B - - // The code section below describes matrix layout of input and output - // matrices. Column Major for Matrix A, Row Major for Matrix B and Row Major - // for Matrix C - using LayoutInputA = cutlass::layout::RowMajor; - using LayoutInputB = cutlass::layout::ColumnMajor; - using LayoutOutput = cutlass::layout::RowMajor; - -#if CUDA_ARCH >= 800 - using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu< - ElementOutput, // <- data type of output matrix - 128 / cutlass::sizeof_bits< - ElementOutput>::value, // <- this is the number of elements per - // vectorized memory access. For half - // precision, it's 8 elements. This - // becomes the vector width of math - // instructions in epilogue too - ElementAccumulator, // <- data type of accumulator - ElementComputeEpilogue // <- data type for alpha in linear combination - // function - >; - using Gemm = cutlass::gemm::device::Gemm< - int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, - ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, - cutlass::gemm::GemmShape<256, 128, 64>, - cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, - EpilogueOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; -#elif CUDA_ARCH >= 750 - using EpilogueOp = cutlass::epilogue::thread::LinearCombinationSilu< - ElementOutput, // <- data type of output matrix - 128 / cutlass::sizeof_bits< - ElementOutput>::value, // <- this is the number of elements per - // vectorized memory access. For half - // precision, it's 8 elements. This - // becomes the vector width of math - // instructions in epilogue too - ElementAccumulator, // <- data type of accumulator - ElementComputeEpilogue // <- data type for alpha in linear combination - // function - >; - - using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration< - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, - ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>; - using Gemm = cutlass::gemm::device::Gemm< - int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, - ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, - DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape, - DefaultGemmCfg::InstructionShape, - EpilogueOp>; -#elif CUDA_ARCH >= 700 - #define USE_TORCH_SILU - using DefaultGemmCfg = cutlass::gemm::device::DefaultGemmConfiguration< - cutlass::arch::OpClassSimt, cutlass::arch::Sm70, - ElementInputA, ElementInputB, ElementOutput, ElementAccumulator>; - using Gemm = cutlass::gemm::device::Gemm< - int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, - ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, - cutlass::arch::OpClassSimt, cutlass::arch::Sm70, - DefaultGemmCfg::ThreadblockShape, DefaultGemmCfg::WarpShape, - DefaultGemmCfg::InstructionShape, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, 1, ElementAccumulator, ElementComputeEpilogue>>; -#else - #error "Unsupported cuda arch" -#endif - - auto input_size = cutlass::MatrixCoord(M, K); - auto weight_size = cutlass::MatrixCoord(K, N); - auto output_size = cutlass::MatrixCoord(M, N); - - auto device = input.device(); - // use the broadcasted bias as the output - auto out = bias.to(device).view({1, -1}).repeat({M, 1}); - - // constexpr int kSparse = Gemm::kSparse; - // How many elements of A are covered per ElementE - // constexpr int kElementsPerElementE = Gemm::kElementsPerElementE; - // The size of individual meta data - // constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; - cutlass::gemm::GemmCoord problem_size(M, N, K); - - cutlass::TensorRef input_ref( - input.data_ptr(), LayoutInputA::packed(input_size)); - cutlass::TensorRef weight_ref( - weight.data_ptr(), LayoutInputB::packed(weight_size)); - cutlass::TensorRef out_ref( - out.data_ptr(), LayoutOutput::packed(output_size)); - - typename Gemm::Arguments arguments{ - problem_size, // <- problem size of matrix multiplication - input_ref, // <- reference to matrix A on device - weight_ref, // <- reference to matrix B on device - out_ref, // <- reference to matrix C on device - out_ref, // <- reference to matrix D on device - {alpha, beta}, 1}; - Gemm gemm_op; - - // Using the arguments, query for extra workspace required for matrix - // multiplication computation - size_t workspace_size = Gemm::get_workspace_size(arguments); - - // Allocate workspace memory - cutlass::device_memory::allocation workspace(workspace_size); - - // Check the problem size is supported or not - cutlass::Status status = gemm_op.can_implement(arguments); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot implement"); - } - - // Initialize CUTLASS kernel with arguments and workspace pointer - status = gemm_op.initialize(arguments, workspace.get()); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot initialize"); - } - - status = gemm_op(); - if (status != cutlass::Status::kSuccess) { - throw std::runtime_error("cutlass cannot run"); - } -#ifdef USE_TORCH_SILU -#undef USE_TORCH_SILU - out = torch::silu(out); -#endif - return out; -} diff --git a/colossalai/kernel/cuda_native/csrc/smoothquant/linear.h b/colossalai/kernel/cuda_native/csrc/smoothquant/linear.h deleted file mode 100644 index b62a27f3f..000000000 --- a/colossalai/kernel/cuda_native/csrc/smoothquant/linear.h +++ /dev/null @@ -1,12 +0,0 @@ -#include -#include - -#include -#include - -torch::Tensor linear_silu_a8_w8_bfp32_ofp32(torch::Tensor input, // INT8 - torch::Tensor weight, // INT8 - torch::Tensor bias, // FP32 - float alpha, // FP32 - float beta // FP32 -); diff --git a/colossalai/kernel/cuda_native/multihead_attention.py b/colossalai/kernel/cuda_native/multihead_attention.py deleted file mode 100644 index 87afc1862..000000000 --- a/colossalai/kernel/cuda_native/multihead_attention.py +++ /dev/null @@ -1,338 +0,0 @@ -import math -from dataclasses import dataclass - -import torch -from torch import nn -from torch.autograd import Function - - -def check_config(config): - if config.hidden_size % config.nhead != 0: - raise Exception("hidden_size % nhead != 0") - - factor = 8 if config.fp16 else 4 - upbound = factor * 1024 * 4 - if config.hidden_size > upbound: - # as required by ln backward kernel currently - raise Exception(f"hidden_size > {upbound}") - - head_dim = config.hidden_size // config.nhead - if head_dim % factor != 0: - # as required by reshape kernel - raise Exception(f"head_dim({head_dim}) % {factor} != 0") - - -def calc_offset(sizes): - offsets = [0] - tmp = 0 - for x in sizes: - tmp += x - offsets.append(tmp) - return offsets - - -colossal_multihead_attention = None - - -@dataclass -class Config: - max_batch_tokens: int # max batch token numbers - max_seq_len: int # max sequence length - hidden_size: int # size of transformer hidden layers - nhead: int # number of heads in attention - attn_prob_dropout_ratio: float # attention score dropout ratio - hidden_dropout_ratio: float # dropout ration before residual - norm_first: bool # norm_first - fp16: bool # fp16 precision - - -class MultiHeadAttention1DFunc(Function): - @staticmethod - def forward( - ctx, - input, - input_mask, - in_proj_weight, - in_proj_bias, - out_proj_weight, - out_proj_bias, - norm_weight, - norm_bias, - config, - ): - cuda_module = colossal_multihead_attention - forward_func = ( - cuda_module.multihead_attention_fw_fp16 if config.fp16 else cuda_module.multihead_attention_fw_fp32 - ) - if config.fp16: - input = input.to(torch.half) - input_mask = input_mask.to(torch.half) - - (output,) = forward_func( - config.layer_id, - input, - input_mask, - in_proj_weight, - in_proj_bias, - out_proj_weight, - out_proj_bias, - norm_weight, - norm_bias, - config.training, - config.norm_first, - ) - - if config.is_grad_enabled and config.training: - ctx.save_for_backward( - output, - input, - input_mask, - in_proj_weight, - in_proj_bias, - out_proj_weight, - out_proj_bias, - norm_weight, - norm_bias, - ) - ctx.config = config - return output - - @staticmethod - def backward(ctx, grad_output): - assert ctx.config.training - - cuda_module = colossal_multihead_attention - backward_func = ( - cuda_module.multihead_attention_bw_fp16 if ctx.config.fp16 else cuda_module.multihead_attention_bw_fp32 - ) - - ( - output, - input, - input_mask, - in_proj_weight, - in_proj_bias, - out_proj_weight, - out_proj_bias, - norm_weight, - norm_bias, - ) = ctx.saved_tensors - - grad_input = None - grad_in_proj_weight = None - grad_in_proj_bias = None - grad_out_proj_weight = None - grad_out_proj_bias = None - grad_norm_weight = None - grad_norm_bias = None - - if ctx.config.fp16: - grad_output = grad_output.to(torch.half) - output = output.to(torch.half) - input = input.to(torch.half) - input_mask = input_mask.to(torch.half) - ( - grad_input, - grad_in_proj_weight, - grad_in_proj_bias, - grad_out_proj_weight, - grad_out_proj_bias, - grad_norm_weight, - grad_norm_bias, - ) = backward_func( - ctx.config.layer_id, - grad_output, - output, - input, - input_mask, - in_proj_weight, - in_proj_bias, - out_proj_weight, - out_proj_bias, - norm_weight, - norm_bias, - ) - - return ( - grad_input, - None, - grad_in_proj_weight, - grad_in_proj_bias, - grad_out_proj_weight, - grad_out_proj_bias, - grad_norm_weight, - grad_norm_bias, - None, - ) - - -class MultiHeadAttention(nn.Module): - """Initialize the MultiHeadAttention. - - Static variable: - - layer_id: The layer-index counter starting from 0 and incrementing by 1 every time a layer object is instantiated, - e.g. if a model has 24 transformer layers, layer_id goes from 0 to 23. - - Arguments: - hidden_size: Total dimension of hidden_size. - nhead: Number of parallel attention heads. - batch_size: Batch Size for one forward - max_seq_len: Max length of input sequence - dropout: Dropout probability - norm_first: perform LayerNorms before attention - """ - - layer_id = 0 - - def __init__(self, hidden_size, nhead, batch_size, max_seq_len, dropout=0.0, norm_first=False, fp16=True, pg=None): - super(MultiHeadAttention, self).__init__() - - self.config = Config( - batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout, dropout, norm_first, fp16 - ) - check_config(self.config) - self.pg = pg - self.pg_size = 1 - if self.pg: - self.pg_size = pg.size() - self.config.layer_id = MultiHeadAttention.layer_id - MultiHeadAttention.layer_id = MultiHeadAttention.layer_id + 1 - - # Load cuda modules if needed - global colossal_multihead_attention - if colossal_multihead_attention is None: - from colossalai.kernel.op_builder import MultiHeadAttnBuilder - - multihead_attention = MultiHeadAttnBuilder().load() - colossal_multihead_attention = multihead_attention - - # create the layer in cuda kernels. - cuda_module = colossal_multihead_attention - create_layer_func = ( - cuda_module.create_multihead_attention_fp16 - if self.config.fp16 - else cuda_module.create_multihead_attention_fp32 - ) - - create_layer_func( - self.config.layer_id, - self.config.max_batch_tokens, - self.config.max_seq_len, - self.config.hidden_size, - self.config.nhead, - self.config.attn_prob_dropout_ratio, - self.config.hidden_dropout_ratio, - self.config.norm_first, - self.pg, - ) - - hs = self.config.hidden_size - - self.precision = torch.float32 - if self.config.fp16: - self.precision = torch.half - - self.hs_per_rank = int(hs / self.pg_size) - - self.in_proj_weight = nn.Parameter(torch.Tensor(3, self.hs_per_rank, hs)) - self.in_proj_bias = nn.Parameter(torch.Tensor(3, self.hs_per_rank)) - self.out_proj_weight = nn.Parameter(torch.Tensor(hs, self.hs_per_rank)) - self.out_proj_bias = nn.Parameter(torch.Tensor(hs)) - self.norm_weight = nn.Parameter(torch.Tensor(hs)) - self.norm_bias = nn.Parameter(torch.Tensor(hs)) - - self.reset_parameters() - torch.cuda.empty_cache() - - def calc_bound(self, w): - fan_in, _ = nn.init._calculate_fan_in_and_fan_out(w) - bound = 1.0 / math.sqrt(fan_in) - return bound - - def reset_parameters(self): - hs = self.config.hidden_size - - nn.init.zeros_(self.out_proj_bias) - - nn.init.ones_(self.norm_weight) - nn.init.zeros_(self.norm_bias) - - if self.pg_size > 1: - rank_in_pg = torch.distributed.get_rank(self.pg) - attn_qkvw_global = torch.empty(hs * 3, hs) - attn_qkvb_global = torch.empty(hs * 3) - nn.init.xavier_uniform_(attn_qkvw_global, 1.0 / math.sqrt(2.0)) - bound = self.calc_bound(attn_qkvw_global) - nn.init.uniform_(attn_qkvb_global, -bound, bound) - - attn_qkvw_global = attn_qkvw_global.cuda() - attn_qkvb_global = attn_qkvb_global.cuda() - torch.distributed.broadcast(attn_qkvw_global, src=0, group=self.pg) - torch.distributed.broadcast(attn_qkvb_global, src=0, group=self.pg) - attn_qkvw_global = attn_qkvw_global.cpu() - attn_qkvb_global = attn_qkvb_global.cpu() - - with torch.no_grad(): - self.in_proj_weight.copy_( - attn_qkvw_global.view(3, hs, hs)[ - :, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size), : - ] - ) - self.in_proj_bias.copy_( - attn_qkvb_global.view(3, hs)[ - :, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size) - ] - ) - - attn_ow_global = torch.empty(hs, hs) - nn.init.xavier_uniform_(attn_ow_global, 1.0) - attn_ow_global = attn_ow_global.cuda() - torch.distributed.broadcast(attn_ow_global, src=0, group=self.pg) - attn_ow_global = attn_ow_global.cpu() - with torch.no_grad(): - self.out_proj_weight.copy_( - attn_ow_global[:, int(hs * rank_in_pg / self.pg_size) : int(hs * (rank_in_pg + 1) / self.pg_size)] - ) - - else: - attn_qkvw = self.in_proj_weight.view(-1, hs) - nn.init.xavier_uniform_(attn_qkvw, 1.0 / math.sqrt(2.0)) - bound = self.calc_bound(attn_qkvw) - nn.init.uniform_(self.in_proj_bias, -bound, bound) - - nn.init.xavier_uniform_(self.out_proj_weight, 1.0) - - def state_dict(self, destination=None, prefix="", keep_vars=False): - destination = torch.nn.Module.state_dict(self, destination=destination, prefix=prefix, keep_vars=keep_vars) - return destination - - def forward(self, hidden_states, encoder_padding_mask): - self.config.training = self.training - self.config.is_grad_enabled = torch.is_grad_enabled() - hidden_states = hidden_states.contiguous() - encoder_padding_mask = (encoder_padding_mask * -1e8).type_as(hidden_states).contiguous() - - bs, sl, dim = hidden_states.size() - if bs * sl > self.config.max_batch_tokens: - raise ValueError(f"Batch token numbers {bs * sl} exceeds the limit {self.config.max_batch_tokens}.") - if sl > self.config.max_seq_len: - raise ValueError(f"Sequence length {sl} exceeds the limit {self.config.max_seq_len}.") - if len(encoder_padding_mask.size()) == 1: - assert bs == 1 and sl == encoder_padding_mask.size(0) - else: - assert bs == encoder_padding_mask.size(0) and sl == encoder_padding_mask.size(1) - - output = MultiHeadAttention1DFunc.apply( - hidden_states, - encoder_padding_mask, - self.in_proj_weight, - self.in_proj_bias, - self.out_proj_weight, - self.out_proj_bias, - self.norm_weight, - self.norm_bias, - self.config, - ) - - return output.to(self.precision) diff --git a/colossalai/kernel/extensions b/colossalai/kernel/extensions new file mode 120000 index 000000000..e8eb45a54 --- /dev/null +++ b/colossalai/kernel/extensions @@ -0,0 +1 @@ +../../extensions \ No newline at end of file diff --git a/colossalai/kernel/extensions/__init__.py b/colossalai/kernel/extensions/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/colossalai/kernel/extensions/base_extension.py b/colossalai/kernel/extensions/base_extension.py deleted file mode 100644 index 8905dbf13..000000000 --- a/colossalai/kernel/extensions/base_extension.py +++ /dev/null @@ -1,21 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Callable - - -class BaseExtension(ABC): - @abstractmethod - def requires_build(self) -> bool: - pass - - @abstractmethod - def build(self) -> None: - pass - - @abstractmethod - def load(self) -> Callable: - pass - - def fetch(self) -> Callable: - if self.requires_build: - self.build() - return self.load() diff --git a/colossalai/kernel/extensions/cpu_adam/__init__.py b/colossalai/kernel/extensions/cpu_adam/__init__.py deleted file mode 100644 index b14f3a978..000000000 --- a/colossalai/kernel/extensions/cpu_adam/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .arm_extension import ArmCPUAdamExtension -from .x86_extension import X86CPUAdamExtension - -__all__ = ["ArmCPUAdamExtension", "X86CPUAdamExtension"] diff --git a/colossalai/kernel/extensions/cpu_adam/arm_extension.py b/colossalai/kernel/extensions/cpu_adam/arm_extension.py deleted file mode 100644 index 9868059bf..000000000 --- a/colossalai/kernel/extensions/cpu_adam/arm_extension.py +++ /dev/null @@ -1,53 +0,0 @@ -from ..base_extension import BaseExtension -from ..extension_builder import ExtensionBuilder - - -class ArmCPUAdamExtension(BaseExtension): - def __init__(self) -> None: - super().__init__() - self.kernel_builder = ArmCPUAdamBuilder() - self._requires_build = False - - @property - def requires_build(self) -> bool: - return self._requires_build - - def build(self): - self.kernel_builder.build() - self._requires_build = True - - def load(self): - return self.kernel_builder.load() - - -class ArmCPUAdamBuilder(ExtensionBuilder): - NAME = "arm_cpu_adam" - PREBUILT_IMPORT_PATH = "colossalai._C.arm_cpu_adam" - ext_type = "cpu" - - def __init__(self): - super().__init__(name=ArmCPUAdamBuilder.NAME, prebuilt_import_path=ArmCPUAdamBuilder.PREBUILT_IMPORT_PATH) - self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] - - # necessary 4 functions - def sources_files(self): - ret = [ - self.csrc_abs_path("cpu_adam_arm.cpp"), - ] - return ret - - def include_dirs(self): - return [self.csrc_abs_path("includes")] - - def cxx_flags(self): - extra_cxx_flags = [ - "-std=c++14", - "-std=c++17", - "-g", - "-Wno-reorder", - "-fopenmp", - ] - return ["-O3"] + self.version_dependent_macros + extra_cxx_flags - - def nvcc_flags(self): - return [] diff --git a/colossalai/kernel/extensions/cpu_adam/x86_extension.py b/colossalai/kernel/extensions/cpu_adam/x86_extension.py deleted file mode 100644 index 687c91f35..000000000 --- a/colossalai/kernel/extensions/cpu_adam/x86_extension.py +++ /dev/null @@ -1,65 +0,0 @@ -from ..base_extension import BaseExtension -from ..extension_builder import ExtensionBuilder -from ..utils import append_nvcc_threads - - -class X86CPUAdamExtension(BaseExtension): - def __init__(self) -> None: - super().__init__() - self.kernel_builder = X86CPUAdamBuilder() - self._requires_build = False - - @property - def requires_build(self) -> bool: - return self._requires_build - - def build(self): - self.kernel_builder.build() - self._requires_build = True - - def load(self): - return self.kernel_builder.load() - - -class X86CPUAdamBuilder(ExtensionBuilder): - NAME = "cpu_adam" - PREBUILT_IMPORT_PATH = "colossalai._C.cpu_adam" - - def __init__(self): - super().__init__(name=X86CPUAdamBuilder.NAME, prebuilt_import_path=X86CPUAdamBuilder.PREBUILT_IMPORT_PATH) - self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] - - # necessary 4 functions - def sources_files(self): - ret = [ - self.csrc_abs_path("cpu_adam.cpp"), - ] - return ret - - def include_dirs(self): - return [self.csrc_abs_path("includes"), self.get_cuda_home_include()] - - def cxx_flags(self): - extra_cxx_flags = [ - "-std=c++14", - "-std=c++17", - "-lcudart", - "-lcublas", - "-g", - "-Wno-reorder", - "-fopenmp", - "-march=native", - ] - return ["-O3"] + self.version_dependent_macros + extra_cxx_flags - - def nvcc_flags(self): - extra_cuda_flags = [ - "-std=c++14", - "-std=c++17", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_HALF2_OPERATORS__", - "-DTHRUST_IGNORE_CUB_VERSION_CHECK", - ] - ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags - return append_nvcc_threads(ret) diff --git a/colossalai/kernel/extensions/extension_builder.py b/colossalai/kernel/extensions/extension_builder.py deleted file mode 100644 index 5849fcfa6..000000000 --- a/colossalai/kernel/extensions/extension_builder.py +++ /dev/null @@ -1,243 +0,0 @@ -# This code has been adapted from the DeepSpeed library. -# Copyright (c) Microsoft Corporation. - -# Licensed under the MIT License. -import importlib -import os -import time -from abc import ABC, abstractmethod -from pathlib import Path -from typing import List, Optional, Union - -from .utils import check_cuda_availability, check_system_pytorch_cuda_match, print_rank_0 - - -class ExtensionBuilder(ABC): - """ - Builder is the base class to build extensions for PyTorch. - - Args: - name (str): the name of the kernel to be built - prebuilt_import_path (str): the path where the extension is installed during pip install - """ - - ext_type: str = "cuda" - - def __init__(self, name: str, prebuilt_import_path: str): - self.name = name - self.prebuilt_import_path = prebuilt_import_path - self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] - - # we store the op as an attribute to avoid repeated building and loading - self.cached_op_module = None - - assert prebuilt_import_path.startswith( - "colossalai._C" - ), f"The prebuilt_import_path should start with colossalai._C, but got {self.prebuilt_import_path}" - - def relative_to_abs_path(self, code_path: str) -> str: - """ - This function takes in a path relative to the colossalai root directory and return the absolute path. - """ - op_builder_module_path = Path(__file__).parent - - # if we install from source - # the current file path will be op_builder/builder.py - # if we install via pip install colossalai - # the current file path will be colossalai/kernel/op_builder/builder.py - # this is because that the op_builder inside colossalai is a symlink - # this symlink will be replaced with actual files if we install via pypi - # thus we cannot tell the colossalai root directory by checking whether the op_builder - # is a symlink, we can only tell whether it is inside or outside colossalai - if str(op_builder_module_path).endswith("colossalai/kernel/op_builder"): - root_path = op_builder_module_path.parent.parent - elif str(op_builder_module_path).endswith("colossalai/kernel/extensions"): - root_path = op_builder_module_path.parent.parent - else: - root_path = op_builder_module_path.parent.joinpath("colossalai") - - code_abs_path = root_path.joinpath(code_path) - return str(code_abs_path) - - def get_cuda_home_include(self): - """ - return include path inside the cuda home. - """ - from torch.utils.cpp_extension import CUDA_HOME - - if CUDA_HOME is None: - raise RuntimeError("CUDA_HOME is None, please set CUDA_HOME to compile C++/CUDA kernels in ColossalAI.") - cuda_include = os.path.join(CUDA_HOME, "include") - return cuda_include - - def csrc_abs_path(self, path): - return os.path.join(self.relative_to_abs_path("kernel/cuda_native/csrc"), path) - - # functions must be overrided begin - @abstractmethod - def sources_files(self) -> List[str]: - """ - This function should return a list of source files for extensions. - """ - raise NotImplementedError - - @abstractmethod - def include_dirs(self) -> List[str]: - """ - This function should return a list of include files for extensions. - """ - - @abstractmethod - def cxx_flags(self) -> List[str]: - """ - This function should return a list of cxx compilation flags for extensions. - """ - - @abstractmethod - def nvcc_flags(self) -> List[str]: - """ - This function should return a list of nvcc compilation flags for extensions. - """ - - # functions must be overrided over - def strip_empty_entries(self, args): - """ - Drop any empty strings from the list of compile and link flags - """ - return [x for x in args if len(x) > 0] - - def import_op(self): - """ - This function will import the op module by its string name. - """ - return importlib.import_module(self.prebuilt_import_path) - - def check_runtime_build_environment(self): - """ - Check whether the system environment is ready for extension compilation. - """ - try: - from torch.utils.cpp_extension import CUDA_HOME - - TORCH_AVAILABLE = True - except ImportError: - TORCH_AVAILABLE = False - CUDA_HOME = None - - if not TORCH_AVAILABLE: - raise ModuleNotFoundError( - "PyTorch is not found. You need to install PyTorch first in order to build CUDA extensions" - ) - - if CUDA_HOME is None: - raise RuntimeError( - "CUDA_HOME is not found. You need to export CUDA_HOME environment variable or install CUDA Toolkit first in order to build CUDA extensions" - ) - - # make sure CUDA is available for compilation during - cuda_available = check_cuda_availability() - if not cuda_available: - raise RuntimeError("CUDA is not available on your system as torch.cuda.is_available() returns False.") - - # make sure system CUDA and pytorch CUDA match, an error will raised inside the function if not - check_system_pytorch_cuda_match(CUDA_HOME) - - def build(self, verbose: Optional[bool] = None): - """ - If the kernel is not built during pip install, it will build the kernel. - If the kernel is built during runtime, it will be stored in `~/.cache/colossalai/torch_extensions/`. If the - kernel is built during pip install, it can be accessed through `colossalai._C`. - - Warning: do not load this kernel repeatedly during model execution as it could slow down the training process. - - Args: - verbose (bool, optional): show detailed info. Defaults to True. - """ - if verbose is None: - verbose = os.environ.get("CAI_KERNEL_VERBOSE", "0") == "1" - try: - # if the kernel has been pre-built during installation - # we just directly import it - op_module = self.import_op() - if verbose: - print_rank_0( - f"[extension] OP {self.prebuilt_import_path} has been compiled ahead of time, skip building." - ) - except ImportError: - # check environment - if self.ext_type == "cuda": - self.check_runtime_build_environment() - - # time the kernel compilation - start_build = time.time() - - # construct the build directory - import torch - from torch.utils.cpp_extension import load - - torch_version_major = torch.__version__.split(".")[0] - torch_version_minor = torch.__version__.split(".")[1] - torch_cuda_version = torch.version.cuda - home_directory = os.path.expanduser("~") - extension_directory = f".cache/colossalai/torch_extensions/torch{torch_version_major}.{torch_version_minor}_cu{torch_cuda_version}" - build_directory = os.path.join(home_directory, extension_directory) - Path(build_directory).mkdir(parents=True, exist_ok=True) - - if verbose: - print_rank_0(f"[extension] Compiling or loading the JIT-built {self.name} kernel during runtime now") - - # load the kernel - op_module = load( - name=self.name, - sources=self.strip_empty_entries(self.sources_files()), - extra_include_paths=self.strip_empty_entries(self.include_dirs()), - extra_cflags=self.cxx_flags(), - extra_cuda_cflags=self.nvcc_flags(), - extra_ldflags=[], - build_directory=build_directory, - verbose=verbose, - ) - - build_duration = time.time() - start_build - - # log jit compilation time - if verbose: - print_rank_0(f"[extension] Time to compile or load {self.name} op: {build_duration} seconds") - - # cache the built/loaded kernel - self.cached_op_module = op_module - - def load(self, verbose: Optional[bool] = None): - """ - load the kernel during runtime. - - Args: - verbose (bool, optional): show detailed info. Defaults to True. - """ - # if the kernel has be compiled and cached, we directly use it - assert self.cached_op_module is not None, "Please build the kernel first before loading it." - return self.cached_op_module - - def builder(self) -> Union["CUDAExtension", "CppExtension"]: - """ - get a CUDAExtension instance used for setup.py - """ - from torch.utils.cpp_extension import CppExtension, CUDAExtension - - if self.ext_type == "cpp": - return CppExtension( - name=self.prebuilt_import_path, - sources=self.strip_empty_entries(self.sources_files()), - include_dirs=self.strip_empty_entries(self.include_dirs()), - extra_compile_args=self.strip_empty_entries(self.cxx_flags()), - ) - - return CUDAExtension( - name=self.prebuilt_import_path, - sources=self.strip_empty_entries(self.sources_files()), - include_dirs=self.strip_empty_entries(self.include_dirs()), - extra_compile_args={ - "cxx": self.strip_empty_entries(self.cxx_flags()), - "nvcc": self.strip_empty_entries(self.nvcc_flags()), - }, - ) diff --git a/colossalai/kernel/extensions/flash_attention/__init__.py b/colossalai/kernel/extensions/flash_attention/__init__.py deleted file mode 100644 index 79c6935d2..000000000 --- a/colossalai/kernel/extensions/flash_attention/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -from .cuda_flash_attn_2_extension import HAS_FLASH_ATTN, CudaFlashAttnExtension -from .cuda_memory_efficient_attn_extension import HAS_MEM_EFF_ATTN, CudaMemoryEfficentAttnExtension -from .npu_sdpa_attn_extension import NpuSdpaAttnExtension -from .npu_triangle_attn_extension import HAS_NPU_TRIANGLE_ATTENTION, NpuTriangleAttnExtension -from .utils import AttnMaskType, Repad, SeqLenInfo, Unpad - -__all__ = [ - "CudaFlashAttnExtension", - "CudaMemoryEfficentAttnExtension", - "NpuSdpaAttnExtension", - "NpuTriangleAttnExtension", - "HAS_FLASH_ATTN", - "HAS_MEM_EFF_ATTN", - "HAS_NPU_TRIANGLE_ATTENTION", - "Unpad", - "AttnMaskType", - "Repad", - "SeqLenInfo", -] diff --git a/colossalai/kernel/extensions/flash_attention/cuda_flash_attn_2_extension.py b/colossalai/kernel/extensions/flash_attention/cuda_flash_attn_2_extension.py deleted file mode 100644 index 99c353606..000000000 --- a/colossalai/kernel/extensions/flash_attention/cuda_flash_attn_2_extension.py +++ /dev/null @@ -1,100 +0,0 @@ -from typing import Optional - -import torch - -from ..base_extension import BaseExtension -from ..utils import print_rank_0 -from .utils import SeqLenInfo - - -def is_ampere_or_better_gpu(): - # Check Ampere GPUs or newer - if torch.cuda.is_available(): - device = torch.device("cuda") - properties = torch.cuda.get_device_properties(device) - if properties.major >= 8: # Ampere GPUs or newer - return True - return False - - -HAS_FLASH_ATTN = False -ERROR_MSG = None -if is_ampere_or_better_gpu(): - try: - from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func - - HAS_FLASH_ATTN = True - except ImportError: - ERROR_MSG = "ImportError: please install flash_attn from https://github.com/HazyResearch/flash-attention" -else: - ERROR_MSG = "ImportError: FlashAttention only supports Ampere GPUs or newer." - - -if HAS_FLASH_ATTN: - - def flash_attention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - seq_len_info_q: SeqLenInfo, - seq_len_info_kv: SeqLenInfo, - origin_attn_mask: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, - dropout_p: float = 0.0, - scale: float = None, - causal: bool = False, - padded: bool = False, - ): - """ - Arguments: - q: (batch, q_seqlen, nheads, headdim) - k: (batch, kv_seqlen, nheads, headdim) - v: (batch, kv_seqlen, nheads, headdim) - batch_size: int. - seq_len: int. - dropout_p: float. Dropout probability. - sm_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - Return: - attn_out: (batch, q_seqlen, nheads, headdim). - """ - if padded: - if seq_len_info_kv == None: - seq_len_info_kv = seq_len_info_q - - attn_out = flash_attn_varlen_func( - q, - k, - v, - seq_len_info_q.cu_seqlens, - seq_len_info_kv.cu_seqlens, - seq_len_info_q.max_seqlen, - seq_len_info_kv.max_seqlen, - dropout_p, - scale, - causal, - ) - else: - attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal) - return attn_out - - -class CudaFlashAttnExtension(BaseExtension): - def __init__(self) -> None: - super().__init__() - - @property - def requires_build(self): - return False - - def build(self): - pass - - def is_available(self): - if HAS_FLASH_ATTN == False: - print_rank_0(ERROR_MSG) - return HAS_FLASH_ATTN - - def load(self): - return flash_attention diff --git a/colossalai/kernel/extensions/flash_attention/cuda_memory_efficient_attn_extension.py b/colossalai/kernel/extensions/flash_attention/cuda_memory_efficient_attn_extension.py deleted file mode 100644 index 4954ab5b1..000000000 --- a/colossalai/kernel/extensions/flash_attention/cuda_memory_efficient_attn_extension.py +++ /dev/null @@ -1,91 +0,0 @@ -from typing import Optional - -import torch - -from ..base_extension import BaseExtension -from ..utils import print_rank_0 -from .utils import SeqLenInfo - -HAS_MEM_EFF_ATTN = False -try: - from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention - from xformers.ops.fmha.attn_bias import ( - BlockDiagonalCausalMask, - BlockDiagonalMask, - LowerTriangularMask, - LowerTriangularMaskWithTensorBias, - ) - - HAS_MEM_EFF_ATTN = True -except ImportError: - pass - -if HAS_MEM_EFF_ATTN: - """ - A general attention module using the flash attention kernels from xformers: - https://github.com/facebookresearch/xformers/tree/main/xformers/ops/fmha - """ - - allow_alibi = True - for op in MemoryEfficientAttentionCutlassOp: - allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES) - - def mem_eff_attention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - seq_len_info_q: SeqLenInfo, - seq_len_info_kv: SeqLenInfo, - origin_attn_mask: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, - dropout_p: float = 0.0, - scale: float = None, - causal: bool = False, - padded: bool = False, - ): - attn_bias = None - if padded: # bert style - if not causal: - attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) - else: - attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) - elif causal: # gpt style - attn_bias = LowerTriangularMask() - - if bias is not None: # alibi / relative position embedding - assert allow_alibi, "flash attention with bias is not supported in this system." - assert causal, "attention with bias is only supported for causal attention so far." - attn_bias = attn_bias.add_bias(bias) - - if padded: - q = q.unsqueeze(0) - k = k.unsqueeze(0) - v = v.unsqueeze(0) - - out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p, scale=scale) - - # shape: (b*s, n, d) - if padded: - out = out.squeeze(0) - - return out - - -class CudaMemoryEfficentAttnExtension(BaseExtension): - def __init__(self) -> None: - super().__init__() - - @property - def requires_build(self) -> bool: - return False - - def build(self): - pass - - def is_available(self): - if HAS_MEM_EFF_ATTN == False: - print_rank_0("ImportError: please install xformers from https://github.com/facebookresearch/xformers") - return HAS_MEM_EFF_ATTN - - def load(self): - return mem_eff_attention diff --git a/colossalai/kernel/extensions/flash_attention/npu_sdpa_attn_extension.py b/colossalai/kernel/extensions/flash_attention/npu_sdpa_attn_extension.py deleted file mode 100644 index 7dc9d9b9b..000000000 --- a/colossalai/kernel/extensions/flash_attention/npu_sdpa_attn_extension.py +++ /dev/null @@ -1,60 +0,0 @@ -import torch -from einops import rearrange - -from ..base_extension import BaseExtension - - -def npu_sdpa_attention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - seq_len_info_q=None, - seq_len_info_kv=None, - origin_attn_mask: torch.Tensor = None, - dropout_p: float = 0.0, - scale: float = 1.0, - causal=None, - padded=None, -): - """ - The scaled dot product attention. - - Arguments: - q: (batch, q_seqlen, nheads, headdim) - k: (batch, kv_seqlen, nheads, headdim) - v: (batch, kv_seqlen, nheads, headdim) - batch_size: int. - seq_len: int. - dropout_p: float. Dropout probability. - scale: float. The scaling of QK^T before applying softmax. - Default to 1. - Return: - attn_out: (batch, q_seqlen, nheads, headdim). - """ - q, k, v = [rearrange(x, "b s h d -> b h s d").contiguous() for x in (q, k, v)] - output = torch.nn.functional.scaled_dot_product_attention( - q, - k, - v, - attn_mask=origin_attn_mask, - dropout_p=dropout_p, - is_causal=origin_attn_mask is None, - scale=scale, - ) - output = rearrange(output, "b h s d -> b s (h d)") - return output - - -class NpuSdpaAttnExtension(BaseExtension): - def __init__(self) -> None: - super().__init__() - - @property - def requires_build(self) -> bool: - return False - - def build(self): - pass - - def load(self): - return npu_sdpa_attention diff --git a/colossalai/kernel/extensions/flash_attention/npu_triangle_attn_extension.py b/colossalai/kernel/extensions/flash_attention/npu_triangle_attn_extension.py deleted file mode 100644 index a760f56a1..000000000 --- a/colossalai/kernel/extensions/flash_attention/npu_triangle_attn_extension.py +++ /dev/null @@ -1,141 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2023, HUAWEI CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch -from einops import rearrange - -from ..base_extension import BaseExtension -from ..utils import print_rank_0 - -HAS_NPU_TRIANGLE_ATTENTION = False -try: - from torch_npu import npu_confusion_transpose, npu_scaled_masked_softmax - - HAS_NPU_TRIANGLE_ATTENTION = True -except ImportError: - pass - - -if HAS_NPU_TRIANGLE_ATTENTION: - - def npu_triangle_attention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - seq_len_info_q=None, - seq_len_info_kv=None, - origin_attn_mask: torch.Tensor = None, - dropout_p: float = 0.0, - scale: float = 1.0, - causal=None, - padded=None, - block_size=512, - ): - """ - The triangle attention reduces the attention calculation of the mask - part by dividing the q, k, and v matrices into blocks - - Arguments: - block_size: The size of the inverted triangle block, the default is 512, - the smaller the block_size, the more calculations will be reduced, - but the number of small operators will be increased - masked_softmax_func: mask function to be applied. - dropout_func: dropout function to be applied. - """ - - def compute_attn(q_layer, k_layer, v_layer, mask_tmp): - # [b, hn, q_size, hd] * [b, hn, hd, kv_size] -> [b, hn, q_size, kv_size] - cur_sim = torch.matmul(q_layer, k_layer) - attention_probs = npu_scaled_masked_softmax(cur_sim, mask_tmp) - # attention dropout - if dropout_p > 0: - attention_probs = torch.nn.functional.dropout( - attention_probs, p=dropout_p, training=attention_probs.require_grad - ) - # [b, hn, q_size, kv_size] * [b, hn, kv_size, hd] -> [b, hn, q_size, hd] - context_layer_tmp = torch.matmul(attention_probs, v_layer) - return context_layer_tmp - - q, k, v = [rearrange(x, "b s h d -> b h s d") for x in (q, k, v)] - origin_attn_mask = origin_attn_mask.to(torch.bool) - # input shape: [b, hn, sq, hd] - bsz, head_num, sequence_len, head_dim = k.shape - sparse_groups = sequence_len // block_size - # Determine whether blocks size can be divided by sequence_length - divisible_flag = sequence_len == block_size * sparse_groups - k = k.transpose(2, 3).contiguous() - if divisible_flag: - q_tmp_layers = torch.chunk(q, sparse_groups, 2) - k_tmp_layers = torch.chunk(k, sparse_groups, 3) - v_tmp_layers = torch.chunk(v, sparse_groups, 2) - else: - seq_tmp = block_size * sparse_groups - q_last = q[:, :, seq_tmp:, :].contiguous() - mask_last = origin_attn_mask[:, :, seq_tmp:, :].contiguous() - q_tmp_layers = torch.chunk(q[:, :, :seq_tmp, :], sparse_groups, 2) - k_tmp_layers = torch.chunk(k[:, :, :, :seq_tmp], sparse_groups, 3) - v_tmp_layers = torch.chunk(v[:, :, :seq_tmp, :], sparse_groups, 2) - context_list_tmp, k_tmp, v_tmp = [], (), () - for i in range(sparse_groups): - # compute slice shape of q k v for each loop - q_begin, q_end = i * block_size, (i + 1) * block_size - kv_begin, kv_end = 0, (i + 1) * block_size - q_tmp = q_tmp_layers[i] - # slice k and v - if i == 0: - k_tmp = k_tmp_layers[i].contiguous() - v_tmp = v_tmp_layers[i].contiguous() - else: - k_tmp = torch.cat((k_tmp, k_tmp_layers[i]), -1).contiguous() - v_tmp = torch.cat((v_tmp, v_tmp_layers[i]), -2).contiguous() - - mask_tmp = origin_attn_mask[:, :, q_begin:q_end, kv_begin:kv_end].contiguous() - context_layer_tmp = compute_attn(q_tmp, k_tmp, v_tmp, mask_tmp) - context_list_tmp.append(context_layer_tmp) - - if not divisible_flag: - # circumstances that cannot be divisible - context_layer_tmp = compute_attn(q_last, k, v, mask_last) - context_list_tmp.append(context_layer_tmp) - context_layer = torch.cat(context_list_tmp, 2) - new_context_layer_shape = (bsz, sequence_len, head_num * head_dim) - context_layer = npu_confusion_transpose(context_layer, [0, 2, 1, 3], [*new_context_layer_shape], True) - # ========================= - # Context layer. [b, sq, hp] - # ========================= - return context_layer - - -class NpuTriangleAttnExtension(BaseExtension): - def __init__(self) -> None: - super().__init__() - - @property - def requires_build(self) -> bool: - return False - - def build(self): - pass - - def is_available(self): - if HAS_NPU_TRIANGLE_ATTENTION == False: - print_rank_0( - "ImportError: please install latest torch_npu with 'npu_confusion_transpose' and 'npu_scaled_masked_softmax' api." - ) - return HAS_NPU_TRIANGLE_ATTENTION - - def load(self): - return npu_triangle_attention diff --git a/colossalai/kernel/extensions/flash_attention/utils.py b/colossalai/kernel/extensions/flash_attention/utils.py deleted file mode 100644 index 06fef491f..000000000 --- a/colossalai/kernel/extensions/flash_attention/utils.py +++ /dev/null @@ -1,91 +0,0 @@ -import enum -from dataclasses import dataclass -from typing import Iterable, Tuple - -import torch -import torch.nn.functional as F -from einops import rearrange - -from colossalai.accelerator import get_accelerator - - -class Unpad(torch.autograd.Function): - """ - Adapted from - https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py - """ - - @staticmethod - def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor): - ctx.save_for_backward(indices) - # [b, s, ...] - assert tensor.ndim >= 3 - ctx.bsz = tensor.shape[0] - out = rearrange(tensor, "b s ... -> (b s) ...") - ctx.shape = out.shape - # [ntokens, ...] - return out[indices] - - @staticmethod - def backward(ctx, grad_output): - (indices,) = ctx.saved_tensors - # [ntokens, ...] - grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device) - grad[indices] = grad_output - grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz) - # [b, s, ...] - return grad, None - - -class Repad(torch.autograd.Function): - """ - Adapted from - https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py - """ - - @staticmethod - def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int): - ctx.save_for_backward(indices) - # [ntokens, ...] - tensor = tensor - out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device) - # [b*s, ...] - out[indices] = tensor - return out - - @staticmethod - def backward(ctx, grad_output): - (indices,) = ctx.saved_tensors - # [b*s, ...] - grad = grad_output[indices] - # [ntokens, ...] - return grad, None, None, None - - -@dataclass -class SeqLenInfo: - seqlens: Iterable[int] = None - indices: torch.Tensor = None - max_seqlen: int = None - cu_seqlens: torch.Tensor = None - - @staticmethod - def materialize( - attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_accelerator().get_current_device() - ): - if attn_mask is not None: - indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device) - seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten() - else: - batch_size, tgt_len = size[0], size[1] - indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device) - seqlens = torch.LongTensor([tgt_len] * batch_size, device=device) - max_seqlen = max(seqlens) - cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device) - return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens) - - -class AttnMaskType(enum.Enum): - padding = 1 - causal = 2 - paddedcausal = 3 diff --git a/colossalai/kernel/kernel_loader.py b/colossalai/kernel/kernel_loader.py new file mode 100644 index 000000000..148c3e3fc --- /dev/null +++ b/colossalai/kernel/kernel_loader.py @@ -0,0 +1,109 @@ +import warnings +from typing import List + +from .extensions import ( + CpuAdamArmExtension, + CpuAdamX86Extension, + FlashAttentionDaoCudaExtension, + FlashAttentionNpuExtension, + FlashAttentionXformersCudaExtension, + FusedOptimizerCudaExtension, + LayerNormCudaExtension, + MoeCudaExtension, + ScaledMaskedSoftmaxCudaExtension, + ScaledUpperTriangleMaskedSoftmaxCudaExtension, +) +from .extensions.base_extension import _Extension + +__all__ = [ + "KernelLoader", + "CPUAdamLoader", + "LayerNormLoader", + "MoeLoader", + "FusedOptimizerLoader", + "ScaledMaskedSoftmaxLoader", + "ScaledUpperTriangleMaskedSoftmaxLoader", +] + + +class KernelLoader: + """ + An abstract class which offers encapsulation to the kernel loading process. + + Usage: + kernel_loader = KernelLoader() + kernel = kernel_loader.load() + """ + + REGISTRY: List[_Extension] = [] + + @classmethod + def register_extension(cls, extension: _Extension): + """ + This classmethod is an extension point which allows users to register their customized + kernel implementations to the loader. + + Args: + extension (_Extension): the extension to be registered. + """ + cls.REGISTRY.append(extension) + + def load(self, ext_name: str = None): + """ + Load the kernel according to the current machine. + + Args: + ext_name (str): the name of the extension to be loaded. If not specified, the loader + will try to look for an kernel available on the current machine. + """ + exts = [ext_cls() for ext_cls in self.__class__.REGISTRY] + + # look for exts which can be built/loaded on the current machine + + if ext_name: + usable_exts = list(filter(lambda ext: ext.name == ext_name, exts)) + else: + usable_exts = [] + for ext in exts: + if ext.is_hardware_available(): + # make sure the machine is compatible during kernel loading + ext.assert_hardware_compatible() + usable_exts.append(ext) + + assert len(usable_exts) != 0, f"No usable kernel found for {self.__class__.__name__} on the current machine." + + if len(usable_exts) > 1: + # if more than one usable kernel is found, we will try to load the kernel with the highest priority + usable_exts = sorted(usable_exts, key=lambda ext: ext.priority, reverse=True) + warnings.warn( + f"More than one kernel is available, loading the kernel with the highest priority - {usable_exts[0].__class__.__name__}" + ) + return usable_exts[0].load() + + +class CPUAdamLoader(KernelLoader): + REGISTRY = [CpuAdamX86Extension, CpuAdamArmExtension] + + +class LayerNormLoader(KernelLoader): + REGISTRY = [LayerNormCudaExtension] + + +class MoeLoader(KernelLoader): + REGISTRY = [MoeCudaExtension] + + +class FusedOptimizerLoader(KernelLoader): + REGISTRY = [FusedOptimizerCudaExtension] + + +class ScaledMaskedSoftmaxLoader(KernelLoader): + REGISTRY = [ScaledMaskedSoftmaxCudaExtension] + + +class ScaledUpperTriangleMaskedSoftmaxLoader(KernelLoader): + REGISTRY = [ScaledUpperTriangleMaskedSoftmaxCudaExtension] + + +class FlashAttentionLoader(KernelLoader): + REGISTRY = [FlashAttentionNpuExtension, FlashAttentionDaoCudaExtension, FlashAttentionXformersCudaExtension] diff --git a/colossalai/kernel/op_builder b/colossalai/kernel/op_builder deleted file mode 120000 index db4f9c335..000000000 --- a/colossalai/kernel/op_builder +++ /dev/null @@ -1 +0,0 @@ -../../op_builder \ No newline at end of file diff --git a/colossalai/legacy/amp/naive_amp/_fp16_optimizer.py b/colossalai/legacy/amp/naive_amp/_fp16_optimizer.py index 97ec57fbd..d2dceb50b 100644 --- a/colossalai/legacy/amp/naive_amp/_fp16_optimizer.py +++ b/colossalai/legacy/amp/naive_amp/_fp16_optimizer.py @@ -7,7 +7,7 @@ from torch.distributed import ProcessGroup from torch.optim import Optimizer from colossalai.amp.naive_amp.grad_scaler import BaseGradScaler -from colossalai.kernel.op_builder import FusedOptimBuilder +from colossalai.kernel.kernel_loader import FusedOptimizerLoader from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.utils import clip_grad_norm_fp32, copy_tensor_parallel_attributes @@ -28,7 +28,7 @@ def load_fused_optim(): global fused_optim if fused_optim is None: - fused_optim = FusedOptimBuilder().load() + fused_optim = FusedOptimizerLoader().load() def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None): diff --git a/colossalai/legacy/nn/layer/parallel_1d/layers.py b/colossalai/legacy/nn/layer/parallel_1d/layers.py index 36cb09d32..b38e1c433 100644 --- a/colossalai/legacy/nn/layer/parallel_1d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_1d/layers.py @@ -11,7 +11,6 @@ from torch import Tensor from torch.nn.parameter import Parameter from colossalai.accelerator import get_accelerator -from colossalai.kernel import LayerNorm from colossalai.legacy.communication import broadcast from colossalai.legacy.context import ParallelMode, seed from colossalai.legacy.context.parallel_context import global_context as gpc @@ -23,6 +22,7 @@ from colossalai.legacy.utils.checkpointing import ( partition_tensor_parallel_state_dict, ) from colossalai.nn import init as init +from colossalai.nn.layer.layernorm import MixedFusedLayerNorm as LayerNorm from ..base_layer import ParallelLayer from ..colossalai_layer._utils import ColossalaiModule diff --git a/colossalai/legacy/nn/layer/parallel_sequence/layers.py b/colossalai/legacy/nn/layer/parallel_sequence/layers.py index 063b0cd8e..445b7e4cd 100644 --- a/colossalai/legacy/nn/layer/parallel_sequence/layers.py +++ b/colossalai/legacy/nn/layer/parallel_sequence/layers.py @@ -8,13 +8,12 @@ import torch.nn as nn import torch.nn.functional as F from torch.nn import Parameter -from colossalai.kernel import FusedScaleMaskSoftmax -from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType from colossalai.legacy.context import seed from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.parallel_sequence._operation import RingAV, RingQK from colossalai.legacy.registry import LAYERS +from colossalai.nn.layer.scaled_softmax import AttnMaskType, FusedScaleMaskSoftmax @LAYERS.register_module diff --git a/colossalai/legacy/utils/common.py b/colossalai/legacy/utils/common.py index 671bcc3d6..76ec08e96 100644 --- a/colossalai/legacy/utils/common.py +++ b/colossalai/legacy/utils/common.py @@ -96,9 +96,9 @@ def _calc_l2_norm(grads): global fused_optim if fused_optim is None: - from colossalai.kernel.op_builder import FusedOptimBuilder + from colossalai.kernel.kernel_loader import FusedOptimizerLoader - fused_optim = FusedOptimBuilder().load() + fused_optim = FusedOptimizerLoader().load() norm = 0.0 if len(grads) > 0: diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py index c71e6c1f4..34342436f 100644 --- a/colossalai/moe/_operation.py +++ b/colossalai/moe/_operation.py @@ -11,9 +11,9 @@ MOE_KERNEL = None def load_moe(): global MOE_KERNEL - from colossalai.kernel.op_builder import MOEBuilder + from colossalai.kernel.kernel_loader import MoeLoader - MOE_KERNEL = MOEBuilder().load() + MOE_KERNEL = MoeLoader().load() class AllGather(torch.autograd.Function): @@ -145,14 +145,8 @@ class AllToAll(torch.autograd.Function): class HierarchicalAllToAll(torch.autograd.Function): - @staticmethod - def forward( - ctx: Any, - inputs: Tensor, - groups: Tuple[ProcessGroup, ProcessGroup], - src_rank: int - ) -> Tensor: + def forward(ctx: Any, inputs: Tensor, groups: Tuple[ProcessGroup, ProcessGroup], src_rank: int) -> Tensor: """ Returns: outputs: Tensor @@ -276,8 +270,9 @@ class MoeCombine(torch.autograd.Function): if tokens_grad.dtype != torch.float32: tokens_grad = tokens_grad.to(torch.float32) - d_expert, d_logits = MOE_KERNEL.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, tokens_grad, expert_tokens, logits, - mask, dest_idx) + d_expert, d_logits = MOE_KERNEL.combine_backward( + ctx.s, ctx.e, ctx.c, ctx.h, tokens_grad, expert_tokens, logits, mask, dest_idx + ) if d_expert.dtype != ctx.dtype: d_expert = d_expert.to(ctx.dtype) diff --git a/colossalai/kernel/flash_attention_loader.py b/colossalai/nn/layer/colo_attention.py similarity index 58% rename from colossalai/kernel/flash_attention_loader.py rename to colossalai/nn/layer/colo_attention.py index 3d0cd3975..0b7011e8e 100644 --- a/colossalai/kernel/flash_attention_loader.py +++ b/colossalai/nn/layer/colo_attention.py @@ -1,75 +1,97 @@ +import enum import math -from collections import OrderedDict -from typing import Optional +import warnings +from dataclasses import dataclass +from typing import Iterable, Optional, Tuple import torch +import torch.nn.functional as F from einops import rearrange from colossalai.accelerator import get_accelerator - -from .base_kernel_loader import BaseKernelLoader -from .extensions.flash_attention import ( - AttnMaskType, - CudaFlashAttnExtension, - CudaMemoryEfficentAttnExtension, - NpuSdpaAttnExtension, - NpuTriangleAttnExtension, - Repad, - SeqLenInfo, - Unpad, -) -from .extensions.utils import print_rank_0 +from colossalai.kernel.kernel_loader import FlashAttentionLoader -class FlashAttentionLoader(BaseKernelLoader): +@dataclass +class SeqLenInfo: + seqlens: Iterable[int] = None + indices: torch.Tensor = None + max_seqlen: int = None + cu_seqlens: torch.Tensor = None + + @staticmethod + def materialize( + attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_accelerator().get_current_device() + ): + if attn_mask is not None: + indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device) + seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten() + else: + batch_size, tgt_len = size[0], size[1] + indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device) + seqlens = torch.LongTensor([tgt_len] * batch_size, device=device) + max_seqlen = max(seqlens) + cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device) + return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens) + + +class AttnMaskType(enum.Enum): + padding = 1 + causal = 2 + paddedcausal = 3 + + +class Unpad(torch.autograd.Function): """ - FlashAttention Loader - - options: cuda flashh attention, cuda memory effcient attention, npu sdpa attention, npu triangle attention - - Args: - q: (batch, q_seqlen, nheads, headdim) - k: (batch, kv_seqlen, nheads, headdim) - v: (batch, kv_seqlen, nheads, headdim) - batch_size: int. - seq_len: int. - dropout_p: float. Dropout probability. - sm_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - Return: - attn_out: (batch, q_seqlen, nheads, headdim). + Adapted from + https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py """ - def __init__(self): - super().__init__( - # extension name must start with the accelerator name. E.g. npu_xxx, cuda_xxx - extension_map=OrderedDict( - cuda_flash_attn=CudaFlashAttnExtension, - cuda_memory_efficent_attn=CudaMemoryEfficentAttnExtension, - npu_sdpa_attn=NpuSdpaAttnExtension, - npu_triangle_attn=NpuTriangleAttnExtension, - ), - supported_device=["cuda", "npu"], - ) + @staticmethod + def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor): + ctx.save_for_backward(indices) + # [b, s, ...] + assert tensor.ndim >= 3 + ctx.bsz = tensor.shape[0] + out = rearrange(tensor, "b s ... -> (b s) ...") + ctx.shape = out.shape + # [ntokens, ...] + return out[indices] - def fetch_kernel(self, backend: str = None): - if backend is not None: - if not self._extension_map[backend]().is_available(): - raise Exception(f"{backend} is not available for flash attention.") - return self._extension_map[backend]().fetch() + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + # [ntokens, ...] + grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device) + grad[indices] = grad_output + grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz) + # [b, s, ...] + return grad, None - kernel = None - accelerator_name = get_accelerator().name - assert accelerator_name in self._supported_device, f"{accelerator_name} is not supported for flash attention." - for extension_name, extension in self._extension_map.items(): - if extension_name.startswith(accelerator_name): - if extension().is_available(): - kernel = extension().fetch() - break - if kernel is None: - raise Exception("No extension for flash attention is supported") - return kernel + +class Repad(torch.autograd.Function): + """ + Adapted from + https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py + """ + + @staticmethod + def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int): + ctx.save_for_backward(indices) + # [ntokens, ...] + tensor = tensor + out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device) + # [b*s, ...] + out[indices] = tensor + return out + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + # [b*s, ...] + grad = grad_output[indices] + # [ntokens, ...] + return grad, None, None, None class ColoAttention(torch.nn.Module): @@ -84,7 +106,7 @@ class ColoAttention(torch.nn.Module): self.scale = 1 / math.sqrt(embed_dim // num_heads) self.dropout = dropout - self.attn = FlashAttentionLoader().fetch_kernel() + self.attn = FlashAttentionLoader().load() @staticmethod def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: @@ -120,8 +142,10 @@ class ColoAttention(torch.nn.Module): if self.attn.__name__ == "flash_attention" and ( query.dtype not in [torch.float16, torch.bfloat16] or bias != None ): - print_rank_0("flash attention is not applicable, switch to memory effcient attention") - self.attn = FlashAttentionLoader().fetch_kernel(backend="cuda_memory_efficent_attn") + warnings.warn( + f"flash-attn expects fp16 or bf16 but got {query.dtype}, switching to xformers' implementation." + ) + self.attn = FlashAttentionLoader().load(ext_name="flash_attention_xformers_cuda") padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1 causal = attn_mask_type is not None and attn_mask_type.value > 1 diff --git a/colossalai/kernel/cuda_native/layer_norm.py b/colossalai/nn/layer/layernorm.py similarity index 95% rename from colossalai/kernel/cuda_native/layer_norm.py rename to colossalai/nn/layer/layernorm.py index c7d2a3a45..1db48faee 100644 --- a/colossalai/kernel/cuda_native/layer_norm.py +++ b/colossalai/nn/layer/layernorm.py @@ -9,7 +9,7 @@ from torch.cuda.amp import custom_bwd, custom_fwd from torch.nn import init from torch.nn.parameter import Parameter -from colossalai.kernel.op_builder.layernorm import LayerNormBuilder +from colossalai.kernel.kernel_loader import LayerNormLoader try: from colossalai._C import layer_norm @@ -29,7 +29,7 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): global layer_norm if layer_norm is None: - layer_norm = LayerNormBuilder().load() + layer_norm = LayerNormLoader().load() output, mean, invvar = layer_norm.forward_affine(input_, ctx.normalized_shape, weight_, bias_, ctx.eps) ctx.layernorm_op = layer_norm ctx.save_for_backward(input_, weight_, bias_, mean, invvar) diff --git a/colossalai/nn/layer/scaled_softmax.py b/colossalai/nn/layer/scaled_softmax.py new file mode 100644 index 000000000..a8d72ddd9 --- /dev/null +++ b/colossalai/nn/layer/scaled_softmax.py @@ -0,0 +1,184 @@ +# This code from NVIDIA Megatron: +# with minor changes. + +import enum + +import torch +import torch.nn as nn + +from colossalai.kernel.kernel_loader import ScaledMaskedSoftmaxLoader, ScaledUpperTriangleMaskedSoftmaxLoader + + +class AttnMaskType(enum.Enum): + padding = 1 + causal = 2 + paddedcausal = 3 + + +class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): + """ + Fused operation which performs following three operations in sequence + + 1. Scale the tensor. + 2. Apply upper triangular mask (typically used in gpt models). + 3. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs, scale): + global scaled_upper_triang_masked_softmax + if scaled_upper_triang_masked_softmax: + scaled_upper_triang_masked_softmax = ScaledUpperTriangleMaskedSoftmaxLoader().load() + + scale_t = torch.tensor([scale]) + softmax_results = scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0]) + + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads): + softmax_results, scale_t = ctx.saved_tensors + input_grads = scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0]) + + return input_grads, None + + +class ScaledMaskedSoftmax(torch.autograd.Function): + """ + Fused operation which performs following three operations in sequence + + 1. Scale the tensor. + 2. Apply the mask. + 3. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs, mask, scale): + scale_t = torch.tensor([scale]) + + # build and load kernel if not pre-built + global scaled_masked_softmax + if scaled_masked_softmax is None: + scaled_masked_softmax = ScaledMaskedSoftmaxLoader().load() + + softmax_results = scaled_masked_softmax.forward(inputs, mask, scale_t[0]) + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads): + softmax_results, scale_t = ctx.saved_tensors + + input_grads = scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0]) + return input_grads, None, None, None + + +class FusedScaleMaskSoftmax(nn.Module): + """ + Fused operation: scaling + mask + softmax + + Arguments: + input_in_fp16: Flag to indicate if input in fp16 data format. + input_in_bf16: Flag to indicate if input in bf16 data format. + attn_mask_type: Attention mask type (pad or causal) + scaled_masked_softmax_fusion: Flag to indicate user want to use softmax fusion + mask_func: Mask function to be applied. + softmax_in_fp32: If True, softmax in performed at fp32 precision. + scale: Scaling factor used in input tensor scaling. + """ + + def __init__( + self, + input_in_fp16, + input_in_bf16, + attn_mask_type, + scaled_masked_softmax_fusion, + mask_func, + softmax_in_fp32, + scale, + ): + super(FusedScaleMaskSoftmax, self).__init__() + self.input_in_fp16 = input_in_fp16 + self.input_in_bf16 = input_in_bf16 + assert not ( + self.input_in_fp16 and self.input_in_bf16 + ), "both fp16 and bf16 flags cannot be active at the same time." + self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 + self.attn_mask_type = attn_mask_type + self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion + self.mask_func = mask_func + self.softmax_in_fp32 = softmax_in_fp32 + self.scale = scale + assert self.scale is None or softmax_in_fp32, "softmax should be in fp32 when scaled" + + def forward(self, input, mask): + # [b, np, sq, sk] + assert input.dim() == 4 + + if self.is_kernel_available(mask, *input.size()): + return self.forward_fused_softmax(input, mask) + else: + return self.forward_torch_softmax(input, mask) + + def is_kernel_available(self, mask, b, np, sq, sk): + attn_batches = b * np + + if ( + self.scaled_masked_softmax_fusion # user want to fuse + and self.input_in_float16 # input must be fp16 + and mask is not None # mask tensor must not be None + and 16 < sk <= 2048 # sk must be 16 ~ 2048 + and sq % 4 == 0 # sq must be divisor of 4 + and attn_batches % 4 == 0 # np * b must be divisor of 4 + ): + if 0 <= sk <= 2048: + batch_per_block = self.get_batch_per_block(sq, sk, b, np) + + if self.attn_mask_type.value > 1: + if attn_batches % batch_per_block == 0: + return True + else: + if sq % batch_per_block == 0: + return True + return False + + def forward_fused_softmax(self, input, mask): + b, np, sq, sk = input.size() + scale = self.scale if self.scale is not None else 1.0 + + if self.attn_mask_type.value > 1: + assert sq == sk, "causal mask is only for self attention" + + # input is 3D tensor (attn_batches, sq, sk) + input = input.view(-1, sq, sk) + probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale) + return probs.view(b, np, sq, sk) + else: + # input is 4D tensor (b, np, sq, sk) + return ScaledMaskedSoftmax.apply(input, mask, scale) + + def forward_torch_softmax(self, input, mask): + if self.input_in_float16 and self.softmax_in_fp32: + input = input.float() + + if self.scale is not None: + input = input * self.scale + mask_output = self.mask_func(input, mask) if mask is not None else input + probs = torch.nn.Softmax(dim=-1)(mask_output) + + if self.input_in_float16 and self.softmax_in_fp32: + if self.input_in_fp16: + probs = probs.half() + else: + probs = probs.bfloat16() + + return probs + + def get_batch_per_block(self, sq, sk, b, np): + # build and load kernel if not pre-built + global scaled_masked_softmax + if scaled_masked_softmax is None: + scaled_masked_softmax = ScaledMaskedSoftmaxLoader().load() + + return scaled_masked_softmax.get_batch_per_block(sq, sk, b, np) diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index b2f67cae6..5be629fb2 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -3,7 +3,7 @@ from typing import Optional import torch -from colossalai.kernel import CPUAdamLoader +from colossalai.kernel.kernel_loader import CPUAdamLoader from .nvme_optimizer import NVMeOptimizer diff --git a/colossalai/nn/optimizer/fused_adam.py b/colossalai/nn/optimizer/fused_adam.py index fcdd3257d..aeb5cc91b 100644 --- a/colossalai/nn/optimizer/fused_adam.py +++ b/colossalai/nn/optimizer/fused_adam.py @@ -70,9 +70,9 @@ class FusedAdam(torch.optim.Optimizer): self.adamw_mode = 1 if adamw_mode else 0 self.set_grad_none = set_grad_none if multi_tensor_applier.available: - from colossalai.kernel.op_builder import FusedOptimBuilder + from colossalai.kernel.kernel_loader import FusedOptimizerLoader - fused_optim = FusedOptimBuilder().load() + fused_optim = FusedOptimizerLoader().load() # Skip buffer self._dummy_overflow_buf = torch.cuda.IntTensor([0]) diff --git a/colossalai/nn/optimizer/fused_lamb.py b/colossalai/nn/optimizer/fused_lamb.py index 3e1d5a7ba..da8d1608a 100644 --- a/colossalai/nn/optimizer/fused_lamb.py +++ b/colossalai/nn/optimizer/fused_lamb.py @@ -77,9 +77,9 @@ class FusedLAMB(torch.optim.Optimizer): ) super(FusedLAMB, self).__init__(params, defaults) if multi_tensor_applier.available: - from colossalai.kernel.op_builder import FusedOptimBuilder + from colossalai.kernel.kernel_loader import FusedOptimizerLoader - fused_optim = FusedOptimBuilder().load() + fused_optim = FusedOptimizerLoader().load() self.multi_tensor_l2norm = fused_optim.multi_tensor_l2norm # Skip buffer diff --git a/colossalai/nn/optimizer/fused_sgd.py b/colossalai/nn/optimizer/fused_sgd.py index 95a635420..3fae9bbca 100644 --- a/colossalai/nn/optimizer/fused_sgd.py +++ b/colossalai/nn/optimizer/fused_sgd.py @@ -72,9 +72,9 @@ class FusedSGD(Optimizer): self.wd_after_momentum = wd_after_momentum if multi_tensor_applier.available: - from colossalai.kernel.op_builder import FusedOptimBuilder + from colossalai.kernel.kernel_loader import FusedOptimizerLoader - fused_optim = FusedOptimBuilder().load() + fused_optim = FusedOptimizerLoader().load() # Skip buffer self._dummy_overflow_buf = torch.tensor( diff --git a/colossalai/nn/optimizer/hybrid_adam.py b/colossalai/nn/optimizer/hybrid_adam.py index d34fd601a..c9c1f81bf 100644 --- a/colossalai/nn/optimizer/hybrid_adam.py +++ b/colossalai/nn/optimizer/hybrid_adam.py @@ -2,7 +2,7 @@ from typing import Any, Optional import torch -from colossalai.kernel.op_builder import FusedOptimBuilder +from colossalai.kernel.kernel_loader import FusedOptimizerLoader from colossalai.utils import multi_tensor_applier from .cpu_adam import CPUAdam @@ -85,7 +85,7 @@ class HybridAdam(CPUAdam): nvme_offload_dir, ) if torch.cuda.is_available(): - fused_optim = FusedOptimBuilder().load() + fused_optim = FusedOptimizerLoader().load() self.gpu_adam_op = fused_optim.multi_tensor_adam self._dummy_overflow_buf = torch.cuda.IntTensor([0]) diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 791c5764e..a4ace5e1b 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -10,6 +10,7 @@ from colossalai.accelerator import get_accelerator from colossalai.interface import OptimizerWrapper from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.utils import get_current_device from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device from .base import PipelineSchedule diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index b4cd77fbd..bf2f01b10 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -10,6 +10,7 @@ from colossalai.accelerator import get_accelerator from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.utils import get_current_device from ._utils import ( detach, diff --git a/colossalai/shardformer/modeling/blip2.py b/colossalai/shardformer/modeling/blip2.py index 3522264ad..d5c10541a 100644 --- a/colossalai/shardformer/modeling/blip2.py +++ b/colossalai/shardformer/modeling/blip2.py @@ -62,7 +62,7 @@ def forward_fn(): def get_blip2_flash_attention_forward(): from transformers.models.blip_2.modeling_blip_2 import Blip2Attention - from colossalai.kernel import ColoAttention + from colossalai.nn.layer.colo_attention import ColoAttention def forward( self: Blip2Attention, diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index 0e469b7dd..d13bd3492 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -14,7 +14,7 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLM def get_flash_core_attention_forward(): - from colossalai.kernel import AttnMaskType, ColoAttention + from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention from .chatglm2_6b.modeling_chatglm import CoreAttention diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 9ab51b90e..055e3096d 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -719,7 +719,7 @@ class GPT2PipelineForwards: def get_gpt2_flash_attention_forward(): from transformers.models.gpt2.modeling_gpt2 import GPT2Attention - from colossalai.kernel import AttnMaskType, ColoAttention + from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention def split_heads(tensor, num_heads, attn_head_size): """ diff --git a/colossalai/shardformer/modeling/gptj.py b/colossalai/shardformer/modeling/gptj.py index ad51bf2c7..22b0f7a90 100644 --- a/colossalai/shardformer/modeling/gptj.py +++ b/colossalai/shardformer/modeling/gptj.py @@ -530,7 +530,7 @@ class GPTJPipelineForwards: def get_gptj_flash_attention_forward(): from transformers.models.gptj.modeling_gptj import GPTJAttention - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention def split_heads(tensor, num_attention_heads, attn_head_size, rotary): """ diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index a3cc11fb5..e10a7ed7d 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -1,5 +1,5 @@ import warnings -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -420,7 +420,7 @@ class LlamaPipelineForwards: def get_llama_flash_attention_forward(shard_config: ShardConfig): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb - from colossalai.kernel import AttnMaskType, ColoAttention + from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention llama_version = 2 try: diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index 1ddb26c25..0da1a35a0 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -6,7 +6,7 @@ import torch def get_mistral_flash_attention_forward(): from transformers.models.mistral.modeling_mistral import MistralAttention, apply_rotary_pos_emb, repeat_kv - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention def forward( self: MistralAttention, diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index 625b78bd8..7f6cbbbcf 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -514,7 +514,7 @@ class OPTPipelineForwards: def get_opt_flash_attention_forward(): from transformers.models.opt.modeling_opt import OPTAttention - from colossalai.kernel import AttnMaskType, ColoAttention + from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention def forward( self: OPTAttention, diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py index ca3574253..ab141a74a 100644 --- a/colossalai/shardformer/modeling/vit.py +++ b/colossalai/shardformer/modeling/vit.py @@ -336,7 +336,7 @@ def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManag def get_vit_flash_self_attention_forward(): from transformers.models.vit.modeling_vit import ViTSelfAttention - from colossalai.kernel import ColoAttention + from colossalai.nn.layer.colo_attention import ColoAttention def transpose_for_scores(x: torch.Tensor, num_attention_heads, attention_head_size) -> torch.Tensor: new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size) diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py index f67f6cd63..cb8b45ae7 100644 --- a/colossalai/shardformer/modeling/whisper.py +++ b/colossalai/shardformer/modeling/whisper.py @@ -26,7 +26,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager def get_whisper_flash_attention_forward(): from transformers.models.whisper.modeling_whisper import WhisperAttention - from colossalai.kernel import AttnMaskType, ColoAttention + from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention def shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int): return tensor.view(bsz, seq_len, num_heads, head_dim).contiguous() diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py index 9d33e4668..cdba46709 100644 --- a/colossalai/utils/__init__.py +++ b/colossalai/utils/__init__.py @@ -4,6 +4,7 @@ from .common import ( disposable, ensure_path_exists, free_storage, + get_current_device, is_ddp_ignored, set_seed, ) @@ -22,5 +23,6 @@ __all__ = [ "_cast_float", "free_storage", "set_seed", + "get_current_device", "is_ddp_ignored", ] diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py index c43caaff4..4a1889eb5 100644 --- a/colossalai/utils/common.py +++ b/colossalai/utils/common.py @@ -10,6 +10,15 @@ from typing import Callable import numpy as np import torch +from colossalai.accelerator import get_accelerator + + +def get_current_device(): + """ + A wrapper function for accelerator's API for backward compatibility. + """ + return get_accelerator().get_current_device() + def ensure_path_exists(filename: str): # ensure the path exists diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index 7a9f58701..cad2622f2 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -190,8 +190,10 @@ class Chunk: def device_type(self) -> str: if self.chunk_temp is not None: return self.chunk_temp.device.type - else: + elif self.is_gathered or self.cuda_shard is not None: return get_accelerator().name + else: + return "cpu" @property def payload(self) -> torch.Tensor: diff --git a/examples/tutorial/sequence_parallel/model/bert.py b/examples/tutorial/sequence_parallel/model/bert.py index 7b0e93d95..64260374a 100644 --- a/examples/tutorial/sequence_parallel/model/bert.py +++ b/examples/tutorial/sequence_parallel/model/bert.py @@ -3,13 +3,13 @@ import inspect import torch import torch.nn as nn -from colossalai.kernel import LayerNorm from colossalai.legacy.context import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.nn.layer.wrapper import PipelineSharedModuleWrapper from colossalai.legacy.pipeline.utils import partition_uniform from colossalai.logging import get_dist_logger +from colossalai.nn.layer.layernorm import MixedFusedLayerNorm as LayerNorm from .layers import BertDualHead, BertLayer, Embedding, PreProcessor, VocabEmbedding from .layers.init_method import init_normal, output_init_normal diff --git a/examples/tutorial/sequence_parallel/model/layers/head.py b/examples/tutorial/sequence_parallel/model/layers/head.py index 75afeee60..ff81ace39 100644 --- a/examples/tutorial/sequence_parallel/model/layers/head.py +++ b/examples/tutorial/sequence_parallel/model/layers/head.py @@ -3,9 +3,9 @@ import torch.nn as nn import torch.nn.functional as F from loss_func.cross_entropy import vocab_cross_entropy -from colossalai.kernel import LayerNorm from colossalai.legacy.context import ParallelMode from colossalai.legacy.core import global_context as gpc +from colossalai.nn.layer.layernorm import MixedFusedLayerNorm as LayerNorm from .linear import Linear from .pooler import Pooler diff --git a/examples/tutorial/sequence_parallel/train.py b/examples/tutorial/sequence_parallel/train.py index e9ceb8d70..f25fc8189 100644 --- a/examples/tutorial/sequence_parallel/train.py +++ b/examples/tutorial/sequence_parallel/train.py @@ -8,12 +8,12 @@ from lr_scheduler import AnnealingLR from model.bert import BertForPretrain, build_pipeline_bert import colossalai -from colossalai.kernel import LayerNorm from colossalai.legacy.amp import AMP_TYPE from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.utils import is_using_pp from colossalai.logging import get_dist_logger +from colossalai.nn.layer.layernorm import MixedFusedLayerNorm as LayerNorm from colossalai.nn.optimizer import FusedAdam from colossalai.utils import MultiTimer diff --git a/extensions/__init__.py b/extensions/__init__.py new file mode 100644 index 000000000..9343cadda --- /dev/null +++ b/extensions/__init__.py @@ -0,0 +1,36 @@ +from .cpu_adam import CpuAdamArmExtension, CpuAdamX86Extension +from .flash_attention import ( + FlashAttentionDaoCudaExtension, + FlashAttentionNpuExtension, + FlashAttentionXformersCudaExtension, +) +from .layernorm import LayerNormCudaExtension +from .moe import MoeCudaExtension +from .optimizer import FusedOptimizerCudaExtension +from .softmax import ScaledMaskedSoftmaxCudaExtension, ScaledUpperTriangleMaskedSoftmaxCudaExtension + +ALL_EXTENSIONS = [ + CpuAdamArmExtension, + CpuAdamX86Extension, + LayerNormCudaExtension, + MoeCudaExtension, + FusedOptimizerCudaExtension, + ScaledMaskedSoftmaxCudaExtension, + ScaledUpperTriangleMaskedSoftmaxCudaExtension, + FlashAttentionDaoCudaExtension, + FlashAttentionXformersCudaExtension, + FlashAttentionNpuExtension, +] + +__all__ = [ + "CpuAdamArmExtension", + "CpuAdamX86Extension", + "LayerNormCudaExtension", + "MoeCudaExtension", + "FusedOptimizerCudaExtension", + "ScaledMaskedSoftmaxCudaExtension", + "ScaledUpperTriangleMaskedSoftmaxCudaExtension", + "FlashAttentionDaoCudaExtension", + "FlashAttentionXformersCudaExtension", + "FlashAttentionNpuExtension", +] diff --git a/extensions/base_extension.py b/extensions/base_extension.py new file mode 100644 index 000000000..a72b20cb3 --- /dev/null +++ b/extensions/base_extension.py @@ -0,0 +1,82 @@ +import hashlib +import os +from abc import ABC, abstractmethod +from typing import Union + +__all__ = ["_Extension"] + + +class _Extension(ABC): + def __init__(self, name: str, support_aot: bool, support_jit: bool, priority: int = 1): + self._name = name + self._support_aot = support_aot + self._support_jit = support_jit + self.priority = priority + + @property + def name(self): + return self._name + + @property + def support_aot(self): + return self._support_aot + + @property + def support_jit(self): + return self._support_jit + + @staticmethod + def get_jit_extension_folder_path(): + """ + Kernels which are compiled during runtime will be stored in the same cache folder for reuse. + The folder is in the path ~/.cache/colossalai/torch_extensions/. + The name of the follows a common format: + torch._- + + The suffix is the hash value of the path of the `colossalai` file. + """ + import torch + + import colossalai + from colossalai.accelerator import get_accelerator + + # get torch version + torch_version_major = torch.__version__.split(".")[0] + torch_version_minor = torch.__version__.split(".")[1] + + # get device version + device_name = get_accelerator().name + device_version = get_accelerator().get_version() + + # use colossalai's file path as hash + hash_suffix = hashlib.sha256(colossalai.__file__.encode()).hexdigest() + + # concat + home_directory = os.path.expanduser("~") + extension_directory = f".cache/colossalai/torch_extensions/torch{torch_version_major}.{torch_version_minor}_{device_name}-{device_version}-{hash_suffix}" + cache_directory = os.path.join(home_directory, extension_directory) + return cache_directory + + @abstractmethod + def is_hardware_available(self) -> bool: + """ + Check if the hardware required by the kernel is available. + """ + + @abstractmethod + def assert_hardware_compatible(self) -> bool: + """ + Check if the hardware required by the kernel is compatible. + """ + + @abstractmethod + def build_aot(self) -> Union["CppExtension", "CUDAExtension"]: + pass + + @abstractmethod + def build_jit(self) -> None: + pass + + @abstractmethod + def load(self): + pass diff --git a/extensions/cpp_extension.py b/extensions/cpp_extension.py new file mode 100644 index 000000000..b4c40c9f1 --- /dev/null +++ b/extensions/cpp_extension.py @@ -0,0 +1,134 @@ +import importlib +import os +import time +from abc import abstractmethod +from pathlib import Path +from typing import List + +from .base_extension import _Extension + +__all__ = ["_CppExtension"] + + +class _CppExtension(_Extension): + def __init__(self, name: str, priority: int = 1): + super().__init__(name, support_aot=True, support_jit=True, priority=priority) + + # we store the op as an attribute to avoid repeated building and loading + self.cached_op = None + + # build-related variables + self.prebuilt_module_path = "colossalai._C" + self.prebuilt_import_path = f"{self.prebuilt_module_path}.{self.name}" + self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] + + def csrc_abs_path(self, path): + return os.path.join(self.relative_to_abs_path("csrc"), path) + + def relative_to_abs_path(self, code_path: str) -> str: + """ + This function takes in a path relative to the colossalai root directory and return the absolute path. + """ + + # get the current file path + # iteratively check the parent directory + # if the parent directory is "extensions", then the current file path is the root directory + # otherwise, the current file path is inside the root directory + current_file_path = Path(__file__) + while True: + if current_file_path.name == "extensions": + break + else: + current_file_path = current_file_path.parent + extension_module_path = current_file_path + code_abs_path = extension_module_path.joinpath(code_path) + return str(code_abs_path) + + # functions must be overrided over + def strip_empty_entries(self, args): + """ + Drop any empty strings from the list of compile and link flags + """ + return [x for x in args if len(x) > 0] + + def import_op(self): + """ + This function will import the op module by its string name. + """ + return importlib.import_module(self.prebuilt_import_path) + + def build_aot(self) -> "CppExtension": + from torch.utils.cpp_extension import CppExtension + + return CppExtension( + name=self.prebuilt_import_path, + sources=self.strip_empty_entries(self.sources_files()), + include_dirs=self.strip_empty_entries(self.include_dirs()), + extra_compile_args=self.strip_empty_entries(self.cxx_flags()), + ) + + def build_jit(self) -> None: + from torch.utils.cpp_extension import load + + build_directory = _Extension.get_jit_extension_folder_path() + build_directory = Path(build_directory) + build_directory.mkdir(parents=True, exist_ok=True) + + # check if the kernel has been built + compiled_before = False + kernel_file_path = build_directory.joinpath(f"{self.name}.o") + if kernel_file_path.exists(): + compiled_before = True + + # load the kernel + if compiled_before: + print(f"[extension] Loading the JIT-built {self.name} kernel during runtime now") + else: + print(f"[extension] Compiling the JIT {self.name} kernel during runtime now") + + build_start = time.time() + op_kernel = load( + name=self.name, + sources=self.strip_empty_entries(self.sources_files()), + extra_include_paths=self.strip_empty_entries(self.include_dirs()), + extra_cflags=self.cxx_flags(), + extra_ldflags=[], + build_directory=str(build_directory), + ) + build_duration = time.time() - build_start + + if compiled_before: + print(f"[extension] Time taken to load {self.name} op: {build_duration} seconds") + else: + print(f"[extension] Time taken to compile {self.name} op: {build_duration} seconds") + + return op_kernel + + # functions must be overrided begin + @abstractmethod + def sources_files(self) -> List[str]: + """ + This function should return a list of source files for extensions. + """ + + @abstractmethod + def include_dirs(self) -> List[str]: + """ + This function should return a list of include files for extensions. + """ + + @abstractmethod + def cxx_flags(self) -> List[str]: + """ + This function should return a list of cxx compilation flags for extensions. + """ + + def load(self): + try: + op_kernel = self.import_op() + except ImportError: + # if import error occurs, it means that the kernel is not pre-built + # so we build it jit + op_kernel = self.build_jit() + + return op_kernel diff --git a/extensions/cpu_adam/__init__.py b/extensions/cpu_adam/__init__.py new file mode 100644 index 000000000..cfd26a6a2 --- /dev/null +++ b/extensions/cpu_adam/__init__.py @@ -0,0 +1,5 @@ +from .cpu_adam_arm import CpuAdamArmExtension +from .cpu_adam_x86 import CpuAdamX86Extension + +__all__ = ['CpuAdamArmExtension', 'CpuAdamX86Extension'] + diff --git a/extensions/cpu_adam/cpu_adam_arm.py b/extensions/cpu_adam/cpu_adam_arm.py new file mode 100644 index 000000000..35bff3b55 --- /dev/null +++ b/extensions/cpu_adam/cpu_adam_arm.py @@ -0,0 +1,41 @@ +import platform + +from ..cpp_extension import _CppExtension + + +class CpuAdamArmExtension(_CppExtension): + def __init__(self): + super().__init__(name="cpu_adam_arm") + + def is_hardware_available(self) -> bool: + # only arm allowed + return platform.machine() == "aarch64" + + def assert_hardware_compatible(self) -> None: + arch = platform.machine() + assert ( + arch == "aarch64" + ), f"[extension] The {self.name} kernel requires the CPU architecture to be aarch64 but got {arch}" + + # necessary 4 functions + def sources_files(self): + ret = [ + self.csrc_abs_path("arm/cpu_adam_arm.cpp"), + ] + return ret + + def include_dirs(self): + return [] + + def cxx_flags(self): + extra_cxx_flags = [ + "-std=c++14", + "-std=c++17", + "-g", + "-Wno-reorder", + "-fopenmp", + ] + return ["-O3"] + self.version_dependent_macros + extra_cxx_flags + + def nvcc_flags(self): + return [] diff --git a/op_builder/cpu_adam.py b/extensions/cpu_adam/cpu_adam_x86.py similarity index 60% rename from op_builder/cpu_adam.py rename to extensions/cpu_adam/cpu_adam_x86.py index 7988aae4b..a38194167 100644 --- a/op_builder/cpu_adam.py +++ b/extensions/cpu_adam/cpu_adam_x86.py @@ -1,19 +1,27 @@ -from .builder import Builder -from .utils import append_nvcc_threads +import platform + +from ..cuda_extension import _CudaExtension +from ..utils import append_nvcc_threads -class CPUAdamBuilder(Builder): - NAME = "cpu_adam" - PREBUILT_IMPORT_PATH = "colossalai._C.cpu_adam" - +class CpuAdamX86Extension(_CudaExtension): def __init__(self): - super().__init__(name=CPUAdamBuilder.NAME, prebuilt_import_path=CPUAdamBuilder.PREBUILT_IMPORT_PATH) - self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] + super().__init__(name="cpu_adam_x86") + + def is_hardware_available(self) -> bool: + return platform.machine() == "x86_64" and super().is_hardware_available() + + def assert_hardware_compatible(self) -> None: + arch = platform.machine() + assert ( + arch == "x86_64" + ), f"[extension] The {self.name} kernel requires the CPU architecture to be x86_64 but got {arch}" + super().assert_hardware_compatible() # necessary 4 functions def sources_files(self): ret = [ - self.csrc_abs_path("cpu_adam.cpp"), + self.csrc_abs_path("cuda/cpu_adam.cpp"), ] return ret diff --git a/colossalai/kernel/cuda_native/__init__.py b/extensions/csrc/__init__.py similarity index 100% rename from colossalai/kernel/cuda_native/__init__.py rename to extensions/csrc/__init__.py diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam_arm.cpp b/extensions/csrc/arm/cpu_adam_arm.cpp similarity index 100% rename from colossalai/kernel/cuda_native/csrc/cpu_adam_arm.cpp rename to extensions/csrc/arm/cpu_adam_arm.cpp diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam_arm.h b/extensions/csrc/arm/cpu_adam_arm.h similarity index 100% rename from colossalai/kernel/cuda_native/csrc/cpu_adam_arm.h rename to extensions/csrc/arm/cpu_adam_arm.h diff --git a/colossalai/kernel/cuda_native/csrc/colossal_C_frontend.cpp b/extensions/csrc/cuda/colossal_C_frontend.cpp similarity index 100% rename from colossalai/kernel/cuda_native/csrc/colossal_C_frontend.cpp rename to extensions/csrc/cuda/colossal_C_frontend.cpp diff --git a/colossalai/kernel/cuda_native/csrc/compat.h b/extensions/csrc/cuda/compat.h similarity index 100% rename from colossalai/kernel/cuda_native/csrc/compat.h rename to extensions/csrc/cuda/compat.h diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam.cpp b/extensions/csrc/cuda/cpu_adam.cpp similarity index 100% rename from colossalai/kernel/cuda_native/csrc/cpu_adam.cpp rename to extensions/csrc/cuda/cpu_adam.cpp diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam.h b/extensions/csrc/cuda/cpu_adam.h similarity index 100% rename from colossalai/kernel/cuda_native/csrc/cpu_adam.h rename to extensions/csrc/cuda/cpu_adam.h diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h b/extensions/csrc/cuda/include/block_reduce.h similarity index 100% rename from colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h rename to extensions/csrc/cuda/include/block_reduce.h diff --git a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp b/extensions/csrc/cuda/layer_norm_cuda.cpp similarity index 100% rename from colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp rename to extensions/csrc/cuda/layer_norm_cuda.cpp diff --git a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu b/extensions/csrc/cuda/layer_norm_cuda_kernel.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu rename to extensions/csrc/cuda/layer_norm_cuda_kernel.cu diff --git a/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp b/extensions/csrc/cuda/moe_cuda.cpp similarity index 100% rename from colossalai/kernel/cuda_native/csrc/moe_cuda.cpp rename to extensions/csrc/cuda/moe_cuda.cpp diff --git a/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu b/extensions/csrc/cuda/moe_cuda_kernel.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu rename to extensions/csrc/cuda/moe_cuda_kernel.cu diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu b/extensions/csrc/cuda/multi_tensor_adam.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu rename to extensions/csrc/cuda/multi_tensor_adam.cu diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_apply.cuh b/extensions/csrc/cuda/multi_tensor_apply.cuh similarity index 100% rename from colossalai/kernel/cuda_native/csrc/multi_tensor_apply.cuh rename to extensions/csrc/cuda/multi_tensor_apply.cuh diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu b/extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu rename to extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu b/extensions/csrc/cuda/multi_tensor_lamb.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu rename to extensions/csrc/cuda/multi_tensor_lamb.cu diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu b/extensions/csrc/cuda/multi_tensor_scale_kernel.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu rename to extensions/csrc/cuda/multi_tensor_scale_kernel.cu diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu b/extensions/csrc/cuda/multi_tensor_sgd_kernel.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu rename to extensions/csrc/cuda/multi_tensor_sgd_kernel.cu diff --git a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp b/extensions/csrc/cuda/scaled_masked_softmax.cpp similarity index 100% rename from colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp rename to extensions/csrc/cuda/scaled_masked_softmax.cpp diff --git a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h b/extensions/csrc/cuda/scaled_masked_softmax.h similarity index 100% rename from colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h rename to extensions/csrc/cuda/scaled_masked_softmax.h diff --git a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu b/extensions/csrc/cuda/scaled_masked_softmax_cuda.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu rename to extensions/csrc/cuda/scaled_masked_softmax_cuda.cu diff --git a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp b/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.cpp similarity index 100% rename from colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp rename to extensions/csrc/cuda/scaled_upper_triang_masked_softmax.cpp diff --git a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h b/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h similarity index 100% rename from colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h rename to extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h diff --git a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu b/extensions/csrc/cuda/scaled_upper_triang_masked_softmax_cuda.cu similarity index 100% rename from colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu rename to extensions/csrc/cuda/scaled_upper_triang_masked_softmax_cuda.cu diff --git a/colossalai/kernel/cuda_native/csrc/type_shim.h b/extensions/csrc/cuda/type_shim.h similarity index 100% rename from colossalai/kernel/cuda_native/csrc/type_shim.h rename to extensions/csrc/cuda/type_shim.h diff --git a/colossalai/kernel/cuda_native/scaled_softmax.py b/extensions/csrc/scaled_softmax.py similarity index 94% rename from colossalai/kernel/cuda_native/scaled_softmax.py rename to extensions/csrc/scaled_softmax.py index 26a5bce16..7c220d60d 100644 --- a/colossalai/kernel/cuda_native/scaled_softmax.py +++ b/extensions/csrc/scaled_softmax.py @@ -6,8 +6,7 @@ import enum import torch import torch.nn as nn -from colossalai.kernel.op_builder.scaled_masked_softmax import ScaledMaskedSoftmaxBuilder -from colossalai.kernel.op_builder.scaled_upper_triangle_masked_softmax import ScaledUpperTrainglemaskedSoftmaxBuilder +from colossalai.kernel.kernel_loader import ScaledMaskedSoftmaxLoader, ScaledUpperTriangleMaskedSoftmaxLoader try: from colossalai._C import scaled_masked_softmax, scaled_upper_triang_masked_softmax @@ -35,7 +34,7 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): def forward(ctx, inputs, scale): global scaled_upper_triang_masked_softmax if scaled_upper_triang_masked_softmax: - scaled_upper_triang_masked_softmax = ScaledUpperTrainglemaskedSoftmaxBuilder().load() + scaled_upper_triang_masked_softmax = ScaledUpperTriangleMaskedSoftmaxLoader().load() scale_t = torch.tensor([scale]) softmax_results = scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0]) @@ -67,7 +66,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function): # build and load kernel if not pre-built global scaled_masked_softmax if scaled_masked_softmax is None: - scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load() + scaled_masked_softmax = ScaledMaskedSoftmaxLoader().load() softmax_results = scaled_masked_softmax.forward(inputs, mask, scale_t[0]) ctx.save_for_backward(softmax_results, scale_t) diff --git a/extensions/cuda_extension.py b/extensions/cuda_extension.py new file mode 100644 index 000000000..188e61b60 --- /dev/null +++ b/extensions/cuda_extension.py @@ -0,0 +1,106 @@ +import os +from abc import abstractmethod +from typing import List + +from .cpp_extension import _CppExtension +from .utils import check_pytorch_version, check_system_pytorch_cuda_match, set_cuda_arch_list + +__all__ = ["_CudaExtension"] + +# Some constants for installation checks +MIN_PYTORCH_VERSION_MAJOR = 1 +MIN_PYTORCH_VERSION_MINOR = 10 + + +class _CudaExtension(_CppExtension): + @abstractmethod + def nvcc_flags(self) -> List[str]: + """ + This function should return a list of nvcc compilation flags for extensions. + """ + + def is_hardware_available(self) -> bool: + # cuda extension can only be built if cuda is availabe + try: + import torch + + cuda_available = torch.cuda.is_available() + except: + cuda_available = False + return cuda_available + + def assert_hardware_compatible(self) -> None: + from torch.utils.cpp_extension import CUDA_HOME + + if not CUDA_HOME: + raise AssertionError( + "[extension] CUDA_HOME is not found. You need to export CUDA_HOME environment variable or install CUDA Toolkit first in order to build/load CUDA extensions" + ) + check_system_pytorch_cuda_match(CUDA_HOME) + check_pytorch_version(MIN_PYTORCH_VERSION_MAJOR, MIN_PYTORCH_VERSION_MINOR) + + def get_cuda_home_include(self): + """ + return include path inside the cuda home. + """ + from torch.utils.cpp_extension import CUDA_HOME + + if CUDA_HOME is None: + raise RuntimeError("CUDA_HOME is None, please set CUDA_HOME to compile C++/CUDA kernels in ColossalAI.") + cuda_include = os.path.join(CUDA_HOME, "include") + return cuda_include + + def build_jit(self) -> None: + from torch.utils.cpp_extension import CUDA_HOME, load + + set_cuda_arch_list(CUDA_HOME) + + # get build dir + build_directory = _Extension.get_jit_extension_folder_path() + build_directory = Path(build_directory) + build_directory.mkdir(parents=True, exist_ok=True) + + # check if the kernel has been built + compiled_before = False + kernel_file_path = build_directory.joinpath(f"{self.name}.o") + if kernel_file_path.exists(): + compiled_before = True + + # load the kernel + if compiled_before: + print(f"[extension] Loading the JIT-built {self.name} kernel during runtime now") + else: + print(f"[extension] Compiling the JIT {self.name} kernel during runtime now") + + build_start = time.time() + op_kernel = load( + name=self.name, + sources=self.strip_empty_entries(self.sources_files()), + extra_include_paths=self.strip_empty_entries(self.include_dirs()), + extra_cflags=self.cxx_flags(), + extra_cuda_cflags=self.nvcc_flags(), + extra_ldflags=[], + build_directory=str(build_directory), + ) + build_duration = time.time() - build_start + + if compiled_before: + print(f"[extension] Time taken to load {self.name} op: {build_duration} seconds") + else: + print(f"[extension] Time taken to compile {self.name} op: {build_duration} seconds") + + return op_kernel + + def build_aot(self) -> "CUDAExtension": + from torch.utils.cpp_extension import CUDA_HOME, CUDAExtension + + set_cuda_arch_list(CUDA_HOME) + return CUDAExtension( + name=self.prebuilt_import_path, + sources=self.strip_empty_entries(self.sources_files()), + include_dirs=self.strip_empty_entries(self.include_dirs()), + extra_compile_args={ + "cxx": self.strip_empty_entries(self.cxx_flags()), + "nvcc": self.strip_empty_entries(self.nvcc_flags()), + }, + ) diff --git a/extensions/flash_attention/__init__.py b/extensions/flash_attention/__init__.py new file mode 100644 index 000000000..18abb6191 --- /dev/null +++ b/extensions/flash_attention/__init__.py @@ -0,0 +1,20 @@ +from .flash_attention_dao_cuda import FlashAttentionDaoCudaExtension +from .flash_attention_npu import FlashAttentionNpuExtension +from .flash_attention_xformers_cuda import FlashAttentionXformersCudaExtension + +try: + import flash_attention # noqa + + HAS_FLASH_ATTN = True +except: + HAS_FLASH_ATTN = False + +try: + import xformers # noqa + + HAS_MEM_EFF_ATTN = True +except: + HAS_MEM_EFF_ATTN = False + + +__all__ = ["FlashAttentionDaoCudaExtension", "FlashAttentionXformersCudaExtension", "FlashAttentionNpuExtension"] diff --git a/extensions/flash_attention/flash_attention_dao_cuda.py b/extensions/flash_attention/flash_attention_dao_cuda.py new file mode 100644 index 000000000..6e2a9a880 --- /dev/null +++ b/extensions/flash_attention/flash_attention_dao_cuda.py @@ -0,0 +1,93 @@ +from ..base_extension import _Extension + + +class FlashAttentionDaoCudaExtension(_Extension): + def __init__(self): + super().__init__(name="flash_attention_dao_cuda", support_aot=False, support_jit=False, priority=10) + + def is_hardware_available(self) -> bool: + # cuda extension can only be built if cuda is availabe + try: + import torch + + cuda_available = torch.cuda.is_available() + except: + cuda_available = False + return cuda_available + + def assert_hardware_compatible(self) -> bool: + pass + + def build_aot(self) -> None: + raise NotImplementedError( + "We rely on the third-party flash-attn library for flash attention (https://github.com/Dao-AILab/flash-attention). Please install flash-attn via 'pip install flash-attn --no-build-isolation'." + ) + + def build_jit(self) -> None: + raise NotImplementedError( + "We rely on the third-party flash-attn library for flash attention (https://github.com/Dao-AILab/flash-attention). Please install flash-attn via 'pip install flash-attn --no-build-isolation'" + ) + + def load(self): + try: + from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func + except ImportError: + raise ModuleNotFoundError( + ( + "We rely on the third-party flash-attn library for flash attention. Please install flash-attn via 'pip install flash-attn --no-build-isolation'" + ) + ) + + from typing import Optional + + import torch + + def flash_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_len_info_q: "SeqLenInfo", + seq_len_info_kv: "SeqLenInfo", + origin_attn_mask: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + scale: float = None, + causal: bool = False, + padded: bool = False, + ): + """ + Arguments: + q: (batch, q_seqlen, nheads, headdim) + k: (batch, kv_seqlen, nheads, headdim) + v: (batch, kv_seqlen, nheads, headdim) + batch_size: int. + seq_len: int. + dropout_p: float. Dropout probability. + sm_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + Return: + attn_out: (batch, q_seqlen, nheads, headdim). + """ + # check if the input is in allowed dtypes + if padded: + if seq_len_info_kv == None: + seq_len_info_kv = seq_len_info_q + + attn_out = flash_attn_varlen_func( + q, + k, + v, + seq_len_info_q.cu_seqlens, + seq_len_info_kv.cu_seqlens, + seq_len_info_q.max_seqlen, + seq_len_info_kv.max_seqlen, + dropout_p, + scale, + causal, + ) + else: + attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal) + return attn_out + + return flash_attention diff --git a/extensions/flash_attention/flash_attention_npu.py b/extensions/flash_attention/flash_attention_npu.py new file mode 100644 index 000000000..58d0f9306 --- /dev/null +++ b/extensions/flash_attention/flash_attention_npu.py @@ -0,0 +1,73 @@ +from ..base_extension import _Extension + + +class FlashAttentionNpuExtension(_Extension): + def __init__(self): + super().__init__(name="flash_attention_npu", support_aot=False, support_jit=False) + + def is_hardware_available(self) -> bool: + try: + import torch_npu # noqa + + return True + except: + return False + + def assert_hardware_compatible(self) -> bool: + pass + + def build_aot(self) -> None: + raise NotImplementedError( + "Flash Attention NPU does not require ahead-of-time compilation. Please use it by installing torch_npu." + ) + + def build_jit(self) -> None: + raise NotImplementedError( + "Flash Attention NPU does not require just-in-time compilation. Please use it by installing torch_npu." + ) + + def load(self): + import torch + from einops import rearrange + + def npu_sdpa_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_len_info_q=None, + seq_len_info_kv=None, + origin_attn_mask: torch.Tensor = None, + dropout_p: float = 0.0, + scale: float = 1.0, + causal=None, + padded=None, + ): + """ + The scaled dot product attention. + + Arguments: + q: (batch, q_seqlen, nheads, headdim) + k: (batch, kv_seqlen, nheads, headdim) + v: (batch, kv_seqlen, nheads, headdim) + batch_size: int. + seq_len: int. + dropout_p: float. Dropout probability. + scale: float. The scaling of QK^T before applying softmax. + Default to 1. + Return: + attn_out: (batch, q_seqlen, nheads, headdim). + """ + q, k, v = [rearrange(x, "b s h d -> b h s d").contiguous() for x in (q, k, v)] + output = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=origin_attn_mask, + dropout_p=dropout_p, + is_causal=origin_attn_mask is None, + scale=scale, + ) + output = rearrange(output, "b h s d -> b s (h d)") + return output + + return npu_sdpa_attention diff --git a/extensions/flash_attention/flash_attention_xformers_cuda.py b/extensions/flash_attention/flash_attention_xformers_cuda.py new file mode 100644 index 000000000..737b8599f --- /dev/null +++ b/extensions/flash_attention/flash_attention_xformers_cuda.py @@ -0,0 +1,94 @@ +from ..base_extension import _Extension + + +class FlashAttentionXformersCudaExtension(_Extension): + def __init__(self): + super().__init__(name="flash_attention_xformers_cuda", support_aot=False, support_jit=False) + + def is_hardware_available(self) -> bool: + # cuda extension can only be built if cuda is availabe + try: + import torch + + cuda_available = torch.cuda.is_available() + except: + cuda_available = False + return cuda_available + + def assert_hardware_compatible(self) -> bool: + pass + + def build_aot(self) -> None: + raise NotImplementedError( + "We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme." + ) + + def build_jit(self) -> None: + raise NotImplementedError( + "We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme." + ) + + def load(self): + try: + from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention + from xformers.ops.fmha.attn_bias import ( + BlockDiagonalCausalMask, + BlockDiagonalMask, + LowerTriangularMask, + LowerTriangularMaskWithTensorBias, + ) + except ImportError: + raise ModuleNotFoundError( + ( + "We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme." + ) + ) + from typing import Optional + + import torch + + allow_alibi = True + for op in MemoryEfficientAttentionCutlassOp: + allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES) + + def mem_eff_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_len_info_q: "SeqLenInfo", + seq_len_info_kv: "SeqLenInfo", + origin_attn_mask: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + scale: float = None, + causal: bool = False, + padded: bool = False, + ): + attn_bias = None + if padded: # bert style + if not causal: + attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) + else: + attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens) + elif causal: # gpt style + attn_bias = LowerTriangularMask() + + if bias is not None: # alibi / relative position embedding + assert allow_alibi, "flash attention with bias is not supported in this system." + assert causal, "attention with bias is only supported for causal attention so far." + attn_bias = attn_bias.add_bias(bias) + + if padded: + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + + out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p, scale=scale) + + # shape: (b*s, n, d) + if padded: + out = out.squeeze(0) + + return out + + return mem_eff_attention diff --git a/extensions/layernorm/__init__.py b/extensions/layernorm/__init__.py new file mode 100644 index 000000000..9d1bd2d01 --- /dev/null +++ b/extensions/layernorm/__init__.py @@ -0,0 +1,3 @@ +from .layernorm_cuda import LayerNormCudaExtension + +__all__ = ["LayerNormCudaExtension"] \ No newline at end of file diff --git a/extensions/layernorm/layernorm_cuda.py b/extensions/layernorm/layernorm_cuda.py new file mode 100644 index 000000000..db5f2fce1 --- /dev/null +++ b/extensions/layernorm/layernorm_cuda.py @@ -0,0 +1,24 @@ +from ..cuda_extension import _CudaExtension +from ..utils import append_nvcc_threads, get_cuda_cc_flag + + +class LayerNormCudaExtension(_CudaExtension): + def __init__(self): + super().__init__(name="layernorm_cuda") + + def sources_files(self): + ret = [self.csrc_abs_path(fname) for fname in ["cuda/layer_norm_cuda.cpp", "cuda/layer_norm_cuda_kernel.cu"]] + return ret + + def include_dirs(self): + ret = [self.get_cuda_home_include()] + return ret + + def cxx_flags(self): + return ["-O3"] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = ["-maxrregcount=50"] + extra_cuda_flags.extend(get_cuda_cc_flag()) + ret = ["-O3", "--use_fast_math"] + extra_cuda_flags + self.version_dependent_macros + return append_nvcc_threads(ret) diff --git a/extensions/moe/__init__.py b/extensions/moe/__init__.py new file mode 100644 index 000000000..962084d4b --- /dev/null +++ b/extensions/moe/__init__.py @@ -0,0 +1,3 @@ +from .moe_cuda import MoeCudaExtension + +__all__ = ['MoeCudaExtension'] \ No newline at end of file diff --git a/op_builder/moe.py b/extensions/moe/moe_cuda.py similarity index 56% rename from op_builder/moe.py rename to extensions/moe/moe_cuda.py index 6f8028b17..52883e97f 100644 --- a/op_builder/moe.py +++ b/extensions/moe/moe_cuda.py @@ -1,20 +1,17 @@ -from .builder import Builder -from .utils import append_nvcc_threads, get_cuda_cc_flag +from ..cuda_extension import _CudaExtension +from ..utils import append_nvcc_threads, get_cuda_cc_flag -class MOEBuilder(Builder): - NAME = "moe" - PREBUILT_IMPORT_PATH = "colossalai._C.moe" - +class MoeCudaExtension(_CudaExtension): def __init__(self): - super().__init__(name=MOEBuilder.NAME, prebuilt_import_path=MOEBuilder.PREBUILT_IMPORT_PATH) + super().__init__(name="moe_cuda") def include_dirs(self): - ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] + ret = [self.csrc_abs_path("cuda/include"), self.get_cuda_home_include()] return ret def sources_files(self): - ret = [self.csrc_abs_path(fname) for fname in ["moe_cuda.cpp", "moe_cuda_kernel.cu"]] + ret = [self.csrc_abs_path(fname) for fname in ["cuda/moe_cuda.cpp", "cuda/moe_cuda_kernel.cu"]] return ret def cxx_flags(self): diff --git a/extensions/optimizer/__init__.py b/extensions/optimizer/__init__.py new file mode 100644 index 000000000..9c8e87cae --- /dev/null +++ b/extensions/optimizer/__init__.py @@ -0,0 +1,3 @@ +from .fused_optimizer_cuda import FusedOptimizerCudaExtension + +__all__ = ['FusedOptimizerCudaExtension'] \ No newline at end of file diff --git a/extensions/optimizer/fused_optimizer_cuda.py b/extensions/optimizer/fused_optimizer_cuda.py new file mode 100644 index 000000000..e065cf34a --- /dev/null +++ b/extensions/optimizer/fused_optimizer_cuda.py @@ -0,0 +1,34 @@ +from ..cuda_extension import _CudaExtension +from ..utils import get_cuda_cc_flag + + +class FusedOptimizerCudaExtension(_CudaExtension): + def __init__(self): + super().__init__(name="fused_optim_cuda") + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) + for fname in [ + "cuda/colossal_C_frontend.cpp", + "cuda/multi_tensor_sgd_kernel.cu", + "cuda/multi_tensor_scale_kernel.cu", + "cuda/multi_tensor_adam.cu", + "cuda/multi_tensor_l2norm_kernel.cu", + "cuda/multi_tensor_lamb.cu", + ] + ] + return ret + + def include_dirs(self): + ret = [self.get_cuda_home_include()] + return ret + + def cxx_flags(self): + version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] + return ["-O3"] + version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = ["-lineinfo"] + extra_cuda_flags.extend(get_cuda_cc_flag()) + return ["-O3", "--use_fast_math"] + extra_cuda_flags diff --git a/extensions/softmax/__init__.py b/extensions/softmax/__init__.py new file mode 100644 index 000000000..9bc50c6cd --- /dev/null +++ b/extensions/softmax/__init__.py @@ -0,0 +1,4 @@ +from .scaled_masked_softmax_cuda import ScaledMaskedSoftmaxCudaExtension +from .scaled_upper_triangle_masked_softmax_cuda import ScaledUpperTriangleMaskedSoftmaxCudaExtension + +__all__ = ['ScaledMaskedSoftmaxCudaExtension', 'ScaledUpperTriangleMaskedSoftmaxCudaExtension'] \ No newline at end of file diff --git a/op_builder/scaled_masked_softmax.py b/extensions/softmax/scaled_masked_softmax_cuda.py similarity index 50% rename from op_builder/scaled_masked_softmax.py rename to extensions/softmax/scaled_masked_softmax_cuda.py index d9239a80e..5b4208dba 100644 --- a/op_builder/scaled_masked_softmax.py +++ b/extensions/softmax/scaled_masked_softmax_cuda.py @@ -1,23 +1,20 @@ -from .builder import Builder -from .utils import append_nvcc_threads +from ..cuda_extension import _CudaExtension +from ..utils import append_nvcc_threads -class ScaledMaskedSoftmaxBuilder(Builder): - NAME = "scaled_masked_softmax" - PREBUILT_IMPORT_PATH = "colossalai._C.scaled_masked_softmax" - +class ScaledMaskedSoftmaxCudaExtension(_CudaExtension): def __init__(self): - super().__init__( - name=ScaledMaskedSoftmaxBuilder.NAME, prebuilt_import_path=ScaledMaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH - ) + super().__init__(name="scaled_masked_softmax_cuda") - # necessary 4 functions def sources_files(self): - ret = [self.csrc_abs_path(fname) for fname in ["scaled_masked_softmax.cpp", "scaled_masked_softmax_cuda.cu"]] + ret = [ + self.csrc_abs_path(fname) + for fname in ["cuda/scaled_masked_softmax.cpp", "cuda/scaled_masked_softmax_cuda.cu"] + ] return ret def include_dirs(self): - return [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] + return [self.get_cuda_home_include()] def cxx_flags(self): return ["-O3"] + self.version_dependent_macros diff --git a/extensions/softmax/scaled_upper_triangle_masked_softmax_cuda.py b/extensions/softmax/scaled_upper_triangle_masked_softmax_cuda.py new file mode 100644 index 000000000..d4f27a921 --- /dev/null +++ b/extensions/softmax/scaled_upper_triangle_masked_softmax_cuda.py @@ -0,0 +1,34 @@ +from ..cuda_extension import _CudaExtension +from ..utils import append_nvcc_threads, get_cuda_cc_flag + + +class ScaledUpperTriangleMaskedSoftmaxCudaExtension(_CudaExtension): + def __init__(self): + super().__init__(name="scaled_upper_triangle_masked_softmax_cuda") + + def include_dirs(self): + return [self.get_cuda_home_include()] + + def sources_files(self): + ret = [ + self.csrc_abs_path(fname) + for fname in [ + "cuda/scaled_upper_triang_masked_softmax.cpp", + "cuda/scaled_upper_triang_masked_softmax_cuda.cu", + ] + ] + return ret + + def cxx_flags(self): + return ["-O3"] + self.version_dependent_macros + + def nvcc_flags(self): + extra_cuda_flags = [ + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + ] + extra_cuda_flags.extend(get_cuda_cc_flag()) + ret = ["-O3", "--use_fast_math"] + extra_cuda_flags + return append_nvcc_threads(ret) diff --git a/extensions/triton_extension.py b/extensions/triton_extension.py new file mode 100644 index 000000000..b3f61644e --- /dev/null +++ b/extensions/triton_extension.py @@ -0,0 +1,21 @@ +from .base_extension import _Extension + +__all__ = ["_TritonExtension"] + + +class _TritonExtension(_Extension): + def __init__(self, name: str, priority: int = 1): + super().__init__(name, support_aot=False, support_jit=True, priority=priority) + + def is_hardware_compatible(self) -> bool: + # cuda extension can only be built if cuda is availabe + try: + import torch + + cuda_available = torch.cuda.is_available() + except: + cuda_available = False + return cuda_available + + def load(self): + return self.build_jit() diff --git a/colossalai/kernel/extensions/utils.py b/extensions/utils.py similarity index 100% rename from colossalai/kernel/extensions/utils.py rename to extensions/utils.py diff --git a/op_builder/README.md b/op_builder/README.md deleted file mode 100644 index 9c33a4a32..000000000 --- a/op_builder/README.md +++ /dev/null @@ -1,32 +0,0 @@ -# Build PyTorch Extensions - -## Overview - -Building PyTorch extensions can be a difficult task for users not from the system background. It is definitely frustrating if the users encounter many strange technical jargons when install Colossal-AI. Therefore, we will provide two methods of building the PyTorch extensions for the users. - -1. Build CUDA extensions when running `pip install` if `CUDA_EXT=1` -2. Build the extension during runtime - -The first method is more suitable for users who are familiar with CUDA environment configurations. The second method is for those who are not as they only need to build the kernel which is required by their program. - -These two methods have different advantages and disadvantages. -Method 1 is good because it allows the user to build all kernels during installation and directly import the kernel. They don't need to care about kernel building when running their program. However, installation may fail if they don't know how to configure their environments and this leads to much frustration. -Method 2 is good because it allows the user to only build the kernel they actually need, such that there is a lower probability that they encounter environment issue. However, it may slow down their program due to the first build and subsequence load. - -## PyTorch Extensions in Colossal-AI - -The project [DeepSpeed](https://github.com/microsoft/DeepSpeed) has proposed a [solution](https://github.com/microsoft/DeepSpeed/tree/master/op_builder) to support kernel-build during either installation or runtime. -We have adapted from DeepSpeed's solution to build extensions. The extension build requires two main functions from PyTorch: - -1. `torch.utils.cpp_extension.CUDAExtension`: used to build extensions in `setup.py` during `pip install`. -2. `torch.utils.cpp_extension.load`: used to build and load extension during runtime - -Please note that the extension build by `CUDAExtension` cannot be loaded by the `load` function and `load` will run its own build again (correct me if I am wrong). - -Based on the DeepSpeed's work, we have make several modifications and improvements: - -1. All pre-built kernels (those installed with `setup.py`) will be found in `colossalai._C` -2. All runtime-built kernels will be found in the default torch extension path, i.e. ~/.cache/colossalai/torch_extensions. (If we put the built kernels in the installed site-package directory, this will make pip uninstall incomplete) -3. Once a kernel is loaded, we will cache it in the builder to avoid repeated kernel loading. - -When loading the built kernel, we will first check if the pre-built one exists. If not, the runtime build will be triggered. diff --git a/op_builder/__init__.py b/op_builder/__init__.py deleted file mode 100644 index 21e216437..000000000 --- a/op_builder/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -from .arm_cpu_adam import ArmCPUAdamBuilder -from .cpu_adam import CPUAdamBuilder -from .fused_optim import FusedOptimBuilder -from .layernorm import LayerNormBuilder -from .moe import MOEBuilder -from .multi_head_attn import MultiHeadAttnBuilder -from .scaled_masked_softmax import ScaledMaskedSoftmaxBuilder -from .scaled_upper_triangle_masked_softmax import ScaledUpperTrainglemaskedSoftmaxBuilder - -ALL_OPS = { - "cpu_adam": CPUAdamBuilder, - "fused_optim": FusedOptimBuilder, - "moe": MOEBuilder, - "multi_head_attn": MultiHeadAttnBuilder, - "scaled_masked_softmax": ScaledMaskedSoftmaxBuilder, - "scaled_upper_triangle_masked_softmax": ScaledUpperTrainglemaskedSoftmaxBuilder, - "layernorm": LayerNormBuilder, -} - -__all__ = [ - "ALL_OPS", - "CPUAdamBuilder", - "FusedOptimBuilder", - "MultiHeadAttnBuilder", - "ScaledMaskedSoftmaxBuilder", - "ScaledUpperTrainglemaskedSoftmaxBuilder", - "MOEBuilder", - "MultiTensorSGDBuilder", - "MultiTensorAdamBuilder", - "MultiTensorLambBuilder", - "MultiTensorScaleBuilder", - "MultiTensorL2NormBuilder", - "ArmCPUAdamBuilder", -] diff --git a/op_builder/arm_cpu_adam.py b/op_builder/arm_cpu_adam.py deleted file mode 100644 index 18dd519fa..000000000 --- a/op_builder/arm_cpu_adam.py +++ /dev/null @@ -1,34 +0,0 @@ -from .builder import Builder - - -class ArmCPUAdamBuilder(Builder): - NAME = "arm_cpu_adam" - PREBUILT_IMPORT_PATH = "colossalai._C.arm_cpu_adam" - ext_type = "cpu" - - def __init__(self): - super().__init__(name=ArmCPUAdamBuilder.NAME, prebuilt_import_path=ArmCPUAdamBuilder.PREBUILT_IMPORT_PATH) - self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] - - # necessary 4 functions - def sources_files(self): - ret = [ - self.csrc_abs_path("cpu_adam_arm.cpp"), - ] - return ret - - def include_dirs(self): - return [self.csrc_abs_path("includes")] - - def cxx_flags(self): - extra_cxx_flags = [ - "-std=c++14", - "-std=c++17", - "-g", - "-Wno-reorder", - "-fopenmp", - ] - return ["-O3"] + self.version_dependent_macros + extra_cxx_flags - - def nvcc_flags(self): - return [] diff --git a/op_builder/builder.py b/op_builder/builder.py deleted file mode 100644 index d804cb160..000000000 --- a/op_builder/builder.py +++ /dev/null @@ -1,236 +0,0 @@ -# This code has been adapted from the DeepSpeed library. -# Copyright (c) Microsoft Corporation. - -# Licensed under the MIT License. -import importlib -import os -import time -from abc import ABC, abstractmethod -from pathlib import Path -from typing import List, Optional, Union - -from .utils import check_cuda_availability, check_system_pytorch_cuda_match, print_rank_0 - - -class Builder(ABC): - """ - Builder is the base class to build extensions for PyTorch. - - Args: - name (str): the name of the kernel to be built - prebuilt_import_path (str): the path where the extension is installed during pip install - """ - - ext_type: str = "cuda" - - def __init__(self, name: str, prebuilt_import_path: str): - self.name = name - self.prebuilt_import_path = prebuilt_import_path - self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] - - # we store the op as an attribute to avoid repeated building and loading - self.cached_op_module = None - - assert prebuilt_import_path.startswith( - "colossalai._C" - ), f"The prebuilt_import_path should start with colossalai._C, but got {self.prebuilt_import_path}" - - def relative_to_abs_path(self, code_path: str) -> str: - """ - This function takes in a path relative to the colossalai root directory and return the absolute path. - """ - op_builder_module_path = Path(__file__).parent - - # if we install from source - # the current file path will be op_builder/builder.py - # if we install via pip install colossalai - # the current file path will be colossalai/kernel/op_builder/builder.py - # this is because that the op_builder inside colossalai is a symlink - # this symlink will be replaced with actual files if we install via pypi - # thus we cannot tell the colossalai root directory by checking whether the op_builder - # is a symlink, we can only tell whether it is inside or outside colossalai - if str(op_builder_module_path).endswith("colossalai/kernel/op_builder"): - root_path = op_builder_module_path.parent.parent - else: - root_path = op_builder_module_path.parent.joinpath("colossalai") - - code_abs_path = root_path.joinpath(code_path) - return str(code_abs_path) - - def get_cuda_home_include(self): - """ - return include path inside the cuda home. - """ - from torch.utils.cpp_extension import CUDA_HOME - - if CUDA_HOME is None: - raise RuntimeError("CUDA_HOME is None, please set CUDA_HOME to compile C++/CUDA kernels in ColossalAI.") - cuda_include = os.path.join(CUDA_HOME, "include") - return cuda_include - - def csrc_abs_path(self, path): - return os.path.join(self.relative_to_abs_path("kernel/cuda_native/csrc"), path) - - # functions must be overrided begin - @abstractmethod - def sources_files(self) -> List[str]: - """ - This function should return a list of source files for extensions. - """ - raise NotImplementedError - - @abstractmethod - def include_dirs(self) -> List[str]: - """ - This function should return a list of include files for extensions. - """ - - @abstractmethod - def cxx_flags(self) -> List[str]: - """ - This function should return a list of cxx compilation flags for extensions. - """ - - @abstractmethod - def nvcc_flags(self) -> List[str]: - """ - This function should return a list of nvcc compilation flags for extensions. - """ - - # functions must be overrided over - def strip_empty_entries(self, args): - """ - Drop any empty strings from the list of compile and link flags - """ - return [x for x in args if len(x) > 0] - - def import_op(self): - """ - This function will import the op module by its string name. - """ - return importlib.import_module(self.prebuilt_import_path) - - def check_runtime_build_environment(self): - """ - Check whether the system environment is ready for extension compilation. - """ - try: - from torch.utils.cpp_extension import CUDA_HOME - - TORCH_AVAILABLE = True - except ImportError: - TORCH_AVAILABLE = False - CUDA_HOME = None - - if not TORCH_AVAILABLE: - raise ModuleNotFoundError( - "PyTorch is not found. You need to install PyTorch first in order to build CUDA extensions" - ) - - if CUDA_HOME is None: - raise RuntimeError( - "CUDA_HOME is not found. You need to export CUDA_HOME environment variable or install CUDA Toolkit first in order to build CUDA extensions" - ) - - # make sure CUDA is available for compilation during - cuda_available = check_cuda_availability() - if not cuda_available: - raise RuntimeError("CUDA is not available on your system as torch.cuda.is_available() returns False.") - - # make sure system CUDA and pytorch CUDA match, an error will raised inside the function if not - check_system_pytorch_cuda_match(CUDA_HOME) - - def load(self, verbose: Optional[bool] = None): - """ - load the kernel during runtime. If the kernel is not built during pip install, it will build the kernel. - If the kernel is built during runtime, it will be stored in `~/.cache/colossalai/torch_extensions/`. If the - kernel is built during pip install, it can be accessed through `colossalai._C`. - - Warning: do not load this kernel repeatedly during model execution as it could slow down the training process. - - Args: - verbose (bool, optional): show detailed info. Defaults to True. - """ - if verbose is None: - verbose = os.environ.get("CAI_KERNEL_VERBOSE", "0") == "1" - # if the kernel has be compiled and cached, we directly use it - if self.cached_op_module is not None: - return self.cached_op_module - - try: - # if the kernel has been pre-built during installation - # we just directly import it - op_module = self.import_op() - if verbose: - print_rank_0( - f"[extension] OP {self.prebuilt_import_path} has been compiled ahead of time, skip building." - ) - except ImportError: - # check environment - if self.ext_type == "cuda": - self.check_runtime_build_environment() - - # time the kernel compilation - start_build = time.time() - - # construct the build directory - import torch - from torch.utils.cpp_extension import load - - torch_version_major = torch.__version__.split(".")[0] - torch_version_minor = torch.__version__.split(".")[1] - torch_cuda_version = torch.version.cuda - home_directory = os.path.expanduser("~") - extension_directory = f".cache/colossalai/torch_extensions/torch{torch_version_major}.{torch_version_minor}_cu{torch_cuda_version}" - build_directory = os.path.join(home_directory, extension_directory) - Path(build_directory).mkdir(parents=True, exist_ok=True) - - if verbose: - print_rank_0(f"[extension] Compiling or loading the JIT-built {self.name} kernel during runtime now") - - # load the kernel - op_module = load( - name=self.name, - sources=self.strip_empty_entries(self.sources_files()), - extra_include_paths=self.strip_empty_entries(self.include_dirs()), - extra_cflags=self.cxx_flags(), - extra_cuda_cflags=self.nvcc_flags(), - extra_ldflags=[], - build_directory=build_directory, - verbose=verbose, - ) - - build_duration = time.time() - start_build - - # log jit compilation time - if verbose: - print_rank_0(f"[extension] Time to compile or load {self.name} op: {build_duration} seconds") - - # cache the built/loaded kernel - self.cached_op_module = op_module - - return op_module - - def builder(self) -> Union["CUDAExtension", "CppExtension"]: - """ - get a CUDAExtension instance used for setup.py - """ - from torch.utils.cpp_extension import CppExtension, CUDAExtension - - if self.ext_type == "cpp": - return CppExtension( - name=self.prebuilt_import_path, - sources=self.strip_empty_entries(self.sources_files()), - include_dirs=self.strip_empty_entries(self.include_dirs()), - extra_compile_args=self.strip_empty_entries(self.cxx_flags()), - ) - - return CUDAExtension( - name=self.prebuilt_import_path, - sources=self.strip_empty_entries(self.sources_files()), - include_dirs=self.strip_empty_entries(self.include_dirs()), - extra_compile_args={ - "cxx": self.strip_empty_entries(self.cxx_flags()), - "nvcc": self.strip_empty_entries(self.nvcc_flags()), - }, - ) diff --git a/op_builder/fused_optim.py b/op_builder/fused_optim.py deleted file mode 100644 index 3baa0880d..000000000 --- a/op_builder/fused_optim.py +++ /dev/null @@ -1,37 +0,0 @@ -from .builder import Builder -from .utils import get_cuda_cc_flag - - -class FusedOptimBuilder(Builder): - NAME = "fused_optim" - PREBUILT_IMPORT_PATH = "colossalai._C.fused_optim" - - def __init__(self): - super().__init__(name=FusedOptimBuilder.NAME, prebuilt_import_path=FusedOptimBuilder.PREBUILT_IMPORT_PATH) - - def sources_files(self): - ret = [ - self.csrc_abs_path(fname) - for fname in [ - "colossal_C_frontend.cpp", - "multi_tensor_sgd_kernel.cu", - "multi_tensor_scale_kernel.cu", - "multi_tensor_adam.cu", - "multi_tensor_l2norm_kernel.cu", - "multi_tensor_lamb.cu", - ] - ] - return ret - - def include_dirs(self): - ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] - return ret - - def cxx_flags(self): - version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] - return ["-O3"] + version_dependent_macros - - def nvcc_flags(self): - extra_cuda_flags = ["-lineinfo"] - extra_cuda_flags.extend(get_cuda_cc_flag()) - return ["-O3", "--use_fast_math"] + extra_cuda_flags diff --git a/op_builder/gptq.py b/op_builder/gptq.py deleted file mode 100644 index a17801f87..000000000 --- a/op_builder/gptq.py +++ /dev/null @@ -1,56 +0,0 @@ -import re - -import torch - -from .builder import Builder -from .utils import append_nvcc_threads - - -class GPTQBuilder(Builder): - NAME = "cu_gptq" - PREBUILT_IMPORT_PATH = "colossalai._C.cu_gptq" - - def __init__(self): - super().__init__(name=GPTQBuilder.NAME, prebuilt_import_path=GPTQBuilder.PREBUILT_IMPORT_PATH) - - def include_dirs(self): - ret = [self.csrc_abs_path("gptq"), self.get_cuda_home_include()] - return ret - - def sources_files(self): - ret = [ - self.csrc_abs_path(fname) - for fname in [ - "gptq/linear_gptq.cpp", - "gptq/column_remap.cu", - "gptq/cuda_buffers.cu", - "gptq/q4_matmul.cu", - "gptq/q4_matrix.cu", - ] - ] - return ret - - def cxx_flags(self): - return ["-O3"] + self.version_dependent_macros - - def nvcc_flags(self): - extra_cuda_flags = [ - "-v", - "-std=c++14", - "-std=c++17", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_HALF2_OPERATORS__", - "-DTHRUST_IGNORE_CUB_VERSION_CHECK", - "-lcublas", - ] - - for arch in torch.cuda.get_arch_list(): - res = re.search(r"sm_(\d+)", arch) - if res: - arch_cap = res[1] - if int(arch_cap) >= 80: - extra_cuda_flags.extend(["-gencode", f"arch=compute_{arch_cap},code={arch}"]) - - ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags - return append_nvcc_threads(ret) diff --git a/op_builder/layernorm.py b/op_builder/layernorm.py deleted file mode 100644 index 2684c6ddb..000000000 --- a/op_builder/layernorm.py +++ /dev/null @@ -1,27 +0,0 @@ -from .builder import Builder -from .utils import append_nvcc_threads, get_cuda_cc_flag - - -class LayerNormBuilder(Builder): - NAME = "layernorm" - PREBUILT_IMPORT_PATH = "colossalai._C.layernorm" - - def __init__(self): - super().__init__(name=LayerNormBuilder.NAME, prebuilt_import_path=LayerNormBuilder.PREBUILT_IMPORT_PATH) - - def sources_files(self): - ret = [self.csrc_abs_path(fname) for fname in ["layer_norm_cuda.cpp", "layer_norm_cuda_kernel.cu"]] - return ret - - def include_dirs(self): - ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] - return ret - - def cxx_flags(self): - return ["-O3"] + self.version_dependent_macros - - def nvcc_flags(self): - extra_cuda_flags = ["-maxrregcount=50"] - extra_cuda_flags.extend(get_cuda_cc_flag()) - ret = ["-O3", "--use_fast_math"] + extra_cuda_flags + self.version_dependent_macros - return append_nvcc_threads(ret) diff --git a/op_builder/multi_head_attn.py b/op_builder/multi_head_attn.py deleted file mode 100644 index cb8fc489c..000000000 --- a/op_builder/multi_head_attn.py +++ /dev/null @@ -1,46 +0,0 @@ -from .builder import Builder -from .utils import append_nvcc_threads, get_cuda_cc_flag - - -class MultiHeadAttnBuilder(Builder): - NAME = "multihead_attention" - PREBUILT_IMPORT_PATH = "colossalai._C.multihead_attention" - - def __init__(self): - super().__init__(name=MultiHeadAttnBuilder.NAME, prebuilt_import_path=MultiHeadAttnBuilder.PREBUILT_IMPORT_PATH) - - def include_dirs(self): - ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] - return ret - - def sources_files(self): - ret = [ - self.csrc_abs_path(fname) - for fname in [ - "multihead_attention_1d.cpp", - "kernels/cublas_wrappers.cu", - "kernels/transform_kernels.cu", - "kernels/dropout_kernels.cu", - "kernels/normalize_kernels.cu", - "kernels/softmax_kernels.cu", - "kernels/general_kernels.cu", - "kernels/cuda_util.cu", - ] - ] - return ret - - def cxx_flags(self): - return ["-O3"] + self.version_dependent_macros - - def nvcc_flags(self): - extra_cuda_flags = [ - "-std=c++14", - "-std=c++17", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_HALF2_OPERATORS__", - "-DTHRUST_IGNORE_CUB_VERSION_CHECK", - ] - extra_cuda_flags.extend(get_cuda_cc_flag()) - ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags - return append_nvcc_threads(ret) diff --git a/op_builder/scaled_upper_triangle_masked_softmax.py b/op_builder/scaled_upper_triangle_masked_softmax.py deleted file mode 100644 index 1445230ac..000000000 --- a/op_builder/scaled_upper_triangle_masked_softmax.py +++ /dev/null @@ -1,37 +0,0 @@ -from .builder import Builder -from .utils import append_nvcc_threads, get_cuda_cc_flag - - -class ScaledUpperTrainglemaskedSoftmaxBuilder(Builder): - NAME = "scaled_upper_triangle_masked_softmax" - PREBUILT_IMPORT_PATH = "colossalai._C.scaled_upper_triangle_masked_softmax" - - def __init__(self): - super().__init__( - name=ScaledUpperTrainglemaskedSoftmaxBuilder.NAME, - prebuilt_import_path=ScaledUpperTrainglemaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH, - ) - - def include_dirs(self): - return [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] - - def sources_files(self): - ret = [ - self.csrc_abs_path(fname) - for fname in ["scaled_upper_triang_masked_softmax.cpp", "scaled_upper_triang_masked_softmax_cuda.cu"] - ] - return ret - - def cxx_flags(self): - return ["-O3"] + self.version_dependent_macros - - def nvcc_flags(self): - extra_cuda_flags = [ - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - ] - extra_cuda_flags.extend(get_cuda_cc_flag()) - ret = ["-O3", "--use_fast_math"] + extra_cuda_flags - return append_nvcc_threads(ret) diff --git a/op_builder/smoothquant.py b/op_builder/smoothquant.py deleted file mode 100644 index d562a4c4f..000000000 --- a/op_builder/smoothquant.py +++ /dev/null @@ -1,52 +0,0 @@ -import torch - -from .builder import Builder -from .utils import append_nvcc_threads - - -class SmoothquantBuilder(Builder): - NAME = "cu_smoothquant" - PREBUILT_IMPORT_PATH = "colossalai._C.cu_smoothquant" - - def __init__(self): - super().__init__(name=SmoothquantBuilder.NAME, prebuilt_import_path=SmoothquantBuilder.PREBUILT_IMPORT_PATH) - - def include_dirs(self): - ret = [self.csrc_abs_path("smoothquant"), self.get_cuda_home_include()] - return ret - - def sources_files(self): - ret = [ - self.csrc_abs_path(fname) - for fname in [ - "smoothquant/binding.cpp", - "smoothquant/linear.cu", - ] - ] - return ret - - def cxx_flags(self): - return ["-O3"] + self.version_dependent_macros - - def nvcc_flags(self): - compute_capability = torch.cuda.get_device_capability() - cuda_arch = compute_capability[0] * 100 + compute_capability[1] * 10 - - extra_cuda_flags = [ - "-v", - f"-DCUDA_ARCH={cuda_arch}", - "-std=c++17", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_HALF2_OPERATORS__", - "-DTHRUST_IGNORE_CUB_VERSION_CHECK", - ] - - ret = ["-O3", "--use_fast_math"] + self.version_dependent_macros + extra_cuda_flags - return append_nvcc_threads(ret) - - def builder(self): - try: - super().builder() - except: - warnings.warn("build smoothquant lib not successful") diff --git a/op_builder/utils.py b/op_builder/utils.py deleted file mode 100644 index 3f75f952d..000000000 --- a/op_builder/utils.py +++ /dev/null @@ -1,229 +0,0 @@ -import os -import re -import subprocess -import warnings -from typing import List - - -def print_rank_0(message: str) -> None: - """ - Print on only one process to avoid spamming. - """ - try: - import torch.distributed as dist - - if not dist.is_initialized(): - is_main_rank = True - else: - is_main_rank = dist.get_rank() == 0 - except ImportError: - is_main_rank = True - - if is_main_rank: - print(message) - - -def get_cuda_version_in_pytorch() -> List[int]: - """ - This function returns the CUDA version in the PyTorch build. - - Returns: - The CUDA version required by PyTorch, in the form of tuple (major, minor). - """ - import torch - - try: - torch_cuda_major = torch.version.cuda.split(".")[0] - torch_cuda_minor = torch.version.cuda.split(".")[1] - except: - raise ValueError( - "[extension] Cannot retrieve the CUDA version in the PyTorch binary given by torch.version.cuda" - ) - return torch_cuda_major, torch_cuda_minor - - -def get_cuda_bare_metal_version(cuda_dir) -> List[int]: - """ - Get the System CUDA version from nvcc. - - Args: - cuda_dir (str): the directory for CUDA Toolkit. - - Returns: - The CUDA version required by PyTorch, in the form of tuple (major, minor). - """ - nvcc_path = os.path.join(cuda_dir, "bin/nvcc") - - if cuda_dir is None: - raise ValueError( - f"[extension] The argument cuda_dir is None, but expected to be a string. Please make sure your have exported the environment variable CUDA_HOME correctly." - ) - - # check for nvcc path - if not os.path.exists(nvcc_path): - raise FileNotFoundError( - f"[extension] The nvcc compiler is not found in {nvcc_path}, please make sure you have set the correct value for CUDA_HOME." - ) - - # parse the nvcc -v output to obtain the system cuda version - try: - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) - output = raw_output.split() - release_idx = output.index("release") + 1 - release = output[release_idx].split(".") - bare_metal_major = release[0] - bare_metal_minor = release[1][0] - except: - raise ValueError( - f"[extension] Failed to parse the nvcc output to obtain the system CUDA bare metal version. The output for 'nvcc -v' is \n{raw_output}" - ) - - return bare_metal_major, bare_metal_minor - - -def check_system_pytorch_cuda_match(cuda_dir): - bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir) - torch_cuda_major, torch_cuda_minor = get_cuda_version_in_pytorch() - - if bare_metal_major != torch_cuda_major: - raise Exception( - f"[extension] Failed to build PyTorch extension because the detected CUDA version ({bare_metal_major}.{bare_metal_minor}) " - f"mismatches the version that was used to compile PyTorch ({torch_cuda_major}.{torch_cuda_minor})." - "Please make sure you have set the CUDA_HOME correctly and installed the correct PyTorch in https://pytorch.org/get-started/locally/ ." - ) - - if bare_metal_minor != torch_cuda_minor: - warnings.warn( - f"[extension] The CUDA version on the system ({bare_metal_major}.{bare_metal_minor}) does not match with the version ({torch_cuda_major}.{torch_cuda_minor}) torch was compiled with. " - "The mismatch is found in the minor version. As the APIs are compatible, we will allow compilation to proceed. " - "If you encounter any issue when using the built kernel, please try to build it again with fully matched CUDA versions" - ) - return True - - -def get_pytorch_version() -> List[int]: - """ - This functions finds the PyTorch version. - - Returns: - A tuple of integers in the form of (major, minor, patch). - """ - import torch - - torch_version = torch.__version__.split("+")[0] - TORCH_MAJOR = int(torch_version.split(".")[0]) - TORCH_MINOR = int(torch_version.split(".")[1]) - TORCH_PATCH = int(torch_version.split(".")[2], 16) - return TORCH_MAJOR, TORCH_MINOR, TORCH_PATCH - - -def check_pytorch_version(min_major_version, min_minor_version) -> bool: - """ - Compare the current PyTorch version with the minium required version. - - Args: - min_major_version (int): the minimum major version of PyTorch required - min_minor_version (int): the minimum minor version of PyTorch required - - Returns: - A boolean value. The value is True if the current pytorch version is acceptable and False otherwise. - """ - # get pytorch version - torch_major, torch_minor, _ = get_pytorch_version() - - # if the - if torch_major < min_major_version or (torch_major == min_major_version and torch_minor < min_minor_version): - raise RuntimeError( - f"[extension] Colossal-AI requires Pytorch {min_major_version}.{min_minor_version} or newer.\n" - "The latest stable release can be obtained from https://pytorch.org/get-started/locally/" - ) - - -def check_cuda_availability(): - """ - Check if CUDA is available on the system. - - Returns: - A boolean value. True if CUDA is available and False otherwise. - """ - import torch - - return torch.cuda.is_available() - - -def set_cuda_arch_list(cuda_dir): - """ - This function sets the PyTorch TORCH_CUDA_ARCH_LIST variable for ahead-of-time extension compilation. - Ahead-of-time compilation occurs when CUDA_EXT=1 is set when running 'pip install'. - """ - cuda_available = check_cuda_availability() - - # we only need to set this when CUDA is not available for cross-compilation - if not cuda_available: - warnings.warn( - "\n[extension] PyTorch did not find available GPUs on this system.\n" - "If your intention is to cross-compile, this is not an error.\n" - "By default, Colossal-AI will cross-compile for \n" - "1. Pascal (compute capabilities 6.0, 6.1, 6.2),\n" - "2. Volta (compute capability 7.0)\n" - "3. Turing (compute capability 7.5),\n" - "4. Ampere (compute capability 8.0, 8.6)if the CUDA version is >= 11.0\n" - "\nIf you wish to cross-compile for a single specific architecture,\n" - 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n' - ) - - if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: - bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir) - - arch_list = ["6.0", "6.1", "6.2", "7.0", "7.5"] - - if int(bare_metal_major) == 11: - if int(bare_metal_minor) == 0: - arch_list.append("8.0") - else: - arch_list.append("8.0") - arch_list.append("8.6") - - arch_list_str = ";".join(arch_list) - os.environ["TORCH_CUDA_ARCH_LIST"] = arch_list_str - return False - return True - - -def get_cuda_cc_flag() -> List[str]: - """ - This function produces the cc flags for your GPU arch - - Returns: - The CUDA cc flags for compilation. - """ - - # only import torch when needed - # this is to avoid importing torch when building on a machine without torch pre-installed - # one case is to build wheel for pypi release - import torch - - cc_flag = [] - max_arch = "".join(str(i) for i in torch.cuda.get_device_capability()) - for arch in torch.cuda.get_arch_list(): - res = re.search(r"sm_(\d+)", arch) - if res: - arch_cap = res[1] - if int(arch_cap) >= 60 and int(arch_cap) <= int(max_arch): - cc_flag.extend(["-gencode", f"arch=compute_{arch_cap},code={arch}"]) - return cc_flag - - -def append_nvcc_threads(nvcc_extra_args: List[str]) -> List[str]: - """ - This function appends the threads flag to your nvcc args. - - Returns: - The nvcc compilation flags including the threads flag. - """ - from torch.utils.cpp_extension import CUDA_HOME - - bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) - if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: - return nvcc_extra_args + ["--threads", "4"] - return nvcc_extra_args diff --git a/setup.py b/setup.py index cda1ba7ee..1244bfff0 100644 --- a/setup.py +++ b/setup.py @@ -5,55 +5,23 @@ from typing import List from setuptools import find_packages, setup -from op_builder.utils import ( - check_cuda_availability, - check_pytorch_version, - check_system_pytorch_cuda_match, - get_cuda_bare_metal_version, - get_pytorch_version, - set_cuda_arch_list, -) - try: - from torch.utils.cpp_extension import CUDA_HOME, BuildExtension + import torch # noqa + from torch.utils.cpp_extension import BuildExtension TORCH_AVAILABLE = True except ImportError: TORCH_AVAILABLE = False - CUDA_HOME = None -# Some constants for installation checks -MIN_PYTORCH_VERSION_MAJOR = 1 -MIN_PYTORCH_VERSION_MINOR = 10 THIS_DIR = os.path.dirname(os.path.abspath(__file__)) -BUILD_CUDA_EXT = int(os.environ.get("CUDA_EXT", "0")) == 1 +BUILD_EXT = int(os.environ.get("BUILD_EXT", "0")) == 1 IS_NIGHTLY = int(os.environ.get("NIGHTLY", "0")) == 1 -# a variable to store the op builder -ext_modules = [] - # we do not support windows currently if sys.platform == "win32": raise RuntimeError("Windows is not supported yet. Please try again within the Windows Subsystem for Linux (WSL).") -# check for CUDA extension dependencies -def environment_check_for_cuda_extension_build(): - if not TORCH_AVAILABLE: - raise ModuleNotFoundError( - "[extension] PyTorch is not found while CUDA_EXT=1. You need to install PyTorch first in order to build CUDA extensions" - ) - - if not CUDA_HOME: - raise RuntimeError( - "[extension] CUDA_HOME is not found while CUDA_EXT=1. You need to export CUDA_HOME environment variable or install CUDA Toolkit first in order to build CUDA extensions" - ) - - check_system_pytorch_cuda_match(CUDA_HOME) - check_pytorch_version(MIN_PYTORCH_VERSION_MAJOR, MIN_PYTORCH_VERSION_MINOR) - check_cuda_availability() - - def fetch_requirements(path) -> List[str]: """ This function reads the requirements file. @@ -98,46 +66,35 @@ def get_version() -> str: # write version into version.py with open(version_py_path, "w") as f: f.write(f"__version__ = '{version}'\n") - - # look for pytorch and cuda version - if BUILD_CUDA_EXT: - torch_major, torch_minor, _ = get_pytorch_version() - torch_version = f"{torch_major}.{torch_minor}" - cuda_version = ".".join(get_cuda_bare_metal_version(CUDA_HOME)) - else: - torch_version = None - cuda_version = None - - # write the version into the python file - if torch_version: - f.write(f'torch = "{torch_version}"\n') - else: - f.write("torch = None\n") - - if cuda_version: - f.write(f'cuda = "{cuda_version}"\n') - else: - f.write("cuda = None\n") - return version -if BUILD_CUDA_EXT: - environment_check_for_cuda_extension_build() - set_cuda_arch_list(CUDA_HOME) +if BUILD_EXT: + if not TORCH_AVAILABLE: + raise ModuleNotFoundError( + "[extension] PyTorch is not found while CUDA_EXT=1. You need to install PyTorch first in order to build CUDA extensions" + ) - from op_builder import ALL_OPS + from extensions import ALL_EXTENSIONS op_names = [] + ext_modules = [] - # load all builders - for name, builder_cls in ALL_OPS.items(): - op_names.append(name) - ext_modules.append(builder_cls().builder()) + for ext_cls in ALL_EXTENSIONS: + ext = ext_cls() + if ext.support_aot and ext.is_hardware_available(): + ext.assert_hardware_compatible() + op_names.append(ext.name) + ext_modules.append(ext.build_aot()) # show log - op_name_list = ", ".join(op_names) - print(f"[extension] loaded builders for {op_name_list}") + if len(ext_modules) == 0: + raise RuntimeError("[extension] Could not find any kernel compatible with the current environment.") + else: + op_name_list = ", ".join(op_names) + print(f"[extension] Building extensions{op_name_list}") +else: + ext_modules = [] # always put not nightly branch as the if branch # otherwise github will treat colossalai-nightly as the project name diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index 49fd85ffb..708a1906b 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -7,7 +7,6 @@ from transformers import LlamaForCausalLM from utils import shared_tempdir import colossalai -from colossalai.testing import skip_if_not_enough_gpus from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin from colossalai.lazy import LazyInitContext @@ -17,6 +16,7 @@ from colossalai.testing import ( clear_cache_before_run, parameterize, rerun_if_address_is_in_use, + skip_if_not_enough_gpus, spawn, ) from tests.kit.model_zoo import model_zoo @@ -52,7 +52,12 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b bert_model.config.save_pretrained(save_directory=pretrained_path) extra_dp_size = dist.get_world_size() // (zero_size * tp_size) - plugin = GeminiPlugin(**placement_config, tp_size=tp_size, enable_all_optimization=enable_all_optimization, extra_dp_size=extra_dp_size) + plugin = GeminiPlugin( + **placement_config, + tp_size=tp_size, + enable_all_optimization=enable_all_optimization, + extra_dp_size=extra_dp_size, + ) booster = Booster(plugin=plugin) bert_model, _, _, _, _ = booster.boost(bert_model) model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2 @@ -78,7 +83,14 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha criterion = lambda x: x.mean() enable_all_optimization = True if tp_size > 1 else False extra_dp_size = dist.get_world_size() // (zero_size * tp_size) - plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), tp_size=tp_size, extra_dp_size=extra_dp_size, enable_all_optimization=enable_all_optimization) + plugin = GeminiPlugin( + **placement_config, + precision="fp16", + initial_scale=(2**14), + tp_size=tp_size, + extra_dp_size=extra_dp_size, + enable_all_optimization=enable_all_optimization, + ) booster = Booster(plugin=plugin) model = model_fn() @@ -161,8 +173,13 @@ def run_dist(rank, world_size, port): def test_gemini_ckpIO(): spawn(run_dist, 4) + @pytest.mark.largedist @skip_if_not_enough_gpus(min_gpus=8) @rerun_if_address_is_in_use() def test_gemini_ckpIO_3d(): - spawn(run_dist, 8) \ No newline at end of file + spawn(run_dist, 8) + + +if __name__ == "__main__": + test_gemini_ckpIO() diff --git a/tests/test_optimizer/test_adam_kernel.py b/tests/test_optimizer/test_adam_kernel.py index 2ff4b3016..6d932156a 100644 --- a/tests/test_optimizer/test_adam_kernel.py +++ b/tests/test_optimizer/test_adam_kernel.py @@ -65,9 +65,9 @@ class TorchAdamKernel(AdamKernel): class FusedAdamKernel(AdamKernel): def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None: super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw) - from colossalai.kernel.op_builder import FusedOptimBuilder + from colossalai.kernel.kernel_loader import FusedOptimizerLoader - fused_optim = FusedOptimBuilder().load() + fused_optim = FusedOptimizerLoader().load() self.fused_adam = fused_optim.multi_tensor_adam self.dummy_overflow_buf = torch.cuda.IntTensor([0]) @@ -91,7 +91,7 @@ class FusedAdamKernel(AdamKernel): class CPUAdamKernel(AdamKernel): def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None: super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw) - from colossalai.kernel import CPUAdamLoader + from colossalai.kernel.kernel_loader import CPUAdamLoader cpu_optim = CPUAdamLoader().load() diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py index 30a30e86a..3ec170004 100644 --- a/tests/test_utils/test_flash_attention.py +++ b/tests/test_utils/test_flash_attention.py @@ -8,7 +8,7 @@ from colossalai.kernel.extensions.flash_attention import HAS_FLASH_ATTN, HAS_MEM from colossalai.testing import clear_cache_before_run, parameterize if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN: - from colossalai.kernel import AttnMaskType, ColoAttention + from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention DTYPE = [torch.float16, torch.bfloat16, torch.float32]