#include #include #include #include #include "block_reduce.h" #include "kernels.h" namespace cg = cooperative_groups; const float EPSILON = 1e-8f; /** @brief: softmax_kernel Softmax forward kernel for enc-self-attn, dec-self-attn, encdec-attn @thread gridDim.x = dynamic gridDim.y = batch_size gridDim.z = nhead blockDim.x = from_len @param inp: [batch_size, nhead, from_len, to_len], softmax input. attn_mask: [batch_size, to_len], padding tokens are -inf, non padding tokens are 0. attn_mask!=nullptr for enc-self-attn and enc-dec-attn attn_mask=nullptr and mask_future=ture for dec-self-attn training attn_mask=nullptr and mask_future=false for dec-self-attn infer */ template __global__ void ker_attn_softmax(T *inp, const T *attn_mask, int from_len, int to_len, bool mask_future) { int batch_id = blockIdx.y; int head_id = blockIdx.z; const int nhead = gridDim.z; const int token_per_reduce = 1; typedef cub::BlockLoad BlockLoad; __shared__ typename BlockLoad::TempStorage ts_load; typedef cub::BlockStore BlockStore; __shared__ typename BlockStore::TempStorage ts_store; T mval[ele_per_thread]; if (attn_mask) { attn_mask += batch_id * to_len; BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG); } inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len); for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len; token_id += gridDim.x * token_per_reduce) { T inp_val[token_per_reduce][ele_per_thread]; for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len, REDUCE_FLOAT_INF_NEG); } /* step 1. compute max */ // thread local max float val[token_per_reduce][ele_per_thread]; float l_max[token_per_reduce]; for (int i = 0; i < token_per_reduce; i++) { l_max[i] = REDUCE_FLOAT_INF_NEG; for (int j = 0; j < ele_per_thread; j++) { if (attn_mask) { val[i][j] = (float)inp_val[i][j] + (float)mval[j]; } else { if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) { val[i][j] = REDUCE_FLOAT_INF_NEG; } else { val[i][j] = (float)inp_val[i][j]; } } l_max[i] = fmaxf(l_max[i], val[i][j]); } } // block reduce max blockReduce(l_max); // write shared __shared__ float s_max[token_per_reduce]; if (threadIdx.x == 0) { for (int i = 0; i < token_per_reduce; i++) { s_max[i] = l_max[i]; } } __syncthreads(); /* step 2. compute sum */ // thread local sum float l_sum[token_per_reduce]; for (int i = 0; i < token_per_reduce; i++) { l_sum[i] = 0.f; for (int j = 0; j < ele_per_thread; j++) { val[i][j] = __expf(val[i][j] - s_max[i]); l_sum[i] += val[i][j]; } } // block reduce sum blockReduce(l_sum); // write shared __shared__ float s_sum[token_per_reduce]; if (threadIdx.x == 0) { for (int i = 0; i < token_per_reduce; i++) { s_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON); } } __syncthreads(); /* step 3. compute final result */ for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { for (int j = 0; j < ele_per_thread; j++) { inp_val[i][j] = (T)(val[i][j] * s_sum[i]); } BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], to_len); } } // blockIdx.x } template __global__ void ker_attn_softmax_lt32(T *inp, const T *attn_mask, int from_len, int to_len, bool mask_future) { int batch_id = blockIdx.y; int head_id = blockIdx.z; const int nhead = gridDim.z; const int token_per_reduce = 1; typedef cub::BlockLoad BlockLoad; __shared__ typename BlockLoad::TempStorage ts_load; typedef cub::BlockStore BlockStore; __shared__ typename BlockStore::TempStorage ts_store; T mval[ele_per_thread]; if (attn_mask) { attn_mask += batch_id * to_len; BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG); } inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len); for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len; token_id += gridDim.x * token_per_reduce) { T inp_val[token_per_reduce][ele_per_thread]; for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len, REDUCE_FLOAT_INF_NEG); } /* step 1. compute max */ // thread local max float val[token_per_reduce][ele_per_thread]; float l_max[token_per_reduce]; for (int i = 0; i < token_per_reduce; i++) { l_max[i] = REDUCE_FLOAT_INF_NEG; for (int j = 0; j < ele_per_thread; j++) { if (attn_mask) { val[i][j] = (float)inp_val[i][j] + (float)mval[j]; } else { if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) { val[i][j] = REDUCE_FLOAT_INF_NEG; } else { val[i][j] = (float)inp_val[i][j]; } } l_max[i] = fmaxf(l_max[i], val[i][j]); } } // warp reduce max warpReduce(l_max); /* step 2. compute sum */ // thread local sum float l_sum[token_per_reduce]; for (int i = 0; i < token_per_reduce; i++) { l_sum[i] = 0.f; for (int j = 0; j < ele_per_thread; j++) { val[i][j] = __expf(val[i][j] - l_max[i]); l_sum[i] += val[i][j]; } } // warp reduce sum warpReduce(l_sum); /* step 3. compute final result */ for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { l_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON); for (int j = 0; j < ele_per_thread; j++) { inp_val[i][j] = (T)(val[i][j] * l_sum[i]); } BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], to_len); } } // blockIdx.x } /* attn_mask!=nullptr for enc-self-attn and enc-dec-attn attn_mask=nullptr and mask_future=ture for dec-self-attn training attn_mask=nullptr and mask_future=false for dec-self-attn infer */ template <> void launch_attn_softmax(float *inp, const float *attn_mask, int batch_size, int nhead, int from_len, int to_len, bool mask_future, cudaStream_t stream) { dim3 grid_dim(1, batch_size, nhead); if (to_len <= 32) { ker_attn_softmax_lt32<<>>( inp, attn_mask, from_len, to_len, mask_future); } else if (to_len <= 64) { ker_attn_softmax_lt32<<>>( inp, attn_mask, from_len, to_len, mask_future); } else if (to_len <= 128) { grid_dim.x = 16; ker_attn_softmax<<>>( inp, attn_mask, from_len, to_len, mask_future); } else if (to_len <= 256) { grid_dim.x = 32; ker_attn_softmax<<>>( inp, attn_mask, from_len, to_len, mask_future); } else if (to_len <= 512) { grid_dim.x = 64; ker_attn_softmax<<>>( inp, attn_mask, from_len, to_len, mask_future); } else { throw std::runtime_error( "Sequence length greater than 512 is currently not supported"); } } template <> void launch_attn_softmax<__half>(__half *inp, const __half *attn_mask, int batch_size, int nhead, int from_len, int to_len, bool mask_future, cudaStream_t stream) { dim3 grid_dim(1, batch_size, nhead); if (to_len <= 32) { ker_attn_softmax_lt32<__half, 32, 1><<>>( inp, attn_mask, from_len, to_len, mask_future); } else if (to_len <= 64) { ker_attn_softmax_lt32<__half, 32, 2><<>>( inp, attn_mask, from_len, to_len, mask_future); } else if (to_len <= 128) { grid_dim.x = 8; ker_attn_softmax<__half, 64, 2><<>>( inp, attn_mask, from_len, to_len, mask_future); } else if (to_len <= 256) { grid_dim.x = 16; ker_attn_softmax<__half, 128, 2><<>>( inp, attn_mask, from_len, to_len, mask_future); } else if (to_len <= 512) { grid_dim.x = 32; ker_attn_softmax<__half, 256, 2><<>>( inp, attn_mask, from_len, to_len, mask_future); } else { throw std::runtime_error( "Sequence length greater than 512 is currently not supported"); } } /** @brief: ker_attn_softmax_bw Softmax backward in self attention. @thread gridDim.x = batch_size * nhead * seq_len / warps_per_block blockDim.x = WARP_SIZE blockDim.y = warps_per_block @param grad: [batch_size, nhead, seq_len, seq_len], output grad. output: [batch_size, nhead, seq_len, seq_len], output of softmax forward. */ template __global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) { int batch_idx = blockIdx.x * blockDim.y + threadIdx.y; int offset = batch_idx * softmax_length + threadIdx.x; grad += offset; inp += offset; T grad_reg[ITERATIONS]; T inp_reg[ITERATIONS]; float sum = 0.0; #pragma unroll for (int i = 0; i < ITERATIONS; ++i) { int curr_idx = threadIdx.x + i * WARP_SIZE; if (curr_idx < softmax_length) { grad_reg[i] = grad[i * WARP_SIZE]; inp_reg[i] = inp[i * WARP_SIZE]; sum += (float)grad_reg[i] * (float)inp_reg[i]; } } cg::thread_block b = cg::this_thread_block(); cg::thread_block_tile g = cg::tiled_partition(b); for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i); #pragma unroll for (int i = 0; i < ITERATIONS; ++i) { int curr_idx = threadIdx.x + i * WARP_SIZE; if (curr_idx < softmax_length) grad[i * WARP_SIZE] = (T)((float)inp_reg[i] * ((float)grad_reg[i] - sum)); } } template void launch_attn_softmax_bw(T *out_grad, const T *soft_inp, int rows, int softmax_len, cudaStream_t stream) { const int warps_per_block = 4; // rows = batch_size * nhead * from_len dim3 grid_dim(rows / warps_per_block); dim3 block_dim(WARP_SIZE, warps_per_block); if (softmax_len <= 32) ker_attn_softmax_bw <<>>(out_grad, soft_inp, softmax_len); else if (softmax_len <= 64) ker_attn_softmax_bw <<>>(out_grad, soft_inp, softmax_len); else if (softmax_len <= 128) ker_attn_softmax_bw <<>>(out_grad, soft_inp, softmax_len); else if (softmax_len <= 256) ker_attn_softmax_bw <<>>(out_grad, soft_inp, softmax_len); else if (softmax_len <= 384) ker_attn_softmax_bw <<>>(out_grad, soft_inp, softmax_len); else if (softmax_len <= 512) ker_attn_softmax_bw <<>>(out_grad, soft_inp, softmax_len); else if (softmax_len <= 768) ker_attn_softmax_bw <<>>(out_grad, soft_inp, softmax_len); else if (softmax_len <= 1024) ker_attn_softmax_bw <<>>(out_grad, soft_inp, softmax_len); else if (softmax_len <= 2048) ker_attn_softmax_bw <<>>(out_grad, soft_inp, softmax_len); else throw std::runtime_error( std::string( "Special sequence length found in softmax backward, seq_len: ") + std::to_string(softmax_len)); } template void launch_attn_softmax_bw<__half>(__half *out_grad, const __half *soft_inp, int rows, int softmax_len, cudaStream_t stream); template void launch_attn_softmax_bw(float *out_grad, const float *soft_inp, int rows, int softmax_len, cudaStream_t stream);