mirror of https://github.com/hpcaitech/ColossalAI
傅剑寒
7 months ago
committed by
GitHub
5 changed files with 197 additions and 10 deletions
@ -0,0 +1,127 @@
|
||||
#include <torch/extension.h> |
||||
#include <ATen/cuda/Exceptions.h> |
||||
#include <ATen/cuda/CUDAContext.h> |
||||
|
||||
#include <cmath> |
||||
|
||||
#include "common/micros.h" |
||||
#include "utils/vec_copy.h" |
||||
#include "funcs/cast_functor.h" |
||||
|
||||
|
||||
using colossalAI::cuda::utils::copy; |
||||
using colossalAI::cuda::utils::get_vec_size; |
||||
using colossalAI::funcs::CastFunctor; |
||||
|
||||
template <typename InT, typename OutT, int VecSize> |
||||
__global__ void convert_fp8_kernel(const InT* ins_data, OutT* outs_data, int numel, int tail) |
||||
{ |
||||
int64_t idx = static_cast<int64_t>(threadIdx.x) + static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x); |
||||
const int64_t grid_size = blockDim.x * gridDim.x; |
||||
if(idx > numel + tail) { |
||||
return; |
||||
} |
||||
|
||||
for(int64_t i = idx; i < numel; i += grid_size) { |
||||
copy<InT, OutT, VecSize>(ins_data + i * VecSize, outs_data + i * VecSize); |
||||
} |
||||
// Tail process |
||||
if(threadIdx.x == 0) |
||||
{ |
||||
for(int i = 0; i < tail; ++i) |
||||
{ |
||||
outs_data[i + numel * VecSize] = CastFunctor<InT, OutT>()(ins_data[i + numel * VecSize]); |
||||
} |
||||
} |
||||
} |
||||
|
||||
template <typename InT, typename OutT> |
||||
void apply_convert_fp8(torch::Tensor& input, torch::Tensor& output) |
||||
{ |
||||
const int kVecSize = get_vec_size<InT>(input); |
||||
const int kNumel = torch::numel(input); |
||||
|
||||
const int kVecNumel = (kNumel >> static_cast<int>(std::log2(kVecSize))); |
||||
const int kTail = kNumel & (kVecSize - 1); |
||||
int grid_size = kVecNumel ? (kVecNumel + 255) / 256 : 1; |
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
||||
|
||||
dim3 grid(grid_size); |
||||
dim3 block(256); |
||||
|
||||
#define _(VEC_SIZE) \ |
||||
convert_fp8_kernel<InT, OutT, VEC_SIZE> \ |
||||
<<<grid, block, 0, stream>>> \ |
||||
(reinterpret_cast<const InT*>(input.data_ptr()), \ |
||||
reinterpret_cast<OutT*>(output.data_ptr()), \ |
||||
kVecNumel, \ |
||||
kTail) |
||||
|
||||
switch (kVecSize) |
||||
{ |
||||
case 1: |
||||
_(1); |
||||
break; |
||||
case 2: |
||||
_(2); |
||||
break; |
||||
case 4: |
||||
_(4); |
||||
break; |
||||
} |
||||
#undef _ |
||||
AT_CUDA_CHECK(cudaGetLastError()); |
||||
} |
||||
|
||||
void convert_fp8(torch::Tensor& input, torch::Tensor& output) |
||||
{ |
||||
TORCH_CHECK(input.scalar_type() == at::ScalarType::Byte || output.scalar_type() == at::ScalarType::Byte, "Data type of Input or Output should be torch.uint8 for convert_fp8!"); |
||||
TORCH_CHECK(input.scalar_type() != output.scalar_type(), "Data type of input and output are the same!"); |
||||
TORCH_CHECK(input.scalar_type() == at::ScalarType::Byte || |
||||
input.scalar_type() == at::ScalarType::Float || |
||||
input.scalar_type() == at::ScalarType::Half || |
||||
input.scalar_type() == at::ScalarType::BFloat16, "Unsupported dtype of input!"); |
||||
TORCH_CHECK(output.scalar_type() == at::ScalarType::Byte || |
||||
output.scalar_type() == at::ScalarType::Float || |
||||
output.scalar_type() == at::ScalarType::Half || |
||||
output.scalar_type() == at::ScalarType::BFloat16, "Unsupported dtype of output!"); |
||||
TORCH_CHECK(input.sizes() == output.sizes(), "Shape of input and output should be the same!"); |
||||
|
||||
#define _(InT, OutT) \ |
||||
apply_convert_fp8<InT, OutT>(input, output) |
||||
|
||||
|
||||
if(input.scalar_type() == at::ScalarType::Byte) |
||||
{ |
||||
if(output.scalar_type() == at::ScalarType::Float) |
||||
{ |
||||
_(uint8_t, float); |
||||
} |
||||
else if(output.scalar_type() == at::ScalarType::Half) |
||||
{ |
||||
_(uint8_t, half); |
||||
} |
||||
else if(output.scalar_type() == at::ScalarType::BFloat16) |
||||
{ |
||||
_(uint8_t, __nv_bfloat16); |
||||
} |
||||
} |
||||
else |
||||
{ |
||||
if(input.scalar_type() == at::ScalarType::Float) |
||||
{ |
||||
_(float, uint8_t); |
||||
} |
||||
else if(input.scalar_type() == at::ScalarType::Half) |
||||
{ |
||||
_(half, uint8_t); |
||||
} |
||||
else if(input.scalar_type() == at::ScalarType::BFloat16) |
||||
{ |
||||
_(__nv_bfloat16, uint8_t); |
||||
} |
||||
} |
||||
|
||||
#undef _ |
||||
} |
@ -0,0 +1,57 @@
|
||||
import random |
||||
|
||||
import pytest |
||||
import torch |
||||
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader |
||||
from colossalai.utils import get_current_device |
||||
|
||||
inference_ops = InferenceOpsLoader().load() |
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float] |
||||
NUM_TOKENS = [42] # Arbitrary values for testing |
||||
NUM_LAYERS = [1] # Arbitrary values for testing |
||||
NUM_HEADS = [8] # Arbitrary values for testing |
||||
HEAD_SIZES = [64, 80, 96, 112, 128, 256] |
||||
BLOCK_SIZES = [8, 16, 32] |
||||
|
||||
|
||||
@pytest.mark.skipif(True, reason="FP8 conversion still needs improvement, now we skip it's relative test!") |
||||
@pytest.mark.parametrize("num_heads", [8]) |
||||
@pytest.mark.parametrize("head_size", [64, 80, 96, 112, 128, 256]) |
||||
@pytest.mark.parametrize("block_size", [8, 16, 32]) |
||||
@pytest.mark.parametrize("num_blocks", [1024, 10000]) |
||||
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16, torch.float]) |
||||
@pytest.mark.parametrize("seed", [0]) |
||||
@torch.inference_mode() |
||||
def test_fp8_conversion( |
||||
num_heads: int, |
||||
head_size: int, |
||||
block_size: int, |
||||
num_blocks: int, |
||||
dtype: torch.dtype, |
||||
seed: int, |
||||
) -> None: |
||||
random.seed(seed) |
||||
torch.random.manual_seed(seed) |
||||
torch.cuda.manual_seed(seed) |
||||
|
||||
device = get_current_device() |
||||
|
||||
low = -224.0 |
||||
high = 224.0 |
||||
shape = (num_blocks, num_heads, head_size, block_size) |
||||
cache = torch.empty(shape, dtype=dtype, device=device) |
||||
cache.uniform_(low, high) |
||||
|
||||
cache_fp8 = torch.empty_like(cache, dtype=torch.uint8) |
||||
inference_ops.convert_fp8(cache, cache_fp8) |
||||
|
||||
converted_cache = torch.empty_like(cache) |
||||
inference_ops.convert_fp8(cache_fp8, converted_cache) |
||||
|
||||
assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1) |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
test_fp8_conversion(8, 64, 8, 1024, torch.half, 0) |
Loading…
Reference in new issue