mirror of https://github.com/hpcaitech/ColossalAI
add silu_and_mul for infer
parent
593a72e4d5
commit
95c21498d4
|
@ -0,0 +1,65 @@
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <torch/extension.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include "type_shim.h"
|
||||||
|
#include "include/mp_type_traits.h"
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
__device__ __forceinline__ T silu_kernel(const T& x) {
|
||||||
|
// x * sigmoid(x)
|
||||||
|
using MT = typename infer::dtype::MPTypeTrait<T>::Type;
|
||||||
|
return static_cast<T>((static_cast<MT>(x)) / (static_cast<MT>(1.0f) + expf(static_cast<MT>(-x))));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
||||||
|
__global__ void act_and_mul_kernel(
|
||||||
|
const scalar_t* __restrict__ ins_data,
|
||||||
|
scalar_t* __restrict__ outs_data,
|
||||||
|
const int64_t numel) {
|
||||||
|
using MT = typename infer::dtype::MPTypeTrait<scalar_t>::Type;
|
||||||
|
|
||||||
|
int64_t idx = static_cast<int64_t>(threadIdx.x) + static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x);
|
||||||
|
const int64_t grid_size = blockDim.x * gridDim.x;
|
||||||
|
if(idx > numel) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for(int64_t i = idx; i < numel; i += grid_size) {
|
||||||
|
scalar_t x = ins_data[i];
|
||||||
|
scalar_t y = ins_data[i+numel];
|
||||||
|
outs_data[i] = static_cast<scalar_t>(static_cast<MT>(ACT_FN(x)) * static_cast<MT>(y));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note(LiuYang):This func is designed for calculation mode like
|
||||||
|
// silu(x[:half_1stdim]) * (x[half_1stdim:])
|
||||||
|
torch::Tensor silu_and_mul(const torch::Tensor& ins)
|
||||||
|
{
|
||||||
|
auto ins_shape = ins.sizes().vec();
|
||||||
|
|
||||||
|
ins_shape[0] = ins_shape[0]/2;
|
||||||
|
auto outs = torch::zeros(ins_shape,ins.options());
|
||||||
|
auto outs_shape = ins.sizes().vec();
|
||||||
|
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
// Note(Liuyang): numel of ins must be divisible by 2
|
||||||
|
int64_t numel = ((torch::numel(ins)) >> 1);
|
||||||
|
|
||||||
|
// TODO(LiuYang): Maybe we need to implement a function to get launch config
|
||||||
|
dim3 grid((numel+255)/256);
|
||||||
|
dim3 block(256);
|
||||||
|
|
||||||
|
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||||
|
ins.scalar_type(),
|
||||||
|
"silu_and_mul",
|
||||||
|
act_and_mul_kernel<scalar_t,silu_kernel<scalar_t>><<<grid, block, 0, stream>>>(
|
||||||
|
ins.data_ptr<scalar_t>(),
|
||||||
|
outs.data_ptr<scalar_t>(),
|
||||||
|
numel
|
||||||
|
);)
|
||||||
|
|
||||||
|
AT_CUDA_CHECK(cudaGetLastError());
|
||||||
|
return outs;
|
||||||
|
}
|
|
@ -9,7 +9,10 @@ void decode_kv_cache_memcpy(
|
||||||
torch::Tensor& sequence_lengths, // [batch_size]
|
torch::Tensor& sequence_lengths, // [batch_size]
|
||||||
torch::Tensor& block_tables); // [batch_size, max_seq_len]
|
torch::Tensor& block_tables); // [batch_size, max_seq_len]
|
||||||
|
|
||||||
|
torch::Tensor silu_and_mul(const torch::Tensor& ins);
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
|
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
|
||||||
"Copy the GPU memory of kvcache during the decode stage.");
|
"Copy the GPU memory of kvcache during the decode stage.");
|
||||||
|
m.def("silu_and_mul", &silu_and_mul, "Silu with a following multiply");
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,35 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <ATen/ATen.h>
|
||||||
|
|
||||||
|
#include "../type_shim.h"
|
||||||
|
|
||||||
|
namespace infer {
|
||||||
|
namespace dtype {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class MPTypeTrait {
|
||||||
|
public:
|
||||||
|
using Type = float;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
class MPTypeTrait<float> {
|
||||||
|
public:
|
||||||
|
using Type = float;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
class MPTypeTrait<at::Half> {
|
||||||
|
public:
|
||||||
|
using Type = float;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
class MPTypeTrait<at::BFloat16> {
|
||||||
|
public:
|
||||||
|
using Type = float;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace dtype
|
||||||
|
} // namespace infer
|
|
@ -4,6 +4,9 @@
|
||||||
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
|
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
|
||||||
Licensed under the MIT License.
|
Licensed under the MIT License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
|
|
||||||
#include "compat.h"
|
#include "compat.h"
|
||||||
|
|
|
@ -12,6 +12,7 @@ class InferenceOpsCudaExtension(_CudaExtension):
|
||||||
for fname in [
|
for fname in [
|
||||||
"cuda/colossal_inference_C_frontend.cpp",
|
"cuda/colossal_inference_C_frontend.cpp",
|
||||||
"cuda/decode_kv_cache_memcpy_kernel.cu",
|
"cuda/decode_kv_cache_memcpy_kernel.cu",
|
||||||
|
"cuda/activation_kernel.cu",
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
return ret
|
return ret
|
||||||
|
|
|
@ -0,0 +1,33 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
inference_ops = InferenceOpsLoader().load()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("SHAPE_X", [2])
|
||||||
|
@pytest.mark.parametrize("SHAPE_Y", [64])
|
||||||
|
@pytest.mark.parametrize("SHAPE_Z", [11008])
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
|
||||||
|
def test_silu_and_mul(SHAPE_X, SHAPE_Y, SHAPE_Z, dtype):
|
||||||
|
torch.manual_seed(5)
|
||||||
|
device = get_current_device()
|
||||||
|
ref_input = torch.randn(SHAPE_X, SHAPE_Y, SHAPE_Z, dtype=dtype, device=device)
|
||||||
|
origin_input = ref_input.clone()
|
||||||
|
|
||||||
|
act_out = torch.nn.functional.silu(ref_input[0], inplace=True)
|
||||||
|
ref_out = act_out * ref_input[1]
|
||||||
|
|
||||||
|
origin_out = inference_ops.silu_and_mul(origin_input)
|
||||||
|
|
||||||
|
if dtype == torch.float32:
|
||||||
|
assert torch.allclose(origin_out, ref_out, atol=1e-5, rtol=1e-5)
|
||||||
|
else:
|
||||||
|
assert torch.allclose(origin_out, ref_out, atol=1e-3, rtol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_silu_and_mul(2, 64, 11008, torch.float32)
|
||||||
|
test_silu_and_mul(2, 64, 11008, torch.float16)
|
Loading…
Reference in New Issue