diff --git a/extensions/csrc/common/dev_info_mgr.h b/extensions/csrc/common/dev_info_mgr.h deleted file mode 100644 index 7570666ad..000000000 --- a/extensions/csrc/common/dev_info_mgr.h +++ /dev/null @@ -1,20 +0,0 @@ -#pragma once - -#include - -#include "common/nvgpu_dev_info.h" -#include "target.h" - -namespace colossalAI { -namespace common { - -template -class DevInfoMgr final { - public: - static std::unique_ptr GetDevInfo(int device_num) const { - return std::make_unique(device_num); - } -}; - -} // namespace common -} // namespace colossalAI diff --git a/extensions/csrc/common/mp_type_traits.h b/extensions/csrc/common/mp_type_traits.h index 8ede2d448..2a767620a 100644 --- a/extensions/csrc/common/mp_type_traits.h +++ b/extensions/csrc/common/mp_type_traits.h @@ -8,26 +8,22 @@ namespace colossalAI { namespace common { template -class MPTypeTrait { - public: +struct MPTypeTrait { using Type = float; }; template <> -class MPTypeTrait { - public: +struct MPTypeTrait { using Type = float; }; template <> -class MPTypeTrait { - public: +struct MPTypeTrait { using Type = float; }; template <> -class MPTypeTrait { - public: +struct MPTypeTrait { using Type = float; }; diff --git a/extensions/csrc/common/target.h b/extensions/csrc/common/target.h index 1c8a508e3..ee3072f62 100644 --- a/extensions/csrc/common/target.h +++ b/extensions/csrc/common/target.h @@ -105,7 +105,7 @@ class Target { static Target DefaultAscendTarget(); static Target DefaultCUDATarget() { - return Target(OS::Linux, Arch::CUDA, BitLen::k64); + return Target(OS::Linux, Arch::NVGPU, BitLen::k64); } friend std::ostream& operator<<(std::ostream& os, const Target& target); diff --git a/extensions/csrc/common/vector_copy_utils.h b/extensions/csrc/common/vector_copy_utils.h deleted file mode 100644 index 456440cf6..000000000 --- a/extensions/csrc/common/vector_copy_utils.h +++ /dev/null @@ -1,98 +0,0 @@ - -#include -#include - -#include - -#include "string" - -template -__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); - -template <> -__device__ __inline__ void copy_vector( - c10::BFloat16 *dst, const c10::BFloat16 *src) { - *dst = *src; -} - -template <> -__device__ __inline__ void copy_vector( - c10::BFloat16 *dst, const c10::BFloat16 *src) { - *((float *)dst) = *((float *)src); -} - -template <> -__device__ __inline__ void copy_vector( - c10::BFloat16 *dst, const c10::BFloat16 *src) { - *((float2 *)dst) = *((float2 *)src); -} - -template <> -__device__ __inline__ void copy_vector( - c10::BFloat16 *dst, const c10::BFloat16 *src) { - *((float4 *)dst) = *((float4 *)src); -} - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, - const c10::Half *src) { - *dst = *src; -} - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, - const c10::Half *src) { - *((float *)dst) = *((float *)src); -} - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, - const c10::Half *src) { - *((float2 *)dst) = *((float2 *)src); -} - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, - const c10::Half *src) { - *((float4 *)dst) = *((float4 *)src); -} - -template <> -__device__ __inline__ void copy_vector(float *dst, const float *src) { - *dst = *src; -} - -template <> -__device__ __inline__ void copy_vector(float *dst, const float *src) { - *((float2 *)dst) = *((float2 *)src); -} - -template <> -__device__ __inline__ void copy_vector(float *dst, const float *src) { - *((float4 *)dst) = *((float4 *)src); -} - -template <> -__device__ __inline__ void copy_vector(float *dst, const float *src) { - // Since the maximum memory alignment length is 128 bits, we choose float4 - // here. - *((float4 *)dst) = *((float4 *)src); - *((float4 *)(dst + 4)) = *((float4 *)(src + 4)); -} - -template -int get_vec_size(const torch::Tensor &tensor) { - uint64_t address = reinterpret_cast(tensor.data_ptr()); - const int max_aligned_size = 128; - const int dtype_size = sizeof(T) * 8; - - const int vec_size = max_aligned_size / sizeof(T) / 8; - - if (address % (dtype_size * 4) == 0) { - return std::min(4, vec_size); - } else if (address % (dtype_size * 2) == 0) { - return std::min(2, vec_size); - } else { - return 1; - } -} diff --git a/extensions/csrc/cuda/activation_kernel.cu b/extensions/csrc/cuda/activation_kernel.cu index e9dc01753..372b30387 100644 --- a/extensions/csrc/cuda/activation_kernel.cu +++ b/extensions/csrc/cuda/activation_kernel.cu @@ -36,6 +36,8 @@ __global__ void act_and_mul_kernel( // silu(x[:half_1stdim]) * (x[half_1stdim:]) torch::Tensor silu_and_mul(const torch::Tensor& ins) { + // Note(LiuYang): According to torch doc, vec() may cost a lot, but I did't find a better api + // to manipulate ins_shape which is IntArrayRef auto ins_shape = ins.sizes().vec(); ins_shape[0] = ins_shape[0]/2; @@ -43,14 +45,19 @@ torch::Tensor silu_and_mul(const torch::Tensor& ins) ins_shape.erase(ins_shape.begin()); } auto outs = torch::zeros(ins_shape,ins.options()); - auto outs_shape = ins.sizes().vec(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // Note(Liuyang): numel of ins must be divisible by 2 int64_t numel = ((torch::numel(ins)) >> 1); - // TODO(LiuYang): Maybe we need to implement a function to get launch config + // Note(LiuYang): For better performance for special case of which input is [2, 64, 11008], now + // I comment this part codeļ¼Œbecause it also cost a little time to calculate a better config + // colossalAI::cuda::utils::NVGPUDevInfo dev_info(0); + // auto config = colossalAI::cuda::utils::GetGPULaunchConfig1D(dev_info,numel,1); + // dim3 grid = config.grid; + // dim3 block = config.block; + dim3 grid((numel+255)/256); dim3 block(256); diff --git a/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu b/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu index 7eb44ecd0..3b1197a91 100644 --- a/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu +++ b/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu @@ -1,7 +1,7 @@ #include #include -#include "../common/vector_copy_utils.h" +#include "utils/vector_copy_utils.h" #include "../common/micros.h" template diff --git a/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu b/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu index c1db06d3f..697dc7110 100644 --- a/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu +++ b/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu @@ -2,7 +2,7 @@ #include #include -#include "../common/vector_copy_utils.h" +#include "utils/vector_copy_utils.h" #include "../common/micros.h" template diff --git a/extensions/csrc/cuda/rms_layernorm_kernel.cu b/extensions/csrc/cuda/rms_layernorm_kernel.cu index 8b250cb10..50f26510e 100644 --- a/extensions/csrc/cuda/rms_layernorm_kernel.cu +++ b/extensions/csrc/cuda/rms_layernorm_kernel.cu @@ -10,7 +10,7 @@ #include "block_reduce.h" #include "../common/micros.h" -#include "../common/cuda_type_utils.h" +#include "utils/cuda_type_utils.h" #define DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(DATA_SIZE, TYPE, NAME, ...) \ if (DATA_SIZE == 2) { \ diff --git a/extensions/csrc/cuda/scaled_masked_softmax.h b/extensions/csrc/cuda/scaled_masked_softmax.h index d3e6f04e6..cbbe7f36a 100644 --- a/extensions/csrc/cuda/scaled_masked_softmax.h +++ b/extensions/csrc/cuda/scaled_masked_softmax.h @@ -6,52 +6,14 @@ #include #include #include -#include #include #include +#include "utils/vector_copy_utils.h" + namespace { -template -__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); - -template <> -__device__ __inline__ void copy_vector( - c10::BFloat16 *dst, const c10::BFloat16 *src) { - *dst = *src; -} - -template <> -__device__ __inline__ void copy_vector( - c10::BFloat16 *dst, const c10::BFloat16 *src) { - *((float2 *)dst) = *((float2 *)src); -} - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, - const c10::Half *src) { - *dst = *src; -} - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, - const c10::Half *src) { - *((float2 *)dst) = *((float2 *)src); -} - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, - const uint8_t *src) { - *dst = *src; -} - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, - const uint8_t *src) { - *((half2 *)dst) = *((half2 *)src); -} - int log2_ceil(int value) { int log2_value = 0; while ((1 << log2_value) < value) ++log2_value; diff --git a/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h b/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h index 54c8e9133..524ef46c6 100644 --- a/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h +++ b/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h @@ -13,70 +13,6 @@ namespace { -template -__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); - -template <> -__device__ __inline__ void copy_vector( - c10::BFloat16 *dst, const c10::BFloat16 *src) { - *dst = *src; -} - -template <> -__device__ __inline__ void copy_vector( - c10::BFloat16 *dst, const c10::BFloat16 *src) { - *((float2 *)dst) = *((float2 *)src); -} - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, - const c10::Half *src) { - *dst = *src; -} - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, - const c10::Half *src) { - *((float2 *)dst) = *((float2 *)src); -} - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, - const uint8_t *src) { - *dst = *src; -} - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, - const uint8_t *src) { - *((half2 *)dst) = *((half2 *)src); -} - -template -__device__ __inline__ void copy_zero_vector(Datatype *dst); - -template <> -__device__ __inline__ void copy_zero_vector( - c10::BFloat16 *dst) { - *dst = 0.0; -} - -template <> -__device__ __inline__ void copy_zero_vector( - c10::BFloat16 *dst) { - *((float2 *)dst) = make_float2(0.0f, 0.0f); -} - -template <> -__device__ __inline__ void copy_zero_vector(c10::Half *dst) { - *dst = 0.0; -} - -template <> -__device__ __inline__ void copy_zero_vector(c10::Half *dst) { - *((float2 *)dst) = make_float2(0.0f, 0.0f); -} - int log2_ceil(int value) { int log2_value = 0; while ((1 << log2_value) < value) ++log2_value; diff --git a/extensions/csrc/common/cuda_type_utils.h b/extensions/csrc/cuda/utils/cuda_type_utils.h similarity index 100% rename from extensions/csrc/common/cuda_type_utils.h rename to extensions/csrc/cuda/utils/cuda_type_utils.h diff --git a/extensions/csrc/cuda/utils/gpu_launch_config.h b/extensions/csrc/cuda/utils/gpu_launch_config.h index c7481323a..b953c6587 100644 --- a/extensions/csrc/cuda/utils/gpu_launch_config.h +++ b/extensions/csrc/cuda/utils/gpu_launch_config.h @@ -3,32 +3,74 @@ #include #include +#include "nvgpu_dev_info.h" + namespace colossalAI { namespace cuda { namespace utils { -GPULaunchConfig GPUGetGPULaunchConfig1D(int64_t numel, int vec_size); +struct GPULaunchConfig { + dim3 block{1, 1, 1}; + dim3 grid{1, 1, 1}; +}; -// TODO(LiuYang): to be implemented -GPULaunchConfig GPUGetGPULaunchConfig2D(int64_t numel, int vec_size); +static GPULaunchConfig GetGPULaunchConfig1D(const NVGPUDevInfo& dev_info, + int64_t numel, int64_t vec_size) { + const int64_t max_threads_per_block = dev_info.GetMaxThreadsPerBlock(); + const int64_t max_blocks_per_grid = dev_info.GetMaxGridDims()[0]; + const int64_t kMinimumSize = 64; + const int64_t kMaximumSize = 512; + int64_t active_threads = (numel + vec_size - 1) / vec_size; + int64_t sm_num = dev_info.GetMultiProcessorCount(); -// TODO(LiuYang): to be implemented -GPULaunchConfig GPUGetGPULaunchConfig3D(int64_t numel, int vec_size); + // Note(LiuYang): expected threads should be in [64, 128, 256, 512] generally + int64_t expected_threads_per_block = kMaximumSize; -class GPULaunchConfig { - public: - GPULaunchConfig(){}; - GPULaunchConfig(const dim3& block, const dim3& grid) - : block_(block), grid_(grid) {} - friend GPULaunchConfig GPUGetGPULaunchConfig1D(int64_t numel, int vec_size); + auto RoundUpToPowerOfTwo = [](int64_t x) { + bool is_power_of_two = false; + int64_t ret = 1; + int64_t y = x; + while (y > 0) { + is_power_of_two = ((ret ^ x) == 0); + y = (x >> 1); + ret = (ret << 1); + if (y > 0) is_power_of_two = false; + } + if (is_power_of_two) return x; + return ret; + }; - protected: - void set_block(const dim3& dim) { block_ = dim; } - void set_grid(const dim3& dim) { grid_ = dim; } + if ((active_threads / (sm_num << 1)) < max_threads_per_block) { + expected_threads_per_block = + RoundUpToPowerOfTwo(active_threads / (sm_num << 1)); + } else if ((active_threads / (sm_num << 2)) < max_threads_per_block) { + expected_threads_per_block = + RoundUpToPowerOfTwo(active_threads / (sm_num << 2)); + } - private: - dim3 block_(1, 1, 1); - dim3 grid_(1, 1, 1); + expected_threads_per_block = + std::max(expected_threads_per_block, kMinimumSize); + int64_t expect_block_per_grid = + ((active_threads + expected_threads_per_block - 1) / + expected_threads_per_block); + + if (expect_block_per_grid > max_blocks_per_grid) { + expect_block_per_grid = max_blocks_per_grid; + expected_threads_per_block = + (active_threads + expect_block_per_grid - 1) / expect_block_per_grid; + if (expected_threads_per_block > max_threads_per_block) + throw std::invalid_argument( + "Threads required for current input exceed for current GPU!"); + expected_threads_per_block = + RoundUpToPowerOfTwo(expected_threads_per_block); + expect_block_per_grid = ((active_threads + expected_threads_per_block - 1) / + expected_threads_per_block); + } + + GPULaunchConfig config; + config.block.x = expected_threads_per_block; + config.grid.x = expect_block_per_grid; + return config; } } // namespace utils diff --git a/extensions/csrc/cuda/utils/micros.h b/extensions/csrc/cuda/utils/micros.h index 9b410e3d8..8dd8be166 100644 --- a/extensions/csrc/cuda/utils/micros.h +++ b/extensions/csrc/cuda/utils/micros.h @@ -3,10 +3,12 @@ #include #include -#define CUDA_CHECK(func) \ - { \ - auto status = func; \ - if (status != cudaSuccess) { \ - LOG(FATAL) << "CUDA Error : " << cudaGetErrorString(status); \ - } \ +#include + +#define CUDA_CHECK(func) \ + { \ + auto status = func; \ + if (status != cudaSuccess) { \ + throw std::runtime_error(cudaGetErrorString(status)); \ + } \ } diff --git a/extensions/csrc/cuda/utils/nvgpu_dev_info.cc b/extensions/csrc/cuda/utils/nvgpu_dev_info.cc deleted file mode 100644 index e52abebff..000000000 --- a/extensions/csrc/cuda/utils/nvgpu_dev_info.cc +++ /dev/null @@ -1,45 +0,0 @@ -#include "nvgpu_dev_info.h" - -#include - -namespace colossalAI { -namespace cuda { -namespace utils { - -std::array NVGPUDevInfo::GetMaxGridDims() const { - std::array ret; - ret[0] = prop_->maxGridSize[0]; - ret[1] = prop_->maxGridSize[1]; - ret[2] = prop_->maxGridSize[2]; - return ret; -} - -std::array NVGPUDevInfo::GetMaxBlockDims() const { - std::array ret; - ret[0] = prop_->maxThreadsDim[0]; - ret[1] = prop_->maxThreadsDim[1]; - ret[2] = prop_->maxThreadsDim[2]; - return ret; -} - -std::array NVGPUDevInfo::GetCapability() const { - std::array ret; - ret[0] = prop_.major; - ret[1] = prop_.minor; -} - -int NVGPUDevInfo::GetMultiProcessorCount() const { - return prop_->multiProcessorCount; -} - -int NVGPUDevInfo::GetMaxThreadsPerMultiProcessor() const { - return prop_->maxThreadsPerMultiProcessor; -} - -int NVGPUDevInfo::GetMaxThreadsPerBlock() const { - return prop_->maxThreadsPerBlock; -} - -} // namespace utils -} // namespace cuda -} // namespace colossalAI diff --git a/extensions/csrc/cuda/utils/nvgpu_dev_info.h b/extensions/csrc/cuda/utils/nvgpu_dev_info.h index c8c67c908..f4c017e75 100644 --- a/extensions/csrc/cuda/utils/nvgpu_dev_info.h +++ b/extensions/csrc/cuda/utils/nvgpu_dev_info.h @@ -8,7 +8,6 @@ #include #include "micros.h" -#include "target.h" namespace colossalAI { namespace cuda { @@ -17,19 +16,43 @@ namespace utils { class NVGPUDevInfo { public: explicit NVGPUDevInfo(int device_num) : device_num_(device_num) { - CUDA_CALL(cudaGetDeviceProperties(prop_, device)); + CUDA_CHECK(cudaGetDeviceProperties(&prop_, device_num)); } - std::array GetMaxGridDims() const; - std::array GetMaxBlockDims() const; - std::array GetCapability() const; - int GetMultiProcessorCount() const; - int GetMaxThreadsPerMultiProcessor() const; - int GetMaxThreadsPerBlock() const; + std::array GetMaxGridDims() const { + std::array ret; + ret[0] = prop_.maxGridSize[0]; + ret[1] = prop_.maxGridSize[1]; + ret[2] = prop_.maxGridSize[2]; + return ret; + } + + std::array GetMaxBlockDims() const { + std::array ret; + ret[0] = prop_.maxThreadsDim[0]; + ret[1] = prop_.maxThreadsDim[1]; + ret[2] = prop_.maxThreadsDim[2]; + return ret; + } + + std::array GetCapability() const { + std::array ret; + ret[0] = prop_.major; + ret[1] = prop_.minor; + return ret; + } + + int GetMultiProcessorCount() const { return prop_.multiProcessorCount; } + + int GetMaxThreadsPerMultiProcessor() const { + return prop_.maxThreadsPerMultiProcessor; + } + + int GetMaxThreadsPerBlock() const { return prop_.maxThreadsPerBlock; } private: int device_num_; - cudaDeviceProp* prop_; + cudaDeviceProp prop_; }; } // namespace utils diff --git a/extensions/csrc/cuda/utils/vec_type_traits.h b/extensions/csrc/cuda/utils/vec_type_traits.h new file mode 100644 index 000000000..3ddd64df9 --- /dev/null +++ b/extensions/csrc/cuda/utils/vec_type_traits.h @@ -0,0 +1,83 @@ +#pragma once + +#include +#include +#include + +#include + +namespace colossalAI { +namespace cuda { +namespace utils { + +template +struct VecTypeTrait {}; + +template +struct VecTypeTrait { + using Type = T; +}; + +template <> +struct VecTypeTrait { + using Type = float; +}; + +template <> +struct VecTypeTrait { + using Type = float2; +}; + +template <> +struct VecTypeTrait { + using Type = float4; +}; + +template <> +struct VecTypeTrait { + using Type = float; +}; + +template <> +struct VecTypeTrait { + using Type = float2; +}; + +template <> +struct VecTypeTrait { + using Type = float4; +}; + +template <> +struct VecTypeTrait { + using Type = float2; +}; + +template <> +struct VecTypeTrait { + using Type = float4; +}; + +template <> +struct VecTypeTrait { + using Type = float4; +}; + +template <> +struct VecTypeTrait { + using Type = half; +}; + +template <> +struct VecTypeTrait { + using Type = half2; +}; + +template <> +struct VecTypeTrait { + using Type = float2; +}; + +} // namespace utils +} // namespace cuda +} // namespace colossalAI diff --git a/extensions/csrc/cuda/utils/vector_copy_utils.h b/extensions/csrc/cuda/utils/vector_copy_utils.h new file mode 100644 index 000000000..3c3afa0b3 --- /dev/null +++ b/extensions/csrc/cuda/utils/vector_copy_utils.h @@ -0,0 +1,52 @@ + +#pragma once + +#include +#include +#include + +#include "vec_type_traits.h" + +template +__device__ __inline__ void copy_vector(T *dst, const T *src) { + using VT = typename colossalAI::cuda::utils::VecTypeTrait::Type; + // Note(LiuYang): Here static_cast can't be used for cast between two pointer + *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); +} + +template <> +__device__ __inline__ void copy_vector(float *dst, const float *src) { + // Since the maximum memory alignment length is 128 bits, we choose float4 + // here. + *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); + *(reinterpret_cast(dst + 4)) = + *(reinterpret_cast(src + 4)); +} + +template +__device__ __inline__ void copy_zero_vector(T *dst) { + using VT = typename colossalAI::cuda::utils::VecTypeTrait::Type; + *(reinterpret_cast(dst)) = {0.0}; +} + +template +int get_vec_size(const torch::Tensor &tensor) { + uint64_t address = reinterpret_cast(tensor.data_ptr()); + const int max_aligned_size = 128; + const int dtype_size = sizeof(T) * 8; + + const int vec_size = max_aligned_size / sizeof(T) / 8; + + // Note(LiuYang): Performance of situation of which + // vec_size equals to 8 need to be profiled in the future + // if (address % (dtype_size * 8) == 0) { + // return std::min(8, vec_size); + // } + if (address % (dtype_size * 4) == 0) { + return std::min(4, vec_size); + } else if (address % (dtype_size * 2) == 0) { + return std::min(2, vec_size); + } else { + return 1; + } +}