mirror of https://github.com/hpcaitech/ColossalAI
Browse Source
* refactor compilation mechanism and unified multi hw * fix file path bug * add init.py to make pybind a module to avoid relative path error caused by softlink * delete duplicated micros * fix micros bug in gccpull/5650/head
傅剑寒
7 months ago
committed by
GitHub
64 changed files with 345 additions and 310 deletions
@ -0,0 +1,60 @@
|
||||
#pragma once |
||||
|
||||
#if defined(COLOSSAL_WITH_CUDA) |
||||
#include <cuda_bf16.h> |
||||
#include <cuda_fp16.h> |
||||
#endif |
||||
|
||||
namespace colossalAI { |
||||
namespace dtype { |
||||
|
||||
struct bfloat164 { |
||||
#ifdef COLOSSAL_WITH_CUDA |
||||
__nv_bfloat162 x; |
||||
__nv_bfloat162 y; |
||||
#endif |
||||
}; |
||||
|
||||
struct bfloat168 { |
||||
#ifdef COLOSSAL_WITH_CUDA |
||||
__nv_bfloat162 x; |
||||
__nv_bfloat162 y; |
||||
__nv_bfloat162 z; |
||||
__nv_bfloat162 w; |
||||
#endif |
||||
}; |
||||
|
||||
struct half4 { |
||||
#ifdef COLOSSAL_WITH_CUDA |
||||
half2 x; |
||||
half2 y; |
||||
#endif |
||||
}; |
||||
|
||||
struct half8 { |
||||
#ifdef COLOSSAL_WITH_CUDA |
||||
half2 x; |
||||
half2 y; |
||||
half2 z; |
||||
half2 w; |
||||
#endif |
||||
}; |
||||
|
||||
struct float4_ { |
||||
#ifdef COLOSSAL_WITH_CUDA |
||||
float2 x; |
||||
float2 y; |
||||
#endif |
||||
}; |
||||
|
||||
struct float8_ { |
||||
#ifdef COLOSSAL_WITH_CUDA |
||||
float2 x; |
||||
float2 y; |
||||
float2 z; |
||||
float2 w; |
||||
#endif |
||||
}; |
||||
|
||||
} // namespace dtype
|
||||
} // namespace colossalAI
|
@ -1,36 +0,0 @@
|
||||
from ..cuda_extension import _CudaExtension |
||||
from ..utils import get_cuda_cc_flag |
||||
|
||||
|
||||
class InferenceOpsCudaExtension(_CudaExtension): |
||||
def __init__(self): |
||||
super().__init__(name="inference_ops_cuda") |
||||
|
||||
def sources_files(self): |
||||
ret = [ |
||||
self.csrc_abs_path(fname) |
||||
for fname in [ |
||||
"cuda/pybind/inference.cpp", |
||||
"cuda/decode_kv_cache_memcpy_kernel.cu", |
||||
"cuda/context_kv_cache_memcpy_kernel.cu", |
||||
"cuda/fused_rotary_emb_and_cache_kernel.cu", |
||||
"cuda/activation_kernel.cu", |
||||
"cuda/rms_layernorm_kernel.cu", |
||||
"cuda/get_cos_and_sin_kernel.cu", |
||||
"cuda/flash_decoding_attention_kernel.cu", |
||||
] |
||||
] |
||||
return ret |
||||
|
||||
def include_dirs(self): |
||||
ret = [self.csrc_abs_path("cuda/include"), self.get_cuda_home_include()] |
||||
return ret |
||||
|
||||
def cxx_flags(self): |
||||
version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] |
||||
return ["-O3"] + version_dependent_macros |
||||
|
||||
def nvcc_flags(self): |
||||
extra_cuda_flags = ["-lineinfo"] |
||||
extra_cuda_flags.extend(get_cuda_cc_flag()) |
||||
return ["-O3", "--use_fast_math"] + extra_cuda_flags |
@ -1,4 +1,4 @@
|
||||
from ..base_extension import _Extension |
||||
from ...base_extension import _Extension |
||||
|
||||
|
||||
class FlashAttentionDaoCudaExtension(_Extension): |
@ -1,4 +1,4 @@
|
||||
from ..base_extension import _Extension |
||||
from ...base_extension import _Extension |
||||
|
||||
|
||||
class FlashAttentionNpuExtension(_Extension): |
@ -1,4 +1,4 @@
|
||||
from ..base_extension import _Extension |
||||
from ...base_extension import _Extension |
||||
|
||||
|
||||
class FlashAttentionSdpaCudaExtension(_Extension): |
@ -0,0 +1,31 @@
|
||||
from ...cuda_extension import _CudaExtension |
||||
from ...utils import get_cuda_cc_flag |
||||
|
||||
|
||||
class InferenceOpsCudaExtension(_CudaExtension): |
||||
def __init__(self): |
||||
super().__init__(name="inference_ops_cuda") |
||||
|
||||
def sources_files(self): |
||||
ret = [ |
||||
self.csrc_abs_path(fname) |
||||
for fname in [ |
||||
"kernel/cuda/decode_kv_cache_memcpy_kernel.cu", |
||||
"kernel/cuda/context_kv_cache_memcpy_kernel.cu", |
||||
"kernel/cuda/fused_rotary_emb_and_cache_kernel.cu", |
||||
"kernel/cuda/activation_kernel.cu", |
||||
"kernel/cuda/rms_layernorm_kernel.cu", |
||||
"kernel/cuda/get_cos_and_sin_kernel.cu", |
||||
"kernel/cuda/flash_decoding_attention_kernel.cu", |
||||
] |
||||
] + [self.pybind_abs_path("inference/inference.cpp")] |
||||
return ret |
||||
|
||||
def cxx_flags(self): |
||||
version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] |
||||
return ["-O3"] + version_dependent_macros |
||||
|
||||
def nvcc_flags(self): |
||||
extra_cuda_flags = ["-lineinfo"] |
||||
extra_cuda_flags.extend(get_cuda_cc_flag()) |
||||
return ["-O3", "--use_fast_math"] + extra_cuda_flags + super().nvcc_flags() |
Loading…
Reference in new issue