#include #include #include #include "kernels.h" namespace cg = cooperative_groups; curandStatePhilox4_32_10_t *curandstate; /** * @brief element-wise activation function on device, like Relu, Gelu * * @tparam enum class ActivationType, kRelu, kGelu * @tparam input type * @param any shape of float and __half2 * @return same shape and type with input */ template __forceinline__ __device__ T activation_kernel(T x); template <> __device__ float activation_kernel(float x) { float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); return x * cdf; } template <> __device__ __half2 activation_kernel(__half2 val) { __half2 val_pow3 = __hmul2(val, __hmul2(val, val)); float2 tmp_pow = __half22float2(val_pow3); float2 tmp = __half22float2(val); tmp.x = 0.5f * (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x)))); tmp.y = 0.5f * (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y)))); return __hmul2(val, __float22half2_rn(tmp)); } template <> __device__ float activation_kernel(float x) { return fmaxf(x, 0); } template <> __device__ __half2 activation_kernel(__half2 x) { return __floats2half2_rn(fmaxf(0.f, __half2float(x.x)), fmaxf(0.f, __half2float(x.y))); } /** * @brief element-wise activation backward function on device * * @tparam enum class ActivationType * @tparam input type * @param any shape of float and __half2 * @return same shape of input */ template __forceinline__ __device__ T activation_bwd_kernel(T grad, T x); template <> __device__ float activation_bwd_kernel(float grad, float x) { const float sqrt_param = 0.79788456080286535587989211986876f; const float mul_param = 0.044715; float x2mul = x * x * mul_param; float tan_h = tanhf(sqrt_param * (x + x * x2mul)); float dg1 = 0.5f * (1.0f + tan_h); float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); float dg3 = dg2 * 3 * x2mul; return grad * (dg1 + dg2 + dg3); } template <> __device__ __half activation_bwd_kernel( __half grad, __half x_half) { float x = __half2float(x_half); const float sqrt_param = 0.79788456080286535587989211986876f; const float mul_param = 0.044715; float x2mul = x * x * mul_param; float tan_h = tanhf(sqrt_param * (x + x * x2mul)); float dg1 = 0.5f * (1.0f + tan_h); float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); float dg3 = dg2 * 3 * x2mul; return grad * __float2half(dg1 + dg2 + dg3); } template <> __device__ float activation_bwd_kernel(float grad, float x) { return x > 0.f ? grad : 0.f; } template <> __device__ __half activation_bwd_kernel(__half grad, __half x) { const __half half_zero = __float2half(0.f); return x > half_zero ? grad : half_zero; } template <> __device__ __half2 activation_bwd_kernel( __half2 grad2, __half2 x_half2) { const __half half_zero = __float2half(0.f); return __floats2half2_rn(x_half2.x > half_zero ? grad2.x : half_zero, x_half2.y > half_zero ? grad2.y : half_zero); } /** * @brief init curand states in global memory * * @thread grid_dim * block*dim to suuport any size of states * @param state persistant curand states * @param seed seed to init states * @return void */ __global__ void curand_init_kernel(curandStatePhilox4_32_10_t *state, int seed) { /* Each thread gets same seed, a different sequence number, no offset */ int id = threadIdx.x + blockIdx.x * blockDim.x; curand_init(seed, id, 0, &state[id]); } void launch_curand_init(int total_count, int dim, cudaStream_t stream) { cudaMalloc(&curandstate, total_count * sizeof(curandStatePhilox4_32_10_t)); int grid_dim = total_count >> 9; curand_init_kernel<<>>( curandstate, std::chrono::duration_cast( std::chrono::system_clock::now().time_since_epoch()) .count()); } /** * @brief element-wise dropout, store dropped position in mask, it's not * in-place * * @thread * gridDim.x = total_count / 1024 * blockDim.x = 1024 * * @param total_count total elements * @param ratio drop ratio * @param out any size of float and __half * @param in same with out * @param mask uint8 type, same size with out * @param seed seed to curand * @return void */ __global__ void ls_dropout_kernel(const int total_count, const float ratio, float *__restrict__ out, const float *__restrict__ in, uint8_t *__restrict__ mask, const int seed) { const float scale = 1.f / (1.f - ratio); int i = blockIdx.x * blockDim.x + threadIdx.x; if (i * 4 >= total_count) return; curandStatePhilox4_32_10_t state; curand_init(seed, i, 0, &state); uint8_t m[4]; float4 *out4 = reinterpret_cast(out); const float4 *data4 = reinterpret_cast(in); uint32_t *mask4 = reinterpret_cast(mask); float4 rand = curand_uniform4(&state); m[0] = (uint8_t)(rand.x > ratio); m[1] = (uint8_t)(rand.y > ratio); m[2] = (uint8_t)(rand.z > ratio); m[3] = (uint8_t)(rand.w > ratio); uint32_t *m4 = reinterpret_cast(m); mask4[i] = m4[0]; float4 input4 = data4[i]; float4 res4; res4.x = input4.x * scale * m[0]; res4.y = input4.y * scale * m[1]; res4.z = input4.z * scale * m[2]; res4.w = input4.w * scale * m[3]; out4[i] = res4; } __global__ void ls_dropout_kernel(const int total_count, const float ratio, __half *__restrict__ out, const __half *__restrict__ in, uint8_t *__restrict__ mask, const int seed) { const float scale = 1.f / (1.f - ratio); int i = blockIdx.x * blockDim.x + threadIdx.x; if (i * 8 >= total_count) return; curandStatePhilox4_32_10_t state; curand_init(seed, i, 0, &state); const float4 *vals_float4 = reinterpret_cast(in); float4 *outs_float4 = reinterpret_cast(out); uint64_t *mask8 = reinterpret_cast(mask); uint8_t m[8]; float4 rand = curand_uniform4(&state); m[0] = (uint8_t)(rand.x > ratio); m[1] = (uint8_t)(rand.y > ratio); m[2] = (uint8_t)(rand.z > ratio); m[3] = (uint8_t)(rand.w > ratio); rand = curand_uniform4(&state); m[4] = (uint8_t)(rand.x > ratio); m[5] = (uint8_t)(rand.y > ratio); m[6] = (uint8_t)(rand.z > ratio); m[7] = (uint8_t)(rand.w > ratio); uint64_t *m8 = reinterpret_cast(m); mask8[i] = *m8; float4 val_float4 = vals_float4[i]; float4 out_float4; __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]); __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]); __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]); __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]); out_half2[0] = __hmul2(val_half2[0], scale_mask_1); out_half2[1] = __hmul2(val_half2[1], scale_mask_2); out_half2[2] = __hmul2(val_half2[2], scale_mask_3); out_half2[3] = __hmul2(val_half2[3], scale_mask_4); outs_float4[i] = out_float4; } /** * @brief element-wise dropout backward with dropout mask, it's * not in-place * * @thread * gridDim.x = total_count / 1024 * blockDim.x = 1024 * * @param total_count total elements * @param ratio drop ratio * @param in any size of float and __half * @param mask uint8 type, same size with in * @return void */ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, float *out, const float *in, const uint8_t *__restrict__ mask) { const float scale = 1.f / (1.f - ratio); int i = blockIdx.x * blockDim.x + threadIdx.x; if (i * 4 >= total_count) return; uint8_t m[4]; float4 *out4 = reinterpret_cast(out); const float4 *in4 = reinterpret_cast(in); const uint32_t *mask4 = reinterpret_cast(mask); uint32_t *m4 = reinterpret_cast(m); m4[0] = mask4[i]; float4 input4 = in4[i]; float4 res4; res4.x = input4.x * scale * static_cast(m[0]); res4.y = input4.y * scale * static_cast(m[1]); res4.z = input4.z * scale * static_cast(m[2]); res4.w = input4.w * scale * static_cast(m[3]); out4[i] = res4; } __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, __half *out, const __half *in, const uint8_t *__restrict__ mask) { const __half scale = 1.f / (1.f - ratio); int i = blockIdx.x * blockDim.x + threadIdx.x; if (i * 8 >= total_count) return; float4 *out4 = reinterpret_cast(out); const float4 *vals_float4 = reinterpret_cast(in); const uint64_t *mask8 = reinterpret_cast(mask); uint8_t m[8]; uint64_t *m8 = reinterpret_cast(m); m8[0] = mask8[i]; float4 val_float4 = vals_float4[i]; float4 out_float4; __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); __half2 scale_mask_1 = __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1])); __half2 scale_mask_2 = __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3])); __half2 scale_mask_3 = __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5])); __half2 scale_mask_4 = __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7])); out_half2[0] = __hmul2(val_half2[0], scale_mask_1); out_half2[1] = __hmul2(val_half2[1], scale_mask_2); out_half2[2] = __hmul2(val_half2[2], scale_mask_3); out_half2[3] = __hmul2(val_half2[3], scale_mask_4); out4[i] = out_float4; } template <> void launch_ls_dropout(float *out, const float *vals, uint8_t *mask, int total_count, float ratio, cudaStream_t stream, bool backward) { int grid_dim = total_count >> 12; if (!backward) { ls_dropout_kernel<<>>( total_count, ratio, out, vals, mask, std::chrono::duration_cast( std::chrono::system_clock::now().time_since_epoch()) .count()); } else { ls_dropout_bwd_kernel<<>>(total_count, ratio, out, vals, mask); } } template <> void launch_ls_dropout<__half>(__half *out, const __half *vals, uint8_t *mask, int total_count, float ratio, cudaStream_t stream, bool backward) { int grid_dim = total_count >> 13; if (!backward) { ls_dropout_kernel<<>>( total_count, ratio, out, vals, mask, std::chrono::duration_cast( std::chrono::system_clock::now().time_since_epoch()) .count()); } else { ls_dropout_bwd_kernel<<>>(total_count, ratio, out, vals, mask); } } /** * @brief fused bias, dropout, and residual at the end of Attention and FFN, * store dropped position in mask, it's not in-place * * @thread * gridDim.x = total_count / 1024 * blockDim.x = 1024 * * @param total_count total elements * @param ratio drop ratio * @param out [batch_size, seq_len, hidden_size], float and __half * @param in [batch_size, seq_len, hidden_size], float and __half * @param mask [batch_size, seq_len, hidden_size], uint8 type * @param bias [hidden_size], ffn bias * @param residual [batch_size, seq_len, hidden_size], float and __half * @param seed seed to curand * @param hidden_size hidden size * @return void */ __global__ void ls_dropout_res_bias_kernel( const int total_count, const float ratio, float *__restrict__ out, const float *__restrict__ in, uint8_t *__restrict__ mask, const float *__restrict__ bias, const float *__restrict__ residual, const int seed, const int hidden_size) { const float scale = 1.f / (1.f - ratio); int i = blockIdx.x * blockDim.x + threadIdx.x; if (i * 4 >= total_count) return; curandStatePhilox4_32_10_t state; curand_init(seed, i, 0, &state); uint8_t m[4]; float4 *out4 = reinterpret_cast(out); const float4 *data4 = reinterpret_cast(in); const float4 *residual4 = reinterpret_cast(residual); const float4 *bias4 = reinterpret_cast(bias); uint32_t *mask4 = reinterpret_cast(mask); float4 rand = curand_uniform4(&state); m[0] = static_cast(rand.x > ratio); m[1] = static_cast(rand.y > ratio); m[2] = static_cast(rand.z > ratio); m[3] = static_cast(rand.w > ratio); int bias_i = i % (hidden_size >> 2); uint32_t *m4 = reinterpret_cast(m); mask4[i] = m4[0]; const float4 input4 = data4[i]; const float4 b4 = __ldg(&bias4[bias_i]); const float4 res4 = residual4[i]; float4 output4; output4.x = (input4.x + b4.x) * scale * m[0] + res4.x; output4.y = (input4.y + b4.y) * scale * m[1] + res4.y; output4.z = (input4.z + b4.z) * scale * m[2] + res4.z; output4.w = (input4.w + b4.w) * scale * m[3] + res4.w; out4[i] = output4; } __global__ void ls_dropout_res_bias_kernel( const int total_count, const float ratio, __half *__restrict__ out, const __half *__restrict__ in, uint8_t *__restrict__ mask, const __half *__restrict__ bias, const __half *__restrict__ residual, const int seed, const int hidden_size) { const __half scale = 1. / (1. - ratio); int i = blockIdx.x * blockDim.x + threadIdx.x; if (i * 8 >= total_count) return; curandStatePhilox4_32_10_t state; curand_init(seed, i, 0, &state); const float4 *vals_float4 = reinterpret_cast(in); float4 *outs_float4 = reinterpret_cast(out); const float4 *residual4 = reinterpret_cast(residual); const float4 *bias4 = reinterpret_cast(bias); uint64_t *mask8 = reinterpret_cast(mask); uint8_t m[8]; float4 rand = curand_uniform4(&state); m[0] = static_cast(rand.x > ratio); m[1] = static_cast(rand.y > ratio); m[2] = static_cast(rand.z > ratio); m[3] = static_cast(rand.w > ratio); rand = curand_uniform4(&state); m[4] = static_cast(rand.x > ratio); m[5] = static_cast(rand.y > ratio); m[6] = static_cast(rand.z > ratio); m[7] = static_cast(rand.w > ratio); uint64_t *m8 = reinterpret_cast(m); mask8[i] = m8[0]; int bias_i = i % (hidden_size >> 3); float4 val_float4 = vals_float4[i]; const float4 b4 = __ldg(&bias4[bias_i]); const float4 res4 = residual4[i]; float4 out_float4; __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); const __half2 *b_half2 = reinterpret_cast(&b4); const __half2 *res_half2 = reinterpret_cast(&res4); __half2 scale_mask_1 = __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1])); __half2 scale_mask_2 = __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3])); __half2 scale_mask_3 = __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5])); __half2 scale_mask_4 = __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7])); out_half2[0] = __hfma2(__hadd2(val_half2[0], b_half2[0]), scale_mask_1, res_half2[0]); out_half2[1] = __hfma2(__hadd2(val_half2[1], b_half2[1]), scale_mask_2, res_half2[1]); out_half2[2] = __hfma2(__hadd2(val_half2[2], b_half2[2]), scale_mask_3, res_half2[2]); out_half2[3] = __hfma2(__hadd2(val_half2[3], b_half2[3]), scale_mask_4, res_half2[3]); outs_float4[i] = out_float4; } template <> void launch_ls_dropout_res_bias(float *out, const float *vals, uint8_t *mask, const float *bias, const float *residual, int total_count, int dim, float ratio, cudaStream_t stream) { int grid_dim = total_count >> 12; ls_dropout_res_bias_kernel<<>>( total_count, ratio, out, vals, mask, bias, residual, std::chrono::duration_cast( std::chrono::system_clock::now().time_since_epoch()) .count(), dim); } template <> void launch_ls_dropout_res_bias<__half>(__half *out, const __half *vals, uint8_t *mask, const __half *bias, const __half *residual, int total_count, int dim, float ratio, cudaStream_t stream) { int grid_dim = total_count >> 13; ls_dropout_res_bias_kernel<<>>( total_count, ratio, out, vals, mask, bias, residual, std::chrono::duration_cast( std::chrono::system_clock::now().time_since_epoch()) .count(), dim); } /** * @brief fused bias and dropout backward at the end of Attention and FFN * * @thread * gridDim.x = hidden_size / 8 * blockDim.x = 8 * blockDim.y = 1024 / 8 = 128 * * @param row_size batch_size * seq_len * @param ratio dropout ratio * @param in_grad [batch_size, seq_len, hidden_size], input grad * @param bias_grad [hidden_size], bias grad * @param out_grad [batch_size, seq_len, hidden_size], output grad * @param mask [batch_size, seq_len, hidden_size], dropout mask * @param hidden_size * @return void */ __global__ void ls_dropout_bias_bwd_kernel( const int row_size, const float ratio, float *__restrict__ in_grad, float *__restrict__ bias_grad, const float *__restrict__ out_grad, const uint8_t *__restrict__ mask, const int hidden_size) { const float scale = 1.f / (1.f - ratio); // every block generate 8 bias result __shared__ float tile[8][129]; cg::thread_block b = cg::this_thread_block(); cg::thread_block_tile g = cg::tiled_partition(b); int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8); int stride = hidden_size * 128; float local_sum = 0; int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); for (int r = threadIdx.y; r < row_size; r += 128) { float val = out_grad[idx]; val *= scale * static_cast(mask[idx]); local_sum += val; in_grad[idx] = val; idx += stride; } tile[threadIdx.x][threadIdx.y] = local_sum; __syncthreads(); float sum = 0; int tid = threadIdx.y * blockDim.x + threadIdx.x; int x = tid >> 7; int y = tid & (127); if (y < 32) { #pragma unroll for (int i = 0; i < 4; i++) { sum += tile[x][y + i * 32]; } } __syncthreads(); for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i); if (y == 0) tile[0][x] = sum; __syncthreads(); if (threadIdx.x < 8) { int pos = flat_2dim(blockIdx.x, threadIdx.x, 8); bias_grad[pos] = tile[0][threadIdx.x]; } } __global__ void ls_dropout_bias_bwd_kernel( const int row_size, const float ratio, __half *__restrict__ in_grad, __half *__restrict__ bias_grad, const __half *__restrict__ out_grad, const uint8_t *__restrict__ mask, const int hidden_size) { const __half2 scale = __float2half2_rn(1.f / (1.f - ratio)); __shared__ __half2 tile[8][129]; cg::thread_block b = cg::this_thread_block(); cg::thread_block_tile g = cg::tiled_partition(b); __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad); const __half2 *out_grad2 = reinterpret_cast(out_grad); __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad); int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8); int stride = hidden_size * 128; __half2 local_sum = __float2half2_rn(0.f); int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); for (int r = threadIdx.y; r < row_size; r += 128) { __half2 val = out_grad2[idx]; __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]); val *= scale * m2; local_sum += val; in_grad2[idx] = val; idx += stride; } tile[threadIdx.x][threadIdx.y] = local_sum; __syncthreads(); __half2 sum = __float2half2_rn(0.f); int tid = threadIdx.y * blockDim.x + threadIdx.x; int x = tid >> 7; int y = tid & (127); if (y < 32) { #pragma unroll for (int i = 0; i < 4; i++) { sum += tile[x][y + i * 32]; } } __syncthreads(); for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); if (y == 0) tile[0][x] = sum; __syncthreads(); if (threadIdx.x < 8) { int pos = flat_2dim(blockIdx.x, threadIdx.x, 8); bias_grad2[pos] = tile[0][threadIdx.x]; } } template void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad, const uint8_t *mask, int row_size, int dim, float ratio, cudaStream_t stream) { dim3 grid_dim((dim - 1) / 8 + 1); dim3 block_dim(8, 128); ls_dropout_bias_bwd_kernel<<>>( row_size, ratio, in_grad, bias_grad, out_grad, mask, dim); } template <> void launch_ls_dropout_bias_bwd(__half *in_grad, __half *bias_grad, const __half *out_grad, const uint8_t *mask, int row_size, int dim, float ratio, cudaStream_t stream) { dim >>= 1; dim3 grid_dim((dim - 1) / 8 + 1); dim3 block_dim(8, 128); ls_dropout_bias_bwd_kernel<<>>( row_size, ratio, in_grad, bias_grad, out_grad, mask, dim); } template void launch_ls_dropout_bias_bwd(float *in_grad, float *bias_grad, const float *out_grad, const uint8_t *mask, int row_size, int dim, float ratio, cudaStream_t stream); /** * @brief fused bias, activation, and dropout at the end of first ffn * * @thread * gridDim.x = hidden_size / 8 * blockDim.x = 8 * blockDim.y = 1024 / 8 = 128 * * @tparam act_type activation function, like kRelu, kGelu * @param total_count total elements * @param ratio drop ratio * @param out [batch_size, seq_len, hidden_size], float and __half * @param in [batch_size, seq_len, hidden_size], float and __half * @param mask [batch_size, seq_len, hidden_size], uint8 type * @param bias [hidden_size], ffn bias * @param seed seed to curand * @param hidden_size * @return void */ template __global__ void ls_dropout_act_bias_kernel( const int total_count, const float ratio, float *__restrict__ out, const float *__restrict__ in, uint8_t *__restrict__ mask, const float *__restrict__ bias, const int seed, const int hidden_size) { const float scale = 1.f / (1.f - ratio); int i = blockIdx.x * blockDim.x + threadIdx.x; if (i * 4 >= total_count) return; curandStatePhilox4_32_10_t state; curand_init(seed, i, 0, &state); uint8_t m[4]; float4 *out4 = reinterpret_cast(out); const float4 *data4 = reinterpret_cast(in); const float4 *bias4 = reinterpret_cast(bias); uint32_t *mask4 = reinterpret_cast(mask); float4 rand = curand_uniform4(&state); m[0] = (uint8_t)(rand.x > ratio); m[1] = (uint8_t)(rand.y > ratio); m[2] = (uint8_t)(rand.z > ratio); m[3] = (uint8_t)(rand.w > ratio); int bias_i = i % (hidden_size >> 2); uint32_t *m4 = reinterpret_cast(m); mask4[i] = m4[0]; const float4 input4 = data4[i]; const float4 b4 = __ldg(&bias4[bias_i]); float4 output4; output4.x = activation_kernel(input4.x + b4.x) * scale * m[0]; output4.y = activation_kernel(input4.y + b4.y) * scale * m[1]; output4.z = activation_kernel(input4.z + b4.z) * scale * m[2]; output4.w = activation_kernel(input4.w + b4.w) * scale * m[3]; out4[i] = output4; } template __global__ void ls_dropout_act_bias_kernel( const int total_count, const float ratio, __half *__restrict__ out, const __half *__restrict__ in, uint8_t *__restrict__ mask, const __half *__restrict__ bias, const int seed, const int hidden_size) { const float scale = 1.f / (1.f - ratio); int i = blockIdx.x * blockDim.x + threadIdx.x; if (i * 8 >= total_count) return; curandStatePhilox4_32_10_t state; curand_init(seed, i, 0, &state); const float4 *vals_float4 = reinterpret_cast(in); float4 *outs_float4 = reinterpret_cast(out); const float4 *bias4 = reinterpret_cast(bias); uint64_t *mask8 = reinterpret_cast(mask); uint8_t m[8]; float4 rand = curand_uniform4(&state); m[0] = (uint8_t)(rand.x > ratio); m[1] = (uint8_t)(rand.y > ratio); m[2] = (uint8_t)(rand.z > ratio); m[3] = (uint8_t)(rand.w > ratio); rand = curand_uniform4(&state); m[4] = (uint8_t)(rand.x > ratio); m[5] = (uint8_t)(rand.y > ratio); m[6] = (uint8_t)(rand.z > ratio); m[7] = (uint8_t)(rand.w > ratio); uint64_t *m8 = reinterpret_cast(m); mask8[i] = *m8; int bias_i = i % (hidden_size >> 3); float4 val_float4 = vals_float4[i]; const float4 b4 = __ldg(&bias4[bias_i]); float4 out_float4; __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); const __half2 *b_half2 = reinterpret_cast(&b4); __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]); __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]); __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]); __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]); out_half2[0] = __hmul2( activation_kernel(__hadd2(val_half2[0], b_half2[0])), scale_mask_1); out_half2[1] = __hmul2( activation_kernel(__hadd2(val_half2[1], b_half2[1])), scale_mask_2); out_half2[2] = __hmul2( activation_kernel(__hadd2(val_half2[2], b_half2[2])), scale_mask_3); out_half2[3] = __hmul2( activation_kernel(__hadd2(val_half2[3], b_half2[3])), scale_mask_4); outs_float4[i] = out_float4; } template <> void launch_ls_dropout_act_bias( float *out, const float *vals, uint8_t *mask, const float *bias, int total_count, int dim, float ratio, cudaStream_t stream) { int grid_dim = total_count >> 10; ls_dropout_act_bias_kernel <<>>( total_count, ratio, out, vals, mask, bias, std::chrono::duration_cast( std::chrono::system_clock::now().time_since_epoch()) .count(), dim); } template <> void launch_ls_dropout_act_bias( __half *out, const __half *vals, uint8_t *mask, const __half *bias, int total_count, int dim, float ratio, cudaStream_t stream) { int grid_dim = total_count >> 11; ls_dropout_act_bias_kernel <<>>( total_count, ratio, out, vals, mask, bias, std::chrono::duration_cast( std::chrono::system_clock::now().time_since_epoch()) .count(), dim); } template <> void launch_ls_dropout_act_bias( float *out, const float *vals, uint8_t *mask, const float *bias, int total_count, int dim, float ratio, cudaStream_t stream) { int grid_dim = total_count >> 10; ls_dropout_act_bias_kernel <<>>( total_count, ratio, out, vals, mask, bias, std::chrono::duration_cast( std::chrono::system_clock::now().time_since_epoch()) .count(), dim); } template <> void launch_ls_dropout_act_bias( __half *out, const __half *vals, uint8_t *mask, const __half *bias, int total_count, int dim, float ratio, cudaStream_t stream) { int grid_dim = total_count >> 11; ls_dropout_act_bias_kernel <<>>( total_count, ratio, out, vals, mask, bias, std::chrono::duration_cast( std::chrono::system_clock::now().time_since_epoch()) .count(), dim); } /** * @brief fused bias, activation, and dropout backward * * @thread * gridDim.x = total_count / 1024 * blockDim.x = 1024 * * @tparam act_type kRelu * @param row_size batch_size * seq_len * @param ratio dropout ratio * @param in_grad [batch_size, seq_len, hidden_size], input grad * @param bias_grad [hidden_size], bias grad * @param out_grad [batch_size, seq_len, hidden_size], output grad * @param mask [batch_size, seq_len, hidden_size], dropout mask * @param hidden_size * @return void */ template __global__ void ls_dropout_act_bias_bwd_kernel( const int row_size, const float ratio, T *in_grad, T *__restrict__ bias_grad, const T *__restrict__ input, const T *__restrict__ bias, const T *out_grad, const uint8_t *__restrict__ mask, const int hidden_size) { const float scale = 1.f / (1.f - ratio); __shared__ float tile[WARP_SIZE][WARP_SIZE + 1]; cg::thread_block b = cg::this_thread_block(); cg::thread_block_tile g = cg::tiled_partition(b); int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); int stride = hidden_size * WARP_SIZE; float local_sum = 0; int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); if (col_idx < hidden_size) { for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) { float val = out_grad[idx]; float in = input[idx]; float b = bias[idx % hidden_size]; val = activation_bwd_kernel( val * scale * static_cast(mask[idx]), in + b); local_sum += val; in_grad[idx] = val; idx += stride; } } tile[threadIdx.x][threadIdx.y] = local_sum; __syncthreads(); float sum = tile[threadIdx.y][threadIdx.x]; __syncthreads(); for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; __syncthreads(); if (threadIdx.y == 0) { int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); bias_grad[pos] = tile[0][threadIdx.x]; } } // @brief fused bias, activation, and dropout backward // It is deprecated for precision reason. Keep it for future optimization. // // template // __global__ void ls_dropout_act_bias_bwd_kernel( // const int row_size, const float ratio, __half * in_grad, // __half *__restrict__ bias_grad, const __half *__restrict__ input, const // __half *__restrict__ bias, const __half * out_grad, const uint8_t // *__restrict__ mask, const int hidden_size) { // const __half2 scale = __float2half2_rn(1.f / (1.f - ratio)); // __shared__ __half2 tile[WARP_SIZE][WARP_SIZE + 1]; // cg::thread_block b = cg::this_thread_block(); // cg::thread_block_tile g = cg::tiled_partition(b); // __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad); // __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad); // const __half2 *out_grad2 = reinterpret_cast(out_grad); // const __half2 *input2 = reinterpret_cast(input); // const __half2 *bias2 = reinterpret_cast(bias); // int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); // int stride = hidden_size * WARP_SIZE; // __half2 local_sum = __float2half2_rn(0.f); // int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); // if (col_idx < hidden_size) { // for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) { // __half2 val = out_grad2[idx]; // __half2 in2 = input2[idx]; // __half2 b2 = bias2[idx % hidden_size ]; // __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]); // val = activation_bwd_kernel(val * scale // * // m2, // in2+b2); // local_sum += val; // in_grad2[idx] = val; // idx += stride; // } // } // tile[threadIdx.x][threadIdx.y] = local_sum; // __syncthreads(); // __half2 sum = tile[threadIdx.y][threadIdx.x]; // __syncthreads(); // for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); // if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; // __syncthreads(); // if (threadIdx.y == 0) { // int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); // bias_grad2[pos] = tile[0][threadIdx.x]; // } // } template void launch_ls_dropout_act_bias_bwd(T *in_grad, T *bias_grad, const T *input, const T *bias, const T *out_grad, const uint8_t *mask, int row_size, int dim, float ratio, cudaStream_t stream) { dim3 grid_dim((dim - 1) / WARP_SIZE + 1); dim3 block_dim(WARP_SIZE, WARP_SIZE); ls_dropout_act_bias_bwd_kernel<<>>( row_size, ratio, in_grad, bias_grad, input, bias, out_grad, mask, dim); } // template <> // void launch_ls_dropout_act_bias_bwd( // __half *in_grad, __half *bias_grad,const __half *input, const __half // *bias, const __half *out_grad, const uint8_t *mask, int row_size, int // dim, float ratio, cudaStream_t stream) { // dim >>= 1; // dim3 grid_dim((dim - 1) / WARP_SIZE + 1); // dim3 block_dim(WARP_SIZE, WARP_SIZE); // ls_dropout_act_bias_bwd_kernel // <<>>(row_size, ratio, in_grad, // bias_grad, // input, bias,out_grad, mask, dim); // } template void launch_ls_dropout_act_bias_bwd( float *in_grad, float *bias_grad, const float *input, const float *bias, const float *out_grad, const uint8_t *mask, int row_size, int dim, float ratio, cudaStream_t stream); template void launch_ls_dropout_act_bias_bwd( __half *in_grad, __half *bias_grad, const __half *input, const __half *bias, const __half *out_grad, const uint8_t *mask, int row_size, int dim, float ratio, cudaStream_t stream); template void launch_ls_dropout_act_bias_bwd( float *in_grad, float *bias_grad, const float *input, const float *bias, const float *out_grad, const uint8_t *mask, int row_size, int dim, float ratio, cudaStream_t stream); template void launch_ls_dropout_act_bias_bwd( __half *in_grad, __half *bias_grad, const __half *input, const __half *bias, const __half *out_grad, const uint8_t *mask, int row_size, int dim, float ratio, cudaStream_t stream);