mirror of https://github.com/hpcaitech/ColossalAI
fix format (#572)
parent
cfb41297ff
commit
dfe423ae42
|
@ -7,8 +7,7 @@ namespace cg = cooperative_groups;
|
|||
const float LN_EPSILON = 1e-8f;
|
||||
#define TILE_DIM 32
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ T add_eps(T x) {
|
||||
template <typename T> __forceinline__ __device__ T add_eps(T x) {
|
||||
return fabsf(x) > LN_EPSILON ? x : (x < 0 ? -LN_EPSILON : LN_EPSILON);
|
||||
}
|
||||
|
||||
|
@ -138,13 +137,14 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
|
|||
|
||||
// __global__ void ker_layer_norm_x2(__half *ln_res, __half *vars,
|
||||
// __half *means, const __half *inp,
|
||||
// const __half *scale, const __half *bias,
|
||||
// int hidden_size) {
|
||||
// const __half *scale, const __half
|
||||
// *bias, int hidden_size) {
|
||||
// // step 0. compute local sum
|
||||
// float l_sum = 0;
|
||||
// float l_square_sum = 0;
|
||||
// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * 2 * hidden_size;
|
||||
// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * 2) {
|
||||
// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x *
|
||||
// 2) {
|
||||
// float4 val_f4 = inp_f4[idx];
|
||||
// float4 val_f4_1 = inp_f4[idx+1];
|
||||
// __half2 *val_h2 = (__half2 *)(&val_f4);
|
||||
|
@ -154,7 +154,8 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
|
|||
// float2 val_f2 = __half22float2(val_h2[i]);
|
||||
// float2 val_f2_1 = __half22float2(val_h2_1[i]);
|
||||
// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y;
|
||||
// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y + val_f2_1.x * val_f2_1.x + val_f2_1.y * val_f2_1.y;
|
||||
// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y + val_f2_1.x
|
||||
// * val_f2_1.x + val_f2_1.y * val_f2_1.y;
|
||||
// }
|
||||
// }
|
||||
|
||||
|
@ -176,7 +177,8 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
|
|||
|
||||
// // step 2. layer norm result
|
||||
// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 2;
|
||||
// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * 2) {
|
||||
// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x *
|
||||
// 2) {
|
||||
// // load scale, bias, input
|
||||
// float4 scale_f4 = __ldg((const float4 *)scale + idx);
|
||||
// __half2 *scale_h2 = (__half2 *)(&scale_f4);
|
||||
|
@ -202,9 +204,9 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
|
|||
// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x;
|
||||
// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y;
|
||||
// val_h2[i] = __float22half2_rn(val_f2);
|
||||
// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + bias_f2_1.x;
|
||||
// val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y + bias_f2_1.y;
|
||||
// val_h2_1[i] = __float22half2_rn(val_f2_1);
|
||||
// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x +
|
||||
// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y
|
||||
// + bias_f2_1.y; val_h2_1[i] = __float22half2_rn(val_f2_1);
|
||||
// }
|
||||
// output_f4[idx] = val_f4;
|
||||
// output_f4[idx+1] = val_f4_1;
|
||||
|
@ -213,13 +215,14 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
|
|||
|
||||
// __global__ void ker_layer_norm_x4(__half *ln_res, __half *vars,
|
||||
// __half *means, const __half *inp,
|
||||
// const __half *scale, const __half *bias,
|
||||
// int hidden_size) {
|
||||
// const __half *scale, const __half
|
||||
// *bias, int hidden_size) {
|
||||
// // step 0. compute local sum
|
||||
// float l_sum = 0;
|
||||
// float l_square_sum = 0;
|
||||
// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size * 4;
|
||||
// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * 4) {
|
||||
// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x *
|
||||
// 4) {
|
||||
// float4 val_f4 = inp_f4[idx];
|
||||
// float4 val_f4_1 = inp_f4[idx+1];
|
||||
// float4 val_f4_2 = inp_f4[idx+2];
|
||||
|
@ -234,11 +237,12 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
|
|||
// float2 val_f2_1 = __half22float2(val_h2_1[i]);
|
||||
// float2 val_f2_2 = __half22float2(val_h2_2[i]);
|
||||
// float2 val_f2_3 = __half22float2(val_h2_3[i]);
|
||||
// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y + val_f2_2.x + val_f2_2.y + val_f2_3.x + val_f2_3.y;
|
||||
// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y;
|
||||
// l_square_sum += val_f2_1.x * val_f2_1.x + val_f2_1.y * val_f2_1.y;
|
||||
// l_square_sum += val_f2_2.x * val_f2_2.x + val_f2_2.y * val_f2_2.y;
|
||||
// l_square_sum += val_f2_3.x * val_f2_3.x + val_f2_3.y * val_f2_3.y;
|
||||
// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y + val_f2_2.x +
|
||||
// val_f2_2.y + val_f2_3.x + val_f2_3.y; l_square_sum += val_f2.x *
|
||||
// val_f2.x + val_f2.y * val_f2.y; l_square_sum += val_f2_1.x * val_f2_1.x
|
||||
// + val_f2_1.y * val_f2_1.y; l_square_sum += val_f2_2.x * val_f2_2.x +
|
||||
// val_f2_2.y * val_f2_2.y; l_square_sum += val_f2_3.x * val_f2_3.x +
|
||||
// val_f2_3.y * val_f2_3.y;
|
||||
// }
|
||||
// }
|
||||
|
||||
|
@ -260,7 +264,8 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
|
|||
|
||||
// // step 2. layer norm result
|
||||
// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 4;
|
||||
// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * 4) {
|
||||
// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x *
|
||||
// 4) {
|
||||
// // load scale, bias, input
|
||||
// float4 scale_f4 = __ldg((const float4 *)scale + idx);
|
||||
// __half2 *scale_h2 = (__half2 *)(&scale_f4);
|
||||
|
@ -303,14 +308,14 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
|
|||
// float2 val_f2_3 = __half22float2(val_h2_3[i]);
|
||||
// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x;
|
||||
// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y;
|
||||
// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + bias_f2_1.x;
|
||||
// val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y + bias_f2_1.y;
|
||||
// val_f2_2.x = (val_f2_2.x - s_mean) * s_var * scale_f2_2.x + bias_f2_2.x;
|
||||
// val_f2_2.y = (val_f2_2.y - s_mean) * s_var * scale_f2_2.y + bias_f2_2.y;
|
||||
// val_f2_3.x = (val_f2_3.x - s_mean) * s_var * scale_f2_3.x + bias_f2_3.x;
|
||||
// val_f2_3.y = (val_f2_3.y - s_mean) * s_var * scale_f2_3.y + bias_f2_3.y;
|
||||
// val_h2[i] = __float22half2_rn(val_f2);
|
||||
// val_h2_1[i] = __float22half2_rn(val_f2_1);
|
||||
// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x +
|
||||
// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y
|
||||
// + bias_f2_1.y; val_f2_2.x = (val_f2_2.x - s_mean) * s_var *
|
||||
// scale_f2_2.x + bias_f2_2.x; val_f2_2.y = (val_f2_2.y - s_mean) * s_var
|
||||
// * scale_f2_2.y + bias_f2_2.y; val_f2_3.x = (val_f2_3.x - s_mean) *
|
||||
// s_var * scale_f2_3.x + bias_f2_3.x; val_f2_3.y = (val_f2_3.y - s_mean)
|
||||
// * s_var * scale_f2_3.y + bias_f2_3.y; val_h2[i] =
|
||||
// __float22half2_rn(val_f2); val_h2_1[i] = __float22half2_rn(val_f2_1);
|
||||
// val_h2_2[i] = __float22half2_rn(val_f2_2);
|
||||
// val_h2_3[i] = __float22half2_rn(val_f2_3);
|
||||
// }
|
||||
|
@ -414,11 +419,10 @@ means: [batch_size * seq_len], mean of ln forward,
|
|||
(gamma && betta) ^ (vars && means) should be true
|
||||
*/
|
||||
template <typename T>
|
||||
__global__ void ker_ln_bw_dgamma_dbetta(T *gamma_grad, T *betta_grad,
|
||||
const T *out_grad, const T *inp_or_out,
|
||||
const T *gamma, const T *betta,
|
||||
const T *vars, const T *means, int rows,
|
||||
int width) {
|
||||
__global__ void
|
||||
ker_ln_bw_dgamma_dbetta(T *gamma_grad, T *betta_grad, const T *out_grad,
|
||||
const T *inp_or_out, const T *gamma, const T *betta,
|
||||
const T *vars, const T *means, int rows, int width) {
|
||||
__shared__ float betta_buffer[TILE_DIM][TILE_DIM];
|
||||
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM];
|
||||
|
||||
|
@ -698,11 +702,10 @@ __global__ void ker_ln_bw_dinp<__half>(__half *inp_grad, const __half *out_grad,
|
|||
}
|
||||
|
||||
__global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad,
|
||||
const __half *residual_grad,
|
||||
const __half *inp_or_out,
|
||||
const __half *gamma, const __half *betta,
|
||||
const __half *vars, const __half *means,
|
||||
int hidden_dim) {
|
||||
const __half *residual_grad,
|
||||
const __half *inp_or_out, const __half *gamma,
|
||||
const __half *betta, const __half *vars,
|
||||
const __half *means, int hidden_dim) {
|
||||
int offset = blockIdx.x * hidden_dim * 2 + threadIdx.x * 2;
|
||||
|
||||
float2 dxhat[4], xhat[4];
|
||||
|
@ -762,7 +765,8 @@ __global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad,
|
|||
xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y);
|
||||
xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y);
|
||||
reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y;
|
||||
reduce_val[1] += xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y;
|
||||
reduce_val[1] +=
|
||||
xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y;
|
||||
}
|
||||
} else {
|
||||
// inp_or_out is input, xhat = (input - mean) * rsqrtf(var)
|
||||
|
@ -776,7 +780,8 @@ __global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad,
|
|||
xhat[i].y = (vinp.y - fmean) * var_rsqrt;
|
||||
xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt;
|
||||
reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y;
|
||||
reduce_val[1] += xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y;
|
||||
reduce_val[1] +=
|
||||
xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -802,7 +807,7 @@ __global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad,
|
|||
// Add the residual grad,
|
||||
// usually in pre-layer-norm for transformer layer
|
||||
float4 dresidual = ((const float4 *)residual_grad)[offset];
|
||||
float4 dresidual_1 = ((const float4 *)residual_grad)[offset+1];
|
||||
float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1];
|
||||
__half *hdres = reinterpret_cast<__half *>(&dresidual);
|
||||
__half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1);
|
||||
#pragma unroll
|
||||
|
@ -846,11 +851,10 @@ __global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad,
|
|||
}
|
||||
|
||||
__global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad,
|
||||
const __half *residual_grad,
|
||||
const __half *inp_or_out,
|
||||
const __half *gamma, const __half *betta,
|
||||
const __half *vars, const __half *means,
|
||||
int hidden_dim) {
|
||||
const __half *residual_grad,
|
||||
const __half *inp_or_out, const __half *gamma,
|
||||
const __half *betta, const __half *vars,
|
||||
const __half *means, int hidden_dim) {
|
||||
int offset = blockIdx.x * hidden_dim * 4 + threadIdx.x * 4;
|
||||
|
||||
float2 dxhat[4], xhat[4];
|
||||
|
@ -901,8 +905,9 @@ __global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad,
|
|||
dxhat_2[i].y = vdout_2.y * vgamma_2.y;
|
||||
dxhat_3[i].x = vdout_3.x * vgamma_3.x;
|
||||
dxhat_3[i].y = vdout_3.y * vgamma_3.y;
|
||||
reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y + dxhat_2[i].x +
|
||||
dxhat_2[i].y + dxhat_3[i].x + dxhat_3[i].y;
|
||||
reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y +
|
||||
dxhat_2[i].x + dxhat_2[i].y + dxhat_3[i].x +
|
||||
dxhat_3[i].y;
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -947,9 +952,12 @@ __global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad,
|
|||
xhat_2[i].y = (vout_2.y - vbetta_2.y) / add_eps(vgamma_2.y);
|
||||
xhat_3[i].y = (vout_3.y - vbetta_3.y) / add_eps(vgamma_3.y);
|
||||
reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y;
|
||||
reduce_val[1] += xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y;
|
||||
reduce_val[1] += xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y;
|
||||
reduce_val[1] += xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y;
|
||||
reduce_val[1] +=
|
||||
xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y;
|
||||
reduce_val[1] +=
|
||||
xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y;
|
||||
reduce_val[1] +=
|
||||
xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y;
|
||||
}
|
||||
} else {
|
||||
// inp_or_out is input, xhat = (input - mean) * rsqrtf(var)
|
||||
|
@ -969,9 +977,12 @@ __global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad,
|
|||
xhat_2[i].y = (vinp_2.y - fmean) * var_rsqrt;
|
||||
xhat_3[i].y = (vinp_3.y - fmean) * var_rsqrt;
|
||||
reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y;
|
||||
reduce_val[1] += xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y;
|
||||
reduce_val[1] += xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y;
|
||||
reduce_val[1] += xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y;
|
||||
reduce_val[1] +=
|
||||
xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y;
|
||||
reduce_val[1] +=
|
||||
xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y;
|
||||
reduce_val[1] +=
|
||||
xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -997,9 +1008,9 @@ __global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad,
|
|||
// Add the residual grad,
|
||||
// usually in pre-layer-norm for transformer layer
|
||||
float4 dresidual = ((const float4 *)residual_grad)[offset];
|
||||
float4 dresidual_1 = ((const float4 *)residual_grad)[offset+1];
|
||||
float4 dresidual_2 = ((const float4 *)residual_grad)[offset+2];
|
||||
float4 dresidual_3 = ((const float4 *)residual_grad)[offset+3];
|
||||
float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1];
|
||||
float4 dresidual_2 = ((const float4 *)residual_grad)[offset + 2];
|
||||
float4 dresidual_3 = ((const float4 *)residual_grad)[offset + 3];
|
||||
__half *hdres = reinterpret_cast<__half *>(&dresidual);
|
||||
__half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1);
|
||||
__half *hdres_2 = reinterpret_cast<__half *>(&dresidual_2);
|
||||
|
@ -1139,22 +1150,21 @@ void launch_ln_bw<__half>(__half *gamma_grad, __half *betta_grad,
|
|||
if (hidden_dim * 8 <= 8192) {
|
||||
int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
|
||||
ker_ln_bw_dinp<<<batch, nthread, 0, stream[1]>>>(
|
||||
inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means,
|
||||
hidden_dim);
|
||||
inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars,
|
||||
means, hidden_dim);
|
||||
} else if (hidden_dim * 8 > 8192 && hidden_dim * 8 <= 8192 * 2) {
|
||||
hidden_dim >>= 1;
|
||||
int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
|
||||
ker_ln_bw_dinp_x2<<<batch, nthread, 0, stream[1]>>>(
|
||||
inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means,
|
||||
hidden_dim);
|
||||
inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars,
|
||||
means, hidden_dim);
|
||||
} else if (hidden_dim * 8 > 2 * 8192 && hidden_dim * 8 <= 8192 * 4) {
|
||||
hidden_dim >>= 2;
|
||||
int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
|
||||
ker_ln_bw_dinp_x4<<<batch, nthread, 0, stream[1]>>>(
|
||||
inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means,
|
||||
hidden_dim);
|
||||
inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars,
|
||||
means, hidden_dim);
|
||||
} else {
|
||||
throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue