From 7aa35eae6ad8de16021e67340b9d3851466776da Mon Sep 17 00:00:00 2001 From: Maruyama_Aya <38985202+MaruyamaAya@users.noreply.github.com> Date: Fri, 13 May 2022 15:38:19 +0800 Subject: [PATCH] [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h code style (#938) --- .../csrc/kernels/include/block_reduce.h | 41 ++++++++++--------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h b/colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h index c58ed44ba..38103c173 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h @@ -13,22 +13,23 @@ const float REDUCE_FLOAT_INF_NEG = -100000000.f; const float REDUCE_FLOAT_INF_POS = 100000000.f; const unsigned int WARP_REDUCE_SIZE = 32; -template __forceinline__ __device__ T warpReduceSum(T val) { +template +__forceinline__ __device__ T warpReduceSum(T val) { for (int mask = (WARP_REDUCE_SIZE >> 1); mask > 0; mask >>= 1) val += __shfl_xor_sync(WARP_REDUCE_MASK, val, mask, WARP_REDUCE_SIZE); return val; } /* Calculate the sum of all elements in a block */ -template __forceinline__ __device__ T blockReduceSum(T val) { +template +__forceinline__ __device__ T blockReduceSum(T val) { static __shared__ T shared[32]; int lane = threadIdx.x & 0x1f; int wid = threadIdx.x >> 5; val = warpReduceSum(val); - if (lane == 0) - shared[wid] = val; + if (lane == 0) shared[wid] = val; __syncthreads(); val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : (T)0.0f; @@ -56,10 +57,10 @@ __inline__ __device__ void warpReduce(float *pval) { template <> __inline__ __device__ void warpReduce(float *pval) { float val0_tmp, val1_tmp; -#define WarpReduceMaxOneStep(a, b) \ - val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval), a, b); \ - val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \ - *(pval) = max(val0_tmp, *(pval)); \ +#define WarpReduceMaxOneStep(a, b) \ + val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval), a, b); \ + val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \ + *(pval) = max(val0_tmp, *(pval)); \ *(pval + 1) = max(val1_tmp, *(pval + 1)); WarpReduceMaxOneStep(16, 32); @@ -88,10 +89,10 @@ __inline__ __device__ void warpReduce(float *pval) { template <> __inline__ __device__ void warpReduce(float *pval) { float val0_tmp, val1_tmp; -#define WarpReduceSumOneStep(a, b) \ - val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \ - val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \ - *(pval + 0) += val0_tmp; \ +#define WarpReduceSumOneStep(a, b) \ + val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \ + val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \ + *(pval + 0) += val0_tmp; \ *(pval + 1) += val1_tmp WarpReduceSumOneStep(16, 32); @@ -106,14 +107,14 @@ __inline__ __device__ void warpReduce(float *pval) { template <> __inline__ __device__ void warpReduce(float *pval) { float val0_tmp, val1_tmp, val2_tmp, val3_tmp; -#define WarpReduceSumOneStep(a, b) \ - val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \ - val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \ - val2_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 2), a, b); \ - val3_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 3), a, b); \ - *(pval + 0) += val0_tmp; \ - *(pval + 1) += val1_tmp; \ - *(pval + 2) += val2_tmp; \ +#define WarpReduceSumOneStep(a, b) \ + val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \ + val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \ + val2_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 2), a, b); \ + val3_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 3), a, b); \ + *(pval + 0) += val0_tmp; \ + *(pval + 1) += val1_tmp; \ + *(pval + 2) += val2_tmp; \ *(pval + 3) += val3_tmp WarpReduceSumOneStep(16, 32);