diff --git a/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu index d992e7e14..5eec0d662 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu @@ -7,8 +7,7 @@ namespace cg = cooperative_groups; const float LN_EPSILON = 1e-8f; #define TILE_DIM 32 -template -__forceinline__ __device__ T add_eps(T x) { +template __forceinline__ __device__ T add_eps(T x) { return fabsf(x) > LN_EPSILON ? x : (x < 0 ? -LN_EPSILON : LN_EPSILON); } @@ -138,13 +137,14 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars, // __global__ void ker_layer_norm_x2(__half *ln_res, __half *vars, // __half *means, const __half *inp, -// const __half *scale, const __half *bias, -// int hidden_size) { +// const __half *scale, const __half +// *bias, int hidden_size) { // // step 0. compute local sum // float l_sum = 0; // float l_square_sum = 0; // const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * 2 * hidden_size; -// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * 2) { +// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * +// 2) { // float4 val_f4 = inp_f4[idx]; // float4 val_f4_1 = inp_f4[idx+1]; // __half2 *val_h2 = (__half2 *)(&val_f4); @@ -154,7 +154,8 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars, // float2 val_f2 = __half22float2(val_h2[i]); // float2 val_f2_1 = __half22float2(val_h2_1[i]); // l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y; -// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y + val_f2_1.x * val_f2_1.x + val_f2_1.y * val_f2_1.y; +// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y + val_f2_1.x +// * val_f2_1.x + val_f2_1.y * val_f2_1.y; // } // } @@ -176,7 +177,8 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars, // // step 2. layer norm result // float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 2; -// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * 2) { +// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * +// 2) { // // load scale, bias, input // float4 scale_f4 = __ldg((const float4 *)scale + idx); // __half2 *scale_h2 = (__half2 *)(&scale_f4); @@ -202,9 +204,9 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars, // val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; // val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; // val_h2[i] = __float22half2_rn(val_f2); -// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + bias_f2_1.x; -// val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y + bias_f2_1.y; -// val_h2_1[i] = __float22half2_rn(val_f2_1); +// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + +// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y +// + bias_f2_1.y; val_h2_1[i] = __float22half2_rn(val_f2_1); // } // output_f4[idx] = val_f4; // output_f4[idx+1] = val_f4_1; @@ -213,13 +215,14 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars, // __global__ void ker_layer_norm_x4(__half *ln_res, __half *vars, // __half *means, const __half *inp, -// const __half *scale, const __half *bias, -// int hidden_size) { +// const __half *scale, const __half +// *bias, int hidden_size) { // // step 0. compute local sum // float l_sum = 0; // float l_square_sum = 0; // const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size * 4; -// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * 4) { +// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * +// 4) { // float4 val_f4 = inp_f4[idx]; // float4 val_f4_1 = inp_f4[idx+1]; // float4 val_f4_2 = inp_f4[idx+2]; @@ -234,11 +237,12 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars, // float2 val_f2_1 = __half22float2(val_h2_1[i]); // float2 val_f2_2 = __half22float2(val_h2_2[i]); // float2 val_f2_3 = __half22float2(val_h2_3[i]); -// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y + val_f2_2.x + val_f2_2.y + val_f2_3.x + val_f2_3.y; -// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y; -// l_square_sum += val_f2_1.x * val_f2_1.x + val_f2_1.y * val_f2_1.y; -// l_square_sum += val_f2_2.x * val_f2_2.x + val_f2_2.y * val_f2_2.y; -// l_square_sum += val_f2_3.x * val_f2_3.x + val_f2_3.y * val_f2_3.y; +// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y + val_f2_2.x + +// val_f2_2.y + val_f2_3.x + val_f2_3.y; l_square_sum += val_f2.x * +// val_f2.x + val_f2.y * val_f2.y; l_square_sum += val_f2_1.x * val_f2_1.x +// + val_f2_1.y * val_f2_1.y; l_square_sum += val_f2_2.x * val_f2_2.x + +// val_f2_2.y * val_f2_2.y; l_square_sum += val_f2_3.x * val_f2_3.x + +// val_f2_3.y * val_f2_3.y; // } // } @@ -260,7 +264,8 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars, // // step 2. layer norm result // float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 4; -// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * 4) { +// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * +// 4) { // // load scale, bias, input // float4 scale_f4 = __ldg((const float4 *)scale + idx); // __half2 *scale_h2 = (__half2 *)(&scale_f4); @@ -303,14 +308,14 @@ __global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars, // float2 val_f2_3 = __half22float2(val_h2_3[i]); // val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; // val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; -// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + bias_f2_1.x; -// val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y + bias_f2_1.y; -// val_f2_2.x = (val_f2_2.x - s_mean) * s_var * scale_f2_2.x + bias_f2_2.x; -// val_f2_2.y = (val_f2_2.y - s_mean) * s_var * scale_f2_2.y + bias_f2_2.y; -// val_f2_3.x = (val_f2_3.x - s_mean) * s_var * scale_f2_3.x + bias_f2_3.x; -// val_f2_3.y = (val_f2_3.y - s_mean) * s_var * scale_f2_3.y + bias_f2_3.y; -// val_h2[i] = __float22half2_rn(val_f2); -// val_h2_1[i] = __float22half2_rn(val_f2_1); +// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + +// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y +// + bias_f2_1.y; val_f2_2.x = (val_f2_2.x - s_mean) * s_var * +// scale_f2_2.x + bias_f2_2.x; val_f2_2.y = (val_f2_2.y - s_mean) * s_var +// * scale_f2_2.y + bias_f2_2.y; val_f2_3.x = (val_f2_3.x - s_mean) * +// s_var * scale_f2_3.x + bias_f2_3.x; val_f2_3.y = (val_f2_3.y - s_mean) +// * s_var * scale_f2_3.y + bias_f2_3.y; val_h2[i] = +// __float22half2_rn(val_f2); val_h2_1[i] = __float22half2_rn(val_f2_1); // val_h2_2[i] = __float22half2_rn(val_f2_2); // val_h2_3[i] = __float22half2_rn(val_f2_3); // } @@ -414,11 +419,10 @@ means: [batch_size * seq_len], mean of ln forward, (gamma && betta) ^ (vars && means) should be true */ template -__global__ void ker_ln_bw_dgamma_dbetta(T *gamma_grad, T *betta_grad, - const T *out_grad, const T *inp_or_out, - const T *gamma, const T *betta, - const T *vars, const T *means, int rows, - int width) { +__global__ void +ker_ln_bw_dgamma_dbetta(T *gamma_grad, T *betta_grad, const T *out_grad, + const T *inp_or_out, const T *gamma, const T *betta, + const T *vars, const T *means, int rows, int width) { __shared__ float betta_buffer[TILE_DIM][TILE_DIM]; __shared__ float gamma_buffer[TILE_DIM][TILE_DIM]; @@ -698,11 +702,10 @@ __global__ void ker_ln_bw_dinp<__half>(__half *inp_grad, const __half *out_grad, } __global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad, - const __half *residual_grad, - const __half *inp_or_out, - const __half *gamma, const __half *betta, - const __half *vars, const __half *means, - int hidden_dim) { + const __half *residual_grad, + const __half *inp_or_out, const __half *gamma, + const __half *betta, const __half *vars, + const __half *means, int hidden_dim) { int offset = blockIdx.x * hidden_dim * 2 + threadIdx.x * 2; float2 dxhat[4], xhat[4]; @@ -762,7 +765,8 @@ __global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad, xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y); reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; + reduce_val[1] += + xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; } } else { // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) @@ -776,7 +780,8 @@ __global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad, xhat[i].y = (vinp.y - fmean) * var_rsqrt; xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt; reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; + reduce_val[1] += + xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; } } } @@ -802,7 +807,7 @@ __global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad, // Add the residual grad, // usually in pre-layer-norm for transformer layer float4 dresidual = ((const float4 *)residual_grad)[offset]; - float4 dresidual_1 = ((const float4 *)residual_grad)[offset+1]; + float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1]; __half *hdres = reinterpret_cast<__half *>(&dresidual); __half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1); #pragma unroll @@ -846,11 +851,10 @@ __global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad, } __global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad, - const __half *residual_grad, - const __half *inp_or_out, - const __half *gamma, const __half *betta, - const __half *vars, const __half *means, - int hidden_dim) { + const __half *residual_grad, + const __half *inp_or_out, const __half *gamma, + const __half *betta, const __half *vars, + const __half *means, int hidden_dim) { int offset = blockIdx.x * hidden_dim * 4 + threadIdx.x * 4; float2 dxhat[4], xhat[4]; @@ -901,8 +905,9 @@ __global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad, dxhat_2[i].y = vdout_2.y * vgamma_2.y; dxhat_3[i].x = vdout_3.x * vgamma_3.x; dxhat_3[i].y = vdout_3.y * vgamma_3.y; - reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y + dxhat_2[i].x + - dxhat_2[i].y + dxhat_3[i].x + dxhat_3[i].y; + reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y + + dxhat_2[i].x + dxhat_2[i].y + dxhat_3[i].x + + dxhat_3[i].y; } /* @@ -947,9 +952,12 @@ __global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad, xhat_2[i].y = (vout_2.y - vbetta_2.y) / add_eps(vgamma_2.y); xhat_3[i].y = (vout_3.y - vbetta_3.y) / add_eps(vgamma_3.y); reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; - reduce_val[1] += xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; - reduce_val[1] += xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; + reduce_val[1] += + xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; + reduce_val[1] += + xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; + reduce_val[1] += + xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; } } else { // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) @@ -969,9 +977,12 @@ __global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad, xhat_2[i].y = (vinp_2.y - fmean) * var_rsqrt; xhat_3[i].y = (vinp_3.y - fmean) * var_rsqrt; reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; - reduce_val[1] += xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; - reduce_val[1] += xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; + reduce_val[1] += + xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; + reduce_val[1] += + xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; + reduce_val[1] += + xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; } } } @@ -997,9 +1008,9 @@ __global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad, // Add the residual grad, // usually in pre-layer-norm for transformer layer float4 dresidual = ((const float4 *)residual_grad)[offset]; - float4 dresidual_1 = ((const float4 *)residual_grad)[offset+1]; - float4 dresidual_2 = ((const float4 *)residual_grad)[offset+2]; - float4 dresidual_3 = ((const float4 *)residual_grad)[offset+3]; + float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1]; + float4 dresidual_2 = ((const float4 *)residual_grad)[offset + 2]; + float4 dresidual_3 = ((const float4 *)residual_grad)[offset + 3]; __half *hdres = reinterpret_cast<__half *>(&dresidual); __half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1); __half *hdres_2 = reinterpret_cast<__half *>(&dresidual_2); @@ -1139,22 +1150,21 @@ void launch_ln_bw<__half>(__half *gamma_grad, __half *betta_grad, if (hidden_dim * 8 <= 8192) { int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); ker_ln_bw_dinp<<>>( - inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means, - hidden_dim); + inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, + means, hidden_dim); } else if (hidden_dim * 8 > 8192 && hidden_dim * 8 <= 8192 * 2) { hidden_dim >>= 1; int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); ker_ln_bw_dinp_x2<<>>( - inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means, - hidden_dim); + inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, + means, hidden_dim); } else if (hidden_dim * 8 > 2 * 8192 && hidden_dim * 8 <= 8192 * 4) { hidden_dim >>= 2; int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); ker_ln_bw_dinp_x4<<>>( - inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means, - hidden_dim); + inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, + means, hidden_dim); } else { throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768"); } } -