fix format (#572)

pull/673/head
BoxiangW 2022-03-31 15:47:24 +08:00 committed by binmakeswell
parent cfb41297ff
commit dfe423ae42
1 changed files with 73 additions and 63 deletions

View File

@ -7,8 +7,7 @@ namespace cg = cooperative_groups;
const float LN_EPSILON = 1e-8f; const float LN_EPSILON = 1e-8f;
#define TILE_DIM 32 #define TILE_DIM 32
template <typename T> template <typename T> __forceinline__ __device__ T add_eps(T x) {
__forceinline__ __device__ T add_eps(T x) {
return fabsf(x) > LN_EPSILON ? x : (x < 0 ? -LN_EPSILON : LN_EPSILON); return fabsf(x) > LN_EPSILON ? x : (x < 0 ? -LN_EPSILON : LN_EPSILON);
} }
@ -138,13 +137,14 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
// __global__ void ker_layer_norm_x2(__half *ln_res, __half *vars, // __global__ void ker_layer_norm_x2(__half *ln_res, __half *vars,
// __half *means, const __half *inp, // __half *means, const __half *inp,
// const __half *scale, const __half *bias, // const __half *scale, const __half
// int hidden_size) { // *bias, int hidden_size) {
// // step 0. compute local sum // // step 0. compute local sum
// float l_sum = 0; // float l_sum = 0;
// float l_square_sum = 0; // float l_square_sum = 0;
// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * 2 * hidden_size; // const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * 2 * hidden_size;
// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * 2) { // for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x *
// 2) {
// float4 val_f4 = inp_f4[idx]; // float4 val_f4 = inp_f4[idx];
// float4 val_f4_1 = inp_f4[idx+1]; // float4 val_f4_1 = inp_f4[idx+1];
// __half2 *val_h2 = (__half2 *)(&val_f4); // __half2 *val_h2 = (__half2 *)(&val_f4);
@ -154,7 +154,8 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
// float2 val_f2 = __half22float2(val_h2[i]); // float2 val_f2 = __half22float2(val_h2[i]);
// float2 val_f2_1 = __half22float2(val_h2_1[i]); // float2 val_f2_1 = __half22float2(val_h2_1[i]);
// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y; // l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y;
// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y + val_f2_1.x * val_f2_1.x + val_f2_1.y * val_f2_1.y; // l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y + val_f2_1.x
// * val_f2_1.x + val_f2_1.y * val_f2_1.y;
// } // }
// } // }
@ -176,7 +177,8 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
// // step 2. layer norm result // // step 2. layer norm result
// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 2; // float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 2;
// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * 2) { // for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x *
// 2) {
// // load scale, bias, input // // load scale, bias, input
// float4 scale_f4 = __ldg((const float4 *)scale + idx); // float4 scale_f4 = __ldg((const float4 *)scale + idx);
// __half2 *scale_h2 = (__half2 *)(&scale_f4); // __half2 *scale_h2 = (__half2 *)(&scale_f4);
@ -202,9 +204,9 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; // val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x;
// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; // val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y;
// val_h2[i] = __float22half2_rn(val_f2); // val_h2[i] = __float22half2_rn(val_f2);
// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + bias_f2_1.x; // val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x +
// val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y + bias_f2_1.y; // bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y
// val_h2_1[i] = __float22half2_rn(val_f2_1); // + bias_f2_1.y; val_h2_1[i] = __float22half2_rn(val_f2_1);
// } // }
// output_f4[idx] = val_f4; // output_f4[idx] = val_f4;
// output_f4[idx+1] = val_f4_1; // output_f4[idx+1] = val_f4_1;
@ -213,13 +215,14 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
// __global__ void ker_layer_norm_x4(__half *ln_res, __half *vars, // __global__ void ker_layer_norm_x4(__half *ln_res, __half *vars,
// __half *means, const __half *inp, // __half *means, const __half *inp,
// const __half *scale, const __half *bias, // const __half *scale, const __half
// int hidden_size) { // *bias, int hidden_size) {
// // step 0. compute local sum // // step 0. compute local sum
// float l_sum = 0; // float l_sum = 0;
// float l_square_sum = 0; // float l_square_sum = 0;
// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size * 4; // const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size * 4;
// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * 4) { // for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x *
// 4) {
// float4 val_f4 = inp_f4[idx]; // float4 val_f4 = inp_f4[idx];
// float4 val_f4_1 = inp_f4[idx+1]; // float4 val_f4_1 = inp_f4[idx+1];
// float4 val_f4_2 = inp_f4[idx+2]; // float4 val_f4_2 = inp_f4[idx+2];
@ -234,11 +237,12 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
// float2 val_f2_1 = __half22float2(val_h2_1[i]); // float2 val_f2_1 = __half22float2(val_h2_1[i]);
// float2 val_f2_2 = __half22float2(val_h2_2[i]); // float2 val_f2_2 = __half22float2(val_h2_2[i]);
// float2 val_f2_3 = __half22float2(val_h2_3[i]); // float2 val_f2_3 = __half22float2(val_h2_3[i]);
// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y + val_f2_2.x + val_f2_2.y + val_f2_3.x + val_f2_3.y; // l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y + val_f2_2.x +
// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y; // val_f2_2.y + val_f2_3.x + val_f2_3.y; l_square_sum += val_f2.x *
// l_square_sum += val_f2_1.x * val_f2_1.x + val_f2_1.y * val_f2_1.y; // val_f2.x + val_f2.y * val_f2.y; l_square_sum += val_f2_1.x * val_f2_1.x
// l_square_sum += val_f2_2.x * val_f2_2.x + val_f2_2.y * val_f2_2.y; // + val_f2_1.y * val_f2_1.y; l_square_sum += val_f2_2.x * val_f2_2.x +
// l_square_sum += val_f2_3.x * val_f2_3.x + val_f2_3.y * val_f2_3.y; // val_f2_2.y * val_f2_2.y; l_square_sum += val_f2_3.x * val_f2_3.x +
// val_f2_3.y * val_f2_3.y;
// } // }
// } // }
@ -260,7 +264,8 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
// // step 2. layer norm result // // step 2. layer norm result
// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 4; // float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 4;
// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * 4) { // for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x *
// 4) {
// // load scale, bias, input // // load scale, bias, input
// float4 scale_f4 = __ldg((const float4 *)scale + idx); // float4 scale_f4 = __ldg((const float4 *)scale + idx);
// __half2 *scale_h2 = (__half2 *)(&scale_f4); // __half2 *scale_h2 = (__half2 *)(&scale_f4);
@ -303,14 +308,14 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
// float2 val_f2_3 = __half22float2(val_h2_3[i]); // float2 val_f2_3 = __half22float2(val_h2_3[i]);
// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; // val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x;
// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; // val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y;
// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + bias_f2_1.x; // val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x +
// val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y + bias_f2_1.y; // bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y
// val_f2_2.x = (val_f2_2.x - s_mean) * s_var * scale_f2_2.x + bias_f2_2.x; // + bias_f2_1.y; val_f2_2.x = (val_f2_2.x - s_mean) * s_var *
// val_f2_2.y = (val_f2_2.y - s_mean) * s_var * scale_f2_2.y + bias_f2_2.y; // scale_f2_2.x + bias_f2_2.x; val_f2_2.y = (val_f2_2.y - s_mean) * s_var
// val_f2_3.x = (val_f2_3.x - s_mean) * s_var * scale_f2_3.x + bias_f2_3.x; // * scale_f2_2.y + bias_f2_2.y; val_f2_3.x = (val_f2_3.x - s_mean) *
// val_f2_3.y = (val_f2_3.y - s_mean) * s_var * scale_f2_3.y + bias_f2_3.y; // s_var * scale_f2_3.x + bias_f2_3.x; val_f2_3.y = (val_f2_3.y - s_mean)
// val_h2[i] = __float22half2_rn(val_f2); // * s_var * scale_f2_3.y + bias_f2_3.y; val_h2[i] =
// val_h2_1[i] = __float22half2_rn(val_f2_1); // __float22half2_rn(val_f2); val_h2_1[i] = __float22half2_rn(val_f2_1);
// val_h2_2[i] = __float22half2_rn(val_f2_2); // val_h2_2[i] = __float22half2_rn(val_f2_2);
// val_h2_3[i] = __float22half2_rn(val_f2_3); // val_h2_3[i] = __float22half2_rn(val_f2_3);
// } // }
@ -414,11 +419,10 @@ means: [batch_size * seq_len], mean of ln forward,
(gamma && betta) ^ (vars && means) should be true (gamma && betta) ^ (vars && means) should be true
*/ */
template <typename T> template <typename T>
__global__ void ker_ln_bw_dgamma_dbetta(T *gamma_grad, T *betta_grad, __global__ void
const T *out_grad, const T *inp_or_out, ker_ln_bw_dgamma_dbetta(T *gamma_grad, T *betta_grad, const T *out_grad,
const T *gamma, const T *betta, const T *inp_or_out, const T *gamma, const T *betta,
const T *vars, const T *means, int rows, const T *vars, const T *means, int rows, int width) {
int width) {
__shared__ float betta_buffer[TILE_DIM][TILE_DIM]; __shared__ float betta_buffer[TILE_DIM][TILE_DIM];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM]; __shared__ float gamma_buffer[TILE_DIM][TILE_DIM];
@ -699,10 +703,9 @@ __global__ void ker_ln_bw_dinp<__half>(__half *inp_grad, const __half *out_grad,
__global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad, __global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad,
const __half *residual_grad, const __half *residual_grad,
const __half *inp_or_out, const __half *inp_or_out, const __half *gamma,
const __half *gamma, const __half *betta, const __half *betta, const __half *vars,
const __half *vars, const __half *means, const __half *means, int hidden_dim) {
int hidden_dim) {
int offset = blockIdx.x * hidden_dim * 2 + threadIdx.x * 2; int offset = blockIdx.x * hidden_dim * 2 + threadIdx.x * 2;
float2 dxhat[4], xhat[4]; float2 dxhat[4], xhat[4];
@ -762,7 +765,8 @@ __global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad,
xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y);
xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y); xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y);
reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y;
reduce_val[1] += xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; reduce_val[1] +=
xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y;
} }
} else { } else {
// inp_or_out is input, xhat = (input - mean) * rsqrtf(var) // inp_or_out is input, xhat = (input - mean) * rsqrtf(var)
@ -776,7 +780,8 @@ __global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad,
xhat[i].y = (vinp.y - fmean) * var_rsqrt; xhat[i].y = (vinp.y - fmean) * var_rsqrt;
xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt; xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt;
reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y;
reduce_val[1] += xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; reduce_val[1] +=
xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y;
} }
} }
} }
@ -802,7 +807,7 @@ __global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad,
// Add the residual grad, // Add the residual grad,
// usually in pre-layer-norm for transformer layer // usually in pre-layer-norm for transformer layer
float4 dresidual = ((const float4 *)residual_grad)[offset]; float4 dresidual = ((const float4 *)residual_grad)[offset];
float4 dresidual_1 = ((const float4 *)residual_grad)[offset+1]; float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1];
__half *hdres = reinterpret_cast<__half *>(&dresidual); __half *hdres = reinterpret_cast<__half *>(&dresidual);
__half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1); __half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1);
#pragma unroll #pragma unroll
@ -847,10 +852,9 @@ __global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad,
__global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad, __global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad,
const __half *residual_grad, const __half *residual_grad,
const __half *inp_or_out, const __half *inp_or_out, const __half *gamma,
const __half *gamma, const __half *betta, const __half *betta, const __half *vars,
const __half *vars, const __half *means, const __half *means, int hidden_dim) {
int hidden_dim) {
int offset = blockIdx.x * hidden_dim * 4 + threadIdx.x * 4; int offset = blockIdx.x * hidden_dim * 4 + threadIdx.x * 4;
float2 dxhat[4], xhat[4]; float2 dxhat[4], xhat[4];
@ -901,8 +905,9 @@ __global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad,
dxhat_2[i].y = vdout_2.y * vgamma_2.y; dxhat_2[i].y = vdout_2.y * vgamma_2.y;
dxhat_3[i].x = vdout_3.x * vgamma_3.x; dxhat_3[i].x = vdout_3.x * vgamma_3.x;
dxhat_3[i].y = vdout_3.y * vgamma_3.y; dxhat_3[i].y = vdout_3.y * vgamma_3.y;
reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y + dxhat_2[i].x + reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y +
dxhat_2[i].y + dxhat_3[i].x + dxhat_3[i].y; dxhat_2[i].x + dxhat_2[i].y + dxhat_3[i].x +
dxhat_3[i].y;
} }
/* /*
@ -947,9 +952,12 @@ __global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad,
xhat_2[i].y = (vout_2.y - vbetta_2.y) / add_eps(vgamma_2.y); xhat_2[i].y = (vout_2.y - vbetta_2.y) / add_eps(vgamma_2.y);
xhat_3[i].y = (vout_3.y - vbetta_3.y) / add_eps(vgamma_3.y); xhat_3[i].y = (vout_3.y - vbetta_3.y) / add_eps(vgamma_3.y);
reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y;
reduce_val[1] += xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; reduce_val[1] +=
reduce_val[1] += xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y;
reduce_val[1] += xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; reduce_val[1] +=
xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y;
reduce_val[1] +=
xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y;
} }
} else { } else {
// inp_or_out is input, xhat = (input - mean) * rsqrtf(var) // inp_or_out is input, xhat = (input - mean) * rsqrtf(var)
@ -969,9 +977,12 @@ __global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad,
xhat_2[i].y = (vinp_2.y - fmean) * var_rsqrt; xhat_2[i].y = (vinp_2.y - fmean) * var_rsqrt;
xhat_3[i].y = (vinp_3.y - fmean) * var_rsqrt; xhat_3[i].y = (vinp_3.y - fmean) * var_rsqrt;
reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y;
reduce_val[1] += xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; reduce_val[1] +=
reduce_val[1] += xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y;
reduce_val[1] += xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; reduce_val[1] +=
xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y;
reduce_val[1] +=
xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y;
} }
} }
} }
@ -997,9 +1008,9 @@ __global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad,
// Add the residual grad, // Add the residual grad,
// usually in pre-layer-norm for transformer layer // usually in pre-layer-norm for transformer layer
float4 dresidual = ((const float4 *)residual_grad)[offset]; float4 dresidual = ((const float4 *)residual_grad)[offset];
float4 dresidual_1 = ((const float4 *)residual_grad)[offset+1]; float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1];
float4 dresidual_2 = ((const float4 *)residual_grad)[offset+2]; float4 dresidual_2 = ((const float4 *)residual_grad)[offset + 2];
float4 dresidual_3 = ((const float4 *)residual_grad)[offset+3]; float4 dresidual_3 = ((const float4 *)residual_grad)[offset + 3];
__half *hdres = reinterpret_cast<__half *>(&dresidual); __half *hdres = reinterpret_cast<__half *>(&dresidual);
__half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1); __half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1);
__half *hdres_2 = reinterpret_cast<__half *>(&dresidual_2); __half *hdres_2 = reinterpret_cast<__half *>(&dresidual_2);
@ -1139,22 +1150,21 @@ void launch_ln_bw<__half>(__half *gamma_grad, __half *betta_grad,
if (hidden_dim * 8 <= 8192) { if (hidden_dim * 8 <= 8192) {
int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
ker_ln_bw_dinp<<<batch, nthread, 0, stream[1]>>>( ker_ln_bw_dinp<<<batch, nthread, 0, stream[1]>>>(
inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means, inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars,
hidden_dim); means, hidden_dim);
} else if (hidden_dim * 8 > 8192 && hidden_dim * 8 <= 8192 * 2) { } else if (hidden_dim * 8 > 8192 && hidden_dim * 8 <= 8192 * 2) {
hidden_dim >>= 1; hidden_dim >>= 1;
int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
ker_ln_bw_dinp_x2<<<batch, nthread, 0, stream[1]>>>( ker_ln_bw_dinp_x2<<<batch, nthread, 0, stream[1]>>>(
inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means, inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars,
hidden_dim); means, hidden_dim);
} else if (hidden_dim * 8 > 2 * 8192 && hidden_dim * 8 <= 8192 * 4) { } else if (hidden_dim * 8 > 2 * 8192 && hidden_dim * 8 <= 8192 * 4) {
hidden_dim >>= 2; hidden_dim >>= 2;
int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
ker_ln_bw_dinp_x4<<<batch, nthread, 0, stream[1]>>>( ker_ln_bw_dinp_x4<<<batch, nthread, 0, stream[1]>>>(
inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means, inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars,
hidden_dim); means, hidden_dim);
} else { } else {
throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768"); throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768");
} }
} }