mirror of https://github.com/hpcaitech/ColossalAI
[Inference/Feat] Add convert_fp8 op for fp8 test in the future (#5706)
* add convert_fp8 op for fp8 test in the future * rerun cipull/5714/head
parent
bfad39357b
commit
50104ab340
@ -0,0 +1,127 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
#include <ATen/cuda/Exceptions.h>
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#include "common/micros.h"
|
||||||
|
#include "utils/vec_copy.h"
|
||||||
|
#include "funcs/cast_functor.h"
|
||||||
|
|
||||||
|
|
||||||
|
using colossalAI::cuda::utils::copy;
|
||||||
|
using colossalAI::cuda::utils::get_vec_size;
|
||||||
|
using colossalAI::funcs::CastFunctor;
|
||||||
|
|
||||||
|
template <typename InT, typename OutT, int VecSize>
|
||||||
|
__global__ void convert_fp8_kernel(const InT* ins_data, OutT* outs_data, int numel, int tail)
|
||||||
|
{
|
||||||
|
int64_t idx = static_cast<int64_t>(threadIdx.x) + static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x);
|
||||||
|
const int64_t grid_size = blockDim.x * gridDim.x;
|
||||||
|
if(idx > numel + tail) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for(int64_t i = idx; i < numel; i += grid_size) {
|
||||||
|
copy<InT, OutT, VecSize>(ins_data + i * VecSize, outs_data + i * VecSize);
|
||||||
|
}
|
||||||
|
// Tail process
|
||||||
|
if(threadIdx.x == 0)
|
||||||
|
{
|
||||||
|
for(int i = 0; i < tail; ++i)
|
||||||
|
{
|
||||||
|
outs_data[i + numel * VecSize] = CastFunctor<InT, OutT>()(ins_data[i + numel * VecSize]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename InT, typename OutT>
|
||||||
|
void apply_convert_fp8(torch::Tensor& input, torch::Tensor& output)
|
||||||
|
{
|
||||||
|
const int kVecSize = get_vec_size<InT>(input);
|
||||||
|
const int kNumel = torch::numel(input);
|
||||||
|
|
||||||
|
const int kVecNumel = (kNumel >> static_cast<int>(std::log2(kVecSize)));
|
||||||
|
const int kTail = kNumel & (kVecSize - 1);
|
||||||
|
int grid_size = kVecNumel ? (kVecNumel + 255) / 256 : 1;
|
||||||
|
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
dim3 grid(grid_size);
|
||||||
|
dim3 block(256);
|
||||||
|
|
||||||
|
#define _(VEC_SIZE) \
|
||||||
|
convert_fp8_kernel<InT, OutT, VEC_SIZE> \
|
||||||
|
<<<grid, block, 0, stream>>> \
|
||||||
|
(reinterpret_cast<const InT*>(input.data_ptr()), \
|
||||||
|
reinterpret_cast<OutT*>(output.data_ptr()), \
|
||||||
|
kVecNumel, \
|
||||||
|
kTail)
|
||||||
|
|
||||||
|
switch (kVecSize)
|
||||||
|
{
|
||||||
|
case 1:
|
||||||
|
_(1);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
_(2);
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
_(4);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
#undef _
|
||||||
|
AT_CUDA_CHECK(cudaGetLastError());
|
||||||
|
}
|
||||||
|
|
||||||
|
void convert_fp8(torch::Tensor& input, torch::Tensor& output)
|
||||||
|
{
|
||||||
|
TORCH_CHECK(input.scalar_type() == at::ScalarType::Byte || output.scalar_type() == at::ScalarType::Byte, "Data type of Input or Output should be torch.uint8 for convert_fp8!");
|
||||||
|
TORCH_CHECK(input.scalar_type() != output.scalar_type(), "Data type of input and output are the same!");
|
||||||
|
TORCH_CHECK(input.scalar_type() == at::ScalarType::Byte ||
|
||||||
|
input.scalar_type() == at::ScalarType::Float ||
|
||||||
|
input.scalar_type() == at::ScalarType::Half ||
|
||||||
|
input.scalar_type() == at::ScalarType::BFloat16, "Unsupported dtype of input!");
|
||||||
|
TORCH_CHECK(output.scalar_type() == at::ScalarType::Byte ||
|
||||||
|
output.scalar_type() == at::ScalarType::Float ||
|
||||||
|
output.scalar_type() == at::ScalarType::Half ||
|
||||||
|
output.scalar_type() == at::ScalarType::BFloat16, "Unsupported dtype of output!");
|
||||||
|
TORCH_CHECK(input.sizes() == output.sizes(), "Shape of input and output should be the same!");
|
||||||
|
|
||||||
|
#define _(InT, OutT) \
|
||||||
|
apply_convert_fp8<InT, OutT>(input, output)
|
||||||
|
|
||||||
|
|
||||||
|
if(input.scalar_type() == at::ScalarType::Byte)
|
||||||
|
{
|
||||||
|
if(output.scalar_type() == at::ScalarType::Float)
|
||||||
|
{
|
||||||
|
_(uint8_t, float);
|
||||||
|
}
|
||||||
|
else if(output.scalar_type() == at::ScalarType::Half)
|
||||||
|
{
|
||||||
|
_(uint8_t, half);
|
||||||
|
}
|
||||||
|
else if(output.scalar_type() == at::ScalarType::BFloat16)
|
||||||
|
{
|
||||||
|
_(uint8_t, __nv_bfloat16);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
if(input.scalar_type() == at::ScalarType::Float)
|
||||||
|
{
|
||||||
|
_(float, uint8_t);
|
||||||
|
}
|
||||||
|
else if(input.scalar_type() == at::ScalarType::Half)
|
||||||
|
{
|
||||||
|
_(half, uint8_t);
|
||||||
|
}
|
||||||
|
else if(input.scalar_type() == at::ScalarType::BFloat16)
|
||||||
|
{
|
||||||
|
_(__nv_bfloat16, uint8_t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#undef _
|
||||||
|
}
|
@ -0,0 +1,57 @@
|
|||||||
|
import random
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
inference_ops = InferenceOpsLoader().load()
|
||||||
|
|
||||||
|
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||||
|
NUM_TOKENS = [42] # Arbitrary values for testing
|
||||||
|
NUM_LAYERS = [1] # Arbitrary values for testing
|
||||||
|
NUM_HEADS = [8] # Arbitrary values for testing
|
||||||
|
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||||
|
BLOCK_SIZES = [8, 16, 32]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(True, reason="FP8 conversion still needs improvement, now we skip it's relative test!")
|
||||||
|
@pytest.mark.parametrize("num_heads", [8])
|
||||||
|
@pytest.mark.parametrize("head_size", [64, 80, 96, 112, 128, 256])
|
||||||
|
@pytest.mark.parametrize("block_size", [8, 16, 32])
|
||||||
|
@pytest.mark.parametrize("num_blocks", [1024, 10000])
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16, torch.float])
|
||||||
|
@pytest.mark.parametrize("seed", [0])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_fp8_conversion(
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
block_size: int,
|
||||||
|
num_blocks: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
seed: int,
|
||||||
|
) -> None:
|
||||||
|
random.seed(seed)
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
device = get_current_device()
|
||||||
|
|
||||||
|
low = -224.0
|
||||||
|
high = 224.0
|
||||||
|
shape = (num_blocks, num_heads, head_size, block_size)
|
||||||
|
cache = torch.empty(shape, dtype=dtype, device=device)
|
||||||
|
cache.uniform_(low, high)
|
||||||
|
|
||||||
|
cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
|
||||||
|
inference_ops.convert_fp8(cache, cache_fp8)
|
||||||
|
|
||||||
|
converted_cache = torch.empty_like(cache)
|
||||||
|
inference_ops.convert_fp8(cache_fp8, converted_cache)
|
||||||
|
|
||||||
|
assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_fp8_conversion(8, 64, 8, 1024, torch.half, 0)
|
Loading…
Reference in new issue