ColossalAI/extensions/csrc/cuda/activation_kernel.cu

76 lines
2.6 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <stdio.h>
#include "../common/micros.h"
#include "../common/mp_type_traits.h"
template<typename T>
__device__ __forceinline__ T silu_kernel(const T& x) {
// x * sigmoid(x)
using MT = typename colossalAI::common::MPTypeTrait<T>::Type;
return static_cast<T>((static_cast<MT>(x)) / (static_cast<MT>(1.0f) + expf(static_cast<MT>(-x))));
}
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void act_and_mul_kernel(
const scalar_t* __restrict__ ins_data,
scalar_t* __restrict__ outs_data,
const int64_t numel) {
using MT = typename colossalAI::common::MPTypeTrait<scalar_t>::Type;
int64_t idx = static_cast<int64_t>(threadIdx.x) + static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x);
const int64_t grid_size = blockDim.x * gridDim.x;
if(idx > numel) {
return;
}
for(int64_t i = idx; i < numel; i += grid_size) {
scalar_t x = ins_data[i];
scalar_t y = ins_data[i+numel];
outs_data[i] = static_cast<scalar_t>(static_cast<MT>(ACT_FN(x)) * static_cast<MT>(y));
}
}
// Note(LiuYang):This func is designed for calculation mode like
// silu(x[:half_1stdim]) * (x[half_1stdim:])
torch::Tensor silu_and_mul(const torch::Tensor& ins)
{
// Note(LiuYang): According to torch doc, vec() may cost a lot, but I did't find a better api
// to manipulate ins_shape which is IntArrayRef
auto ins_shape = ins.sizes().vec();
ins_shape[0] = ins_shape[0]/2;
if (ins_shape[0] == 1) {
ins_shape.erase(ins_shape.begin());
}
auto outs = torch::zeros(ins_shape,ins.options());
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// Note(Liuyang): numel of ins must be divisible by 2
int64_t numel = ((torch::numel(ins)) >> 1);
// Note(LiuYang): For better performance for special case of which input is [2, 64, 11008], now
// I comment this part codebecause it also cost a little time to calculate a better config
// colossalAI::cuda::utils::NVGPUDevInfo dev_info(0);
// auto config = colossalAI::cuda::utils::GetGPULaunchConfig1D(dev_info,numel,1);
// dim3 grid = config.grid;
// dim3 block = config.block;
dim3 grid((numel+255)/256);
dim3 block(256);
DISPATCH_FLOAT_HALF_AND_BFLOAT(
ins.scalar_type(),
"silu_and_mul",
act_and_mul_kernel<scalar_t,silu_kernel<scalar_t>><<<grid, block, 0, stream>>>(
ins.data_ptr<scalar_t>(),
outs.data_ptr<scalar_t>(),
numel
);)
AT_CUDA_CHECK(cudaGetLastError());
return outs;
}