diff --git a/extensions/csrc/common/dev_info_mgr.h b/extensions/csrc/common/dev_info_mgr.h deleted file mode 100644 index 7570666ad..000000000 --- a/extensions/csrc/common/dev_info_mgr.h +++ /dev/null @@ -1,20 +0,0 @@ -#pragma once - -#include - -#include "common/nvgpu_dev_info.h" -#include "target.h" - -namespace colossalAI { -namespace common { - -template -class DevInfoMgr final { - public: - static std::unique_ptr GetDevInfo(int device_num) const { - return std::make_unique(device_num); - } -}; - -} // namespace common -} // namespace colossalAI diff --git a/extensions/csrc/common/target.h b/extensions/csrc/common/target.h index 1c8a508e3..ee3072f62 100644 --- a/extensions/csrc/common/target.h +++ b/extensions/csrc/common/target.h @@ -105,7 +105,7 @@ class Target { static Target DefaultAscendTarget(); static Target DefaultCUDATarget() { - return Target(OS::Linux, Arch::CUDA, BitLen::k64); + return Target(OS::Linux, Arch::NVGPU, BitLen::k64); } friend std::ostream& operator<<(std::ostream& os, const Target& target); diff --git a/extensions/csrc/cuda/activation_kernel.cu b/extensions/csrc/cuda/activation_kernel.cu index e9dc01753..2745e5fbd 100644 --- a/extensions/csrc/cuda/activation_kernel.cu +++ b/extensions/csrc/cuda/activation_kernel.cu @@ -4,6 +4,7 @@ #include "../common/micros.h" #include "../common/mp_type_traits.h" +#include "utils/gpu_launch_config.h" template __device__ __forceinline__ T silu_kernel(const T& x) { @@ -51,8 +52,10 @@ torch::Tensor silu_and_mul(const torch::Tensor& ins) int64_t numel = ((torch::numel(ins)) >> 1); // TODO(LiuYang): Maybe we need to implement a function to get launch config - dim3 grid((numel+255)/256); - dim3 block(256); + colossalAI::cuda::utils::NVGPUDevInfo dev_info(0); + auto config = colossalAI::cuda::utils::GetGPULaunchConfig1D(dev_info,numel,1); + dim3 grid = config.grid; + dim3 block = config.block; DISPATCH_FLOAT_HALF_AND_BFLOAT( ins.scalar_type(), diff --git a/extensions/csrc/cuda/utils/gpu_launch_config.h b/extensions/csrc/cuda/utils/gpu_launch_config.h index c7481323a..b953c6587 100644 --- a/extensions/csrc/cuda/utils/gpu_launch_config.h +++ b/extensions/csrc/cuda/utils/gpu_launch_config.h @@ -3,32 +3,74 @@ #include #include +#include "nvgpu_dev_info.h" + namespace colossalAI { namespace cuda { namespace utils { -GPULaunchConfig GPUGetGPULaunchConfig1D(int64_t numel, int vec_size); +struct GPULaunchConfig { + dim3 block{1, 1, 1}; + dim3 grid{1, 1, 1}; +}; + +static GPULaunchConfig GetGPULaunchConfig1D(const NVGPUDevInfo& dev_info, + int64_t numel, int64_t vec_size) { + const int64_t max_threads_per_block = dev_info.GetMaxThreadsPerBlock(); + const int64_t max_blocks_per_grid = dev_info.GetMaxGridDims()[0]; + const int64_t kMinimumSize = 64; + const int64_t kMaximumSize = 512; + int64_t active_threads = (numel + vec_size - 1) / vec_size; + int64_t sm_num = dev_info.GetMultiProcessorCount(); + + // Note(LiuYang): expected threads should be in [64, 128, 256, 512] generally + int64_t expected_threads_per_block = kMaximumSize; -// TODO(LiuYang): to be implemented -GPULaunchConfig GPUGetGPULaunchConfig2D(int64_t numel, int vec_size); + auto RoundUpToPowerOfTwo = [](int64_t x) { + bool is_power_of_two = false; + int64_t ret = 1; + int64_t y = x; + while (y > 0) { + is_power_of_two = ((ret ^ x) == 0); + y = (x >> 1); + ret = (ret << 1); + if (y > 0) is_power_of_two = false; + } + if (is_power_of_two) return x; + return ret; + }; -// TODO(LiuYang): to be implemented -GPULaunchConfig GPUGetGPULaunchConfig3D(int64_t numel, int vec_size); + if ((active_threads / (sm_num << 1)) < max_threads_per_block) { + expected_threads_per_block = + RoundUpToPowerOfTwo(active_threads / (sm_num << 1)); + } else if ((active_threads / (sm_num << 2)) < max_threads_per_block) { + expected_threads_per_block = + RoundUpToPowerOfTwo(active_threads / (sm_num << 2)); + } -class GPULaunchConfig { - public: - GPULaunchConfig(){}; - GPULaunchConfig(const dim3& block, const dim3& grid) - : block_(block), grid_(grid) {} - friend GPULaunchConfig GPUGetGPULaunchConfig1D(int64_t numel, int vec_size); + expected_threads_per_block = + std::max(expected_threads_per_block, kMinimumSize); + int64_t expect_block_per_grid = + ((active_threads + expected_threads_per_block - 1) / + expected_threads_per_block); - protected: - void set_block(const dim3& dim) { block_ = dim; } - void set_grid(const dim3& dim) { grid_ = dim; } + if (expect_block_per_grid > max_blocks_per_grid) { + expect_block_per_grid = max_blocks_per_grid; + expected_threads_per_block = + (active_threads + expect_block_per_grid - 1) / expect_block_per_grid; + if (expected_threads_per_block > max_threads_per_block) + throw std::invalid_argument( + "Threads required for current input exceed for current GPU!"); + expected_threads_per_block = + RoundUpToPowerOfTwo(expected_threads_per_block); + expect_block_per_grid = ((active_threads + expected_threads_per_block - 1) / + expected_threads_per_block); + } - private: - dim3 block_(1, 1, 1); - dim3 grid_(1, 1, 1); + GPULaunchConfig config; + config.block.x = expected_threads_per_block; + config.grid.x = expect_block_per_grid; + return config; } } // namespace utils diff --git a/extensions/csrc/cuda/utils/micros.h b/extensions/csrc/cuda/utils/micros.h index 9b410e3d8..8dd8be166 100644 --- a/extensions/csrc/cuda/utils/micros.h +++ b/extensions/csrc/cuda/utils/micros.h @@ -3,10 +3,12 @@ #include #include -#define CUDA_CHECK(func) \ - { \ - auto status = func; \ - if (status != cudaSuccess) { \ - LOG(FATAL) << "CUDA Error : " << cudaGetErrorString(status); \ - } \ +#include + +#define CUDA_CHECK(func) \ + { \ + auto status = func; \ + if (status != cudaSuccess) { \ + throw std::runtime_error(cudaGetErrorString(status)); \ + } \ } diff --git a/extensions/csrc/cuda/utils/nvgpu_dev_info.cc b/extensions/csrc/cuda/utils/nvgpu_dev_info.cc deleted file mode 100644 index e52abebff..000000000 --- a/extensions/csrc/cuda/utils/nvgpu_dev_info.cc +++ /dev/null @@ -1,45 +0,0 @@ -#include "nvgpu_dev_info.h" - -#include - -namespace colossalAI { -namespace cuda { -namespace utils { - -std::array NVGPUDevInfo::GetMaxGridDims() const { - std::array ret; - ret[0] = prop_->maxGridSize[0]; - ret[1] = prop_->maxGridSize[1]; - ret[2] = prop_->maxGridSize[2]; - return ret; -} - -std::array NVGPUDevInfo::GetMaxBlockDims() const { - std::array ret; - ret[0] = prop_->maxThreadsDim[0]; - ret[1] = prop_->maxThreadsDim[1]; - ret[2] = prop_->maxThreadsDim[2]; - return ret; -} - -std::array NVGPUDevInfo::GetCapability() const { - std::array ret; - ret[0] = prop_.major; - ret[1] = prop_.minor; -} - -int NVGPUDevInfo::GetMultiProcessorCount() const { - return prop_->multiProcessorCount; -} - -int NVGPUDevInfo::GetMaxThreadsPerMultiProcessor() const { - return prop_->maxThreadsPerMultiProcessor; -} - -int NVGPUDevInfo::GetMaxThreadsPerBlock() const { - return prop_->maxThreadsPerBlock; -} - -} // namespace utils -} // namespace cuda -} // namespace colossalAI diff --git a/extensions/csrc/cuda/utils/nvgpu_dev_info.h b/extensions/csrc/cuda/utils/nvgpu_dev_info.h index c8c67c908..f4c017e75 100644 --- a/extensions/csrc/cuda/utils/nvgpu_dev_info.h +++ b/extensions/csrc/cuda/utils/nvgpu_dev_info.h @@ -8,7 +8,6 @@ #include #include "micros.h" -#include "target.h" namespace colossalAI { namespace cuda { @@ -17,19 +16,43 @@ namespace utils { class NVGPUDevInfo { public: explicit NVGPUDevInfo(int device_num) : device_num_(device_num) { - CUDA_CALL(cudaGetDeviceProperties(prop_, device)); + CUDA_CHECK(cudaGetDeviceProperties(&prop_, device_num)); } - std::array GetMaxGridDims() const; - std::array GetMaxBlockDims() const; - std::array GetCapability() const; - int GetMultiProcessorCount() const; - int GetMaxThreadsPerMultiProcessor() const; - int GetMaxThreadsPerBlock() const; + std::array GetMaxGridDims() const { + std::array ret; + ret[0] = prop_.maxGridSize[0]; + ret[1] = prop_.maxGridSize[1]; + ret[2] = prop_.maxGridSize[2]; + return ret; + } + + std::array GetMaxBlockDims() const { + std::array ret; + ret[0] = prop_.maxThreadsDim[0]; + ret[1] = prop_.maxThreadsDim[1]; + ret[2] = prop_.maxThreadsDim[2]; + return ret; + } + + std::array GetCapability() const { + std::array ret; + ret[0] = prop_.major; + ret[1] = prop_.minor; + return ret; + } + + int GetMultiProcessorCount() const { return prop_.multiProcessorCount; } + + int GetMaxThreadsPerMultiProcessor() const { + return prop_.maxThreadsPerMultiProcessor; + } + + int GetMaxThreadsPerBlock() const { return prop_.maxThreadsPerBlock; } private: int device_num_; - cudaDeviceProp* prop_; + cudaDeviceProp prop_; }; } // namespace utils