add some comments

pull/5456/head
xs_courtesy 2024-03-15 11:18:57 +08:00
parent 388e043930
commit 5724b9e31e
1 changed files with 11 additions and 6 deletions

View File

@ -37,6 +37,8 @@ __global__ void act_and_mul_kernel(
// silu(x[:half_1stdim]) * (x[half_1stdim:]) // silu(x[:half_1stdim]) * (x[half_1stdim:])
torch::Tensor silu_and_mul(const torch::Tensor& ins) torch::Tensor silu_and_mul(const torch::Tensor& ins)
{ {
// Note(LiuYang): According to torch doc, vec() may cost a lot, but I did't find a better api
// to manipulate ins_shape which is IntArrayRef
auto ins_shape = ins.sizes().vec(); auto ins_shape = ins.sizes().vec();
ins_shape[0] = ins_shape[0]/2; ins_shape[0] = ins_shape[0]/2;
@ -44,18 +46,21 @@ torch::Tensor silu_and_mul(const torch::Tensor& ins)
ins_shape.erase(ins_shape.begin()); ins_shape.erase(ins_shape.begin());
} }
auto outs = torch::zeros(ins_shape,ins.options()); auto outs = torch::zeros(ins_shape,ins.options());
auto outs_shape = ins.sizes().vec();
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// Note(Liuyang): numel of ins must be divisible by 2 // Note(Liuyang): numel of ins must be divisible by 2
int64_t numel = ((torch::numel(ins)) >> 1); int64_t numel = ((torch::numel(ins)) >> 1);
// TODO(LiuYang): Maybe we need to implement a function to get launch config // Note(LiuYang): For better performance for special case of which input is [2, 64, 11008], now
colossalAI::cuda::utils::NVGPUDevInfo dev_info(0); // I comment this part codebecause it also cost a little time to calculate a better config
auto config = colossalAI::cuda::utils::GetGPULaunchConfig1D(dev_info,numel,1); // colossalAI::cuda::utils::NVGPUDevInfo dev_info(0);
dim3 grid = config.grid; // auto config = colossalAI::cuda::utils::GetGPULaunchConfig1D(dev_info,numel,1);
dim3 block = config.block; // dim3 grid = config.grid;
// dim3 block = config.block;
dim3 grid((numel+255)/256);
dim3 block(256);
DISPATCH_FLOAT_HALF_AND_BFLOAT( DISPATCH_FLOAT_HALF_AND_BFLOAT(
ins.scalar_type(), ins.scalar_type(),