mirror of https://github.com/hpcaitech/ColossalAI
xs_courtesy
9 months ago
6 changed files with 140 additions and 0 deletions
@ -0,0 +1,65 @@
|
||||
#include <ATen/cuda/CUDAContext.h> |
||||
#include <torch/extension.h> |
||||
#include <stdio.h> |
||||
|
||||
#include "type_shim.h" |
||||
#include "include/mp_type_traits.h" |
||||
|
||||
template<typename T> |
||||
__device__ __forceinline__ T silu_kernel(const T& x) { |
||||
// x * sigmoid(x) |
||||
using MT = typename infer::dtype::MPTypeTrait<T>::Type; |
||||
return static_cast<T>((static_cast<MT>(x)) / (static_cast<MT>(1.0f) + expf(static_cast<MT>(-x)))); |
||||
} |
||||
|
||||
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)> |
||||
__global__ void act_and_mul_kernel( |
||||
const scalar_t* __restrict__ ins_data, |
||||
scalar_t* __restrict__ outs_data, |
||||
const int64_t numel) { |
||||
using MT = typename infer::dtype::MPTypeTrait<scalar_t>::Type; |
||||
|
||||
int64_t idx = static_cast<int64_t>(threadIdx.x) + static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x); |
||||
const int64_t grid_size = blockDim.x * gridDim.x; |
||||
if(idx > numel) { |
||||
return; |
||||
} |
||||
|
||||
for(int64_t i = idx; i < numel; i += grid_size) { |
||||
scalar_t x = ins_data[i]; |
||||
scalar_t y = ins_data[i+numel]; |
||||
outs_data[i] = static_cast<scalar_t>(static_cast<MT>(ACT_FN(x)) * static_cast<MT>(y)); |
||||
} |
||||
} |
||||
|
||||
// Note(LiuYang):This func is designed for calculation mode like |
||||
// silu(x[:half_1stdim]) * (x[half_1stdim:]) |
||||
torch::Tensor silu_and_mul(const torch::Tensor& ins) |
||||
{ |
||||
auto ins_shape = ins.sizes().vec(); |
||||
|
||||
ins_shape[0] = ins_shape[0]/2; |
||||
auto outs = torch::zeros(ins_shape,ins.options()); |
||||
auto outs_shape = ins.sizes().vec(); |
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
||||
|
||||
// Note(Liuyang): numel of ins must be divisible by 2 |
||||
int64_t numel = ((torch::numel(ins)) >> 1); |
||||
|
||||
// TODO(LiuYang): Maybe we need to implement a function to get launch config |
||||
dim3 grid((numel+255)/256); |
||||
dim3 block(256); |
||||
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT( |
||||
ins.scalar_type(), |
||||
"silu_and_mul", |
||||
act_and_mul_kernel<scalar_t,silu_kernel<scalar_t>><<<grid, block, 0, stream>>>( |
||||
ins.data_ptr<scalar_t>(), |
||||
outs.data_ptr<scalar_t>(), |
||||
numel |
||||
);) |
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError()); |
||||
return outs; |
||||
} |
@ -0,0 +1,35 @@
|
||||
#pragma once |
||||
|
||||
#include <ATen/ATen.h> |
||||
|
||||
#include "../type_shim.h" |
||||
|
||||
namespace infer { |
||||
namespace dtype { |
||||
|
||||
template <typename T> |
||||
class MPTypeTrait { |
||||
public: |
||||
using Type = float; |
||||
}; |
||||
|
||||
template <> |
||||
class MPTypeTrait<float> { |
||||
public: |
||||
using Type = float; |
||||
}; |
||||
|
||||
template <> |
||||
class MPTypeTrait<at::Half> { |
||||
public: |
||||
using Type = float; |
||||
}; |
||||
|
||||
template <> |
||||
class MPTypeTrait<at::BFloat16> { |
||||
public: |
||||
using Type = float; |
||||
}; |
||||
|
||||
} // namespace dtype
|
||||
} // namespace infer
|
@ -0,0 +1,33 @@
|
||||
import pytest |
||||
import torch |
||||
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader |
||||
from colossalai.utils import get_current_device |
||||
|
||||
inference_ops = InferenceOpsLoader().load() |
||||
|
||||
|
||||
@pytest.mark.parametrize("SHAPE_X", [2]) |
||||
@pytest.mark.parametrize("SHAPE_Y", [64]) |
||||
@pytest.mark.parametrize("SHAPE_Z", [11008]) |
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) |
||||
def test_silu_and_mul(SHAPE_X, SHAPE_Y, SHAPE_Z, dtype): |
||||
torch.manual_seed(5) |
||||
device = get_current_device() |
||||
ref_input = torch.randn(SHAPE_X, SHAPE_Y, SHAPE_Z, dtype=dtype, device=device) |
||||
origin_input = ref_input.clone() |
||||
|
||||
act_out = torch.nn.functional.silu(ref_input[0], inplace=True) |
||||
ref_out = act_out * ref_input[1] |
||||
|
||||
origin_out = inference_ops.silu_and_mul(origin_input) |
||||
|
||||
if dtype == torch.float32: |
||||
assert torch.allclose(origin_out, ref_out, atol=1e-5, rtol=1e-5) |
||||
else: |
||||
assert torch.allclose(origin_out, ref_out, atol=1e-3, rtol=1e-3) |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
test_silu_and_mul(2, 64, 11008, torch.float32) |
||||
test_silu_and_mul(2, 64, 11008, torch.float16) |
Loading…
Reference in new issue