Merge branch 'feature/colossal-infer' of https://github.com/hpcaitech/ColossalAI into colossal-infer-cuda-graph

pull/5434/head
Runyu Lu 2024-03-21 14:25:22 +08:00
commit 606603bb88
17 changed files with 253 additions and 313 deletions

View File

@ -1,20 +0,0 @@
#pragma once
#include <memory>
#include "common/nvgpu_dev_info.h"
#include "target.h"
namespace colossalAI {
namespace common {
template <typename Ret>
class DevInfoMgr final {
public:
static std::unique_ptr<Ret> GetDevInfo(int device_num) const {
return std::make_unique<Ret>(device_num);
}
};
} // namespace common
} // namespace colossalAI

View File

@ -8,26 +8,22 @@ namespace colossalAI {
namespace common {
template <typename T>
class MPTypeTrait {
public:
struct MPTypeTrait {
using Type = float;
};
template <>
class MPTypeTrait<float> {
public:
struct MPTypeTrait<float> {
using Type = float;
};
template <>
class MPTypeTrait<at::Half> {
public:
struct MPTypeTrait<at::Half> {
using Type = float;
};
template <>
class MPTypeTrait<at::BFloat16> {
public:
struct MPTypeTrait<at::BFloat16> {
using Type = float;
};

View File

@ -105,7 +105,7 @@ class Target {
static Target DefaultAscendTarget();
static Target DefaultCUDATarget() {
return Target(OS::Linux, Arch::CUDA, BitLen::k64);
return Target(OS::Linux, Arch::NVGPU, BitLen::k64);
}
friend std::ostream& operator<<(std::ostream& os, const Target& target);

View File

@ -1,98 +0,0 @@
#include <c10/macros/Macros.h>
#include <cuda_fp16.h>
#include <cfloat>
#include "string"
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(
c10::BFloat16 *dst, const c10::BFloat16 *src) {
*dst = *src;
}
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 2>(
c10::BFloat16 *dst, const c10::BFloat16 *src) {
*((float *)dst) = *((float *)src);
}
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(
c10::BFloat16 *dst, const c10::BFloat16 *src) {
*((float2 *)dst) = *((float2 *)src);
}
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 8>(
c10::BFloat16 *dst, const c10::BFloat16 *src) {
*((float4 *)dst) = *((float4 *)src);
}
template <>
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst,
const c10::Half *src) {
*dst = *src;
}
template <>
__device__ __inline__ void copy_vector<c10::Half, 2>(c10::Half *dst,
const c10::Half *src) {
*((float *)dst) = *((float *)src);
}
template <>
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst,
const c10::Half *src) {
*((float2 *)dst) = *((float2 *)src);
}
template <>
__device__ __inline__ void copy_vector<c10::Half, 8>(c10::Half *dst,
const c10::Half *src) {
*((float4 *)dst) = *((float4 *)src);
}
template <>
__device__ __inline__ void copy_vector<float, 1>(float *dst, const float *src) {
*dst = *src;
}
template <>
__device__ __inline__ void copy_vector<float, 2>(float *dst, const float *src) {
*((float2 *)dst) = *((float2 *)src);
}
template <>
__device__ __inline__ void copy_vector<float, 4>(float *dst, const float *src) {
*((float4 *)dst) = *((float4 *)src);
}
template <>
__device__ __inline__ void copy_vector<float, 8>(float *dst, const float *src) {
// Since the maximum memory alignment length is 128 bits, we choose float4
// here.
*((float4 *)dst) = *((float4 *)src);
*((float4 *)(dst + 4)) = *((float4 *)(src + 4));
}
template <typename T>
int get_vec_size(const torch::Tensor &tensor) {
uint64_t address = reinterpret_cast<uint64_t>(tensor.data_ptr<T>());
const int max_aligned_size = 128;
const int dtype_size = sizeof(T) * 8;
const int vec_size = max_aligned_size / sizeof(T) / 8;
if (address % (dtype_size * 4) == 0) {
return std::min(4, vec_size);
} else if (address % (dtype_size * 2) == 0) {
return std::min(2, vec_size);
} else {
return 1;
}
}

View File

@ -36,6 +36,8 @@ __global__ void act_and_mul_kernel(
// silu(x[:half_1stdim]) * (x[half_1stdim:])
torch::Tensor silu_and_mul(const torch::Tensor& ins)
{
// Note(LiuYang): According to torch doc, vec() may cost a lot, but I did't find a better api
// to manipulate ins_shape which is IntArrayRef
auto ins_shape = ins.sizes().vec();
ins_shape[0] = ins_shape[0]/2;
@ -43,14 +45,19 @@ torch::Tensor silu_and_mul(const torch::Tensor& ins)
ins_shape.erase(ins_shape.begin());
}
auto outs = torch::zeros(ins_shape,ins.options());
auto outs_shape = ins.sizes().vec();
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// Note(Liuyang): numel of ins must be divisible by 2
int64_t numel = ((torch::numel(ins)) >> 1);
// TODO(LiuYang): Maybe we need to implement a function to get launch config
// Note(LiuYang): For better performance for special case of which input is [2, 64, 11008], now
// I comment this part codebecause it also cost a little time to calculate a better config
// colossalAI::cuda::utils::NVGPUDevInfo dev_info(0);
// auto config = colossalAI::cuda::utils::GetGPULaunchConfig1D(dev_info,numel,1);
// dim3 grid = config.grid;
// dim3 block = config.block;
dim3 grid((numel+255)/256);
dim3 block(256);

View File

@ -1,7 +1,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "../common/vector_copy_utils.h"
#include "utils/vector_copy_utils.h"
#include "../common/micros.h"
template<typename scalar_t, int VecSize>

View File

@ -2,7 +2,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "../common/vector_copy_utils.h"
#include "utils/vector_copy_utils.h"
#include "../common/micros.h"
template <typename scalar_t, int VecSize>

View File

@ -10,7 +10,7 @@
#include "block_reduce.h"
#include "../common/micros.h"
#include "../common/cuda_type_utils.h"
#include "utils/cuda_type_utils.h"
#define DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(DATA_SIZE, TYPE, NAME, ...) \
if (DATA_SIZE == 2) { \

View File

@ -6,52 +6,14 @@
#include <assert.h>
#include <c10/macros/Macros.h>
#include <cuda_fp16.h>
#include <stdint.h>
#include <cfloat>
#include <limits>
#include "utils/vector_copy_utils.h"
namespace {
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(
c10::BFloat16 *dst, const c10::BFloat16 *src) {
*dst = *src;
}
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(
c10::BFloat16 *dst, const c10::BFloat16 *src) {
*((float2 *)dst) = *((float2 *)src);
}
template <>
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst,
const c10::Half *src) {
*dst = *src;
}
template <>
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst,
const c10::Half *src) {
*((float2 *)dst) = *((float2 *)src);
}
template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst,
const uint8_t *src) {
*dst = *src;
}
template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst,
const uint8_t *src) {
*((half2 *)dst) = *((half2 *)src);
}
int log2_ceil(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value;

View File

@ -13,70 +13,6 @@
namespace {
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(
c10::BFloat16 *dst, const c10::BFloat16 *src) {
*dst = *src;
}
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(
c10::BFloat16 *dst, const c10::BFloat16 *src) {
*((float2 *)dst) = *((float2 *)src);
}
template <>
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst,
const c10::Half *src) {
*dst = *src;
}
template <>
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst,
const c10::Half *src) {
*((float2 *)dst) = *((float2 *)src);
}
template <>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst,
const uint8_t *src) {
*dst = *src;
}
template <>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst,
const uint8_t *src) {
*((half2 *)dst) = *((half2 *)src);
}
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_zero_vector(Datatype *dst);
template <>
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 1>(
c10::BFloat16 *dst) {
*dst = 0.0;
}
template <>
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 4>(
c10::BFloat16 *dst) {
*((float2 *)dst) = make_float2(0.0f, 0.0f);
}
template <>
__device__ __inline__ void copy_zero_vector<c10::Half, 1>(c10::Half *dst) {
*dst = 0.0;
}
template <>
__device__ __inline__ void copy_zero_vector<c10::Half, 4>(c10::Half *dst) {
*((float2 *)dst) = make_float2(0.0f, 0.0f);
}
int log2_ceil(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value;

View File

@ -3,32 +3,74 @@
#include <cuda.h>
#include <cuda_runtime.h>
#include "nvgpu_dev_info.h"
namespace colossalAI {
namespace cuda {
namespace utils {
GPULaunchConfig GPUGetGPULaunchConfig1D(int64_t numel, int vec_size);
struct GPULaunchConfig {
dim3 block{1, 1, 1};
dim3 grid{1, 1, 1};
};
// TODO(LiuYang): to be implemented
GPULaunchConfig GPUGetGPULaunchConfig2D(int64_t numel, int vec_size);
static GPULaunchConfig GetGPULaunchConfig1D(const NVGPUDevInfo& dev_info,
int64_t numel, int64_t vec_size) {
const int64_t max_threads_per_block = dev_info.GetMaxThreadsPerBlock();
const int64_t max_blocks_per_grid = dev_info.GetMaxGridDims()[0];
const int64_t kMinimumSize = 64;
const int64_t kMaximumSize = 512;
int64_t active_threads = (numel + vec_size - 1) / vec_size;
int64_t sm_num = dev_info.GetMultiProcessorCount();
// TODO(LiuYang): to be implemented
GPULaunchConfig GPUGetGPULaunchConfig3D(int64_t numel, int vec_size);
// Note(LiuYang): expected threads should be in [64, 128, 256, 512] generally
int64_t expected_threads_per_block = kMaximumSize;
class GPULaunchConfig {
public:
GPULaunchConfig(){};
GPULaunchConfig(const dim3& block, const dim3& grid)
: block_(block), grid_(grid) {}
friend GPULaunchConfig GPUGetGPULaunchConfig1D(int64_t numel, int vec_size);
auto RoundUpToPowerOfTwo = [](int64_t x) {
bool is_power_of_two = false;
int64_t ret = 1;
int64_t y = x;
while (y > 0) {
is_power_of_two = ((ret ^ x) == 0);
y = (x >> 1);
ret = (ret << 1);
if (y > 0) is_power_of_two = false;
}
if (is_power_of_two) return x;
return ret;
};
protected:
void set_block(const dim3& dim) { block_ = dim; }
void set_grid(const dim3& dim) { grid_ = dim; }
if ((active_threads / (sm_num << 1)) < max_threads_per_block) {
expected_threads_per_block =
RoundUpToPowerOfTwo(active_threads / (sm_num << 1));
} else if ((active_threads / (sm_num << 2)) < max_threads_per_block) {
expected_threads_per_block =
RoundUpToPowerOfTwo(active_threads / (sm_num << 2));
}
private:
dim3 block_(1, 1, 1);
dim3 grid_(1, 1, 1);
expected_threads_per_block =
std::max(expected_threads_per_block, kMinimumSize);
int64_t expect_block_per_grid =
((active_threads + expected_threads_per_block - 1) /
expected_threads_per_block);
if (expect_block_per_grid > max_blocks_per_grid) {
expect_block_per_grid = max_blocks_per_grid;
expected_threads_per_block =
(active_threads + expect_block_per_grid - 1) / expect_block_per_grid;
if (expected_threads_per_block > max_threads_per_block)
throw std::invalid_argument(
"Threads required for current input exceed for current GPU!");
expected_threads_per_block =
RoundUpToPowerOfTwo(expected_threads_per_block);
expect_block_per_grid = ((active_threads + expected_threads_per_block - 1) /
expected_threads_per_block);
}
GPULaunchConfig config;
config.block.x = expected_threads_per_block;
config.grid.x = expect_block_per_grid;
return config;
}
} // namespace utils

View File

@ -3,10 +3,12 @@
#include <cuda.h>
#include <cuda_runtime.h>
#define CUDA_CHECK(func) \
{ \
auto status = func; \
if (status != cudaSuccess) { \
LOG(FATAL) << "CUDA Error : " << cudaGetErrorString(status); \
} \
#include <exception>
#define CUDA_CHECK(func) \
{ \
auto status = func; \
if (status != cudaSuccess) { \
throw std::runtime_error(cudaGetErrorString(status)); \
} \
}

View File

@ -1,45 +0,0 @@
#include "nvgpu_dev_info.h"
#include <array>
namespace colossalAI {
namespace cuda {
namespace utils {
std::array<int, 3> NVGPUDevInfo::GetMaxGridDims() const {
std::array<int, 3> ret;
ret[0] = prop_->maxGridSize[0];
ret[1] = prop_->maxGridSize[1];
ret[2] = prop_->maxGridSize[2];
return ret;
}
std::array<int, 3> NVGPUDevInfo::GetMaxBlockDims() const {
std::array<int, 3> ret;
ret[0] = prop_->maxThreadsDim[0];
ret[1] = prop_->maxThreadsDim[1];
ret[2] = prop_->maxThreadsDim[2];
return ret;
}
std::array<int, 2> NVGPUDevInfo::GetCapability() const {
std::array<int, 2> ret;
ret[0] = prop_.major;
ret[1] = prop_.minor;
}
int NVGPUDevInfo::GetMultiProcessorCount() const {
return prop_->multiProcessorCount;
}
int NVGPUDevInfo::GetMaxThreadsPerMultiProcessor() const {
return prop_->maxThreadsPerMultiProcessor;
}
int NVGPUDevInfo::GetMaxThreadsPerBlock() const {
return prop_->maxThreadsPerBlock;
}
} // namespace utils
} // namespace cuda
} // namespace colossalAI

View File

@ -8,7 +8,6 @@
#include <vector>
#include "micros.h"
#include "target.h"
namespace colossalAI {
namespace cuda {
@ -17,19 +16,43 @@ namespace utils {
class NVGPUDevInfo {
public:
explicit NVGPUDevInfo(int device_num) : device_num_(device_num) {
CUDA_CALL(cudaGetDeviceProperties(prop_, device));
CUDA_CHECK(cudaGetDeviceProperties(&prop_, device_num));
}
std::array<int, 3> GetMaxGridDims() const;
std::array<int, 3> GetMaxBlockDims() const;
std::array<int, 2> GetCapability() const;
int GetMultiProcessorCount() const;
int GetMaxThreadsPerMultiProcessor() const;
int GetMaxThreadsPerBlock() const;
std::array<int, 3> GetMaxGridDims() const {
std::array<int, 3> ret;
ret[0] = prop_.maxGridSize[0];
ret[1] = prop_.maxGridSize[1];
ret[2] = prop_.maxGridSize[2];
return ret;
}
std::array<int, 3> GetMaxBlockDims() const {
std::array<int, 3> ret;
ret[0] = prop_.maxThreadsDim[0];
ret[1] = prop_.maxThreadsDim[1];
ret[2] = prop_.maxThreadsDim[2];
return ret;
}
std::array<int, 2> GetCapability() const {
std::array<int, 2> ret;
ret[0] = prop_.major;
ret[1] = prop_.minor;
return ret;
}
int GetMultiProcessorCount() const { return prop_.multiProcessorCount; }
int GetMaxThreadsPerMultiProcessor() const {
return prop_.maxThreadsPerMultiProcessor;
}
int GetMaxThreadsPerBlock() const { return prop_.maxThreadsPerBlock; }
private:
int device_num_;
cudaDeviceProp* prop_;
cudaDeviceProp prop_;
};
} // namespace utils

View File

@ -0,0 +1,83 @@
#pragma once
#include <c10/macros/Macros.h>
#include <cuda_fp16.h>
#include <stdint.h>
#include <cfloat>
namespace colossalAI {
namespace cuda {
namespace utils {
template <typename T, int VecSize>
struct VecTypeTrait {};
template <typename T>
struct VecTypeTrait<T, 1> {
using Type = T;
};
template <>
struct VecTypeTrait<c10::BFloat16, 2> {
using Type = float;
};
template <>
struct VecTypeTrait<c10::BFloat16, 4> {
using Type = float2;
};
template <>
struct VecTypeTrait<c10::BFloat16, 8> {
using Type = float4;
};
template <>
struct VecTypeTrait<c10::Half, 2> {
using Type = float;
};
template <>
struct VecTypeTrait<c10::Half, 4> {
using Type = float2;
};
template <>
struct VecTypeTrait<c10::Half, 8> {
using Type = float4;
};
template <>
struct VecTypeTrait<float, 2> {
using Type = float2;
};
template <>
struct VecTypeTrait<float, 4> {
using Type = float4;
};
template <>
struct VecTypeTrait<float, 8> {
using Type = float4;
};
template <>
struct VecTypeTrait<uint8_t, 2> {
using Type = half;
};
template <>
struct VecTypeTrait<uint8_t, 4> {
using Type = half2;
};
template <>
struct VecTypeTrait<uint8_t, 8> {
using Type = float2;
};
} // namespace utils
} // namespace cuda
} // namespace colossalAI

View File

@ -0,0 +1,52 @@
#pragma once
#include <c10/macros/Macros.h>
#include <cuda_fp16.h>
#include <stdint.h>
#include "vec_type_traits.h"
template <typename T, int VecSize>
__device__ __inline__ void copy_vector(T *dst, const T *src) {
using VT = typename colossalAI::cuda::utils::VecTypeTrait<T, VecSize>::Type;
// Note(LiuYang): Here static_cast can't be used for cast between two pointer
*(reinterpret_cast<VT *>(dst)) = *(reinterpret_cast<VT *>(src));
}
template <>
__device__ __inline__ void copy_vector<float, 8>(float *dst, const float *src) {
// Since the maximum memory alignment length is 128 bits, we choose float4
// here.
*(reinterpret_cast<float4 *>(dst)) = *(reinterpret_cast<float4 *>(src));
*(reinterpret_cast<float4 *>(dst + 4)) =
*(reinterpret_cast<float4 *>(src + 4));
}
template <typename T, int VecSize>
__device__ __inline__ void copy_zero_vector(T *dst) {
using VT = typename colossalAI::cuda::utils::VecTypeTrait<T, VecSize>::Type;
*(reinterpret_cast<VT *>(dst)) = {0.0};
}
template <typename T>
int get_vec_size(const torch::Tensor &tensor) {
uint64_t address = reinterpret_cast<uint64_t>(tensor.data_ptr<T>());
const int max_aligned_size = 128;
const int dtype_size = sizeof(T) * 8;
const int vec_size = max_aligned_size / sizeof(T) / 8;
// Note(LiuYang): Performance of situation of which
// vec_size equals to 8 need to be profiled in the future
// if (address % (dtype_size * 8) == 0) {
// return std::min(8, vec_size);
// }
if (address % (dtype_size * 4) == 0) {
return std::min(4, vec_size);
} else if (address % (dtype_size * 2) == 0) {
return std::min(2, vec_size);
} else {
return 1;
}
}