ColossalAI/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu

313 lines
11 KiB
Plaintext

#include <cub/block/block_load.cuh>
#include <cub/block/block_scan.cuh>
#include <cub/block/block_store.cuh>
#include "kernels.h"
using namespace cub;
/**
@brief: transform_0213
Split the attention heads and reshape input
during backward progress of encoder self-attention
@thread
gridDim.x = batch_size
gridDim.y = seq_len
blockDim.x = min(hidden_dim, MAX_THREADS)
@param
input: [batch_size, seq_len, hidden_dim]
output: [batch_size, nhead, seq_len, head_dim]
batch_size: the size of the current batch
seq_len: the sequence length of the current batch
hidden_dim: dim of the hidden tensor
nhead: number of attention heads
*/
template <typename T>
__global__ void transform_0213(T *output, const T *input, int hidden_dim,
int head_dim);
template <>
__global__ void transform_0213<float>(float *output, const float *input,
int hidden_dim, int head_dim) {
int batch_id = blockIdx.x;
int token_id = blockIdx.y;
int seq_len = gridDim.y;
int nhead = hidden_dim / head_dim;
// [b, s, h]
int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim);
// [b, nh, s, ad]
int trg_offset =
flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim);
const float4 *input4 = reinterpret_cast<const float4 *>(input);
float4 *res4 = reinterpret_cast<float4 *>(output);
float4 vinput4;
for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
vinput4 = input4[src_offset + i];
int head_id = i / head_dim;
int dim_id = i % head_dim;
int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim);
res4[trg_offset + cur_trg_offset] = vinput4;
}
}
template <>
__global__ void transform_0213<__half>(__half *output, const __half *input,
int hidden_dim, int head_dim) {
int batch_id = blockIdx.x;
int token_id = blockIdx.y;
int seq_len = gridDim.y;
int nhead = hidden_dim / head_dim;
// [b, s, h]
int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim);
// [b, nh, s, ad]
int trg_offset =
flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim);
const float4 *input4 = reinterpret_cast<const float4 *>(input);
float4 *res4 = reinterpret_cast<float4 *>(output);
float4 vinput4;
for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
vinput4 = input4[src_offset + i];
int head_id = i / head_dim;
int dim_id = i % head_dim;
int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim);
res4[trg_offset + cur_trg_offset] = vinput4;
}
}
// [b, s, h] -> [b, nh, s, ad]
template <>
void launch_transform_0213<float>(float *output, const float *input,
int batch_size, int seq_len, int hidden_dim,
int nhead, cudaStream_t stream) {
hidden_dim >>= 2;
int head_dim = hidden_dim / nhead;
dim3 grid_dim(batch_size, seq_len);
dim3 block_dim(min(hidden_dim, MAX_THREADS));
transform_0213<float>
<<<grid_dim, block_dim, 0, stream>>>(output, input, hidden_dim, head_dim);
}
template <>
void launch_transform_0213<__half>(__half *output, const __half *input,
int batch_size, int seq_len, int hidden_dim,
int nhead, cudaStream_t stream) {
hidden_dim >>= 3;
int head_dim = hidden_dim / nhead;
dim3 grid_dim(batch_size, seq_len);
dim3 block_dim(min(hidden_dim, MAX_THREADS));
transform_0213<__half>
<<<grid_dim, block_dim, 0, stream>>>(output, input, hidden_dim, head_dim);
}
/**
@brief: bias_add_transform_20314
Add bias to input, transform from
[0, 1, 2, 3, 4] to [2, 0, 3, 1, 4]
@thread
gridDim.x = dim_0
gridDim.y = dim_1
gridDim.z = dim_2
blockDim.x = min(dim_3 * dim_4, MAX_THREADS)
@param
input: [dim_0, dim_1, dim_2, dim_3, dim_4]
bias: [dim_2, dim_3, dim_4]
output: [dim_2, dim_0, dim_3, dim_1, dim_4]
*/
template <typename T>
__global__ void bias_add_transform_20314(T *output, const T *input,
const T *bias, int dim_3, int dim_4);
template <>
__global__ void
bias_add_transform_20314<float>(float *output, const float *input,
const float *bias, int dim_3, int dim_4) {
int id0 = blockIdx.x;
int id1 = blockIdx.y;
int id2 = blockIdx.z;
int dim_0 = gridDim.x;
int dim_1 = gridDim.y;
int dim_2 = gridDim.z;
int dim_34 = dim_3 * dim_4;
int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34);
int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4);
int bias_offset = flat_2dim(id2, 0, dim_34);
const float4 *qkv4 = reinterpret_cast<const float4 *>(input);
const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
float4 *res4 = reinterpret_cast<float4 *>(output);
float4 vqkv4;
float4 vbias4;
float4 vres4;
for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) {
vqkv4 = qkv4[src_offset + i];
vbias4 = bias4[bias_offset + i];
vres4.x = vqkv4.x + vbias4.x;
vres4.y = vqkv4.y + vbias4.y;
vres4.z = vqkv4.z + vbias4.z;
vres4.w = vqkv4.w + vbias4.w;
int id3 = i / dim_4;
int id4 = i % dim_4;
int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4);
res4[trg_offset + cur_trg_offset] = vres4;
}
}
template <>
__global__ void
bias_add_transform_20314<__half>(__half *output, const __half *input,
const __half *bias, int dim_3, int dim_4) {
int id0 = blockIdx.x;
int id1 = blockIdx.y;
int id2 = blockIdx.z;
int dim_0 = gridDim.x;
int dim_1 = gridDim.y;
int dim_2 = gridDim.z;
int dim_34 = dim_3 * dim_4;
int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34);
int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4);
int bias_offset = flat_2dim(id2, 0, dim_34);
const float4 *qkv4 = reinterpret_cast<const float4 *>(input);
const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
float4 *res4 = reinterpret_cast<float4 *>(output);
float4 vqkv4;
float4 vbias4;
float4 vres4;
__half2 *h2_qkv = reinterpret_cast<__half2 *>(&vqkv4);
__half2 *h2_bias = reinterpret_cast<__half2 *>(&vbias4);
__half2 *h2_res = reinterpret_cast<__half2 *>(&vres4);
for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) {
vqkv4 = qkv4[src_offset + i];
vbias4 = bias4[bias_offset + i];
h2_res[0] = __hadd2(h2_qkv[0], h2_bias[0]);
h2_res[1] = __hadd2(h2_qkv[1], h2_bias[1]);
h2_res[2] = __hadd2(h2_qkv[2], h2_bias[2]);
h2_res[3] = __hadd2(h2_qkv[3], h2_bias[3]);
int id3 = i / dim_4;
int id4 = i % dim_4;
int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4);
res4[trg_offset + cur_trg_offset] = vres4;
}
}
// [b, s, 3, h] -> [3, b, nh, s, ad]
template <>
void launch_bias_add_transform_20314<float>(float *output, const float *input,
const float *bias, int dim_0,
int dim_1, int dim_2, int dim_3,
int dim_4, cudaStream_t stream) {
dim_4 >>= 2;
dim3 grid_dim(dim_0, dim_1, dim_2);
dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS));
bias_add_transform_20314<float>
<<<grid_dim, block_dim, 0, stream>>>(output, input, bias, dim_3, dim_4);
}
template <>
void launch_bias_add_transform_20314<__half>(__half *output,
const __half *input,
const __half *bias, int dim_0,
int dim_1, int dim_2, int dim_3,
int dim_4, cudaStream_t stream) {
dim_4 >>= 3;
dim3 grid_dim(dim_0, dim_1, dim_2);
dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS));
bias_add_transform_20314<__half>
<<<grid_dim, block_dim, 0, stream>>>(output, input, bias, dim_3, dim_4);
}
/**
@brief: transform4d_0213
Reshape the input matrix to merge the heads
@thread
gridDim.x = (num_all + max_block_thread - 1) / max_block_thread
blockDim.x = max_block_thread
@param
input: [trans_count, batch_size, nhead, seq_len, head_dim]
output: [batch_size, seq_len, trans_count, nhead, head_dim]
batch_size: the size of the current batch
seq_len: the sequence length of the current batch
hidden_dim: dim of the hidden tensor
nhead: number of attention heads
trans_count: 1 or 3, the count of matrice need to be transformed
*/
template <typename T>
__global__ void transform4d_0213(T *output, const T *input, int batch_size,
int seq_len, int trans_count, int nhead,
int head_dim, int num_all) {
int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset >= num_all) {
return;
}
int trans_id, batch_id, head_id, token_id, dim_id;
decompose_5dim(offset, batch_size, nhead, seq_len, head_dim, &trans_id,
&batch_id, &head_id, &token_id, &dim_id);
// [b, s, tc, nh, ad]
int trg_offset = flat_5dim(batch_id, token_id, trans_id, head_id, dim_id,
seq_len, trans_count, nhead, head_dim);
const float4 *input4 = reinterpret_cast<const float4 *>(input);
float4 *res4 = reinterpret_cast<float4 *>(output);
res4[trg_offset] = input4[offset];
}
// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad]
template <>
void launch_transform4d_0213<float>(float *output, const float *input,
int batch_size, int seq_len, int hidden_dim,
int nhead, int trans_count,
cudaStream_t stream) {
hidden_dim >>= 2;
int head_dim = hidden_dim / nhead;
int num_all = batch_size * seq_len * trans_count * hidden_dim;
int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS;
transform4d_0213<float><<<nblock, MAX_THREADS, 0, stream>>>(
output, input, batch_size, seq_len, trans_count, nhead, head_dim,
num_all);
}
template <>
void launch_transform4d_0213<__half>(__half *output, const __half *input,
int batch_size, int seq_len,
int hidden_dim, int nhead, int trans_count,
cudaStream_t stream) {
hidden_dim >>= 3;
int head_dim = hidden_dim / nhead;
int num_all = batch_size * seq_len * trans_count * hidden_dim;
int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS;
transform4d_0213<__half><<<nblock, MAX_THREADS, 0, stream>>>(
output, input, batch_size, seq_len, trans_count, nhead, head_dim,
num_all);
}