[NFC] polish colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu code style (#979)

pull/997/head
Kai Wang (Victor Kai) 2022-05-16 14:20:36 +08:00 committed by binmakeswell
parent f28c021376
commit c50c08dcbb
1 changed files with 16 additions and 30 deletions

View File

@ -1,10 +1,10 @@
#include <cooperative_groups.h>
#include <chrono>
#include <ctime>
#include "kernels.h"
#include <cooperative_groups.h>
namespace cg = cooperative_groups;
curandStatePhilox4_32_10_t *curandstate;
@ -165,8 +165,7 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio,
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count)
return;
if (i * 4 >= total_count) return;
curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
@ -202,8 +201,7 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio,
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count)
return;
if (i * 8 >= total_count) return;
curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
@ -261,8 +259,7 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count)
return;
if (i * 4 >= total_count) return;
uint8_t m[4];
@ -289,8 +286,7 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count)
return;
if (i * 8 >= total_count) return;
float4 *out4 = reinterpret_cast<float4 *>(out);
const float4 *vals_float4 = reinterpret_cast<const float4 *>(in);
@ -380,8 +376,7 @@ __global__ void ls_dropout_res_bias_kernel(
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count)
return;
if (i * 4 >= total_count) return;
curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
@ -424,8 +419,7 @@ __global__ void ls_dropout_res_bias_kernel(
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count)
return;
if (i * 8 >= total_count) return;
curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
@ -565,11 +559,9 @@ __global__ void ls_dropout_bias_bwd_kernel(
}
__syncthreads();
for (int i = 1; i < 32; i <<= 1)
sum += g.shfl_down(sum, i);
for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i);
if (y == 0)
tile[0][x] = sum;
if (y == 0) tile[0][x] = sum;
__syncthreads();
if (threadIdx.x < 8) {
@ -621,11 +613,9 @@ __global__ void ls_dropout_bias_bwd_kernel(
}
__syncthreads();
for (int i = 1; i < WARP_SIZE; i <<= 1)
sum += g.shfl_down(sum, i);
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
if (y == 0)
tile[0][x] = sum;
if (y == 0) tile[0][x] = sum;
__syncthreads();
if (threadIdx.x < 8) {
@ -689,8 +679,7 @@ __global__ void ls_dropout_act_bias_kernel(
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count)
return;
if (i * 4 >= total_count) return;
curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
@ -735,8 +724,7 @@ __global__ void ls_dropout_act_bias_kernel(
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count)
return;
if (i * 8 >= total_count) return;
curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
@ -897,11 +885,9 @@ __global__ void ls_dropout_act_bias_bwd_kernel(
float sum = tile[threadIdx.y][threadIdx.x];
__syncthreads();
for (int i = 1; i < WARP_SIZE; i <<= 1)
sum += g.shfl_down(sum, i);
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
if (threadIdx.x == 0)
tile[0][threadIdx.y] = sum;
if (threadIdx.x == 0) tile[0][threadIdx.y] = sum;
__syncthreads();
if (threadIdx.y == 0) {