From 95c21498d4f6e640e218f4b00349020f4ae7c69a Mon Sep 17 00:00:00 2001 From: xs_courtesy Date: Thu, 7 Mar 2024 16:57:49 +0800 Subject: [PATCH] add silu_and_mul for infer --- extensions/csrc/cuda/activation_kernel.cu | 65 +++++++++++++++++++ .../cuda/colossal_inference_C_frontend.cpp | 3 + extensions/csrc/cuda/include/mp_type_traits.h | 35 ++++++++++ extensions/csrc/cuda/type_shim.h | 3 + extensions/inference/inference_ops_cuda.py | 1 + .../test_ops/cuda/test_silu_and_mul.py | 33 ++++++++++ 6 files changed, 140 insertions(+) create mode 100644 extensions/csrc/cuda/activation_kernel.cu create mode 100644 extensions/csrc/cuda/include/mp_type_traits.h create mode 100644 tests/test_infer/test_ops/cuda/test_silu_and_mul.py diff --git a/extensions/csrc/cuda/activation_kernel.cu b/extensions/csrc/cuda/activation_kernel.cu new file mode 100644 index 000000000..4121b67fc --- /dev/null +++ b/extensions/csrc/cuda/activation_kernel.cu @@ -0,0 +1,65 @@ +#include +#include +#include + +#include "type_shim.h" +#include "include/mp_type_traits.h" + +template +__device__ __forceinline__ T silu_kernel(const T& x) { + // x * sigmoid(x) + using MT = typename infer::dtype::MPTypeTrait::Type; + return static_cast((static_cast(x)) / (static_cast(1.0f) + expf(static_cast(-x)))); +} + +template +__global__ void act_and_mul_kernel( + const scalar_t* __restrict__ ins_data, + scalar_t* __restrict__ outs_data, + const int64_t numel) { + using MT = typename infer::dtype::MPTypeTrait::Type; + + int64_t idx = static_cast(threadIdx.x) + static_cast(blockIdx.x) * static_cast(blockDim.x); + const int64_t grid_size = blockDim.x * gridDim.x; + if(idx > numel) { + return; + } + + for(int64_t i = idx; i < numel; i += grid_size) { + scalar_t x = ins_data[i]; + scalar_t y = ins_data[i+numel]; + outs_data[i] = static_cast(static_cast(ACT_FN(x)) * static_cast(y)); + } +} + +// Note(LiuYang):This func is designed for calculation mode like +// silu(x[:half_1stdim]) * (x[half_1stdim:]) +torch::Tensor silu_and_mul(const torch::Tensor& ins) +{ + auto ins_shape = ins.sizes().vec(); + + ins_shape[0] = ins_shape[0]/2; + auto outs = torch::zeros(ins_shape,ins.options()); + auto outs_shape = ins.sizes().vec(); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Note(Liuyang): numel of ins must be divisible by 2 + int64_t numel = ((torch::numel(ins)) >> 1); + + // TODO(LiuYang): Maybe we need to implement a function to get launch config + dim3 grid((numel+255)/256); + dim3 block(256); + + DISPATCH_FLOAT_HALF_AND_BFLOAT( + ins.scalar_type(), + "silu_and_mul", + act_and_mul_kernel><<>>( + ins.data_ptr(), + outs.data_ptr(), + numel + );) + + AT_CUDA_CHECK(cudaGetLastError()); + return outs; +} diff --git a/extensions/csrc/cuda/colossal_inference_C_frontend.cpp b/extensions/csrc/cuda/colossal_inference_C_frontend.cpp index ae410c14f..cc53d8b88 100644 --- a/extensions/csrc/cuda/colossal_inference_C_frontend.cpp +++ b/extensions/csrc/cuda/colossal_inference_C_frontend.cpp @@ -9,7 +9,10 @@ void decode_kv_cache_memcpy( torch::Tensor& sequence_lengths, // [batch_size] torch::Tensor& block_tables); // [batch_size, max_seq_len] +torch::Tensor silu_and_mul(const torch::Tensor& ins); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy, "Copy the GPU memory of kvcache during the decode stage."); + m.def("silu_and_mul", &silu_and_mul, "Silu with a following multiply"); } diff --git a/extensions/csrc/cuda/include/mp_type_traits.h b/extensions/csrc/cuda/include/mp_type_traits.h new file mode 100644 index 000000000..6b3ae9c1b --- /dev/null +++ b/extensions/csrc/cuda/include/mp_type_traits.h @@ -0,0 +1,35 @@ +#pragma once + +#include + +#include "../type_shim.h" + +namespace infer { +namespace dtype { + +template +class MPTypeTrait { + public: + using Type = float; +}; + +template <> +class MPTypeTrait { + public: + using Type = float; +}; + +template <> +class MPTypeTrait { + public: + using Type = float; +}; + +template <> +class MPTypeTrait { + public: + using Type = float; +}; + +} // namespace dtype +} // namespace infer diff --git a/extensions/csrc/cuda/type_shim.h b/extensions/csrc/cuda/type_shim.h index 511631935..7be3fab1b 100644 --- a/extensions/csrc/cuda/type_shim.h +++ b/extensions/csrc/cuda/type_shim.h @@ -4,6 +4,9 @@ This file is adapted from fused adam in NVIDIA/apex, commit a109f85 Licensed under the MIT License. */ + +#pragma once + #include #include "compat.h" diff --git a/extensions/inference/inference_ops_cuda.py b/extensions/inference/inference_ops_cuda.py index 12bec6fab..2858d7160 100644 --- a/extensions/inference/inference_ops_cuda.py +++ b/extensions/inference/inference_ops_cuda.py @@ -12,6 +12,7 @@ class InferenceOpsCudaExtension(_CudaExtension): for fname in [ "cuda/colossal_inference_C_frontend.cpp", "cuda/decode_kv_cache_memcpy_kernel.cu", + "cuda/activation_kernel.cu", ] ] return ret diff --git a/tests/test_infer/test_ops/cuda/test_silu_and_mul.py b/tests/test_infer/test_ops/cuda/test_silu_and_mul.py new file mode 100644 index 000000000..ced2db7ca --- /dev/null +++ b/tests/test_infer/test_ops/cuda/test_silu_and_mul.py @@ -0,0 +1,33 @@ +import pytest +import torch + +from colossalai.kernel.kernel_loader import InferenceOpsLoader +from colossalai.utils import get_current_device + +inference_ops = InferenceOpsLoader().load() + + +@pytest.mark.parametrize("SHAPE_X", [2]) +@pytest.mark.parametrize("SHAPE_Y", [64]) +@pytest.mark.parametrize("SHAPE_Z", [11008]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) +def test_silu_and_mul(SHAPE_X, SHAPE_Y, SHAPE_Z, dtype): + torch.manual_seed(5) + device = get_current_device() + ref_input = torch.randn(SHAPE_X, SHAPE_Y, SHAPE_Z, dtype=dtype, device=device) + origin_input = ref_input.clone() + + act_out = torch.nn.functional.silu(ref_input[0], inplace=True) + ref_out = act_out * ref_input[1] + + origin_out = inference_ops.silu_and_mul(origin_input) + + if dtype == torch.float32: + assert torch.allclose(origin_out, ref_out, atol=1e-5, rtol=1e-5) + else: + assert torch.allclose(origin_out, ref_out, atol=1e-3, rtol=1e-3) + + +if __name__ == "__main__": + test_silu_and_mul(2, 64, 11008, torch.float32) + test_silu_and_mul(2, 64, 11008, torch.float16)