#include "kernels.h" #include namespace cg = cooperative_groups; /** @brief: fuse_transpose_bias Calculate the sum of elements in each column of the matrix. @thread gridDim.x = ceil(cols / WARP_SIZE) blockDim.x = WARP_SIZE blockDim.y = WARP_SIZE @param inp: [rows, cols] out: [cols] rows: the number of rows in the matrix cols: the number of cols in the matrix */ template __global__ void column_sum_reduce(const T *__restrict__ inp, T *__restrict__ out, int rows, int cols) { __shared__ float tile[WARP_SIZE][WARP_SIZE]; cg::thread_block b = cg::this_thread_block(); cg::thread_block_tile g = cg::tiled_partition(b); int idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); int y_stride = cols * WARP_SIZE; float localSum = 0; // Loop across matrix row // TODO: optimize to log complexity if (idx < cols) { int offset = flat_2dim(threadIdx.y, idx, cols); for (int r = threadIdx.y; r < rows; r += WARP_SIZE) { localSum += (float)inp[offset]; offset += y_stride; } } // The sum of a row in tile is equal to the sum of a col in original matrix tile[threadIdx.x][threadIdx.y] = localSum; __syncthreads(); // Sum the shared buffer. // The change of threadIdx.x is continuous float sum = tile[threadIdx.y][threadIdx.x]; __syncthreads(); // Calculate the sum of a row in tile for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); if (threadIdx.x == 0) { int pos = flat_2dim(blockIdx.x, threadIdx.y, WARP_SIZE); if (pos < cols) out[pos] = sum; } } // [r, c] -> [c] template <> void launch_fuse_transpose_bias_kernel(const float *inp, float *out, int rows, int cols, cudaStream_t stream) { dim3 grid_dim((cols - 1) / WARP_SIZE + 1); dim3 block_dim(WARP_SIZE, WARP_SIZE); column_sum_reduce <<>>(inp, out, rows, cols); } template <> void launch_fuse_transpose_bias_kernel<__half>(const __half *inp, __half *out, int rows, int cols, cudaStream_t stream) { dim3 grid_dim((cols - 1) / WARP_SIZE + 1); dim3 block_dim(WARP_SIZE, WARP_SIZE); column_sum_reduce<__half> <<>>(inp, out, rows, cols); } /** @brief: fused_add2 Add two matrix inp1 and inp2 to out. @thread gridDim.x = batch_size * seq_len blockDim.x = min(hidden_dim, MAX_THREADS) @param inp1: [batch_size, seq_len, hidden_dim] inp2: [batch_size, seq_len, hidden_dim] out: [batch_size, seq_len, hidden_dim] batch_size: the size of the current batch seq_len: the sequence length of the current batch hidden_dim: dim of the hidden tensor */ template __global__ void fused_add2_kernel(T *out, const T *inp1, const T *inp2, int hidden_dim); template <> __global__ void fused_add2_kernel(float *out, const float *inp1, const float *inp2, int hidden_dim) { int row_id = blockIdx.x; int offset = flat_2dim(row_id, 0, hidden_dim); const float4 *inp1_4 = reinterpret_cast(inp1); const float4 *inp2_4 = reinterpret_cast(inp2); float4 *out_4 = reinterpret_cast(out); float4 vinp1; float4 vinp2; float4 val; for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { vinp1 = inp1_4[offset + i]; vinp2 = inp2_4[offset + i]; val.x = vinp1.x + vinp2.x; val.y = vinp1.y + vinp2.y; val.z = vinp1.z + vinp2.z; val.w = vinp1.w + vinp2.w; out_4[offset + i] = val; } } template <> __global__ void fused_add2_kernel<__half>(__half *out, const __half *inp1, const __half *inp2, int hidden_dim) { int row_id = blockIdx.x; int offset = flat_2dim(row_id, 0, hidden_dim); const float4 *inp1_4 = reinterpret_cast(inp1); const float4 *inp2_4 = reinterpret_cast(inp2); float4 *out_4 = reinterpret_cast(out); float4 vinp1; float4 vinp2; float4 val; __half2 *h2_inp1 = reinterpret_cast<__half2 *>(&vinp1); __half2 *h2_inp2 = reinterpret_cast<__half2 *>(&vinp2); __half2 *h2_val = reinterpret_cast<__half2 *>(&val); for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { vinp1 = inp1_4[offset + i]; vinp2 = inp2_4[offset + i]; h2_val[0] = __hadd2(h2_inp1[0], h2_inp2[0]); h2_val[1] = __hadd2(h2_inp1[1], h2_inp2[1]); h2_val[2] = __hadd2(h2_inp1[2], h2_inp2[2]); h2_val[3] = __hadd2(h2_inp1[3], h2_inp2[3]); out_4[offset + i] = val; } } //[b, s, h] -> [b, s, h] template <> void launch_fused_add2(float *out, const float *inp1, const float *inp2, int batch_size, int seq_len, int hidden_dim, cudaStream_t &stream) { hidden_dim >>= 2; dim3 grid_dim(batch_size * seq_len); dim3 block_dim(min(hidden_dim, MAX_THREADS)); fused_add2_kernel<<>>(out, inp1, inp2, hidden_dim); } template <> void launch_fused_add2<__half>(__half *out, const __half *inp1, const __half *inp2, int batch_size, int seq_len, int hidden_dim, cudaStream_t &stream) { hidden_dim >>= 3; dim3 grid_dim(batch_size * seq_len); dim3 block_dim(min(hidden_dim, MAX_THREADS)); fused_add2_kernel<<>>(out, inp1, inp2, hidden_dim); } template __global__ void kernel_concat3_dim1(const T *inp1, const T *inp2, T *output, int sz0, int sz2, int sz1_1, int sz1_2) { int nele = sz0 * sz2 * (sz1_1 + sz1_2); int idx = flat_2dim(blockIdx.x, threadIdx.x, blockDim.x); if (idx >= nele) { return; } float4 *dst_ptr = (float4 *)output + idx; int idx2 = idx % sz2; idx = idx / sz2; int idx1 = idx % (sz1_1 + sz1_2); int idx0 = idx / (sz1_1 + sz1_2); float4 *src_ptr = nullptr; int sz1 = 0; if (idx1 < sz1_1) { sz1 = sz1_1; src_ptr = (float4 *)inp1; } else { idx1 -= sz1_1; sz1 = sz1_2; src_ptr = (float4 *)inp2; } src_ptr += flat_3dim(idx0, idx1, idx2, sz1, sz2); dst_ptr[0] = src_ptr[0]; } template <> void launch_concat3_dim1(const float *inp1, const float *inp2, float *output, int sz0, int sz2, int sz1_1, int sz1_2, cudaStream_t stream) { sz2 >>= 2; int nele = sz0 * sz2 * (sz1_1 + sz1_2); int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; kernel_concat3_dim1<<>>( inp1, inp2, output, sz0, sz2, sz1_1, sz1_2); } template <> void launch_concat3_dim1<__half>(const __half *inp1, const __half *inp2, __half *output, int sz0, int sz2, int sz1_1, int sz1_2, cudaStream_t stream) { sz2 >>= 3; int nele = sz0 * sz2 * (sz1_1 + sz1_2); int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; kernel_concat3_dim1<<>>( inp1, inp2, output, sz0, sz2, sz1_1, sz1_2); }