#include <ATen/cuda/CUDAContext.h> #include <torch/extension.h> #include <stdio.h> #include "../common/micros.h" #include "../common/mp_type_traits.h" template<typename T> __device__ __forceinline__ T silu_kernel(const T& x) { // x * sigmoid(x) using MT = typename colossalAI::common::MPTypeTrait<T>::Type; return static_cast<T>((static_cast<MT>(x)) / (static_cast<MT>(1.0f) + expf(static_cast<MT>(-x)))); } template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)> __global__ void act_and_mul_kernel( const scalar_t* __restrict__ ins_data, scalar_t* __restrict__ outs_data, const int64_t numel) { using MT = typename colossalAI::common::MPTypeTrait<scalar_t>::Type; int64_t idx = static_cast<int64_t>(threadIdx.x) + static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x); const int64_t grid_size = blockDim.x * gridDim.x; if(idx > numel) { return; } for(int64_t i = idx; i < numel; i += grid_size) { scalar_t x = ins_data[i]; scalar_t y = ins_data[i+numel]; outs_data[i] = static_cast<scalar_t>(static_cast<MT>(ACT_FN(x)) * static_cast<MT>(y)); } } // Note(LiuYang):This func is designed for calculation mode like // silu(x[:half_1stdim]) * (x[half_1stdim:]) torch::Tensor silu_and_mul(const torch::Tensor& ins) { // Note(LiuYang): According to torch doc, vec() may cost a lot, but I did't find a better api // to manipulate ins_shape which is IntArrayRef auto ins_shape = ins.sizes().vec(); ins_shape[0] = ins_shape[0]/2; if (ins_shape[0] == 1) { ins_shape.erase(ins_shape.begin()); } auto outs = torch::zeros(ins_shape,ins.options()); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // Note(Liuyang): numel of ins must be divisible by 2 int64_t numel = ((torch::numel(ins)) >> 1); // Note(LiuYang): For better performance for special case of which input is [2, 64, 11008], now // I comment this part codeļ¼because it also cost a little time to calculate a better config // colossalAI::cuda::utils::NVGPUDevInfo dev_info(0); // auto config = colossalAI::cuda::utils::GetGPULaunchConfig1D(dev_info,numel,1); // dim3 grid = config.grid; // dim3 block = config.block; dim3 grid((numel+255)/256); dim3 block(256); DISPATCH_FLOAT_HALF_AND_BFLOAT( ins.scalar_type(), "silu_and_mul", act_and_mul_kernel<scalar_t,silu_kernel<scalar_t>><<<grid, block, 0, stream>>>( ins.data_ptr<scalar_t>(), outs.data_ptr<scalar_t>(), numel );) AT_CUDA_CHECK(cudaGetLastError()); return outs; }