fix format (#572)

pull/673/head
BoxiangW 2022-03-31 15:47:24 +08:00 committed by binmakeswell
parent cfb41297ff
commit dfe423ae42
1 changed files with 73 additions and 63 deletions

View File

@ -7,8 +7,7 @@ namespace cg = cooperative_groups;
const float LN_EPSILON = 1e-8f;
#define TILE_DIM 32
template <typename T>
__forceinline__ __device__ T add_eps(T x) {
template <typename T> __forceinline__ __device__ T add_eps(T x) {
return fabsf(x) > LN_EPSILON ? x : (x < 0 ? -LN_EPSILON : LN_EPSILON);
}
@ -138,13 +137,14 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
// __global__ void ker_layer_norm_x2(__half *ln_res, __half *vars,
// __half *means, const __half *inp,
// const __half *scale, const __half *bias,
// int hidden_size) {
// const __half *scale, const __half
// *bias, int hidden_size) {
// // step 0. compute local sum
// float l_sum = 0;
// float l_square_sum = 0;
// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * 2 * hidden_size;
// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * 2) {
// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x *
// 2) {
// float4 val_f4 = inp_f4[idx];
// float4 val_f4_1 = inp_f4[idx+1];
// __half2 *val_h2 = (__half2 *)(&val_f4);
@ -154,7 +154,8 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
// float2 val_f2 = __half22float2(val_h2[i]);
// float2 val_f2_1 = __half22float2(val_h2_1[i]);
// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y;
// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y + val_f2_1.x * val_f2_1.x + val_f2_1.y * val_f2_1.y;
// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y + val_f2_1.x
// * val_f2_1.x + val_f2_1.y * val_f2_1.y;
// }
// }
@ -176,7 +177,8 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
// // step 2. layer norm result
// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 2;
// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * 2) {
// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x *
// 2) {
// // load scale, bias, input
// float4 scale_f4 = __ldg((const float4 *)scale + idx);
// __half2 *scale_h2 = (__half2 *)(&scale_f4);
@ -202,9 +204,9 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x;
// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y;
// val_h2[i] = __float22half2_rn(val_f2);
// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + bias_f2_1.x;
// val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y + bias_f2_1.y;
// val_h2_1[i] = __float22half2_rn(val_f2_1);
// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x +
// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y
// + bias_f2_1.y; val_h2_1[i] = __float22half2_rn(val_f2_1);
// }
// output_f4[idx] = val_f4;
// output_f4[idx+1] = val_f4_1;
@ -213,13 +215,14 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
// __global__ void ker_layer_norm_x4(__half *ln_res, __half *vars,
// __half *means, const __half *inp,
// const __half *scale, const __half *bias,
// int hidden_size) {
// const __half *scale, const __half
// *bias, int hidden_size) {
// // step 0. compute local sum
// float l_sum = 0;
// float l_square_sum = 0;
// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size * 4;
// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * 4) {
// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x *
// 4) {
// float4 val_f4 = inp_f4[idx];
// float4 val_f4_1 = inp_f4[idx+1];
// float4 val_f4_2 = inp_f4[idx+2];
@ -234,11 +237,12 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
// float2 val_f2_1 = __half22float2(val_h2_1[i]);
// float2 val_f2_2 = __half22float2(val_h2_2[i]);
// float2 val_f2_3 = __half22float2(val_h2_3[i]);
// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y + val_f2_2.x + val_f2_2.y + val_f2_3.x + val_f2_3.y;
// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y;
// l_square_sum += val_f2_1.x * val_f2_1.x + val_f2_1.y * val_f2_1.y;
// l_square_sum += val_f2_2.x * val_f2_2.x + val_f2_2.y * val_f2_2.y;
// l_square_sum += val_f2_3.x * val_f2_3.x + val_f2_3.y * val_f2_3.y;
// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y + val_f2_2.x +
// val_f2_2.y + val_f2_3.x + val_f2_3.y; l_square_sum += val_f2.x *
// val_f2.x + val_f2.y * val_f2.y; l_square_sum += val_f2_1.x * val_f2_1.x
// + val_f2_1.y * val_f2_1.y; l_square_sum += val_f2_2.x * val_f2_2.x +
// val_f2_2.y * val_f2_2.y; l_square_sum += val_f2_3.x * val_f2_3.x +
// val_f2_3.y * val_f2_3.y;
// }
// }
@ -260,7 +264,8 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
// // step 2. layer norm result
// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 4;
// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * 4) {
// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x *
// 4) {
// // load scale, bias, input
// float4 scale_f4 = __ldg((const float4 *)scale + idx);
// __half2 *scale_h2 = (__half2 *)(&scale_f4);
@ -303,14 +308,14 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
// float2 val_f2_3 = __half22float2(val_h2_3[i]);
// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x;
// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y;
// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + bias_f2_1.x;
// val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y + bias_f2_1.y;
// val_f2_2.x = (val_f2_2.x - s_mean) * s_var * scale_f2_2.x + bias_f2_2.x;
// val_f2_2.y = (val_f2_2.y - s_mean) * s_var * scale_f2_2.y + bias_f2_2.y;
// val_f2_3.x = (val_f2_3.x - s_mean) * s_var * scale_f2_3.x + bias_f2_3.x;
// val_f2_3.y = (val_f2_3.y - s_mean) * s_var * scale_f2_3.y + bias_f2_3.y;
// val_h2[i] = __float22half2_rn(val_f2);
// val_h2_1[i] = __float22half2_rn(val_f2_1);
// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x +
// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y
// + bias_f2_1.y; val_f2_2.x = (val_f2_2.x - s_mean) * s_var *
// scale_f2_2.x + bias_f2_2.x; val_f2_2.y = (val_f2_2.y - s_mean) * s_var
// * scale_f2_2.y + bias_f2_2.y; val_f2_3.x = (val_f2_3.x - s_mean) *
// s_var * scale_f2_3.x + bias_f2_3.x; val_f2_3.y = (val_f2_3.y - s_mean)
// * s_var * scale_f2_3.y + bias_f2_3.y; val_h2[i] =
// __float22half2_rn(val_f2); val_h2_1[i] = __float22half2_rn(val_f2_1);
// val_h2_2[i] = __float22half2_rn(val_f2_2);
// val_h2_3[i] = __float22half2_rn(val_f2_3);
// }
@ -414,11 +419,10 @@ means: [batch_size * seq_len], mean of ln forward,
(gamma && betta) ^ (vars && means) should be true
*/
template <typename T>
__global__ void ker_ln_bw_dgamma_dbetta(T *gamma_grad, T *betta_grad,
const T *out_grad, const T *inp_or_out,
const T *gamma, const T *betta,
const T *vars, const T *means, int rows,
int width) {
__global__ void
ker_ln_bw_dgamma_dbetta(T *gamma_grad, T *betta_grad, const T *out_grad,
const T *inp_or_out, const T *gamma, const T *betta,
const T *vars, const T *means, int rows, int width) {
__shared__ float betta_buffer[TILE_DIM][TILE_DIM];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM];
@ -698,11 +702,10 @@ __global__ void ker_ln_bw_dinp<__half>(__half *inp_grad, const __half *out_grad,
}
__global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad,
const __half *residual_grad,
const __half *inp_or_out,
const __half *gamma, const __half *betta,
const __half *vars, const __half *means,
int hidden_dim) {
const __half *residual_grad,
const __half *inp_or_out, const __half *gamma,
const __half *betta, const __half *vars,
const __half *means, int hidden_dim) {
int offset = blockIdx.x * hidden_dim * 2 + threadIdx.x * 2;
float2 dxhat[4], xhat[4];
@ -762,7 +765,8 @@ __global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad,
xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y);
xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y);
reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y;
reduce_val[1] += xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y;
reduce_val[1] +=
xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y;
}
} else {
// inp_or_out is input, xhat = (input - mean) * rsqrtf(var)
@ -776,7 +780,8 @@ __global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad,
xhat[i].y = (vinp.y - fmean) * var_rsqrt;
xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt;
reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y;
reduce_val[1] += xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y;
reduce_val[1] +=
xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y;
}
}
}
@ -802,7 +807,7 @@ __global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad,
// Add the residual grad,
// usually in pre-layer-norm for transformer layer
float4 dresidual = ((const float4 *)residual_grad)[offset];
float4 dresidual_1 = ((const float4 *)residual_grad)[offset+1];
float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1];
__half *hdres = reinterpret_cast<__half *>(&dresidual);
__half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1);
#pragma unroll
@ -846,11 +851,10 @@ __global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad,
}
__global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad,
const __half *residual_grad,
const __half *inp_or_out,
const __half *gamma, const __half *betta,
const __half *vars, const __half *means,
int hidden_dim) {
const __half *residual_grad,
const __half *inp_or_out, const __half *gamma,
const __half *betta, const __half *vars,
const __half *means, int hidden_dim) {
int offset = blockIdx.x * hidden_dim * 4 + threadIdx.x * 4;
float2 dxhat[4], xhat[4];
@ -901,8 +905,9 @@ __global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad,
dxhat_2[i].y = vdout_2.y * vgamma_2.y;
dxhat_3[i].x = vdout_3.x * vgamma_3.x;
dxhat_3[i].y = vdout_3.y * vgamma_3.y;
reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y + dxhat_2[i].x +
dxhat_2[i].y + dxhat_3[i].x + dxhat_3[i].y;
reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y +
dxhat_2[i].x + dxhat_2[i].y + dxhat_3[i].x +
dxhat_3[i].y;
}
/*
@ -947,9 +952,12 @@ __global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad,
xhat_2[i].y = (vout_2.y - vbetta_2.y) / add_eps(vgamma_2.y);
xhat_3[i].y = (vout_3.y - vbetta_3.y) / add_eps(vgamma_3.y);
reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y;
reduce_val[1] += xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y;
reduce_val[1] += xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y;
reduce_val[1] += xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y;
reduce_val[1] +=
xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y;
reduce_val[1] +=
xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y;
reduce_val[1] +=
xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y;
}
} else {
// inp_or_out is input, xhat = (input - mean) * rsqrtf(var)
@ -969,9 +977,12 @@ __global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad,
xhat_2[i].y = (vinp_2.y - fmean) * var_rsqrt;
xhat_3[i].y = (vinp_3.y - fmean) * var_rsqrt;
reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y;
reduce_val[1] += xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y;
reduce_val[1] += xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y;
reduce_val[1] += xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y;
reduce_val[1] +=
xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y;
reduce_val[1] +=
xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y;
reduce_val[1] +=
xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y;
}
}
}
@ -997,9 +1008,9 @@ __global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad,
// Add the residual grad,
// usually in pre-layer-norm for transformer layer
float4 dresidual = ((const float4 *)residual_grad)[offset];
float4 dresidual_1 = ((const float4 *)residual_grad)[offset+1];
float4 dresidual_2 = ((const float4 *)residual_grad)[offset+2];
float4 dresidual_3 = ((const float4 *)residual_grad)[offset+3];
float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1];
float4 dresidual_2 = ((const float4 *)residual_grad)[offset + 2];
float4 dresidual_3 = ((const float4 *)residual_grad)[offset + 3];
__half *hdres = reinterpret_cast<__half *>(&dresidual);
__half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1);
__half *hdres_2 = reinterpret_cast<__half *>(&dresidual_2);
@ -1139,22 +1150,21 @@ void launch_ln_bw<__half>(__half *gamma_grad, __half *betta_grad,
if (hidden_dim * 8 <= 8192) {
int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
ker_ln_bw_dinp<<<batch, nthread, 0, stream[1]>>>(
inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means,
hidden_dim);
inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars,
means, hidden_dim);
} else if (hidden_dim * 8 > 8192 && hidden_dim * 8 <= 8192 * 2) {
hidden_dim >>= 1;
int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
ker_ln_bw_dinp_x2<<<batch, nthread, 0, stream[1]>>>(
inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means,
hidden_dim);
inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars,
means, hidden_dim);
} else if (hidden_dim * 8 > 2 * 8192 && hidden_dim * 8 <= 8192 * 4) {
hidden_dim >>= 2;
int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
ker_ln_bw_dinp_x4<<<batch, nthread, 0, stream[1]>>>(
inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means,
hidden_dim);
inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars,
means, hidden_dim);
} else {
throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768");
}
}