You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu

1161 lines
47 KiB

#include "block_reduce.h"
#include "kernels.h"
#include <cooperative_groups.h>
namespace cg = cooperative_groups;
const float LN_EPSILON = 1e-8f;
#define TILE_DIM 32
template <typename T>
__forceinline__ __device__ T add_eps(T x) {
return fabsf(x) > LN_EPSILON ? x : (x < 0 ? -LN_EPSILON : LN_EPSILON);
}
/**
@brief: ker_layer_norm
Standard layer normalization.
It will not only output the layer norm result,
but also outputs variance.
may also output means, depends on whether
the means argument is nullptr
@thread
gridDim.x = batch_size * seq_len
blockDim.x = hidden_size
@param
ln_res: [batch_size* seq_len, hidden_size], ln result.
vars: [batch_size* seq_len], variance per token
means: [batch_size* seq_len], means per token, can be nullput
inp: [batch_size * seq_len, hidden_size], ln input.
scale: [hidden_size], ln scale
bias: [hidden_size], ln bias
*/
template <typename T>
__global__ void ker_layer_norm(T *ln_res, T *vars, T *means, const T *inp,
const T *scale, const T *bias, int hidden_size) {
// step 0. compute local sum
float l_sum = 0;
float l_square_sum = 0;
const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size;
for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float4 val = inp_f4[idx];
l_sum += val.x + val.y + val.z + val.w;
l_square_sum +=
val.x * val.x + val.y * val.y + val.z * val.z + val.w * val.w;
}
// step 1. compute reduce sum
float mean_dim = float(hidden_size) * 4.f;
float reduce_val[2] = {l_sum, l_square_sum};
blockReduce<ReduceType::kSum, 2>(reduce_val);
__shared__ float s_mean, s_var;
if (threadIdx.x == 0) {
s_mean = reduce_val[0] / mean_dim;
if (means != nullptr) {
means[blockIdx.x] = s_mean;
}
s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON;
vars[blockIdx.x] = s_var;
s_var = rsqrtf(s_var);
}
__syncthreads();
// step 2. layer norm result
float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size;
for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float4 vscale = __ldg((const float4 *)scale + idx);
float4 vbias = __ldg((const float4 *)bias + idx);
float4 val = inp_f4[idx];
val.x = (val.x - s_mean) * s_var * vscale.x + vbias.x;
val.y = (val.y - s_mean) * s_var * vscale.y + vbias.y;
val.z = (val.z - s_mean) * s_var * vscale.z + vbias.z;
val.w = (val.w - s_mean) * s_var * vscale.w + vbias.w;
output_f4[idx] = val;
}
}
template <>
__global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars,
__half *means, const __half *inp,
const __half *scale, const __half *bias,
int hidden_size) {
// step 0. compute local sum
float l_sum = 0;
float l_square_sum = 0;
const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size;
for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float4 val_f4 = inp_f4[idx];
__half2 *val_h2 = (__half2 *)(&val_f4);
#pragma unroll
for (int i = 0; i < 4; i++) {
float2 val_f2 = __half22float2(val_h2[i]);
l_sum += val_f2.x + val_f2.y;
l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y;
}
}
// step 1. compute reduce sum
float mean_dim = float(hidden_size) * 8.f;
float reduce_val[2] = {l_sum, l_square_sum};
blockReduce<ReduceType::kSum, 2>(reduce_val);
__shared__ float s_mean, s_var;
if (threadIdx.x == 0) {
s_mean = reduce_val[0] / mean_dim;
if (means != nullptr) {
means[blockIdx.x] = s_mean;
}
s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON;
vars[blockIdx.x] = s_var;
s_var = rsqrtf(s_var);
}
__syncthreads();
// step 2. layer norm result
float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size;
for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
// load scale, bias, input
float4 scale_f4 = __ldg((const float4 *)scale + idx);
__half2 *scale_h2 = (__half2 *)(&scale_f4);
float4 bias_f4 = __ldg((const float4 *)bias + idx);
__half2 *bias_h2 = (__half2 *)(&bias_f4);
float4 val_f4 = inp_f4[idx];
__half2 *val_h2 = (__half2 *)(&val_f4);
#pragma unroll
for (int i = 0; i < 4; i++) {
float2 scale_f2 = __half22float2(scale_h2[i]);
float2 bias_f2 = __half22float2(bias_h2[i]);
float2 val_f2 = __half22float2(val_h2[i]);
val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x;
val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y;
val_h2[i] = __float22half2_rn(val_f2);
}
output_f4[idx] = val_f4;
}
}
// __global__ void ker_layer_norm_x2(__half *ln_res, __half *vars,
// __half *means, const __half *inp,
// const __half *scale, const __half *bias,
// int hidden_size) {
// // step 0. compute local sum
// float l_sum = 0;
// float l_square_sum = 0;
// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * 2 * hidden_size;
// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * 2) {
// float4 val_f4 = inp_f4[idx];
// float4 val_f4_1 = inp_f4[idx+1];
// __half2 *val_h2 = (__half2 *)(&val_f4);
// __half2 *val_h2_1 = (__half2 *)(&val_f4_1);
// #pragma unroll
// for (int i = 0; i < 4; i++) {
// float2 val_f2 = __half22float2(val_h2[i]);
// float2 val_f2_1 = __half22float2(val_h2_1[i]);
// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y;
// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y + val_f2_1.x * val_f2_1.x + val_f2_1.y * val_f2_1.y;
// }
// }
// // step 1. compute reduce sum
// float mean_dim = float(hidden_size) * 8.f * 2;
// float reduce_val[2] = {l_sum, l_square_sum};
// blockReduce<ReduceType::kSum, 2>(reduce_val);
// __shared__ float s_mean, s_var;
// if (threadIdx.x == 0) {
// s_mean = reduce_val[0] / mean_dim;
// if (means != nullptr) {
// means[blockIdx.x] = s_mean;
// }
// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON;
// vars[blockIdx.x] = s_var;
// s_var = rsqrtf(s_var);
// }
// __syncthreads();
// // step 2. layer norm result
// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 2;
// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * 2) {
// // load scale, bias, input
// float4 scale_f4 = __ldg((const float4 *)scale + idx);
// __half2 *scale_h2 = (__half2 *)(&scale_f4);
// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1);
// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1);
// float4 bias_f4 = __ldg((const float4 *)bias + idx);
// __half2 *bias_h2 = (__half2 *)(&bias_f4);
// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1);
// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1);
// float4 val_f4 = inp_f4[idx];
// __half2 *val_h2 = (__half2 *)(&val_f4);
// float4 val_f4_1 = inp_f4[idx+1];
// __half2 *val_h2_1 = (__half2 *)(&val_f4_1);
// #pragma unroll
// for (int i = 0; i < 4; i++) {
// float2 scale_f2 = __half22float2(scale_h2[i]);
// float2 scale_f2_1 = __half22float2(scale_h2_1[i]);
// float2 bias_f2 = __half22float2(bias_h2[i]);
// float2 bias_f2_1 = __half22float2(bias_h2_1[i]);
// float2 val_f2 = __half22float2(val_h2[i]);
// float2 val_f2_1 = __half22float2(val_h2_1[i]);
// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x;
// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y;
// val_h2[i] = __float22half2_rn(val_f2);
// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + bias_f2_1.x;
// val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y + bias_f2_1.y;
// val_h2_1[i] = __float22half2_rn(val_f2_1);
// }
// output_f4[idx] = val_f4;
// output_f4[idx+1] = val_f4_1;
// }
// }
// __global__ void ker_layer_norm_x4(__half *ln_res, __half *vars,
// __half *means, const __half *inp,
// const __half *scale, const __half *bias,
// int hidden_size) {
// // step 0. compute local sum
// float l_sum = 0;
// float l_square_sum = 0;
// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size * 4;
// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * 4) {
// float4 val_f4 = inp_f4[idx];
// float4 val_f4_1 = inp_f4[idx+1];
// float4 val_f4_2 = inp_f4[idx+2];
// float4 val_f4_3 = inp_f4[idx+3];
// __half2 *val_h2 = (__half2 *)(&val_f4);
// __half2 *val_h2_1 = (__half2 *)(&val_f4_1);
// __half2 *val_h2_2 = (__half2 *)(&val_f4_2);
// __half2 *val_h2_3 = (__half2 *)(&val_f4_3);
// #pragma unroll
// for (int i = 0; i < 4; i++) {
// float2 val_f2 = __half22float2(val_h2[i]);
// float2 val_f2_1 = __half22float2(val_h2_1[i]);
// float2 val_f2_2 = __half22float2(val_h2_2[i]);
// float2 val_f2_3 = __half22float2(val_h2_3[i]);
// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y + val_f2_2.x + val_f2_2.y + val_f2_3.x + val_f2_3.y;
// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y;
// l_square_sum += val_f2_1.x * val_f2_1.x + val_f2_1.y * val_f2_1.y;
// l_square_sum += val_f2_2.x * val_f2_2.x + val_f2_2.y * val_f2_2.y;
// l_square_sum += val_f2_3.x * val_f2_3.x + val_f2_3.y * val_f2_3.y;
// }
// }
// // step 1. compute reduce sum
// float mean_dim = float(hidden_size) * 8.f * 4;
// float reduce_val[2] = {l_sum, l_square_sum};
// blockReduce<ReduceType::kSum, 2>(reduce_val);
// __shared__ float s_mean, s_var;
// if (threadIdx.x == 0) {
// s_mean = reduce_val[0] / mean_dim;
// if (means != nullptr) {
// means[blockIdx.x] = s_mean;
// }
// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON;
// vars[blockIdx.x] = s_var;
// s_var = rsqrtf(s_var);
// }
// __syncthreads();
// // step 2. layer norm result
// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 4;
// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * 4) {
// // load scale, bias, input
// float4 scale_f4 = __ldg((const float4 *)scale + idx);
// __half2 *scale_h2 = (__half2 *)(&scale_f4);
// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1);
// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1);
// float4 scale_f4_2 = __ldg((const float4 *)scale + idx + 2);
// __half2 *scale_h2_2 = (__half2 *)(&scale_f4_2);
// float4 scale_f4_3 = __ldg((const float4 *)scale + idx + 3);
// __half2 *scale_h2_3 = (__half2 *)(&scale_f4_3);
// float4 bias_f4 = __ldg((const float4 *)bias + idx);
// __half2 *bias_h2 = (__half2 *)(&bias_f4);
// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1);
// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1);
// float4 bias_f4_2 = __ldg((const float4 *)bias + idx + 2);
// __half2 *bias_h2_2 = (__half2 *)(&bias_f4_2);
// float4 bias_f4_3 = __ldg((const float4 *)bias + idx + 3);
// __half2 *bias_h2_3 = (__half2 *)(&bias_f4_3);
// float4 val_f4 = inp_f4[idx];
// __half2 *val_h2 = (__half2 *)(&val_f4);
// float4 val_f4_1 = inp_f4[idx+1];
// __half2 *val_h2_1 = (__half2 *)(&val_f4_1);
// float4 val_f4_2 = inp_f4[idx+2];
// __half2 *val_h2_2 = (__half2 *)(&val_f4_2);
// float4 val_f4_3 = inp_f4[idx+3];
// __half2 *val_h2_3 = (__half2 *)(&val_f4_3);
// #pragma unroll
// for (int i = 0; i < 4; i++) {
// float2 scale_f2 = __half22float2(scale_h2[i]);
// float2 scale_f2_1 = __half22float2(scale_h2_1[i]);
// float2 scale_f2_2 = __half22float2(scale_h2_2[i]);
// float2 scale_f2_3 = __half22float2(scale_h2_3[i]);
// float2 bias_f2 = __half22float2(bias_h2[i]);
// float2 bias_f2_1 = __half22float2(bias_h2_1[i]);
// float2 bias_f2_2 = __half22float2(bias_h2_2[i]);
// float2 bias_f2_3 = __half22float2(bias_h2_3[i]);
// float2 val_f2 = __half22float2(val_h2[i]);
// float2 val_f2_1 = __half22float2(val_h2_1[i]);
// float2 val_f2_2 = __half22float2(val_h2_2[i]);
// float2 val_f2_3 = __half22float2(val_h2_3[i]);
// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x;
// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y;
// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + bias_f2_1.x;
// val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y + bias_f2_1.y;
// val_f2_2.x = (val_f2_2.x - s_mean) * s_var * scale_f2_2.x + bias_f2_2.x;
// val_f2_2.y = (val_f2_2.y - s_mean) * s_var * scale_f2_2.y + bias_f2_2.y;
// val_f2_3.x = (val_f2_3.x - s_mean) * s_var * scale_f2_3.x + bias_f2_3.x;
// val_f2_3.y = (val_f2_3.y - s_mean) * s_var * scale_f2_3.y + bias_f2_3.y;
// val_h2[i] = __float22half2_rn(val_f2);
// val_h2_1[i] = __float22half2_rn(val_f2_1);
// val_h2_2[i] = __float22half2_rn(val_f2_2);
// val_h2_3[i] = __float22half2_rn(val_f2_3);
// }
// output_f4[idx] = val_f4;
// output_f4[idx+1] = val_f4_1;
// output_f4[idx+2] = val_f4_2;
// output_f4[idx+3] = val_f4_3;
// }
// }
template <>
void launch_layer_norm<float>(float *ln_res, float *vars, float *means,
const float *inp, const float *scale,
const float *bias, int batch_size, int hidden_dim,
cudaStream_t stream) {
if (hidden_dim % 4 != 0) {
throw std::runtime_error("violate hidden_dim % 4 = 0");
}
hidden_dim >>= 2;
int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
dim3 grid_dim(batch_size);
dim3 block_dim(nthread);
ker_layer_norm<float><<<grid_dim, block_dim, 0, stream>>>(
ln_res, vars, means, inp, scale, bias, hidden_dim);
}
template <>
void launch_layer_norm<__half>(__half *ln_res, __half *vars, __half *means,
const __half *inp, const __half *scale,
const __half *bias, int batch_size,
int hidden_dim, cudaStream_t stream) {
if (hidden_dim % 8 != 0) {
throw std::runtime_error("violate hidden_dim % 8 = 0");
}
hidden_dim >>= 3;
int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
dim3 grid_dim(batch_size);
dim3 block_dim(nthread);
ker_layer_norm<__half><<<grid_dim, block_dim, 0, stream>>>(
ln_res, vars, means, inp, scale, bias, hidden_dim);
// if (hidden_dim % 8 != 0) {
// throw std::runtime_error("violate hidden_dim % 8 = 0");
// }
// hidden_dim >>= 3;
// if (hidden_dim * 8 < 8192) {
// int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
// dim3 grid_dim(batch_size);
// dim3 block_dim(nthread);
// ker_layer_norm<__half><<<grid_dim, block_dim, 0, stream>>>(
// ln_res, vars, means, inp, scale, bias, hidden_dim);
// } else if (hidden_dim * 8 >= 8192 && hidden_dim * 8 <= 8192 * 2) {
// hidden_dim >>= 1;
// int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
// dim3 grid_dim(batch_size);
// dim3 block_dim(nthread);
// ker_layer_norm_x2<<<grid_dim, block_dim, 0, stream>>>(
// ln_res, vars, means, inp, scale, bias, hidden_dim);
// } else if (hidden_dim * 8 > 8192 * 2 && hidden_dim * 8 <= 8192 * 4) {
// hidden_dim >>= 2;
// int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
// dim3 grid_dim(batch_size);
// dim3 block_dim(nthread);
// ker_layer_norm_x4<<<grid_dim, block_dim, 0, stream>>>(
// ln_res, vars, means, inp, scale, bias, hidden_dim);
// } else {
// throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768");
// }
}
/**
@brief: ker_ln_bw_dgamma_dbetta
Layer norm backword kernel, compute the gradient of gamma and betta.
dbetta = sum(dout, dim=0)
dgamma = sum(xhat * dout, dim=0)
xhat = (input - mean) * rsqrt(var) or
(output - betta) / gamma
@thread
gridDim.x = hidden_size / 32
blockDim.x = 32
blockDim.y = 32
@param
gamma_grad: [hidden_size], gradient of gamma
betta_grad: [hidden_size], gradient of betta
out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output
inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr
ln input if means is not nullptr
gamma: [hidden_size], gamma of ln,
used to compute xhat, maybe nullptr
betta: [hidden_size], betta of ln,
used to compute xhat, maybe nullptr
vars: [batch_size * seq_len], variance of ln forward,
used to compute xhat, maybe nullptr
means: [batch_size * seq_len], mean of ln forward,
used to compute xhat, maybe nullptr
(gamma && betta) ^ (vars && means) should be true
*/
template <typename T>
__global__ void ker_ln_bw_dgamma_dbetta(T *gamma_grad, T *betta_grad,
const T *out_grad, const T *inp_or_out,
const T *gamma, const T *betta,
const T *vars, const T *means, int rows,
int width) {
__shared__ float betta_buffer[TILE_DIM][TILE_DIM];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
// Loop across inp height
float dbetta = 0;
float dgamma = 0;
float dout, val;
if (idx < width) {
if (means == nullptr) {
float vbetta = (float)betta[idx];
float vgamma = (float)gamma[idx];
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
dout = (float)out_grad[offset];
// inp_or_out is output
val = (float)inp_or_out[offset];
dbetta += dout;
dgamma += ((val - vbetta) / add_eps(vgamma) * dout);
offset += y_stride;
}
} else {
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
dout = (float)out_grad[offset];
// inp_or_out is input
val = (float)inp_or_out[offset];
dbetta += dout;
dgamma += ((val - (float)means[r]) *
rsqrtf((float)vars[r] + LN_EPSILON) * dout);
offset += y_stride;
}
}
}
// Sum the shared buffer.
betta_buffer[threadIdx.x][threadIdx.y] = dbetta;
gamma_buffer[threadIdx.x][threadIdx.y] = dgamma;
__syncthreads();
float s1 = betta_buffer[threadIdx.y][threadIdx.x];
float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
__syncthreads();
for (int i = 1; i < TILE_DIM; i <<= 1) {
s1 += g.shfl_down(s1, i);
s2 += g.shfl_down(s2, i);
}
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
if (threadIdx.x == 0 && idx < width) {
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
}
/**
@brief: ker_ln_bw_dinp
Layer norm backword kernel, compute the gradient of input.
dinp = (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / hidden_dim)
* rsqrt(var)
xhat = (input - mean) * rsqrt(var) if mean is not nullptr
(output - betta) / gamma if mean is nullptr
dxhat = dout * gamma
@thread
gridDim.x = batch_size * seq_len
blockDim.x = hidden_size
@param
inp_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output
out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output
residual_grad: [batch_size * seq_len, hidden_size], gradient of residual input,
usually appear in pre-layer-norm for transformer layer, maybe nullptr
inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr
ln input if means is not nullptr
gamma: [hidden_size], gamma of ln,
used to compute xhat and dxhat
betta: [hidden_size], betta of ln,
used to compute xhat, maybe nullptr
vars: [batch_size * seq_len], variance of ln forward,
used to compute xhat and dinp
means: [batch_size * seq_len], mean of ln forward,
used to compute xhat, maybe nullptr
*/
template <typename T>
__global__ void ker_ln_bw_dinp(T *inp_grad, const T *out_grad,
const T *residual_grad, const T *inp_or_out,
const T *gamma, const T *betta, const T *vars,
const T *means, int hidden_dim) {
int offset = blockIdx.x * hidden_dim + threadIdx.x;
float4 dxhat, xhat;
float var_rsqrt;
if (threadIdx.x < hidden_dim) {
// step 0. dxhat = dout * gamma
dxhat = ((const float4 *)out_grad)[offset];
float4 vgamma = ((const float4 *)gamma)[threadIdx.x];
dxhat.x *= vgamma.x;
dxhat.y *= vgamma.y;
dxhat.z *= vgamma.z;
dxhat.w *= vgamma.w;
/*
step 1. xhat = (output - betta) / gamma or
(input - mean) * rsqrtf(var)
*/
xhat = ((const float4 *)inp_or_out)[offset];
var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON);
if (means == nullptr) {
// inp_or_out is output, xhat = (output - betta) / gamma
float4 vbetta = ((const float4 *)betta)[threadIdx.x];
xhat.x = (xhat.x - vbetta.x) / add_eps(vgamma.x);
xhat.y = (xhat.y - vbetta.y) / add_eps(vgamma.y);
xhat.z = (xhat.z - vbetta.z) / add_eps(vgamma.z);
xhat.w = (xhat.w - vbetta.w) / add_eps(vgamma.w);
} else {
// inp_or_out is input, xhat = (input - mean) * rsqrtf(var)
float fmean = (float)means[blockIdx.x];
xhat.x = (xhat.x - fmean) * var_rsqrt;
xhat.y = (xhat.y - fmean) * var_rsqrt;
xhat.z = (xhat.z - fmean) * var_rsqrt;
xhat.w = (xhat.w - fmean) * var_rsqrt;
}
}
/* step2. block reduce sum for dxhat and dxhat*xhat */
float reduce_val[2] = {0.f, 0.f};
if (threadIdx.x < hidden_dim) {
reduce_val[0] = dxhat.x + dxhat.y + dxhat.z + dxhat.w;
reduce_val[1] = dxhat.x * xhat.x + dxhat.y * xhat.y + dxhat.z * xhat.z +
dxhat.w * xhat.w;
}
blockReduce<ReduceType::kSum, 2>(reduce_val);
__shared__ float s_sum_dxhat, s_sum_dxhat_xhat;
if (threadIdx.x == 0) {
float mean_dim = hidden_dim * 4;
s_sum_dxhat = reduce_val[0] / mean_dim;
s_sum_dxhat_xhat = reduce_val[1] / mean_dim;
}
__syncthreads();
/*
step3. compute input gradient
(dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var)
*/
if (threadIdx.x >= hidden_dim) {
return;
}
dxhat.x = (dxhat.x - s_sum_dxhat - xhat.x * s_sum_dxhat_xhat) * var_rsqrt;
dxhat.y = (dxhat.y - s_sum_dxhat - xhat.y * s_sum_dxhat_xhat) * var_rsqrt;
dxhat.z = (dxhat.z - s_sum_dxhat - xhat.z * s_sum_dxhat_xhat) * var_rsqrt;
dxhat.w = (dxhat.w - s_sum_dxhat - xhat.w * s_sum_dxhat_xhat) * var_rsqrt;
if (residual_grad) {
// Add the residual grad,
// usually in pre-layer-norm for transformer layer
float4 dresidual = ((const float4 *)residual_grad)[offset];
dxhat.x += dresidual.x;
dxhat.y += dresidual.y;
dxhat.z += dresidual.z;
dxhat.w += dresidual.w;
}
((float4 *)inp_grad)[offset] = dxhat;
}
template <>
__global__ void ker_ln_bw_dinp<__half>(__half *inp_grad, const __half *out_grad,
const __half *residual_grad,
const __half *inp_or_out,
const __half *gamma, const __half *betta,
const __half *vars, const __half *means,
int hidden_dim) {
int offset = blockIdx.x * hidden_dim + threadIdx.x;
float2 dxhat[4], xhat[4];
float var_rsqrt;
float4 vtmp;
__half2 *tmp_h2;
float reduce_val[2] = {0.f, 0.f};
if (threadIdx.x < hidden_dim) {
// step 0. dxhat = dout * gamma
vtmp = ((const float4 *)out_grad)[offset];
tmp_h2 = reinterpret_cast<__half2 *>(&vtmp);
float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x];
__half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4);
#pragma unroll
for (int i = 0; i < 4; i++) {
float2 vdout = __half22float2(tmp_h2[i]);
float2 vgamma = __half22float2(gamma_h2[i]);
dxhat[i].x = vdout.x * vgamma.x;
dxhat[i].y = vdout.y * vgamma.y;
reduce_val[0] += dxhat[i].x + dxhat[i].y;
}
/*
step 1. xhat = (output - betta) / gamma or
(input - mean) * rsqrtf(var)
*/
vtmp = ((const float4 *)inp_or_out)[offset];
var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON);
if (means == nullptr) {
// inp_or_out is output, xhat = (output - betta) / gamma
float4 vbetta = ((const float4 *)betta)[threadIdx.x];
__half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta);
#pragma unroll
for (int i = 0; i < 4; i++) {
float2 vout = __half22float2(tmp_h2[i]);
float2 vgamma = __half22float2(gamma_h2[i]);
float2 vbetta = __half22float2(betta_h2[i]);
xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x);
xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y);
reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y;
}
} else {
// inp_or_out is input, xhat = (input - mean) * rsqrtf(var)
float fmean = (float)means[blockIdx.x];
#pragma unroll
for (int i = 0; i < 4; i++) {
float2 vinp = __half22float2(tmp_h2[i]);
xhat[i].x = (vinp.x - fmean) * var_rsqrt;
xhat[i].y = (vinp.y - fmean) * var_rsqrt;
reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y;
}
}
}
/* step2. block reduce sum for dxhat and dxhat*xhat */
blockReduce<ReduceType::kSum, 2>(reduce_val);
__shared__ float s_sum_dxhat, s_sum_dxhat_xhat;
if (threadIdx.x == 0) {
float mean_dim = hidden_dim * 8;
s_sum_dxhat = reduce_val[0] / mean_dim;
s_sum_dxhat_xhat = reduce_val[1] / mean_dim;
}
__syncthreads();
/*
step3. compute input gradient
(dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var)
*/
if (threadIdx.x >= hidden_dim) {
return;
}
if (residual_grad) {
// Add the residual grad,
// usually in pre-layer-norm for transformer layer
float4 dresidual = ((const float4 *)residual_grad)[offset];
__half *hdres = reinterpret_cast<__half *>(&dresidual);
#pragma unroll
for (int i = 0; i < 4; i++) {
tmp_h2[i].x = __float2half(
(dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) *
var_rsqrt +
__half2float(hdres[2 * i]));
tmp_h2[i].y = __float2half(
(dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) *
var_rsqrt +
__half2float(hdres[2 * i + 1]));
}
} else {
#pragma unroll
for (int i = 0; i < 4; i++) {
tmp_h2[i].x = __float2half(
(dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) *
var_rsqrt);
tmp_h2[i].y = __float2half(
(dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) *
var_rsqrt);
}
}
((float4 *)inp_grad)[offset] = vtmp;
}
__global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad,
const __half *residual_grad,
const __half *inp_or_out,
const __half *gamma, const __half *betta,
const __half *vars, const __half *means,
int hidden_dim) {
int offset = blockIdx.x * hidden_dim * 2 + threadIdx.x * 2;
float2 dxhat[4], xhat[4];
float2 dxhat_1[4], xhat_1[4];
float var_rsqrt;
float4 vtmp, vtmp_1;
__half2 *tmp_h2;
__half2 *tmp_h2_1;
float reduce_val[2] = {0.f, 0.f};
if (threadIdx.x < hidden_dim) {
// step 0. dxhat = dout * gamma
vtmp = ((const float4 *)out_grad)[offset];
vtmp_1 = ((const float4 *)out_grad)[offset + 1];
tmp_h2 = reinterpret_cast<__half2 *>(&vtmp);
tmp_h2_1 = reinterpret_cast<__half2 *>(&vtmp_1);
float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x * 2];
float4 gamma_f4_1 = ((const float4 *)gamma)[threadIdx.x * 2 + 1];
__half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4);
__half2 *gamma_h2_1 = reinterpret_cast<__half2 *>(&gamma_f4_1);
#pragma unroll
for (int i = 0; i < 4; i++) {
float2 vdout = __half22float2(tmp_h2[i]);
float2 vdout_1 = __half22float2(tmp_h2_1[i]);
float2 vgamma = __half22float2(gamma_h2[i]);
float2 vgamma_1 = __half22float2(gamma_h2_1[i]);
dxhat[i].x = vdout.x * vgamma.x;
dxhat[i].y = vdout.y * vgamma.y;
dxhat_1[i].x = vdout_1.x * vgamma_1.x;
dxhat_1[i].y = vdout_1.y * vgamma_1.y;
reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y;
}
/*
step 1. xhat = (output - betta) / gamma or
(input - mean) * rsqrtf(var)
*/
vtmp = ((const float4 *)inp_or_out)[offset];
vtmp_1 = ((const float4 *)inp_or_out)[offset + 1];
var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON);
if (means == nullptr) {
// inp_or_out is output, xhat = (output - betta) / gamma
float4 vbetta = ((const float4 *)betta)[2 * threadIdx.x];
float4 vbetta_1 = ((const float4 *)betta)[2 * threadIdx.x + 1];
__half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta);
__half2 *betta_h2_1 = reinterpret_cast<__half2 *>(&vbetta_1);
#pragma unroll
for (int i = 0; i < 4; i++) {
float2 vout = __half22float2(tmp_h2[i]);
float2 vout_1 = __half22float2(tmp_h2_1[i]);
float2 vgamma = __half22float2(gamma_h2[i]);
float2 vgamma_1 = __half22float2(gamma_h2_1[i]);
float2 vbetta = __half22float2(betta_h2[i]);
float2 vbetta_1 = __half22float2(betta_h2_1[i]);
xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x);
xhat_1[i].x = (vout_1.x - vbetta_1.x) / add_eps(vgamma_1.x);
xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y);
xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y);
reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y;
reduce_val[1] += xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y;
}
} else {
// inp_or_out is input, xhat = (input - mean) * rsqrtf(var)
float fmean = (float)means[blockIdx.x];
#pragma unroll
for (int i = 0; i < 4; i++) {
float2 vinp = __half22float2(tmp_h2[i]);
float2 vinp_1 = __half22float2(tmp_h2_1[i]);
xhat[i].x = (vinp.x - fmean) * var_rsqrt;
xhat_1[i].x = (vinp_1.x - fmean) * var_rsqrt;
xhat[i].y = (vinp.y - fmean) * var_rsqrt;
xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt;
reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y;
reduce_val[1] += xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y;
}
}
}
/* step2. block reduce sum for dxhat and dxhat*xhat */
blockReduce<ReduceType::kSum, 2>(reduce_val);
__shared__ float s_sum_dxhat, s_sum_dxhat_xhat;
if (threadIdx.x == 0) {
float mean_dim = hidden_dim * 8 * 2;
s_sum_dxhat = reduce_val[0] / mean_dim;
s_sum_dxhat_xhat = reduce_val[1] / mean_dim;
}
__syncthreads();
/*
step3. compute input gradient
(dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var)
*/
if (threadIdx.x >= hidden_dim) {
return;
}
if (residual_grad) {
// Add the residual grad,
// usually in pre-layer-norm for transformer layer
float4 dresidual = ((const float4 *)residual_grad)[offset];
float4 dresidual_1 = ((const float4 *)residual_grad)[offset+1];
__half *hdres = reinterpret_cast<__half *>(&dresidual);
__half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1);
#pragma unroll
for (int i = 0; i < 4; i++) {
tmp_h2[i].x = __float2half(
(dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) *
var_rsqrt +
__half2float(hdres[2 * i]));
tmp_h2_1[i].x = __float2half(
(dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) *
var_rsqrt +
__half2float(hdres_1[2 * i]));
tmp_h2[i].y = __float2half(
(dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) *
var_rsqrt +
__half2float(hdres[2 * i + 1]));
tmp_h2_1[i].y = __float2half(
(dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) *
var_rsqrt +
__half2float(hdres_1[2 * i + 1]));
}
} else {
#pragma unroll
for (int i = 0; i < 4; i++) {
tmp_h2[i].x = __float2half(
(dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) *
var_rsqrt);
tmp_h2_1[i].x = __float2half(
(dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) *
var_rsqrt);
tmp_h2[i].y = __float2half(
(dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) *
var_rsqrt);
tmp_h2_1[i].y = __float2half(
(dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) *
var_rsqrt);
}
}
((float4 *)inp_grad)[offset] = vtmp;
((float4 *)inp_grad)[offset + 1] = vtmp_1;
}
__global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad,
const __half *residual_grad,
const __half *inp_or_out,
const __half *gamma, const __half *betta,
const __half *vars, const __half *means,
int hidden_dim) {
int offset = blockIdx.x * hidden_dim * 4 + threadIdx.x * 4;
float2 dxhat[4], xhat[4];
float2 dxhat_1[4], xhat_1[4];
float2 dxhat_2[4], xhat_2[4];
float2 dxhat_3[4], xhat_3[4];
float var_rsqrt;
float4 vtmp, vtmp_1, vtmp_2, vtmp_3;
__half2 *tmp_h2;
__half2 *tmp_h2_1;
__half2 *tmp_h2_2;
__half2 *tmp_h2_3;
float reduce_val[2] = {0.f, 0.f};
if (threadIdx.x < hidden_dim) {
// step 0. dxhat = dout * gamma
vtmp = ((const float4 *)out_grad)[offset];
vtmp_1 = ((const float4 *)out_grad)[offset + 1];
vtmp_2 = ((const float4 *)out_grad)[offset + 2];
vtmp_3 = ((const float4 *)out_grad)[offset + 3];
tmp_h2 = reinterpret_cast<__half2 *>(&vtmp);
tmp_h2_1 = reinterpret_cast<__half2 *>(&vtmp_1);
tmp_h2_2 = reinterpret_cast<__half2 *>(&vtmp_2);
tmp_h2_3 = reinterpret_cast<__half2 *>(&vtmp_3);
float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x * 4];
float4 gamma_f4_1 = ((const float4 *)gamma)[threadIdx.x * 4 + 1];
float4 gamma_f4_2 = ((const float4 *)gamma)[threadIdx.x * 4 + 2];
float4 gamma_f4_3 = ((const float4 *)gamma)[threadIdx.x * 4 + 3];
__half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4);
__half2 *gamma_h2_1 = reinterpret_cast<__half2 *>(&gamma_f4_1);
__half2 *gamma_h2_2 = reinterpret_cast<__half2 *>(&gamma_f4_2);
__half2 *gamma_h2_3 = reinterpret_cast<__half2 *>(&gamma_f4_3);
#pragma unroll
for (int i = 0; i < 4; i++) {
float2 vdout = __half22float2(tmp_h2[i]);
float2 vdout_1 = __half22float2(tmp_h2_1[i]);
float2 vdout_2 = __half22float2(tmp_h2_2[i]);
float2 vdout_3 = __half22float2(tmp_h2_3[i]);
float2 vgamma = __half22float2(gamma_h2[i]);
float2 vgamma_1 = __half22float2(gamma_h2_1[i]);
float2 vgamma_2 = __half22float2(gamma_h2_2[i]);
float2 vgamma_3 = __half22float2(gamma_h2_3[i]);
dxhat[i].x = vdout.x * vgamma.x;
dxhat[i].y = vdout.y * vgamma.y;
dxhat_1[i].x = vdout_1.x * vgamma_1.x;
dxhat_1[i].y = vdout_1.y * vgamma_1.y;
dxhat_2[i].x = vdout_2.x * vgamma_2.x;
dxhat_2[i].y = vdout_2.y * vgamma_2.y;
dxhat_3[i].x = vdout_3.x * vgamma_3.x;
dxhat_3[i].y = vdout_3.y * vgamma_3.y;
reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y + dxhat_2[i].x +
dxhat_2[i].y + dxhat_3[i].x + dxhat_3[i].y;
}
/*
step 1. xhat = (output - betta) / gamma or
(input - mean) * rsqrtf(var)
*/
vtmp = ((const float4 *)inp_or_out)[offset];
vtmp_1 = ((const float4 *)inp_or_out)[offset + 1];
vtmp_2 = ((const float4 *)inp_or_out)[offset + 2];
vtmp_3 = ((const float4 *)inp_or_out)[offset + 3];
var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON);
if (means == nullptr) {
// inp_or_out is output, xhat = (output - betta) / gamma
float4 vbetta = ((const float4 *)betta)[4 * threadIdx.x];
float4 vbetta_1 = ((const float4 *)betta)[4 * threadIdx.x + 1];
float4 vbetta_2 = ((const float4 *)betta)[4 * threadIdx.x + 2];
float4 vbetta_3 = ((const float4 *)betta)[4 * threadIdx.x + 3];
__half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta);
__half2 *betta_h2_1 = reinterpret_cast<__half2 *>(&vbetta_1);
__half2 *betta_h2_2 = reinterpret_cast<__half2 *>(&vbetta_2);
__half2 *betta_h2_3 = reinterpret_cast<__half2 *>(&vbetta_3);
#pragma unroll
for (int i = 0; i < 4; i++) {
float2 vout = __half22float2(tmp_h2[i]);
float2 vout_1 = __half22float2(tmp_h2_1[i]);
float2 vout_2 = __half22float2(tmp_h2_2[i]);
float2 vout_3 = __half22float2(tmp_h2_3[i]);
float2 vgamma = __half22float2(gamma_h2[i]);
float2 vgamma_1 = __half22float2(gamma_h2_1[i]);
float2 vgamma_2 = __half22float2(gamma_h2_2[i]);
float2 vgamma_3 = __half22float2(gamma_h2_3[i]);
float2 vbetta = __half22float2(betta_h2[i]);
float2 vbetta_1 = __half22float2(betta_h2_1[i]);
float2 vbetta_2 = __half22float2(betta_h2_2[i]);
float2 vbetta_3 = __half22float2(betta_h2_3[i]);
xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x);
xhat_1[i].x = (vout_1.x - vbetta_1.x) / add_eps(vgamma_1.x);
xhat_2[i].x = (vout_2.x - vbetta_2.x) / add_eps(vgamma_2.x);
xhat_3[i].x = (vout_3.x - vbetta_3.x) / add_eps(vgamma_3.x);
xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y);
xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y);
xhat_2[i].y = (vout_2.y - vbetta_2.y) / add_eps(vgamma_2.y);
xhat_3[i].y = (vout_3.y - vbetta_3.y) / add_eps(vgamma_3.y);
reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y;
reduce_val[1] += xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y;
reduce_val[1] += xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y;
reduce_val[1] += xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y;
}
} else {
// inp_or_out is input, xhat = (input - mean) * rsqrtf(var)
float fmean = (float)means[blockIdx.x];
#pragma unroll
for (int i = 0; i < 4; i++) {
float2 vinp = __half22float2(tmp_h2[i]);
float2 vinp_1 = __half22float2(tmp_h2_1[i]);
float2 vinp_2 = __half22float2(tmp_h2_2[i]);
float2 vinp_3 = __half22float2(tmp_h2_3[i]);
xhat[i].x = (vinp.x - fmean) * var_rsqrt;
xhat_1[i].x = (vinp_1.x - fmean) * var_rsqrt;
xhat_2[i].x = (vinp_2.x - fmean) * var_rsqrt;
xhat_3[i].x = (vinp_3.x - fmean) * var_rsqrt;
xhat[i].y = (vinp.y - fmean) * var_rsqrt;
xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt;
xhat_2[i].y = (vinp_2.y - fmean) * var_rsqrt;
xhat_3[i].y = (vinp_3.y - fmean) * var_rsqrt;
reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y;
reduce_val[1] += xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y;
reduce_val[1] += xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y;
reduce_val[1] += xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y;
}
}
}
/* step2. block reduce sum for dxhat and dxhat*xhat */
blockReduce<ReduceType::kSum, 2>(reduce_val);
__shared__ float s_sum_dxhat, s_sum_dxhat_xhat;
if (threadIdx.x == 0) {
float mean_dim = hidden_dim * 8 * 4;
s_sum_dxhat = reduce_val[0] / mean_dim;
s_sum_dxhat_xhat = reduce_val[1] / mean_dim;
}
__syncthreads();
/*
step3. compute input gradient
(dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var)
*/
if (threadIdx.x >= hidden_dim) {
return;
}
if (residual_grad) {
// Add the residual grad,
// usually in pre-layer-norm for transformer layer
float4 dresidual = ((const float4 *)residual_grad)[offset];
float4 dresidual_1 = ((const float4 *)residual_grad)[offset+1];
float4 dresidual_2 = ((const float4 *)residual_grad)[offset+2];
float4 dresidual_3 = ((const float4 *)residual_grad)[offset+3];
__half *hdres = reinterpret_cast<__half *>(&dresidual);
__half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1);
__half *hdres_2 = reinterpret_cast<__half *>(&dresidual_2);
__half *hdres_3 = reinterpret_cast<__half *>(&dresidual_3);
#pragma unroll
for (int i = 0; i < 4; i++) {
tmp_h2[i].x = __float2half(
(dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) *
var_rsqrt +
__half2float(hdres[2 * i]));
tmp_h2_1[i].x = __float2half(
(dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) *
var_rsqrt +
__half2float(hdres_1[2 * i]));
tmp_h2_2[i].x = __float2half(
(dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) *
var_rsqrt +
__half2float(hdres_2[2 * i]));
tmp_h2_3[i].x = __float2half(
(dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) *
var_rsqrt +
__half2float(hdres_3[2 * i]));
tmp_h2[i].y = __float2half(
(dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) *
var_rsqrt +
__half2float(hdres[2 * i + 1]));
tmp_h2_1[i].y = __float2half(
(dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) *
var_rsqrt +
__half2float(hdres_1[2 * i + 1]));
tmp_h2_2[i].y = __float2half(
(dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) *
var_rsqrt +
__half2float(hdres_1[2 * i + 1]));
tmp_h2_3[i].y = __float2half(
(dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) *
var_rsqrt +
__half2float(hdres_1[2 * i + 1]));
}
} else {
#pragma unroll
for (int i = 0; i < 4; i++) {
tmp_h2[i].x = __float2half(
(dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) *
var_rsqrt);
tmp_h2_1[i].x = __float2half(
(dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) *
var_rsqrt);
tmp_h2_2[i].x = __float2half(
(dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) *
var_rsqrt);
tmp_h2_3[i].x = __float2half(
(dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) *
var_rsqrt);
tmp_h2[i].y = __float2half(
(dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) *
var_rsqrt);
tmp_h2_1[i].y = __float2half(
(dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) *
var_rsqrt);
tmp_h2_2[i].y = __float2half(
(dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) *
var_rsqrt);
tmp_h2_3[i].y = __float2half(
(dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) *
var_rsqrt);
}
}
((float4 *)inp_grad)[offset] = vtmp;
((float4 *)inp_grad)[offset + 1] = vtmp_1;
((float4 *)inp_grad)[offset + 2] = vtmp_2;
((float4 *)inp_grad)[offset + 3] = vtmp_3;
}
/**
Layer norm backword,
compute the gradient of gamma, betta and input.
dbetta = sum(dout, dim=0)
xhat = (input - mean) * rsqrt(var) if mean is not nullptr
(output - betta) / gamma if mean is nullptr
dgamma = sum(xhat * dout, dim=0)
dxhat = dout * gamma
dinp = (dxhat - (sum(dxhat, 1) + xhat * sum(dxhat * xhat, 1)) / hidden_dim)
* rsqrt(var)
residual_grad, means, betta can be nullptr.
residual_grad will be added to dinp if it is not nullptr
which is useful in transformer layer when pre-ln
means and betta are only used to compute xhat,
(means == nullptr) ^ (betta == nullptr) should be true
*/
template <>
void launch_ln_bw<float>(float *gamma_grad, float *betta_grad, float *inp_grad,
const float *out_grad, const float *residual_grad,
const float *inp_or_out, const float *gamma,
const float *betta, const float *vars,
const float *means, int batch, int hidden_dim,
cudaStream_t stream[2]) {
// compute grad of gamma and betta
dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
ker_ln_bw_dgamma_dbetta<float><<<grid_dim, block_dim, 0, stream[0]>>>(
gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means,
batch, hidden_dim);
// compute grad of input
if (hidden_dim % 4 != 0 || hidden_dim > 4096) {
throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 4096");
}
hidden_dim >>= 2;
int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
ker_ln_bw_dinp<<<batch, nthread, 0, stream[1]>>>(
inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means,
hidden_dim);
}
template <>
void launch_ln_bw<__half>(__half *gamma_grad, __half *betta_grad,
__half *inp_grad, const __half *out_grad,
const __half *residual_grad, const __half *inp_or_out,
const __half *gamma, const __half *betta,
const __half *vars, const __half *means, int batch,
int hidden_dim, cudaStream_t stream[2]) {
// compute grad of gamma and betta
dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
ker_ln_bw_dgamma_dbetta<__half><<<grid_dim, block_dim, 0, stream[0]>>>(
gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means,
batch, hidden_dim);
// compute grad of input
if (hidden_dim % 8 != 0) {
throw std::runtime_error("hidden_dim % 8 != 0");
}
hidden_dim >>= 3;
if (hidden_dim * 8 <= 8192) {
int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
ker_ln_bw_dinp<<<batch, nthread, 0, stream[1]>>>(
inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means,
hidden_dim);
} else if (hidden_dim * 8 > 8192 && hidden_dim * 8 <= 8192 * 2) {
hidden_dim >>= 1;
int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
ker_ln_bw_dinp_x2<<<batch, nthread, 0, stream[1]>>>(
inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means,
hidden_dim);
} else if (hidden_dim * 8 > 2 * 8192 && hidden_dim * 8 <= 8192 * 4) {
hidden_dim >>= 2;
int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
ker_ln_bw_dinp_x4<<<batch, nthread, 0, stream[1]>>>(
inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means,
hidden_dim);
} else {
throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768");
}
}