From 5eb5ff1464311ac16c29307d03a3c076aced7e03 Mon Sep 17 00:00:00 2001 From: xs_courtesy Date: Fri, 8 Mar 2024 15:41:14 +0800 Subject: [PATCH] refactor code --- .../{cuda/type_shim.h => common/micros.h} | 97 ++----------------- .../{cuda/include => common}/mp_type_traits.h | 10 +- extensions/csrc/cuda/activation_kernel.cu | 8 +- extensions/csrc/cuda/compat.h | 10 -- .../cuda/decode_kv_cache_memcpy_kernel.cu | 2 +- extensions/csrc/cuda/include/block_reduce.h | 87 +++++++++++++++++ extensions/csrc/cuda/layer_norm_cuda.cpp | 2 +- .../csrc/cuda/layer_norm_cuda_kernel.cu | 2 +- extensions/csrc/cuda/multi_tensor_adam.cu | 2 +- extensions/csrc/cuda/multi_tensor_apply.cuh | 2 +- .../csrc/cuda/multi_tensor_l2norm_kernel.cu | 3 +- extensions/csrc/cuda/multi_tensor_lamb.cu | 2 +- .../csrc/cuda/multi_tensor_scale_kernel.cu | 2 +- .../csrc/cuda/multi_tensor_sgd_kernel.cu | 2 +- .../csrc/cuda/scaled_masked_softmax_cuda.cu | 2 +- ...scaled_upper_triang_masked_softmax_cuda.cu | 2 +- 16 files changed, 117 insertions(+), 118 deletions(-) rename extensions/csrc/{cuda/type_shim.h => common/micros.h} (87%) rename extensions/csrc/{cuda/include => common}/mp_type_traits.h (75%) diff --git a/extensions/csrc/cuda/type_shim.h b/extensions/csrc/common/micros.h similarity index 87% rename from extensions/csrc/cuda/type_shim.h rename to extensions/csrc/common/micros.h index 7be3fab1b..c2241029f 100644 --- a/extensions/csrc/cuda/type_shim.h +++ b/extensions/csrc/common/micros.h @@ -9,7 +9,15 @@ #include -#include "compat.h" +#ifndef TORCH_CHECK +#define TORCH_CHECK AT_CHECK +#endif + +#ifdef VERSION_GE_1_3 +#define DATA_PTR data_ptr +#else +#define DATA_PTR data +#endif #define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ switch (TYPE) { \ @@ -214,90 +222,3 @@ AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \ "'"); \ } - -template -__device__ __forceinline__ T reduce_block_into_lanes( - T *x, T val, int lanes = 1, - bool share_result = false) // lanes is intended to be <= 32. -{ - int tid = threadIdx.x + threadIdx.y * blockDim.x; - int blockSize = - blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. - - if (blockSize >= 64) { - x[tid] = val; - __syncthreads(); - } - -#pragma unroll - for (int i = (blockSize >> 1); i >= 64; i >>= 1) { - if (tid < i) x[tid] = x[tid] + x[tid + i]; - __syncthreads(); - } - - T final; - - if (tid < 32) { - if (blockSize >= 64) - final = x[tid] + x[tid + 32]; - else - final = val; - // __SYNCWARP(); - -#pragma unroll - for (int i = 16; i >= lanes; i >>= 1) - final = final + __shfl_down_sync(0xffffffff, final, i); - } - - if (share_result) { - if (tid < lanes) x[tid] = final; // EpilogueOp - // Make sure the smem result is visible to all warps. - __syncthreads(); - } - - return final; -} - -template -__device__ __forceinline__ T reduce_block_into_lanes_max_op( - T *x, T val, int lanes = 1, - bool share_result = false) // lanes is intended to be <= 32. -{ - int tid = threadIdx.x + threadIdx.y * blockDim.x; - int blockSize = - blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. - - if (blockSize >= 64) { - x[tid] = val; - __syncthreads(); - } - -#pragma unroll - for (int i = (blockSize >> 1); i >= 64; i >>= 1) { - if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i])); - __syncthreads(); - } - - T final; - - if (tid < 32) { - if (blockSize >= 64) - final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32])); - else - final = val; - // __SYNCWARP(); - -#pragma unroll - for (int i = 16; i >= lanes; i >>= 1) - final = - fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i))); - } - - if (share_result) { - if (tid < lanes) x[tid] = final; // EpilogueOp - // Make sure the smem result is visible to all warps. - __syncthreads(); - } - - return final; -} diff --git a/extensions/csrc/cuda/include/mp_type_traits.h b/extensions/csrc/common/mp_type_traits.h similarity index 75% rename from extensions/csrc/cuda/include/mp_type_traits.h rename to extensions/csrc/common/mp_type_traits.h index 6b3ae9c1b..8ede2d448 100644 --- a/extensions/csrc/cuda/include/mp_type_traits.h +++ b/extensions/csrc/common/mp_type_traits.h @@ -2,10 +2,10 @@ #include -#include "../type_shim.h" +#include "micros.h" -namespace infer { -namespace dtype { +namespace colossalAI { +namespace common { template class MPTypeTrait { @@ -31,5 +31,5 @@ class MPTypeTrait { using Type = float; }; -} // namespace dtype -} // namespace infer +} // namespace common +} // namespace colossalAI diff --git a/extensions/csrc/cuda/activation_kernel.cu b/extensions/csrc/cuda/activation_kernel.cu index 4121b67fc..5213a2313 100644 --- a/extensions/csrc/cuda/activation_kernel.cu +++ b/extensions/csrc/cuda/activation_kernel.cu @@ -2,13 +2,13 @@ #include #include -#include "type_shim.h" -#include "include/mp_type_traits.h" +#include "../common/micros.h" +#include "../common/mp_type_traits.h" template __device__ __forceinline__ T silu_kernel(const T& x) { // x * sigmoid(x) - using MT = typename infer::dtype::MPTypeTrait::Type; + using MT = typename colossalAI::common::MPTypeTrait::Type; return static_cast((static_cast(x)) / (static_cast(1.0f) + expf(static_cast(-x)))); } @@ -17,7 +17,7 @@ __global__ void act_and_mul_kernel( const scalar_t* __restrict__ ins_data, scalar_t* __restrict__ outs_data, const int64_t numel) { - using MT = typename infer::dtype::MPTypeTrait::Type; + using MT = typename colossalAI::common::MPTypeTrait::Type; int64_t idx = static_cast(threadIdx.x) + static_cast(blockIdx.x) * static_cast(blockDim.x); const int64_t grid_size = blockDim.x * gridDim.x; diff --git a/extensions/csrc/cuda/compat.h b/extensions/csrc/cuda/compat.h index a62beef91..e69de29bb 100644 --- a/extensions/csrc/cuda/compat.h +++ b/extensions/csrc/cuda/compat.h @@ -1,10 +0,0 @@ -// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h -#ifndef TORCH_CHECK -#define TORCH_CHECK AT_CHECK -#endif - -#ifdef VERSION_GE_1_3 -#define DATA_PTR data_ptr -#else -#define DATA_PTR data -#endif diff --git a/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu b/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu index 86db90c8b..15e613e35 100644 --- a/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu +++ b/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu @@ -2,7 +2,7 @@ #include #include -#include "type_shim.h" +#include "../common/micros.h" template __global__ void decode_kv_cache_memcpy_kernel( diff --git a/extensions/csrc/cuda/include/block_reduce.h b/extensions/csrc/cuda/include/block_reduce.h index 38103c173..86409136b 100644 --- a/extensions/csrc/cuda/include/block_reduce.h +++ b/extensions/csrc/cuda/include/block_reduce.h @@ -310,3 +310,90 @@ __inline__ __device__ void blockReduce(float *pval) { } warpReduce(pval); } + +template +__device__ __forceinline__ T reduce_block_into_lanes( + T *x, T val, int lanes = 1, + bool share_result = false) // lanes is intended to be <= 32. +{ + int tid = threadIdx.x + threadIdx.y * blockDim.x; + int blockSize = + blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. + + if (blockSize >= 64) { + x[tid] = val; + __syncthreads(); + } + +#pragma unroll + for (int i = (blockSize >> 1); i >= 64; i >>= 1) { + if (tid < i) x[tid] = x[tid] + x[tid + i]; + __syncthreads(); + } + + T final; + + if (tid < 32) { + if (blockSize >= 64) + final = x[tid] + x[tid + 32]; + else + final = val; + // __SYNCWARP(); + +#pragma unroll + for (int i = 16; i >= lanes; i >>= 1) + final = final + __shfl_down_sync(0xffffffff, final, i); + } + + if (share_result) { + if (tid < lanes) x[tid] = final; // EpilogueOp + // Make sure the smem result is visible to all warps. + __syncthreads(); + } + + return final; +} + +template +__device__ __forceinline__ T reduce_block_into_lanes_max_op( + T *x, T val, int lanes = 1, + bool share_result = false) // lanes is intended to be <= 32. +{ + int tid = threadIdx.x + threadIdx.y * blockDim.x; + int blockSize = + blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. + + if (blockSize >= 64) { + x[tid] = val; + __syncthreads(); + } + +#pragma unroll + for (int i = (blockSize >> 1); i >= 64; i >>= 1) { + if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i])); + __syncthreads(); + } + + T final; + + if (tid < 32) { + if (blockSize >= 64) + final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32])); + else + final = val; + // __SYNCWARP(); + +#pragma unroll + for (int i = 16; i >= lanes; i >>= 1) + final = + fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i))); + } + + if (share_result) { + if (tid < lanes) x[tid] = final; // EpilogueOp + // Make sure the smem result is visible to all warps. + __syncthreads(); + } + + return final; +} diff --git a/extensions/csrc/cuda/layer_norm_cuda.cpp b/extensions/csrc/cuda/layer_norm_cuda.cpp index 15a07bb0c..3439e5e71 100644 --- a/extensions/csrc/cuda/layer_norm_cuda.cpp +++ b/extensions/csrc/cuda/layer_norm_cuda.cpp @@ -7,7 +7,7 @@ #include #include -#include "compat.h" +#include "../common/micros.h" namespace { diff --git a/extensions/csrc/cuda/layer_norm_cuda_kernel.cu b/extensions/csrc/cuda/layer_norm_cuda_kernel.cu index 72b84d6ca..17d5b10f4 100644 --- a/extensions/csrc/cuda/layer_norm_cuda_kernel.cu +++ b/extensions/csrc/cuda/layer_norm_cuda_kernel.cu @@ -9,7 +9,7 @@ #include "ATen/AccumulateType.h" #include "ATen/cuda/CUDAContext.h" #include "ATen/cuda/DeviceUtils.cuh" -#include "type_shim.h" +#include "../common/micros.h" template __device__ void cuWelfordOnlineSum(const U curr, U& mu, U& sigma2, U& count) { diff --git a/extensions/csrc/cuda/multi_tensor_adam.cu b/extensions/csrc/cuda/multi_tensor_adam.cu index 9cc3ae1ea..b7793b364 100644 --- a/extensions/csrc/cuda/multi_tensor_adam.cu +++ b/extensions/csrc/cuda/multi_tensor_adam.cu @@ -15,7 +15,7 @@ #include #include "multi_tensor_apply.cuh" -#include "type_shim.h" +#include "../common/micros.h" #define BLOCK_SIZE 512 #define ILP 4 diff --git a/extensions/csrc/cuda/multi_tensor_apply.cuh b/extensions/csrc/cuda/multi_tensor_apply.cuh index ec55dd320..01a858661 100644 --- a/extensions/csrc/cuda/multi_tensor_apply.cuh +++ b/extensions/csrc/cuda/multi_tensor_apply.cuh @@ -12,7 +12,7 @@ #include #include -#include "compat.h" +#include "../common/micros.h" // #include diff --git a/extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu b/extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu index 85f935152..57a79f7a8 100644 --- a/extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu +++ b/extensions/csrc/cuda/multi_tensor_l2norm_kernel.cu @@ -11,7 +11,8 @@ #include #include "multi_tensor_apply.cuh" -#include "type_shim.h" +#include "../common/micros.h" +#include "include/block_reduce.h" #define BLOCK_SIZE 512 #define ILP 4 diff --git a/extensions/csrc/cuda/multi_tensor_lamb.cu b/extensions/csrc/cuda/multi_tensor_lamb.cu index 63771cf40..50dfc56bc 100644 --- a/extensions/csrc/cuda/multi_tensor_lamb.cu +++ b/extensions/csrc/cuda/multi_tensor_lamb.cu @@ -10,7 +10,7 @@ #include #include "multi_tensor_apply.cuh" -#include "type_shim.h" +#include "../common/micros.h" #define BLOCK_SIZE 512 #define ILP 4 diff --git a/extensions/csrc/cuda/multi_tensor_scale_kernel.cu b/extensions/csrc/cuda/multi_tensor_scale_kernel.cu index 2f58a0f16..0dec1d5d1 100644 --- a/extensions/csrc/cuda/multi_tensor_scale_kernel.cu +++ b/extensions/csrc/cuda/multi_tensor_scale_kernel.cu @@ -10,7 +10,7 @@ #include #include "multi_tensor_apply.cuh" -#include "type_shim.h" +#include "../common/micros.h" #define BLOCK_SIZE 512 #define ILP 4 diff --git a/extensions/csrc/cuda/multi_tensor_sgd_kernel.cu b/extensions/csrc/cuda/multi_tensor_sgd_kernel.cu index 7f48dbd5d..d0cf786f8 100644 --- a/extensions/csrc/cuda/multi_tensor_sgd_kernel.cu +++ b/extensions/csrc/cuda/multi_tensor_sgd_kernel.cu @@ -7,7 +7,7 @@ #include #include -#include "compat.h" +#include "../common/micros.h" #include "multi_tensor_apply.cuh" #define BLOCK_SIZE 512 diff --git a/extensions/csrc/cuda/scaled_masked_softmax_cuda.cu b/extensions/csrc/cuda/scaled_masked_softmax_cuda.cu index 41781ebc7..2f968d30f 100644 --- a/extensions/csrc/cuda/scaled_masked_softmax_cuda.cu +++ b/extensions/csrc/cuda/scaled_masked_softmax_cuda.cu @@ -10,7 +10,7 @@ #include #include "scaled_masked_softmax.h" -#include "type_shim.h" +#include "../common/micros.h" namespace multihead_attn { namespace fused_softmax { diff --git a/extensions/csrc/cuda/scaled_upper_triang_masked_softmax_cuda.cu b/extensions/csrc/cuda/scaled_upper_triang_masked_softmax_cuda.cu index 62c56e6f7..d9550dc2c 100644 --- a/extensions/csrc/cuda/scaled_upper_triang_masked_softmax_cuda.cu +++ b/extensions/csrc/cuda/scaled_upper_triang_masked_softmax_cuda.cu @@ -10,7 +10,7 @@ #include #include "scaled_upper_triang_masked_softmax.h" -#include "type_shim.h" +#include "../common/micros.h" namespace multihead_attn { namespace fused_softmax {