mirror of https://github.com/hpcaitech/ColossalAI
add implementatino for GetGPULaunchConfig1D
parent
f366a5ea1f
commit
388e043930
|
@ -1,20 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
#include "common/nvgpu_dev_info.h"
|
|
||||||
#include "target.h"
|
|
||||||
|
|
||||||
namespace colossalAI {
|
|
||||||
namespace common {
|
|
||||||
|
|
||||||
template <typename Ret>
|
|
||||||
class DevInfoMgr final {
|
|
||||||
public:
|
|
||||||
static std::unique_ptr<Ret> GetDevInfo(int device_num) const {
|
|
||||||
return std::make_unique<Ret>(device_num);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace common
|
|
||||||
} // namespace colossalAI
|
|
|
@ -105,7 +105,7 @@ class Target {
|
||||||
static Target DefaultAscendTarget();
|
static Target DefaultAscendTarget();
|
||||||
|
|
||||||
static Target DefaultCUDATarget() {
|
static Target DefaultCUDATarget() {
|
||||||
return Target(OS::Linux, Arch::CUDA, BitLen::k64);
|
return Target(OS::Linux, Arch::NVGPU, BitLen::k64);
|
||||||
}
|
}
|
||||||
|
|
||||||
friend std::ostream& operator<<(std::ostream& os, const Target& target);
|
friend std::ostream& operator<<(std::ostream& os, const Target& target);
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
|
|
||||||
#include "../common/micros.h"
|
#include "../common/micros.h"
|
||||||
#include "../common/mp_type_traits.h"
|
#include "../common/mp_type_traits.h"
|
||||||
|
#include "utils/gpu_launch_config.h"
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
__device__ __forceinline__ T silu_kernel(const T& x) {
|
__device__ __forceinline__ T silu_kernel(const T& x) {
|
||||||
|
@ -51,8 +52,10 @@ torch::Tensor silu_and_mul(const torch::Tensor& ins)
|
||||||
int64_t numel = ((torch::numel(ins)) >> 1);
|
int64_t numel = ((torch::numel(ins)) >> 1);
|
||||||
|
|
||||||
// TODO(LiuYang): Maybe we need to implement a function to get launch config
|
// TODO(LiuYang): Maybe we need to implement a function to get launch config
|
||||||
dim3 grid((numel+255)/256);
|
colossalAI::cuda::utils::NVGPUDevInfo dev_info(0);
|
||||||
dim3 block(256);
|
auto config = colossalAI::cuda::utils::GetGPULaunchConfig1D(dev_info,numel,1);
|
||||||
|
dim3 grid = config.grid;
|
||||||
|
dim3 block = config.block;
|
||||||
|
|
||||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||||
ins.scalar_type(),
|
ins.scalar_type(),
|
||||||
|
|
|
@ -3,32 +3,74 @@
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
|
#include "nvgpu_dev_info.h"
|
||||||
|
|
||||||
namespace colossalAI {
|
namespace colossalAI {
|
||||||
namespace cuda {
|
namespace cuda {
|
||||||
namespace utils {
|
namespace utils {
|
||||||
|
|
||||||
GPULaunchConfig GPUGetGPULaunchConfig1D(int64_t numel, int vec_size);
|
struct GPULaunchConfig {
|
||||||
|
dim3 block{1, 1, 1};
|
||||||
|
dim3 grid{1, 1, 1};
|
||||||
|
};
|
||||||
|
|
||||||
// TODO(LiuYang): to be implemented
|
static GPULaunchConfig GetGPULaunchConfig1D(const NVGPUDevInfo& dev_info,
|
||||||
GPULaunchConfig GPUGetGPULaunchConfig2D(int64_t numel, int vec_size);
|
int64_t numel, int64_t vec_size) {
|
||||||
|
const int64_t max_threads_per_block = dev_info.GetMaxThreadsPerBlock();
|
||||||
|
const int64_t max_blocks_per_grid = dev_info.GetMaxGridDims()[0];
|
||||||
|
const int64_t kMinimumSize = 64;
|
||||||
|
const int64_t kMaximumSize = 512;
|
||||||
|
int64_t active_threads = (numel + vec_size - 1) / vec_size;
|
||||||
|
int64_t sm_num = dev_info.GetMultiProcessorCount();
|
||||||
|
|
||||||
// TODO(LiuYang): to be implemented
|
// Note(LiuYang): expected threads should be in [64, 128, 256, 512] generally
|
||||||
GPULaunchConfig GPUGetGPULaunchConfig3D(int64_t numel, int vec_size);
|
int64_t expected_threads_per_block = kMaximumSize;
|
||||||
|
|
||||||
class GPULaunchConfig {
|
auto RoundUpToPowerOfTwo = [](int64_t x) {
|
||||||
public:
|
bool is_power_of_two = false;
|
||||||
GPULaunchConfig(){};
|
int64_t ret = 1;
|
||||||
GPULaunchConfig(const dim3& block, const dim3& grid)
|
int64_t y = x;
|
||||||
: block_(block), grid_(grid) {}
|
while (y > 0) {
|
||||||
friend GPULaunchConfig GPUGetGPULaunchConfig1D(int64_t numel, int vec_size);
|
is_power_of_two = ((ret ^ x) == 0);
|
||||||
|
y = (x >> 1);
|
||||||
|
ret = (ret << 1);
|
||||||
|
if (y > 0) is_power_of_two = false;
|
||||||
|
}
|
||||||
|
if (is_power_of_two) return x;
|
||||||
|
return ret;
|
||||||
|
};
|
||||||
|
|
||||||
protected:
|
if ((active_threads / (sm_num << 1)) < max_threads_per_block) {
|
||||||
void set_block(const dim3& dim) { block_ = dim; }
|
expected_threads_per_block =
|
||||||
void set_grid(const dim3& dim) { grid_ = dim; }
|
RoundUpToPowerOfTwo(active_threads / (sm_num << 1));
|
||||||
|
} else if ((active_threads / (sm_num << 2)) < max_threads_per_block) {
|
||||||
|
expected_threads_per_block =
|
||||||
|
RoundUpToPowerOfTwo(active_threads / (sm_num << 2));
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
expected_threads_per_block =
|
||||||
dim3 block_(1, 1, 1);
|
std::max(expected_threads_per_block, kMinimumSize);
|
||||||
dim3 grid_(1, 1, 1);
|
int64_t expect_block_per_grid =
|
||||||
|
((active_threads + expected_threads_per_block - 1) /
|
||||||
|
expected_threads_per_block);
|
||||||
|
|
||||||
|
if (expect_block_per_grid > max_blocks_per_grid) {
|
||||||
|
expect_block_per_grid = max_blocks_per_grid;
|
||||||
|
expected_threads_per_block =
|
||||||
|
(active_threads + expect_block_per_grid - 1) / expect_block_per_grid;
|
||||||
|
if (expected_threads_per_block > max_threads_per_block)
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Threads required for current input exceed for current GPU!");
|
||||||
|
expected_threads_per_block =
|
||||||
|
RoundUpToPowerOfTwo(expected_threads_per_block);
|
||||||
|
expect_block_per_grid = ((active_threads + expected_threads_per_block - 1) /
|
||||||
|
expected_threads_per_block);
|
||||||
|
}
|
||||||
|
|
||||||
|
GPULaunchConfig config;
|
||||||
|
config.block.x = expected_threads_per_block;
|
||||||
|
config.grid.x = expect_block_per_grid;
|
||||||
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace utils
|
} // namespace utils
|
||||||
|
|
|
@ -3,10 +3,12 @@
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
#define CUDA_CHECK(func) \
|
#include <exception>
|
||||||
{ \
|
|
||||||
auto status = func; \
|
#define CUDA_CHECK(func) \
|
||||||
if (status != cudaSuccess) { \
|
{ \
|
||||||
LOG(FATAL) << "CUDA Error : " << cudaGetErrorString(status); \
|
auto status = func; \
|
||||||
} \
|
if (status != cudaSuccess) { \
|
||||||
|
throw std::runtime_error(cudaGetErrorString(status)); \
|
||||||
|
} \
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,45 +0,0 @@
|
||||||
#include "nvgpu_dev_info.h"
|
|
||||||
|
|
||||||
#include <array>
|
|
||||||
|
|
||||||
namespace colossalAI {
|
|
||||||
namespace cuda {
|
|
||||||
namespace utils {
|
|
||||||
|
|
||||||
std::array<int, 3> NVGPUDevInfo::GetMaxGridDims() const {
|
|
||||||
std::array<int, 3> ret;
|
|
||||||
ret[0] = prop_->maxGridSize[0];
|
|
||||||
ret[1] = prop_->maxGridSize[1];
|
|
||||||
ret[2] = prop_->maxGridSize[2];
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::array<int, 3> NVGPUDevInfo::GetMaxBlockDims() const {
|
|
||||||
std::array<int, 3> ret;
|
|
||||||
ret[0] = prop_->maxThreadsDim[0];
|
|
||||||
ret[1] = prop_->maxThreadsDim[1];
|
|
||||||
ret[2] = prop_->maxThreadsDim[2];
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::array<int, 2> NVGPUDevInfo::GetCapability() const {
|
|
||||||
std::array<int, 2> ret;
|
|
||||||
ret[0] = prop_.major;
|
|
||||||
ret[1] = prop_.minor;
|
|
||||||
}
|
|
||||||
|
|
||||||
int NVGPUDevInfo::GetMultiProcessorCount() const {
|
|
||||||
return prop_->multiProcessorCount;
|
|
||||||
}
|
|
||||||
|
|
||||||
int NVGPUDevInfo::GetMaxThreadsPerMultiProcessor() const {
|
|
||||||
return prop_->maxThreadsPerMultiProcessor;
|
|
||||||
}
|
|
||||||
|
|
||||||
int NVGPUDevInfo::GetMaxThreadsPerBlock() const {
|
|
||||||
return prop_->maxThreadsPerBlock;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace utils
|
|
||||||
} // namespace cuda
|
|
||||||
} // namespace colossalAI
|
|
|
@ -8,7 +8,6 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "micros.h"
|
#include "micros.h"
|
||||||
#include "target.h"
|
|
||||||
|
|
||||||
namespace colossalAI {
|
namespace colossalAI {
|
||||||
namespace cuda {
|
namespace cuda {
|
||||||
|
@ -17,19 +16,43 @@ namespace utils {
|
||||||
class NVGPUDevInfo {
|
class NVGPUDevInfo {
|
||||||
public:
|
public:
|
||||||
explicit NVGPUDevInfo(int device_num) : device_num_(device_num) {
|
explicit NVGPUDevInfo(int device_num) : device_num_(device_num) {
|
||||||
CUDA_CALL(cudaGetDeviceProperties(prop_, device));
|
CUDA_CHECK(cudaGetDeviceProperties(&prop_, device_num));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::array<int, 3> GetMaxGridDims() const;
|
std::array<int, 3> GetMaxGridDims() const {
|
||||||
std::array<int, 3> GetMaxBlockDims() const;
|
std::array<int, 3> ret;
|
||||||
std::array<int, 2> GetCapability() const;
|
ret[0] = prop_.maxGridSize[0];
|
||||||
int GetMultiProcessorCount() const;
|
ret[1] = prop_.maxGridSize[1];
|
||||||
int GetMaxThreadsPerMultiProcessor() const;
|
ret[2] = prop_.maxGridSize[2];
|
||||||
int GetMaxThreadsPerBlock() const;
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::array<int, 3> GetMaxBlockDims() const {
|
||||||
|
std::array<int, 3> ret;
|
||||||
|
ret[0] = prop_.maxThreadsDim[0];
|
||||||
|
ret[1] = prop_.maxThreadsDim[1];
|
||||||
|
ret[2] = prop_.maxThreadsDim[2];
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::array<int, 2> GetCapability() const {
|
||||||
|
std::array<int, 2> ret;
|
||||||
|
ret[0] = prop_.major;
|
||||||
|
ret[1] = prop_.minor;
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
int GetMultiProcessorCount() const { return prop_.multiProcessorCount; }
|
||||||
|
|
||||||
|
int GetMaxThreadsPerMultiProcessor() const {
|
||||||
|
return prop_.maxThreadsPerMultiProcessor;
|
||||||
|
}
|
||||||
|
|
||||||
|
int GetMaxThreadsPerBlock() const { return prop_.maxThreadsPerBlock; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int device_num_;
|
int device_num_;
|
||||||
cudaDeviceProp* prop_;
|
cudaDeviceProp prop_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace utils
|
} // namespace utils
|
||||||
|
|
Loading…
Reference in New Issue