diff --git a/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu index 9ccf09d76..184106bd2 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu @@ -1,10 +1,10 @@ +#include + #include #include #include "kernels.h" -#include - namespace cg = cooperative_groups; curandStatePhilox4_32_10_t *curandstate; @@ -165,8 +165,7 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio, const float scale = 1.f / (1.f - ratio); int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i * 4 >= total_count) - return; + if (i * 4 >= total_count) return; curandStatePhilox4_32_10_t state; curand_init(seed, i, 0, &state); @@ -202,8 +201,7 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio, int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i * 8 >= total_count) - return; + if (i * 8 >= total_count) return; curandStatePhilox4_32_10_t state; curand_init(seed, i, 0, &state); @@ -261,8 +259,7 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, const float scale = 1.f / (1.f - ratio); int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i * 4 >= total_count) - return; + if (i * 4 >= total_count) return; uint8_t m[4]; @@ -289,8 +286,7 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i * 8 >= total_count) - return; + if (i * 8 >= total_count) return; float4 *out4 = reinterpret_cast(out); const float4 *vals_float4 = reinterpret_cast(in); @@ -380,8 +376,7 @@ __global__ void ls_dropout_res_bias_kernel( const float scale = 1.f / (1.f - ratio); int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i * 4 >= total_count) - return; + if (i * 4 >= total_count) return; curandStatePhilox4_32_10_t state; curand_init(seed, i, 0, &state); @@ -424,8 +419,7 @@ __global__ void ls_dropout_res_bias_kernel( int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i * 8 >= total_count) - return; + if (i * 8 >= total_count) return; curandStatePhilox4_32_10_t state; curand_init(seed, i, 0, &state); @@ -565,11 +559,9 @@ __global__ void ls_dropout_bias_bwd_kernel( } __syncthreads(); - for (int i = 1; i < 32; i <<= 1) - sum += g.shfl_down(sum, i); + for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i); - if (y == 0) - tile[0][x] = sum; + if (y == 0) tile[0][x] = sum; __syncthreads(); if (threadIdx.x < 8) { @@ -621,11 +613,9 @@ __global__ void ls_dropout_bias_bwd_kernel( } __syncthreads(); - for (int i = 1; i < WARP_SIZE; i <<= 1) - sum += g.shfl_down(sum, i); + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); - if (y == 0) - tile[0][x] = sum; + if (y == 0) tile[0][x] = sum; __syncthreads(); if (threadIdx.x < 8) { @@ -689,8 +679,7 @@ __global__ void ls_dropout_act_bias_kernel( const float scale = 1.f / (1.f - ratio); int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i * 4 >= total_count) - return; + if (i * 4 >= total_count) return; curandStatePhilox4_32_10_t state; curand_init(seed, i, 0, &state); @@ -735,8 +724,7 @@ __global__ void ls_dropout_act_bias_kernel( int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i * 8 >= total_count) - return; + if (i * 8 >= total_count) return; curandStatePhilox4_32_10_t state; curand_init(seed, i, 0, &state); @@ -897,11 +885,9 @@ __global__ void ls_dropout_act_bias_bwd_kernel( float sum = tile[threadIdx.y][threadIdx.x]; __syncthreads(); - for (int i = 1; i < WARP_SIZE; i <<= 1) - sum += g.shfl_down(sum, i); + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); - if (threadIdx.x == 0) - tile[0][threadIdx.y] = sum; + if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; __syncthreads(); if (threadIdx.y == 0) {