From 7696cead8d4bd4f907137d8f34017e8963f7c9b9 Mon Sep 17 00:00:00 2001 From: binmakeswell Date: Wed, 13 Jul 2022 11:57:42 +0800 Subject: [PATCH] Recover kernal files --- .../csrc/kernels/dropout_kernels.cu | 5 +- .../csrc/kernels/include/kernels.h | 27 +- .../csrc/kernels/transform_kernels.cu | 14 +- .../cuda_native/csrc/multi_tensor_apply.cuh | 202 ++-- .../csrc/scaled_masked_softmax.cpp | 78 +- .../cuda_native/csrc/scaled_masked_softmax.h | 836 ++++++++--------- .../csrc/scaled_upper_triang_masked_softmax.h | 882 ++++++++---------- .../kernel/cuda_native/csrc/type_shim.h | 476 +++++----- 8 files changed, 1217 insertions(+), 1303 deletions(-) diff --git a/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu index 184106bd2..a39a6dae0 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu @@ -1,10 +1,11 @@ -#include - #include #include #include "kernels.h" +#include + + namespace cg = cooperative_groups; curandStatePhilox4_32_10_t *curandstate; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h b/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h index 735e1363c..fbb9c5465 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h @@ -3,11 +3,10 @@ #include #include #include +#include #include #include -#include - #define MAX_THREADS 1024 #define WARP_SIZE 32 @@ -133,9 +132,8 @@ __forceinline__ __host__ __device__ int flat_3dim(int id1, int id2, int id3, } /* Convert 4-dim tensor index into vector index */ -__forceinline__ __host__ __device__ int flat_4dim(int id1, int id2, int id3, - int id4, int dim2, int dim3, - int dim4) { +__forceinline__ __host__ __device__ int +flat_4dim(int id1, int id2, int id3, int id4, int dim2, int dim3, int dim4) { // return id1*(dim2*dim3*dim4) + id2*(dim3*dim4) + id3*dim4 + id4; int res = id4; @@ -203,9 +201,9 @@ __forceinline__ __host__ __device__ int flat_6dim(int id1, int id2, int id3, } /* Convert vector index to 6-dim tensor index */ -__forceinline__ __host__ __device__ void decompose_6dim( - int src, int dim1, int dim2, int dim3, int dim4, int dim5, int *id0, - int *id1, int *id2, int *id3, int *id4, int *id5) { +__forceinline__ __host__ __device__ void +decompose_6dim(int src, int dim1, int dim2, int dim3, int dim4, int dim5, + int *id0, int *id1, int *id2, int *id3, int *id4, int *id5) { *id5 = src % dim5; src /= dim5; @@ -223,11 +221,9 @@ __forceinline__ __host__ __device__ void decompose_6dim( } /* Convert vector index to 5-dim tensor index */ -__forceinline__ __host__ __device__ void decompose_5dim(int src, int dim1, - int dim2, int dim3, - int dim4, int *id0, - int *id1, int *id2, - int *id3, int *id4) { +__forceinline__ __host__ __device__ void +decompose_5dim(int src, int dim1, int dim2, int dim3, int dim4, int *id0, + int *id1, int *id2, int *id3, int *id4) { *id4 = src % dim4; src /= dim4; @@ -257,9 +253,8 @@ __forceinline__ __host__ __device__ void decompose_4dim(int src, int dim1, } /* Convert vector index to 3-dim tensor index */ -__forceinline__ __host__ __device__ void decompose_3dim(int src, int dim1, - int dim2, int *id0, - int *id1, int *id2) { +__forceinline__ __host__ __device__ void +decompose_3dim(int src, int dim1, int dim2, int *id0, int *id1, int *id2) { *id2 = src % dim2; src /= dim2; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu index d389d57e1..d03084b22 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu @@ -135,10 +135,9 @@ __global__ void bias_add_transform_20314(T *output, const T *input, const T *bias, int dim_3, int dim_4); template <> -__global__ void bias_add_transform_20314(float *output, - const float *input, - const float *bias, int dim_3, - int dim_4) { +__global__ void +bias_add_transform_20314(float *output, const float *input, + const float *bias, int dim_3, int dim_4) { int id0 = blockIdx.x; int id1 = blockIdx.y; int id2 = blockIdx.z; @@ -174,10 +173,9 @@ __global__ void bias_add_transform_20314(float *output, } template <> -__global__ void bias_add_transform_20314<__half>(__half *output, - const __half *input, - const __half *bias, int dim_3, - int dim_4) { +__global__ void +bias_add_transform_20314<__half>(__half *output, const __half *input, + const __half *bias, int dim_3, int dim_4) { int id0 = blockIdx.x; int id1 = blockIdx.y; int id2 = blockIdx.z; diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_apply.cuh b/colossalai/kernel/cuda_native/csrc/multi_tensor_apply.cuh index 0d46f513b..9ce411911 100644 --- a/colossalai/kernel/cuda_native/csrc/multi_tensor_apply.cuh +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_apply.cuh @@ -1,14 +1,13 @@ -// modified from -// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_apply.cuh +// modified from https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_apply.cuh #include #include #include #include -#include #include - #include "compat.h" +#include + // #include // This header is the one-stop shop for all your multi-tensor apply needs. @@ -18,108 +17,117 @@ constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30}; constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320}; template -struct TensorListMetadata { - void *addresses[n][depth_to_max_tensors[n - 1]]; - int sizes[depth_to_max_tensors[n - 1]]; - unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; - int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a - // full int. - int start_tensor_this_launch; +struct TensorListMetadata +{ + void *addresses[n][depth_to_max_tensors[n - 1]]; + int sizes[depth_to_max_tensors[n - 1]]; + unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; + int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a full int. + int start_tensor_this_launch; }; template -__global__ void multi_tensor_apply_kernel(int chunk_size, - volatile int *noop_flag, T tl, - U callable, ArgTypes... args) { - // Hand the chunk information to the user-supplied functor to process however - // it likes. - callable(chunk_size, noop_flag, tl, args...); +__global__ void multi_tensor_apply_kernel( + int chunk_size, + volatile int *noop_flag, + T tl, + U callable, + ArgTypes... args) +{ + // Hand the chunk information to the user-supplied functor to process however it likes. + callable(chunk_size, noop_flag, tl, args...); } template void multi_tensor_apply( - int block_size, int chunk_size, const at::Tensor &noop_flag, - const std::vector> &tensor_lists, T callable, - ArgTypes... args) { - TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth"); - int len0 = tensor_lists[0].size(); - TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0"); - auto ref_device = tensor_lists[0][0].device(); - TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda"); - for (int l = 0; l < tensor_lists.size(); - l++) // No range-based for because I need indices - { - TORCH_CHECK(tensor_lists[l].size() == len0, - "Size mismatch among tensor lists"); - for (int t = 0; t < tensor_lists[l].size(); t++) { - // TODO: Print which tensor fails. - bool contiguous_memory = tensor_lists[l][t].is_contiguous(); + int block_size, + int chunk_size, + const at::Tensor &noop_flag, + const std::vector> &tensor_lists, + T callable, + ArgTypes... args) +{ + TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth"); + int len0 = tensor_lists[0].size(); + TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0"); + auto ref_device = tensor_lists[0][0].device(); + TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda"); + for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices + { + TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists"); + for (int t = 0; t < tensor_lists[l].size(); t++) + { + // TODO: Print which tensor fails. + bool contiguous_memory = tensor_lists[l][t].is_contiguous(); #ifdef VERSION_GE_1_5 - contiguous_memory = - (contiguous_memory || - tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast)); + contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast)); #endif - TORCH_CHECK(contiguous_memory, "A tensor was not contiguous."); - TORCH_CHECK(tensor_lists[l][t].device() == ref_device, - "A tensor was not on the same device as the first tensor"); - TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), - "Size mismatch"); - } - } - - int ntensors = tensor_lists[0].size(); - - TensorListMetadata tl; - - const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0])); - auto stream = at::cuda::getCurrentCUDAStream(); - - tl.start_tensor_this_launch = 0; - int loc_block_info = 0; - int loc_tensor_info = 0; - for (int t = 0; t < ntensors; t++) { - tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel(); - for (int d = 0; d < depth; d++) - tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); - loc_tensor_info++; - - int chunks_this_tensor = - (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; - - for (int chunk = 0; chunk < chunks_this_tensor; chunk++) { - // std::cout << chunks_this_tensor << std::endl; - tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1; - tl.block_to_chunk[loc_block_info] = chunk; - loc_block_info++; - - bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] && - chunk == chunks_this_tensor - 1); - bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]); - bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1); - if (tensors_full || blocks_full || last_chunk) { - // using accscalar_t = acc_type; - multi_tensor_apply_kernel<<>>( - chunk_size, noop_flag.DATA_PTR(), tl, callable, args...); - - AT_CUDA_CHECK(cudaGetLastError()); - - // Reset. The control flow possibilities here make my brain hurt. - loc_block_info = 0; - if (chunk == chunks_this_tensor - 1) { - // std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 - // << std::endl; - loc_tensor_info = 0; - tl.start_tensor_this_launch = t + 1; - } else { - // std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 - // << std::endl; - tl.sizes[0] = tl.sizes[loc_tensor_info - 1]; - for (int d = 0; d < depth; d++) - tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1]; - loc_tensor_info = 1; - tl.start_tensor_this_launch = t; + TORCH_CHECK(contiguous_memory, "A tensor was not contiguous."); + TORCH_CHECK(tensor_lists[l][t].device() == ref_device, "A tensor was not on the same device as the first tensor"); + TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch"); + } + } + + int ntensors = tensor_lists[0].size(); + + TensorListMetadata tl; + + const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0])); + auto stream = at::cuda::getCurrentCUDAStream(); + + tl.start_tensor_this_launch = 0; + int loc_block_info = 0; + int loc_tensor_info = 0; + for (int t = 0; t < ntensors; t++) + { + tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel(); + for (int d = 0; d < depth; d++) + tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); + loc_tensor_info++; + + int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size; + + for (int chunk = 0; chunk < chunks_this_tensor; chunk++) + { + // std::cout << chunks_this_tensor << std::endl; + tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1; + tl.block_to_chunk[loc_block_info] = chunk; + loc_block_info++; + + bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] && + chunk == chunks_this_tensor - 1); + bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]); + bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1); + if (tensors_full || blocks_full || last_chunk) + { + // using accscalar_t = acc_type; + multi_tensor_apply_kernel<<>>( + chunk_size, + noop_flag.DATA_PTR(), + tl, + callable, + args...); + + AT_CUDA_CHECK(cudaGetLastError()); + + // Reset. The control flow possibilities here make my brain hurt. + loc_block_info = 0; + if (chunk == chunks_this_tensor - 1) + { + // std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 << std::endl; + loc_tensor_info = 0; + tl.start_tensor_this_launch = t + 1; + } + else + { + // std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 << std::endl; + tl.sizes[0] = tl.sizes[loc_tensor_info - 1]; + for (int d = 0; d < depth; d++) + tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1]; + loc_tensor_info = 1; + tl.start_tensor_this_launch = t; + } + } } - } } - } } \ No newline at end of file diff --git a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp index 8c2982b0c..4ae3c853c 100644 --- a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp +++ b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp @@ -3,68 +3,82 @@ #include #include - #include namespace multihead_attn { namespace fused_softmax { namespace scaled_masked_softmax { -torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, - float scale_factor); +torch::Tensor fwd_cuda( + torch::Tensor const& input, + torch::Tensor const& mask, + float scale_factor); -torch::Tensor bwd_cuda(torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor); +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor); -int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, - int attn_heads); +int get_batch_per_block_cuda( + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads); -torch::Tensor fwd(torch::Tensor const& input, torch::Tensor const& mask, - float scale_factor) { +torch::Tensor fwd( + torch::Tensor const& input, + torch::Tensor const& mask, + float scale_factor) { AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); return fwd_cuda(input, mask, scale_factor); } -torch::Tensor bwd(torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, float scale_factor) { +torch::Tensor bwd( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor) { + AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); return bwd_cuda(output_grads, softmax_results, scale_factor); } -int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, - int attn_heads) { - return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, - attn_heads); +int get_batch_per_block( + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads) { + return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads); } -} // end namespace scaled_masked_softmax -} // end namespace fused_softmax -} // end namespace multihead_attn +} // end namespace scaled_masked_softmax +} // end namespace fused_softmax +} // end namespace multihead_attn PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, - "Self Multihead Attention scaled, time masked softmax -- Forward."); + m.def("forward", + &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, + "Self Multihead Attention scaled, time masked softmax -- Forward."); - m.def("backward", &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, - "Self Multihead Attention scaled, time masked softmax -- Backward."); + m.def("backward", + &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, + "Self Multihead Attention scaled, time masked softmax -- Backward."); m.def("get_batch_per_block", - &multihead_attn::fused_softmax::scaled_masked_softmax:: - get_batch_per_block, - "Return Batch per block size."); + &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block, + "Return Batch per block size." + ); } diff --git a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h index d3e6f04e6..1583030b8 100644 --- a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h +++ b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h @@ -4,12 +4,12 @@ #pragma once #include -#include #include -#include - #include #include +#include +#include +#include namespace { @@ -17,53 +17,37 @@ template __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); template <> -__device__ __inline__ void copy_vector( - c10::BFloat16 *dst, const c10::BFloat16 *src) { - *dst = *src; -} +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } template <> -__device__ __inline__ void copy_vector( - c10::BFloat16 *dst, const c10::BFloat16 *src) { - *((float2 *)dst) = *((float2 *)src); -} +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } template <> -__device__ __inline__ void copy_vector(c10::Half *dst, - const c10::Half *src) { - *dst = *src; -} +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } template <> -__device__ __inline__ void copy_vector(c10::Half *dst, - const c10::Half *src) { - *((float2 *)dst) = *((float2 *)src); -} +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } template <> -__device__ __inline__ void copy_vector(uint8_t *dst, - const uint8_t *src) { - *dst = *src; -} +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } template <> -__device__ __inline__ void copy_vector(uint8_t *dst, - const uint8_t *src) { - *((half2 *)dst) = *((half2 *)src); -} +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } int log2_ceil(int value) { - int log2_value = 0; - while ((1 << log2_value) < value) ++log2_value; - return log2_value; + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; } -template +template struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { return a + b; } + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } }; -template +template struct Max { __device__ __forceinline__ T operator()(T a, T b) const { return a < b ? b : a; @@ -71,468 +55,438 @@ struct Max { }; template -__device__ __forceinline__ T -WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, - unsigned int mask = 0xffffffff) { +__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) +{ #if CUDA_VERSION >= 9000 - return __shfl_xor_sync(mask, value, laneMask, width); + return __shfl_xor_sync(mask, value, laneMask, width); #else - return __shfl_xor(value, laneMask, width); + return __shfl_xor(value, laneMask, width); #endif } -template class ReduceOp> -__device__ __forceinline__ void warp_reduce(acc_t *sum) { - ReduceOp r; -#pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); - sum[i] = r(sum[i], b); +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t* sum) { + ReduceOp r; + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); + } } - } } /* - * Extended softmax (from native aten pytorch) with following additional - * features 1) input scaling 2) Explicit masking - */ -template + * Extended softmax (from native aten pytorch) with following additional features + * 1) input scaling + * 2) Explicit masking + */ +template __global__ void scaled_masked_softmax_warp_forward( - output_t *dst, const input_t *src, const uint8_t *mask, const acc_t scale, - int micro_batch_size, int element_count, int pad_batches) { - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_forward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + output_t *dst, + const input_t *src, + const uint8_t *mask, + const acc_t scale, + int micro_batch_size, + int element_count, + int pad_batches) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) - // gridDim/blockIdx = (seq_len, attn_heads, batches) - int first_batch = - (blockDim.y * - (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) + - threadIdx.y) * - WARP_BATCH; - int pad_first_batch = 0; - if (pad_batches != 1) { // bert style - pad_first_batch = - (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * - WARP_BATCH; - } else { // gpt2 style - pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - } + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH; + int pad_first_batch = 0; + if (pad_batches != 1) { // bert style + pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH; + } else { // gpt2 style + pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + } - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; - // there might be multiple batches per warp. compute the index within the - // batch - int local_idx = threadIdx.x; + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; - src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - // load data from global memory - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; - input_t temp_data[ELEMENTS_PER_LDG_STG]; - uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; + uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - int itr_idx = i * element_count + it * WARP_SIZE; - copy_vector(temp_data, src + itr_idx); - copy_vector(temp_mask, mask + itr_idx); + if (element_index < batch_element_count) { + int itr_idx = i*element_count+it*WARP_SIZE; + copy_vector(temp_data, src + itr_idx); + copy_vector(temp_mask, mask + itr_idx); -#pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (temp_mask[element] != 1) { - elements[i][it + element] = (acc_t)temp_data[element] * scale; - } else { - elements[i][it + element] = -10000.0; - } + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (temp_mask[element] != 1) { + elements[i][it + element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -10000.0; + } + } + } else { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } } - } else { -#pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements[i][it + element] = -std::numeric_limits::infinity(); + } + + // compute max_value + acc_t max_value[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; } - } } - } + warp_reduce(max_value); - // compute max_value - acc_t max_value[WARP_BATCH]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; -#pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - max_value[i] = - (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } - } - warp_reduce(max_value); - - acc_t sum[WARP_BATCH]{0.0f}; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - elements[i][it] = std::exp((elements[i][it] - max_value[i])); - sum[i] += elements[i][it]; - } - } - warp_reduce(sum); - - // store result - output_t out[ELEMENTS_PER_LDG_STG]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) break; -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { -#pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = elements[i][it + element] / sum[i]; + acc_t sum[WARP_BATCH] { 0.0f }; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = elements[i][it + element] / sum[i]; + } + copy_vector(dst + i * element_count + it * WARP_SIZE, out); + } else { + break; + } } - copy_vector( - dst + i * element_count + it * WARP_SIZE, out); - } else { - break; - } } - } } -template +template __global__ void scaled_masked_softmax_warp_backward( - output_t *gradInput, input_t *grad, const input_t *output, acc_t scale, - int micro_batch_size, int element_count) { - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_backward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + output_t *gradInput, + input_t *grad, + const input_t *output, + acc_t scale, + int micro_batch_size, + int element_count) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) - // gridDim/blockIdx = (seq_len, attn_heads, batches) - int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; - // there might be multiple batches per warp. compute the index within the - // batch - int local_idx = threadIdx.x; + // the first element to process by the current thread + int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; - // the first element to process by the current thread - int thread_offset = - first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - grad += thread_offset; - output += thread_offset; - gradInput += thread_offset; + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; - // load data from global memory - acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; - acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; - input_t temp_grad[ELEMENTS_PER_LDG_STG]; - input_t temp_output[ELEMENTS_PER_LDG_STG]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector(temp_grad, grad + i * element_count + it * WARP_SIZE); + copy_vector(temp_output, output + i * element_count + it * WARP_SIZE); -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - copy_vector( - temp_grad, grad + i * element_count + it * WARP_SIZE); - copy_vector( - temp_output, output + i * element_count + it * WARP_SIZE); - -#pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - output_reg[i][it + element] = (acc_t)temp_output[element]; + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + output_reg[i][it + element] = (acc_t)temp_output[element]; + } + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; + } + } } -#pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - grad_reg[i][it + element] = - (acc_t)temp_grad[element] * output_reg[i][it + element]; + } + + acc_t sum[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; } - } } - } + warp_reduce(sum); - acc_t sum[WARP_BATCH]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] = grad_reg[i][0]; -#pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_reg[i][it]; - } - } - warp_reduce(sum); - -// store result -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) break; -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // compute gradients - output_t out[ELEMENTS_PER_LDG_STG]; -#pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = - (output_t)(scale * (grad_reg[i][it + element] - - output_reg[i][it + element] * sum[i])); + // store result + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); + } + copy_vector(gradInput + i * element_count + it * WARP_SIZE, out); + } } - copy_vector( - gradInput + i * element_count + it * WARP_SIZE, out); - } } - } } -} // end of anonymous namespace +} // end of anonymous namespace -int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, - int attn_heads) { - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - - int warp_size = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - constexpr int threads_per_block = 128; - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - - return batches_per_block; -} - -template -void dispatch_scaled_masked_softmax_forward(output_t *dst, const input_t *src, - const uint8_t *mask, - const input_t scale, - int query_seq_len, int key_seq_len, - int batches, int attn_heads, - int pad_batches) { - TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048); - if (key_seq_len == 0) { - return; - } else { +int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){ int log2_elements = log2_ceil(key_seq_len); const int next_power_of_two = 1 << log2_elements; - int batch_count = batches * attn_heads * query_seq_len; - // This value must match the WARP_SIZE constexpr value computed inside - // softmax_warp_forward. - int warp_size = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside - // softmax_warp_forward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - // use 128 threads per block to maximimize gpu utilization constexpr int threads_per_block = 128; - int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(query_seq_len % batches_per_block == 0); - dim3 blocks(query_seq_len / batches_per_block, attn_heads, batches); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_masked_softmax_warp_forward - <<>>( - dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 1: // 2 - scaled_masked_softmax_warp_forward - <<>>( - dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 2: // 4 - scaled_masked_softmax_warp_forward - <<>>( - dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 3: // 8 - scaled_masked_softmax_warp_forward - <<>>( - dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 4: // 16 - scaled_masked_softmax_warp_forward - <<>>( - dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 5: // 32 - scaled_masked_softmax_warp_forward - <<>>( - dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 6: // 64 - scaled_masked_softmax_warp_forward - <<>>( - dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 7: // 128 - scaled_masked_softmax_warp_forward - <<>>( - dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 8: // 256 - scaled_masked_softmax_warp_forward - <<>>( - dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 9: // 512 - scaled_masked_softmax_warp_forward - <<>>( - dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 10: // 1024 - scaled_masked_softmax_warp_forward - <<>>( - dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 11: // 2048 - scaled_masked_softmax_warp_forward - <<>>( - dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - default: - break; - } - } + + return batches_per_block; } -template -void dispatch_scaled_masked_softmax_backward(output_t *grad_input, - input_t *grad, - const input_t *output, - const acc_t scale, - int query_seq_len, int key_seq_len, - int batches, int attn_heads) { - TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048); - if (key_seq_len == 0) { - return; - } else { - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - int batch_count = batches * attn_heads * query_seq_len; +template +void dispatch_scaled_masked_softmax_forward( + output_t *dst, + const input_t *src, + const uint8_t *mask, + const input_t scale, + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads, + int pad_batches) +{ + TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 ); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; - // This value must match the WARP_SIZE constexpr value computed inside - // softmax_warp_backward. - int warp_size = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - // This value must match the WARP_BATCH constexpr value computed inside - // softmax_warp_backward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = batch_count / batches_per_block; - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 1: // 2 - scaled_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 2: // 4 - scaled_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 3: // 8 - scaled_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 4: // 16 - scaled_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 5: // 32 - scaled_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 6: // 64 - scaled_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 7: // 128 - scaled_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 8: // 256 - scaled_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 9: // 512 - scaled_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 10: // 1024 - scaled_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 11: // 2048 - scaled_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, key_seq_len); - break; - default: - break; + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0); + dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 1: // 2 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 2: // 4 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 3: // 8 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 4: // 16 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 5: // 32 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 6: // 64 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 7: // 128 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 8: // 256 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 9: // 512 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 10: // 1024 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 11: // 2048 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + default: + break; + } + } +} + +template +void dispatch_scaled_masked_softmax_backward( + output_t *grad_input, + input_t *grad, + const input_t *output, + const acc_t scale, + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads) +{ + TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 ); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = batch_count/batches_per_block; + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 1: // 2 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 2: // 4 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 3: // 8 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 4: // 16 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 5: // 32 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 6: // 64 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 7: // 128 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 8: // 256 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 9: // 512 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 10: // 1024 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 11: // 2048 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + default: + break; + } } - } } diff --git a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h b/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h index 54c8e9133..3af487f9d 100644 --- a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h +++ b/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h @@ -4,12 +4,11 @@ #pragma once #include -#include #include -#include - #include #include +#include +#include namespace { @@ -17,78 +16,53 @@ template __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); template <> -__device__ __inline__ void copy_vector( - c10::BFloat16 *dst, const c10::BFloat16 *src) { - *dst = *src; -} +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } template <> -__device__ __inline__ void copy_vector( - c10::BFloat16 *dst, const c10::BFloat16 *src) { - *((float2 *)dst) = *((float2 *)src); -} +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } template <> -__device__ __inline__ void copy_vector(c10::Half *dst, - const c10::Half *src) { - *dst = *src; -} +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } template <> -__device__ __inline__ void copy_vector(c10::Half *dst, - const c10::Half *src) { - *((float2 *)dst) = *((float2 *)src); -} +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } template <> -__device__ __inline__ void copy_vector(uint8_t *dst, - const uint8_t *src) { - *dst = *src; -} - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, - const uint8_t *src) { - *((half2 *)dst) = *((half2 *)src); -} +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } template __device__ __inline__ void copy_zero_vector(Datatype *dst); template <> -__device__ __inline__ void copy_zero_vector( - c10::BFloat16 *dst) { - *dst = 0.0; -} +__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *dst = 0.0; } template <> -__device__ __inline__ void copy_zero_vector( - c10::BFloat16 *dst) { - *((float2 *)dst) = make_float2(0.0f, 0.0f); -} +__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } template <> -__device__ __inline__ void copy_zero_vector(c10::Half *dst) { - *dst = 0.0; -} +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *dst = 0.0; } template <> -__device__ __inline__ void copy_zero_vector(c10::Half *dst) { - *((float2 *)dst) = make_float2(0.0f, 0.0f); -} +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } + int log2_ceil(int value) { - int log2_value = 0; - while ((1 << log2_value) < value) ++log2_value; - return log2_value; + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; } -template +template struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { return a + b; } + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } }; -template +template struct Max { __device__ __forceinline__ T operator()(T a, T b) const { return a < b ? b : a; @@ -96,505 +70,431 @@ struct Max { }; template -__device__ __forceinline__ T -WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, - unsigned int mask = 0xffffffff) { +__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) +{ #if CUDA_VERSION >= 9000 - return __shfl_xor_sync(mask, value, laneMask, width); + return __shfl_xor_sync(mask, value, laneMask, width); #else - return __shfl_xor(value, laneMask, width); + return __shfl_xor(value, laneMask, width); #endif } -template class ReduceOp> -__device__ __forceinline__ void warp_reduce(acc_t *sum) { - ReduceOp r; -#pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); - sum[i] = r(sum[i], b); +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t* sum) { + ReduceOp r; + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); + } } - } } /* - * Extended softmax (from native aten pytorch) with following additional - * features 1) input scaling 2) Implicit time (diagonal masking) + * Extended softmax (from native aten pytorch) with following additional features + * 1) input scaling + * 2) Implicit time (diagonal masking) */ -template +template __global__ void scaled_upper_triang_masked_softmax_warp_forward( - output_t *dst, const input_t *src, const acc_t scale, int micro_batch_size, - int stride, int element_count) { - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_forward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + output_t *dst, + const input_t *src, + const acc_t scale, + int micro_batch_size, + int stride, + int element_count) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - int first_batch = - (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + - blockIdx.x; - int local_seq = blockIdx.x + 1; - int warp_iteration_limit = - (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1) / WARP_SIZE; + int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; + int local_seq = blockIdx.x + 1; + int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE; - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; - // there might be multiple batches per warp. compute the index within the - // batch - int local_idx = threadIdx.x; + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; - src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - // load data from global memory - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; - input_t temp_data[ELEMENTS_PER_LDG_STG]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : local_seq; + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - copy_vector( - temp_data, src + i * element_count * stride + it * WARP_SIZE); + if (element_index < batch_element_count) { + copy_vector(temp_data, src + i*element_count*stride + it*WARP_SIZE); -#pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if ((element_index + element) < batch_element_count) { - elements[i][it + element] = (acc_t)temp_data[element] * scale; - } else { - elements[i][it + element] = -std::numeric_limits::infinity(); - } + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if ((element_index + element) < batch_element_count) { + elements[i][it+element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } else { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } } - } else { -#pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements[i][it + element] = -std::numeric_limits::infinity(); + } + + // compute max_value + acc_t max_value[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; } - } } - } + warp_reduce(max_value); - // compute max_value - acc_t max_value[WARP_BATCH]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; -#pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - max_value[i] = - (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } - } - warp_reduce(max_value); - - acc_t sum[WARP_BATCH]{0.0f}; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - if (it < warp_iteration_limit) { - elements[i][it] = std::exp((elements[i][it] - max_value[i])); - sum[i] += elements[i][it]; - } - } - } - warp_reduce(sum); - - // store result - output_t out[ELEMENTS_PER_LDG_STG]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) break; -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < local_seq) { -#pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < local_seq) { - out[element] = elements[i][it + element] / sum[i]; - } else { - out[element] = 0; - } + acc_t sum[WARP_BATCH] { 0.0f }; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + if (it < warp_iteration_limit) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } + } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < local_seq) { + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < local_seq) { + out[element] = elements[i][it + element] / sum[i]; + } else { + out[element] = 0; + } + } + copy_vector(dst + i * element_count * stride + it * WARP_SIZE, out); + } else if (element_index < element_count) { + copy_zero_vector(dst + i * element_count * stride + it * WARP_SIZE); + } else { + break; + } } - copy_vector( - dst + i * element_count * stride + it * WARP_SIZE, out); - } else if (element_index < element_count) { - copy_zero_vector( - dst + i * element_count * stride + it * WARP_SIZE); - } else { - break; - } } - } } -template +template __global__ void scaled_upper_triang_masked_softmax_warp_backward( - output_t *gradInput, input_t *grad, const input_t *output, acc_t scale, - int micro_batch_size, int stride, int element_count) { - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_backward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + output_t *gradInput, + input_t *grad, + const input_t *output, + acc_t scale, + int micro_batch_size, + int stride, + int element_count) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - int first_batch = - (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + - blockIdx.x; - int local_seq = blockIdx.x + 1; + int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; + int local_seq = blockIdx.x + 1; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; - // there might be multiple batches per warp. compute the index within the - // batch - int local_idx = threadIdx.x; + // the first element to process by the current thread + int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; - // the first element to process by the current thread - int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - grad += thread_offset; - output += thread_offset; - gradInput += thread_offset; + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; - // load data from global memory - acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; - acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; - input_t temp_grad[ELEMENTS_PER_LDG_STG]; - input_t temp_output[ELEMENTS_PER_LDG_STG]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : local_seq; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector(temp_grad, grad + i * element_count * stride + it * WARP_SIZE); + copy_vector(temp_output, output + i * element_count * stride + it * WARP_SIZE); -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - copy_vector( - temp_grad, grad + i * element_count * stride + it * WARP_SIZE); - copy_vector( - temp_output, output + i * element_count * stride + it * WARP_SIZE); - -#pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < batch_element_count) { - output_reg[i][it + element] = (acc_t)temp_output[element]; - } + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + output_reg[i][it + element] = (acc_t)temp_output[element]; + } + } + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; + } + } + } } -#pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < batch_element_count) { - grad_reg[i][it + element] = - (acc_t)temp_grad[element] * output_reg[i][it + element]; - } + } + + acc_t sum[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; } - } } - } + warp_reduce(sum); - acc_t sum[WARP_BATCH]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] = grad_reg[i][0]; -#pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_reg[i][it]; - } - } - warp_reduce(sum); - -// store result -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) break; -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // compute gradients - output_t out[ELEMENTS_PER_LDG_STG]; -#pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = - (output_t)(scale * (grad_reg[i][it + element] - - output_reg[i][it + element] * sum[i])); + // store result + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); + } + copy_vector(gradInput + i * element_count * stride + it * WARP_SIZE, out); + } } - copy_vector( - gradInput + i * element_count * stride + it * WARP_SIZE, out); - } } - } } -} // end of anonymous namespace +} // end of anonymous namespace -template +template void dispatch_scaled_upper_triang_masked_softmax_forward( - output_t *dst, const input_t *src, const input_t scale, - int softmax_elements, int softmax_elements_stride, int attn_batches) { - TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048); - if (softmax_elements == 0) { - return; - } else { - int log2_elements = log2_ceil(softmax_elements); - const int next_power_of_two = 1 << log2_elements; - int seq_len = softmax_elements; - int batch_count = attn_batches * seq_len; + output_t *dst, + const input_t *src, + const input_t scale, + int softmax_elements, + int softmax_elements_stride, + int attn_batches) +{ + TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 ); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; - // This value must match the WARP_SIZE constexpr value computed inside - // softmax_warp_forward. - int warp_size = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - // This value must match the WARP_BATCH constexpr value computed inside - // softmax_warp_forward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); - int blocks_per_seq = attn_batches / batches_per_block; - dim3 blocks(seq_len, blocks_per_seq, 1); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_upper_triang_masked_softmax_warp_forward - <<>>( - dst, src, scale, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 1: // 2 - scaled_upper_triang_masked_softmax_warp_forward - <<>>( - dst, src, scale, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 2: // 4 - scaled_upper_triang_masked_softmax_warp_forward - <<>>( - dst, src, scale, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 3: // 8 - scaled_upper_triang_masked_softmax_warp_forward - <<>>( - dst, src, scale, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 4: // 16 - scaled_upper_triang_masked_softmax_warp_forward - <<>>( - dst, src, scale, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 5: // 32 - scaled_upper_triang_masked_softmax_warp_forward - <<>>( - dst, src, scale, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 6: // 64 - scaled_upper_triang_masked_softmax_warp_forward - <<>>( - dst, src, scale, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 7: // 128 - scaled_upper_triang_masked_softmax_warp_forward - <<>>( - dst, src, scale, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 8: // 256 - scaled_upper_triang_masked_softmax_warp_forward - <<>>( - dst, src, scale, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 9: // 512 - scaled_upper_triang_masked_softmax_warp_forward - <<>>( - dst, src, scale, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 10: // 1024 - scaled_upper_triang_masked_softmax_warp_forward - <<>>( - dst, src, scale, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 11: // 2048 - scaled_upper_triang_masked_softmax_warp_forward - <<>>( - dst, src, scale, batch_count, softmax_elements_stride, - softmax_elements); - break; - default: - break; + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + default: + break; + } } - } } -template +template void dispatch_scaled_upper_triang_masked_softmax_backward( - output_t *grad_input, input_t *grad, const input_t *output, - const acc_t scale, int softmax_elements, int softmax_elements_stride, - int attn_batches) { - TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048); - if (softmax_elements == 0) { - return; - } else { - int log2_elements = log2_ceil(softmax_elements); - const int next_power_of_two = 1 << log2_elements; - int seq_len = softmax_elements; - int batch_count = attn_batches * seq_len; + output_t *grad_input, + input_t *grad, + const input_t *output, + const acc_t scale, + int softmax_elements, + int softmax_elements_stride, + int attn_batches) +{ + TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 ); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; - // This value must match the WARP_SIZE constexpr value computed inside - // softmax_warp_backward. - int warp_size = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - // This value must match the WARP_BATCH constexpr value computed inside - // softmax_warp_backward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); - int blocks_per_seq = attn_batches / batches_per_block; - dim3 blocks(seq_len, blocks_per_seq, 1); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_upper_triang_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 1: // 2 - scaled_upper_triang_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 2: // 4 - scaled_upper_triang_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 3: // 8 - scaled_upper_triang_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 4: // 16 - scaled_upper_triang_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 5: // 32 - scaled_upper_triang_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 6: // 64 - scaled_upper_triang_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 7: // 128 - scaled_upper_triang_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 8: // 256 - scaled_upper_triang_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 9: // 512 - scaled_upper_triang_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 10: // 1024 - scaled_upper_triang_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 11: // 2048 - scaled_upper_triang_masked_softmax_warp_backward - <<>>( - grad_input, grad, output, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - default: - break; + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + default: + break; + } } - } } diff --git a/colossalai/kernel/cuda_native/csrc/type_shim.h b/colossalai/kernel/cuda_native/csrc/type_shim.h index 39de96374..cf83414af 100644 --- a/colossalai/kernel/cuda_native/csrc/type_shim.h +++ b/colossalai/kernel/cuda_native/csrc/type_shim.h @@ -1,63 +1,76 @@ #include - #include "compat.h" -#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ - switch (TYPE) { \ - case at::ScalarType::Half: { \ - using scalar_t = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: { \ - using scalar_t = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } + +#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ + switch(TYPE) \ + { \ + case at::ScalarType::Half: \ + { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + + #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ - switch (TYPEIN) { \ - case at::ScalarType::Float: { \ - using scalar_t_in = float; \ - switch (TYPEOUT) { \ - case at::ScalarType::Float: { \ - using scalar_t_out = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: { \ - using scalar_t_out = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: { \ - using scalar_t_out = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ - } \ - break; \ - } \ - case at::ScalarType::Half: { \ - using scalar_t_in = at::Half; \ - using scalar_t_out = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: { \ - using scalar_t_in = at::BFloat16; \ - using scalar_t_out = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ - } + switch(TYPEIN) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_in = float; \ + switch(TYPEOUT) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_out = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ + } \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_in = at::Half; \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_in = at::BFloat16; \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ + } // Forward/backward compatiblity hack around // https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288 @@ -68,191 +81,222 @@ // TypeShim(const at::Type& type) : payload(type) {} // // Enable trivial conversion to a const at::Type& for pre-3aeb78 // operator const at::Type&(){ return payload; }; -// // Enable dispatch switch statements to take *this directly for post-3aeb78 +// // Enable dispatch switch statements to take *this directly for post-3aeb78 // //operator at::ScalarType(){ return payload.; }; // }; -#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ - switch (TYPE) { \ - case at::ScalarType::Float: { \ - using scalar_t_##LEVEL = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: { \ - using scalar_t_##LEVEL = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } +#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ + switch (TYPE) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } -#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \ - switch (TYPE) { \ - case at::ScalarType::Float: { \ - using scalar_t_##LEVEL = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: { \ - using scalar_t_##LEVEL = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Byte: { \ - using scalar_t_##LEVEL = uint8_t; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } +#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \ + switch (TYPE) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Byte: \ + { \ + using scalar_t_##LEVEL = uint8_t; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } -#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ - switch (TYPE) { \ - case at::ScalarType::Double: { \ - using scalar_t_##LEVEL = double; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Float: { \ - using scalar_t_##LEVEL = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: { \ - using scalar_t_##LEVEL = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } +#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ + switch (TYPE) \ + { \ + case at::ScalarType::Double: \ + { \ + using scalar_t_##LEVEL = double; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Float: \ + { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } -#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \ - switch (TYPE) { \ - case at::ScalarType::Double: { \ - using scalar_t_##LEVEL = double; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Float: { \ - using scalar_t_##LEVEL = float; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } +#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \ + switch (TYPE) \ + { \ + case at::ScalarType::Double: \ + { \ + using scalar_t_##LEVEL = double; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Float: \ + { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } -#define DISPATCH_FLOAT_AND_HALF_FOR_G_P(GTYPE, PTYPE, LEVEL, NAME, ...) \ - if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Float) { \ - using g_scalar_t_##LEVEL = float; \ - using p_scalar_t_##LEVEL = float; \ - __VA_ARGS__; \ - } else if (GTYPE == at::ScalarType::Float && \ - PTYPE == at::ScalarType::Half) { \ - using g_scalar_t_##LEVEL = float; \ - using p_scalar_t_##LEVEL = at::Half; \ - __VA_ARGS__; \ - } else if (GTYPE == at::ScalarType::Half && \ - PTYPE == at::ScalarType::Float) { \ - using g_scalar_t_##LEVEL = at::Half; \ - using p_scalar_t_##LEVEL = float; \ - __VA_ARGS__; \ - } else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Half) { \ - using g_scalar_t_##LEVEL = at::Half; \ - using p_scalar_t_##LEVEL = at::Half; \ - __VA_ARGS__; \ - } else { \ - AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \ - "'"); \ - } +#define DISPATCH_FLOAT_AND_HALF_FOR_G_P(GTYPE, PTYPE, LEVEL, NAME, ...) \ + if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Float) \ + { \ + using g_scalar_t_##LEVEL = float; \ + using p_scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + } \ + else if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Half) \ + { \ + using g_scalar_t_##LEVEL = float; \ + using p_scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + } \ + else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Float) \ + { \ + using g_scalar_t_##LEVEL = at::Half; \ + using p_scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + } \ + else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Half) \ + { \ + using g_scalar_t_##LEVEL = at::Half; \ + using p_scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + } \ + else \ + { \ + AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), "'"); \ + } \ template -__device__ __forceinline__ T reduce_block_into_lanes( - T *x, T val, int lanes = 1, - bool share_result = false) // lanes is intended to be <= 32. +__device__ __forceinline__ T reduce_block_into_lanes(T *x, + T val, + int lanes = 1, + bool share_result = false) // lanes is intended to be <= 32. { - int tid = threadIdx.x + threadIdx.y * blockDim.x; - int blockSize = - blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. + int tid = threadIdx.x + threadIdx.y * blockDim.x; + int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. - if (blockSize >= 64) { - x[tid] = val; - __syncthreads(); - } - -#pragma unroll - for (int i = (blockSize >> 1); i >= 64; i >>= 1) { - if (tid < i) x[tid] = x[tid] + x[tid + i]; - __syncthreads(); - } - - T final; - - if (tid < 32) { if (blockSize >= 64) - final = x[tid] + x[tid + 32]; - else - final = val; - // __SYNCWARP(); + { + x[tid] = val; + __syncthreads(); + } #pragma unroll - for (int i = 16; i >= lanes; i >>= 1) - final = final + __shfl_down_sync(0xffffffff, final, i); - } + for (int i = (blockSize >> 1); i >= 64; i >>= 1) + { + if (tid < i) + x[tid] = x[tid] + x[tid + i]; + __syncthreads(); + } - if (share_result) { - if (tid < lanes) x[tid] = final; // EpilogueOp - // Make sure the smem result is visible to all warps. - __syncthreads(); - } + T final; - return final; + if (tid < 32) + { + if (blockSize >= 64) + final = x[tid] + x[tid + 32]; + else + final = val; + // __SYNCWARP(); + +#pragma unroll + for (int i = 16; i >= lanes; i >>= 1) + final = final + __shfl_down_sync(0xffffffff, final, i); + } + + if (share_result) + { + if (tid < lanes) + x[tid] = final; // EpilogueOp + // Make sure the smem result is visible to all warps. + __syncthreads(); + } + + return final; } template -__device__ __forceinline__ T reduce_block_into_lanes_max_op( - T *x, T val, int lanes = 1, - bool share_result = false) // lanes is intended to be <= 32. +__device__ __forceinline__ T reduce_block_into_lanes_max_op(T *x, + T val, + int lanes = 1, + bool share_result = false) // lanes is intended to be <= 32. { - int tid = threadIdx.x + threadIdx.y * blockDim.x; - int blockSize = - blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. + int tid = threadIdx.x + threadIdx.y * blockDim.x; + int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. - if (blockSize >= 64) { - x[tid] = val; - __syncthreads(); - } - -#pragma unroll - for (int i = (blockSize >> 1); i >= 64; i >>= 1) { - if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i])); - __syncthreads(); - } - - T final; - - if (tid < 32) { if (blockSize >= 64) - final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32])); - else - final = val; - // __SYNCWARP(); + { + x[tid] = val; + __syncthreads(); + } #pragma unroll - for (int i = 16; i >= lanes; i >>= 1) - final = - fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i))); - } + for (int i = (blockSize >> 1); i >= 64; i >>= 1) + { + if (tid < i) + x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i])); + __syncthreads(); + } - if (share_result) { - if (tid < lanes) x[tid] = final; // EpilogueOp - // Make sure the smem result is visible to all warps. - __syncthreads(); - } + T final; - return final; + if (tid < 32) + { + if (blockSize >= 64) + final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32])); + else + final = val; + // __SYNCWARP(); + +#pragma unroll + for (int i = 16; i >= lanes; i >>= 1) + final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i))); + } + + if (share_result) + { + if (tid < lanes) + x[tid] = final; // EpilogueOp + // Make sure the smem result is visible to all warps. + __syncthreads(); + } + + return final; } \ No newline at end of file