mirror of https://github.com/hpcaitech/ColossalAI
add some comments
parent
388e043930
commit
5724b9e31e
|
@ -37,6 +37,8 @@ __global__ void act_and_mul_kernel(
|
||||||
// silu(x[:half_1stdim]) * (x[half_1stdim:])
|
// silu(x[:half_1stdim]) * (x[half_1stdim:])
|
||||||
torch::Tensor silu_and_mul(const torch::Tensor& ins)
|
torch::Tensor silu_and_mul(const torch::Tensor& ins)
|
||||||
{
|
{
|
||||||
|
// Note(LiuYang): According to torch doc, vec() may cost a lot, but I did't find a better api
|
||||||
|
// to manipulate ins_shape which is IntArrayRef
|
||||||
auto ins_shape = ins.sizes().vec();
|
auto ins_shape = ins.sizes().vec();
|
||||||
|
|
||||||
ins_shape[0] = ins_shape[0]/2;
|
ins_shape[0] = ins_shape[0]/2;
|
||||||
|
@ -44,18 +46,21 @@ torch::Tensor silu_and_mul(const torch::Tensor& ins)
|
||||||
ins_shape.erase(ins_shape.begin());
|
ins_shape.erase(ins_shape.begin());
|
||||||
}
|
}
|
||||||
auto outs = torch::zeros(ins_shape,ins.options());
|
auto outs = torch::zeros(ins_shape,ins.options());
|
||||||
auto outs_shape = ins.sizes().vec();
|
|
||||||
|
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
// Note(Liuyang): numel of ins must be divisible by 2
|
// Note(Liuyang): numel of ins must be divisible by 2
|
||||||
int64_t numel = ((torch::numel(ins)) >> 1);
|
int64_t numel = ((torch::numel(ins)) >> 1);
|
||||||
|
|
||||||
// TODO(LiuYang): Maybe we need to implement a function to get launch config
|
// Note(LiuYang): For better performance for special case of which input is [2, 64, 11008], now
|
||||||
colossalAI::cuda::utils::NVGPUDevInfo dev_info(0);
|
// I comment this part code,because it also cost a little time to calculate a better config
|
||||||
auto config = colossalAI::cuda::utils::GetGPULaunchConfig1D(dev_info,numel,1);
|
// colossalAI::cuda::utils::NVGPUDevInfo dev_info(0);
|
||||||
dim3 grid = config.grid;
|
// auto config = colossalAI::cuda::utils::GetGPULaunchConfig1D(dev_info,numel,1);
|
||||||
dim3 block = config.block;
|
// dim3 grid = config.grid;
|
||||||
|
// dim3 block = config.block;
|
||||||
|
|
||||||
|
dim3 grid((numel+255)/256);
|
||||||
|
dim3 block(256);
|
||||||
|
|
||||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||||
ins.scalar_type(),
|
ins.scalar_type(),
|
||||||
|
|
Loading…
Reference in New Issue