Browse Source

add silu_and_mul for infer

pull/5433/head
xs_courtesy 9 months ago
parent
commit
95c21498d4
  1. 65
      extensions/csrc/cuda/activation_kernel.cu
  2. 3
      extensions/csrc/cuda/colossal_inference_C_frontend.cpp
  3. 35
      extensions/csrc/cuda/include/mp_type_traits.h
  4. 3
      extensions/csrc/cuda/type_shim.h
  5. 1
      extensions/inference/inference_ops_cuda.py
  6. 33
      tests/test_infer/test_ops/cuda/test_silu_and_mul.py

65
extensions/csrc/cuda/activation_kernel.cu

@ -0,0 +1,65 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <stdio.h>
#include "type_shim.h"
#include "include/mp_type_traits.h"
template<typename T>
__device__ __forceinline__ T silu_kernel(const T& x) {
// x * sigmoid(x)
using MT = typename infer::dtype::MPTypeTrait<T>::Type;
return static_cast<T>((static_cast<MT>(x)) / (static_cast<MT>(1.0f) + expf(static_cast<MT>(-x))));
}
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void act_and_mul_kernel(
const scalar_t* __restrict__ ins_data,
scalar_t* __restrict__ outs_data,
const int64_t numel) {
using MT = typename infer::dtype::MPTypeTrait<scalar_t>::Type;
int64_t idx = static_cast<int64_t>(threadIdx.x) + static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x);
const int64_t grid_size = blockDim.x * gridDim.x;
if(idx > numel) {
return;
}
for(int64_t i = idx; i < numel; i += grid_size) {
scalar_t x = ins_data[i];
scalar_t y = ins_data[i+numel];
outs_data[i] = static_cast<scalar_t>(static_cast<MT>(ACT_FN(x)) * static_cast<MT>(y));
}
}
// Note(LiuYang):This func is designed for calculation mode like
// silu(x[:half_1stdim]) * (x[half_1stdim:])
torch::Tensor silu_and_mul(const torch::Tensor& ins)
{
auto ins_shape = ins.sizes().vec();
ins_shape[0] = ins_shape[0]/2;
auto outs = torch::zeros(ins_shape,ins.options());
auto outs_shape = ins.sizes().vec();
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// Note(Liuyang): numel of ins must be divisible by 2
int64_t numel = ((torch::numel(ins)) >> 1);
// TODO(LiuYang): Maybe we need to implement a function to get launch config
dim3 grid((numel+255)/256);
dim3 block(256);
DISPATCH_FLOAT_HALF_AND_BFLOAT(
ins.scalar_type(),
"silu_and_mul",
act_and_mul_kernel<scalar_t,silu_kernel<scalar_t>><<<grid, block, 0, stream>>>(
ins.data_ptr<scalar_t>(),
outs.data_ptr<scalar_t>(),
numel
);)
AT_CUDA_CHECK(cudaGetLastError());
return outs;
}

3
extensions/csrc/cuda/colossal_inference_C_frontend.cpp

@ -9,7 +9,10 @@ void decode_kv_cache_memcpy(
torch::Tensor& sequence_lengths, // [batch_size]
torch::Tensor& block_tables); // [batch_size, max_seq_len]
torch::Tensor silu_and_mul(const torch::Tensor& ins);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
"Copy the GPU memory of kvcache during the decode stage.");
m.def("silu_and_mul", &silu_and_mul, "Silu with a following multiply");
}

35
extensions/csrc/cuda/include/mp_type_traits.h

@ -0,0 +1,35 @@
#pragma once
#include <ATen/ATen.h>
#include "../type_shim.h"
namespace infer {
namespace dtype {
template <typename T>
class MPTypeTrait {
public:
using Type = float;
};
template <>
class MPTypeTrait<float> {
public:
using Type = float;
};
template <>
class MPTypeTrait<at::Half> {
public:
using Type = float;
};
template <>
class MPTypeTrait<at::BFloat16> {
public:
using Type = float;
};
} // namespace dtype
} // namespace infer

3
extensions/csrc/cuda/type_shim.h

@ -4,6 +4,9 @@
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
Licensed under the MIT License.
*/
#pragma once
#include <ATen/ATen.h>
#include "compat.h"

1
extensions/inference/inference_ops_cuda.py

@ -12,6 +12,7 @@ class InferenceOpsCudaExtension(_CudaExtension):
for fname in [
"cuda/colossal_inference_C_frontend.cpp",
"cuda/decode_kv_cache_memcpy_kernel.cu",
"cuda/activation_kernel.cu",
]
]
return ret

33
tests/test_infer/test_ops/cuda/test_silu_and_mul.py

@ -0,0 +1,33 @@
import pytest
import torch
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.utils import get_current_device
inference_ops = InferenceOpsLoader().load()
@pytest.mark.parametrize("SHAPE_X", [2])
@pytest.mark.parametrize("SHAPE_Y", [64])
@pytest.mark.parametrize("SHAPE_Z", [11008])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
def test_silu_and_mul(SHAPE_X, SHAPE_Y, SHAPE_Z, dtype):
torch.manual_seed(5)
device = get_current_device()
ref_input = torch.randn(SHAPE_X, SHAPE_Y, SHAPE_Z, dtype=dtype, device=device)
origin_input = ref_input.clone()
act_out = torch.nn.functional.silu(ref_input[0], inplace=True)
ref_out = act_out * ref_input[1]
origin_out = inference_ops.silu_and_mul(origin_input)
if dtype == torch.float32:
assert torch.allclose(origin_out, ref_out, atol=1e-5, rtol=1e-5)
else:
assert torch.allclose(origin_out, ref_out, atol=1e-3, rtol=1e-3)
if __name__ == "__main__":
test_silu_and_mul(2, 64, 11008, torch.float32)
test_silu_and_mul(2, 64, 11008, torch.float16)
Loading…
Cancel
Save