mirror of https://github.com/hpcaitech/ColossalAI
Merge pull request #5435 from Courtesy-Xs/add_gpu_launch_config
Add query and other componentspull/5445/head
commit
21e1e3645c
|
@ -0,0 +1,20 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "common/nvgpu_dev_info.h"
|
||||||
|
#include "target.h"
|
||||||
|
|
||||||
|
namespace colossalAI {
|
||||||
|
namespace common {
|
||||||
|
|
||||||
|
template <typename Ret>
|
||||||
|
class DevInfoMgr final {
|
||||||
|
public:
|
||||||
|
static std::unique_ptr<Ret> GetDevInfo(int device_num) const {
|
||||||
|
return std::make_unique<Ret>(device_num);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace common
|
||||||
|
} // namespace colossalAI
|
|
@ -9,7 +9,15 @@
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
|
|
||||||
#include "compat.h"
|
#ifndef TORCH_CHECK
|
||||||
|
#define TORCH_CHECK AT_CHECK
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef VERSION_GE_1_3
|
||||||
|
#define DATA_PTR data_ptr
|
||||||
|
#else
|
||||||
|
#define DATA_PTR data
|
||||||
|
#endif
|
||||||
|
|
||||||
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
|
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
|
||||||
switch (TYPE) { \
|
switch (TYPE) { \
|
||||||
|
@ -214,90 +222,3 @@
|
||||||
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \
|
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \
|
||||||
"'"); \
|
"'"); \
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__device__ __forceinline__ T reduce_block_into_lanes(
|
|
||||||
T *x, T val, int lanes = 1,
|
|
||||||
bool share_result = false) // lanes is intended to be <= 32.
|
|
||||||
{
|
|
||||||
int tid = threadIdx.x + threadIdx.y * blockDim.x;
|
|
||||||
int blockSize =
|
|
||||||
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
|
|
||||||
|
|
||||||
if (blockSize >= 64) {
|
|
||||||
x[tid] = val;
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
|
|
||||||
if (tid < i) x[tid] = x[tid] + x[tid + i];
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
|
|
||||||
T final;
|
|
||||||
|
|
||||||
if (tid < 32) {
|
|
||||||
if (blockSize >= 64)
|
|
||||||
final = x[tid] + x[tid + 32];
|
|
||||||
else
|
|
||||||
final = val;
|
|
||||||
// __SYNCWARP();
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 16; i >= lanes; i >>= 1)
|
|
||||||
final = final + __shfl_down_sync(0xffffffff, final, i);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (share_result) {
|
|
||||||
if (tid < lanes) x[tid] = final; // EpilogueOp
|
|
||||||
// Make sure the smem result is visible to all warps.
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
|
|
||||||
return final;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__device__ __forceinline__ T reduce_block_into_lanes_max_op(
|
|
||||||
T *x, T val, int lanes = 1,
|
|
||||||
bool share_result = false) // lanes is intended to be <= 32.
|
|
||||||
{
|
|
||||||
int tid = threadIdx.x + threadIdx.y * blockDim.x;
|
|
||||||
int blockSize =
|
|
||||||
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
|
|
||||||
|
|
||||||
if (blockSize >= 64) {
|
|
||||||
x[tid] = val;
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
|
|
||||||
if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
|
|
||||||
T final;
|
|
||||||
|
|
||||||
if (tid < 32) {
|
|
||||||
if (blockSize >= 64)
|
|
||||||
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
|
|
||||||
else
|
|
||||||
final = val;
|
|
||||||
// __SYNCWARP();
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 16; i >= lanes; i >>= 1)
|
|
||||||
final =
|
|
||||||
fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (share_result) {
|
|
||||||
if (tid < lanes) x[tid] = final; // EpilogueOp
|
|
||||||
// Make sure the smem result is visible to all warps.
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
|
|
||||||
return final;
|
|
||||||
}
|
|
|
@ -2,10 +2,10 @@
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
|
|
||||||
#include "../type_shim.h"
|
#include "micros.h"
|
||||||
|
|
||||||
namespace infer {
|
namespace colossalAI {
|
||||||
namespace dtype {
|
namespace common {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class MPTypeTrait {
|
class MPTypeTrait {
|
||||||
|
@ -31,5 +31,5 @@ class MPTypeTrait<at::BFloat16> {
|
||||||
using Type = float;
|
using Type = float;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace dtype
|
} // namespace common
|
||||||
} // namespace infer
|
} // namespace colossalAI
|
|
@ -0,0 +1,134 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <exception>
|
||||||
|
#include <iostream>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace colossalAI {
|
||||||
|
namespace common {
|
||||||
|
|
||||||
|
class Target {
|
||||||
|
public:
|
||||||
|
enum class OS : int {
|
||||||
|
Unk = -1,
|
||||||
|
Linux,
|
||||||
|
Windows,
|
||||||
|
};
|
||||||
|
enum class Arch : int {
|
||||||
|
Unk = -1,
|
||||||
|
X86,
|
||||||
|
Arm,
|
||||||
|
NVGPU,
|
||||||
|
AMDGPU,
|
||||||
|
Ascend,
|
||||||
|
};
|
||||||
|
enum class BitLen : int {
|
||||||
|
Unk = -1,
|
||||||
|
k32,
|
||||||
|
k64,
|
||||||
|
};
|
||||||
|
|
||||||
|
explicit Target(OS os, Arch arch, BitLen bitlen)
|
||||||
|
: os_(os), arch_(arch), bitlen_(bitlen) {}
|
||||||
|
|
||||||
|
bool defined() const {
|
||||||
|
return (os_ != OS::Unk) && (arch_ != Arch::Unk) && (bitlen_ != BitLen::Unk);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string str() const {
|
||||||
|
std::string s{"OS: "};
|
||||||
|
switch (os_) {
|
||||||
|
case OS::Unk:
|
||||||
|
s += "Unk";
|
||||||
|
break;
|
||||||
|
case OS::Linux:
|
||||||
|
s += "Linux";
|
||||||
|
break;
|
||||||
|
case OS::Windows:
|
||||||
|
s += "Windows";
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument("Invalid OS type!");
|
||||||
|
}
|
||||||
|
s += "\t";
|
||||||
|
s += "Arch: ";
|
||||||
|
|
||||||
|
switch (arch_) {
|
||||||
|
case Arch::Unk:
|
||||||
|
s += "Unk";
|
||||||
|
break;
|
||||||
|
case Arch::X86:
|
||||||
|
s += "X86";
|
||||||
|
break;
|
||||||
|
case Arch::Arm:
|
||||||
|
s += "Arm";
|
||||||
|
break;
|
||||||
|
case Arch::NVGPU:
|
||||||
|
s += "NVGPU";
|
||||||
|
break;
|
||||||
|
case Arch::AMDGPU:
|
||||||
|
s += "AMDGPU";
|
||||||
|
break;
|
||||||
|
case Arch::Ascend:
|
||||||
|
s += "Ascend";
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument("Invalid Arch type!");
|
||||||
|
}
|
||||||
|
s += "\t";
|
||||||
|
s += "BitLen: ";
|
||||||
|
|
||||||
|
switch (bitlen_) {
|
||||||
|
case BitLen::Unk:
|
||||||
|
s += "Unk";
|
||||||
|
break;
|
||||||
|
case BitLen::k32:
|
||||||
|
s += "k32";
|
||||||
|
break;
|
||||||
|
case BitLen::k64:
|
||||||
|
s += "k64";
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument("Invalid target bit length!");
|
||||||
|
}
|
||||||
|
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
OS os() const { return os_; }
|
||||||
|
Arch arch() const { return arch_; }
|
||||||
|
BitLen bitlen() const { return bitlen_; }
|
||||||
|
|
||||||
|
static Target DefaultX86Target();
|
||||||
|
static Target DefaultArmTarget();
|
||||||
|
static Target DefaultRocmTarget();
|
||||||
|
static Target DefaultAscendTarget();
|
||||||
|
|
||||||
|
static Target DefaultCUDATarget() {
|
||||||
|
return Target(OS::Linux, Arch::CUDA, BitLen::k64);
|
||||||
|
}
|
||||||
|
|
||||||
|
friend std::ostream& operator<<(std::ostream& os, const Target& target);
|
||||||
|
friend bool operator==(const Target& lhs, const Target& rhs);
|
||||||
|
friend bool operator!=(const Target& lhs, const Target& rhs);
|
||||||
|
|
||||||
|
private:
|
||||||
|
OS os_{OS::Unk};
|
||||||
|
Arch arch_{Arch::Unk};
|
||||||
|
BitLen bitlen_{BitLen::Unk};
|
||||||
|
};
|
||||||
|
|
||||||
|
std::ostream& operator<<(std::ostream& os, const Target& target) {
|
||||||
|
std::cout << target.str() << std::endl;
|
||||||
|
}
|
||||||
|
bool operator==(const Target& lhs, const Target& rhs) {
|
||||||
|
return (lhs.os_ == rhs.os_) && (lhs.arch_ == rhs.arch_) &&
|
||||||
|
(lhs.bitlen_ == rhs.bitlen_);
|
||||||
|
}
|
||||||
|
bool operator!=(const Target& lhs, const Target& rhs) {
|
||||||
|
return (lhs.os_ != rhs.os_) && (lhs.arch_ != rhs.arch_) &&
|
||||||
|
(lhs.bitlen_ != rhs.bitlen_);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace common
|
||||||
|
} // namespace colossalAI
|
|
@ -2,13 +2,13 @@
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#include "type_shim.h"
|
#include "../common/micros.h"
|
||||||
#include "include/mp_type_traits.h"
|
#include "../common/mp_type_traits.h"
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
__device__ __forceinline__ T silu_kernel(const T& x) {
|
__device__ __forceinline__ T silu_kernel(const T& x) {
|
||||||
// x * sigmoid(x)
|
// x * sigmoid(x)
|
||||||
using MT = typename infer::dtype::MPTypeTrait<T>::Type;
|
using MT = typename colossalAI::common::MPTypeTrait<T>::Type;
|
||||||
return static_cast<T>((static_cast<MT>(x)) / (static_cast<MT>(1.0f) + expf(static_cast<MT>(-x))));
|
return static_cast<T>((static_cast<MT>(x)) / (static_cast<MT>(1.0f) + expf(static_cast<MT>(-x))));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ __global__ void act_and_mul_kernel(
|
||||||
const scalar_t* __restrict__ ins_data,
|
const scalar_t* __restrict__ ins_data,
|
||||||
scalar_t* __restrict__ outs_data,
|
scalar_t* __restrict__ outs_data,
|
||||||
const int64_t numel) {
|
const int64_t numel) {
|
||||||
using MT = typename infer::dtype::MPTypeTrait<scalar_t>::Type;
|
using MT = typename colossalAI::common::MPTypeTrait<scalar_t>::Type;
|
||||||
|
|
||||||
int64_t idx = static_cast<int64_t>(threadIdx.x) + static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x);
|
int64_t idx = static_cast<int64_t>(threadIdx.x) + static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x);
|
||||||
const int64_t grid_size = blockDim.x * gridDim.x;
|
const int64_t grid_size = blockDim.x * gridDim.x;
|
||||||
|
|
|
@ -1,10 +0,0 @@
|
||||||
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h
|
|
||||||
#ifndef TORCH_CHECK
|
|
||||||
#define TORCH_CHECK AT_CHECK
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef VERSION_GE_1_3
|
|
||||||
#define DATA_PTR data_ptr
|
|
||||||
#else
|
|
||||||
#define DATA_PTR data
|
|
||||||
#endif
|
|
|
@ -2,7 +2,7 @@
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#include "type_shim.h"
|
#include "../common/micros.h"
|
||||||
|
|
||||||
template<typename scalar_t>
|
template<typename scalar_t>
|
||||||
__global__ void decode_kv_cache_memcpy_kernel(
|
__global__ void decode_kv_cache_memcpy_kernel(
|
||||||
|
|
|
@ -310,3 +310,90 @@ __inline__ __device__ void blockReduce<ReduceType::kMax, 4>(float *pval) {
|
||||||
}
|
}
|
||||||
warpReduce<ReduceType::kMax, num>(pval);
|
warpReduce<ReduceType::kMax, num>(pval);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ __forceinline__ T reduce_block_into_lanes(
|
||||||
|
T *x, T val, int lanes = 1,
|
||||||
|
bool share_result = false) // lanes is intended to be <= 32.
|
||||||
|
{
|
||||||
|
int tid = threadIdx.x + threadIdx.y * blockDim.x;
|
||||||
|
int blockSize =
|
||||||
|
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
|
||||||
|
|
||||||
|
if (blockSize >= 64) {
|
||||||
|
x[tid] = val;
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
|
||||||
|
if (tid < i) x[tid] = x[tid] + x[tid + i];
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
T final;
|
||||||
|
|
||||||
|
if (tid < 32) {
|
||||||
|
if (blockSize >= 64)
|
||||||
|
final = x[tid] + x[tid + 32];
|
||||||
|
else
|
||||||
|
final = val;
|
||||||
|
// __SYNCWARP();
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 16; i >= lanes; i >>= 1)
|
||||||
|
final = final + __shfl_down_sync(0xffffffff, final, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (share_result) {
|
||||||
|
if (tid < lanes) x[tid] = final; // EpilogueOp
|
||||||
|
// Make sure the smem result is visible to all warps.
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
return final;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ __forceinline__ T reduce_block_into_lanes_max_op(
|
||||||
|
T *x, T val, int lanes = 1,
|
||||||
|
bool share_result = false) // lanes is intended to be <= 32.
|
||||||
|
{
|
||||||
|
int tid = threadIdx.x + threadIdx.y * blockDim.x;
|
||||||
|
int blockSize =
|
||||||
|
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
|
||||||
|
|
||||||
|
if (blockSize >= 64) {
|
||||||
|
x[tid] = val;
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
|
||||||
|
if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
T final;
|
||||||
|
|
||||||
|
if (tid < 32) {
|
||||||
|
if (blockSize >= 64)
|
||||||
|
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
|
||||||
|
else
|
||||||
|
final = val;
|
||||||
|
// __SYNCWARP();
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 16; i >= lanes; i >>= 1)
|
||||||
|
final =
|
||||||
|
fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (share_result) {
|
||||||
|
if (tid < lanes) x[tid] = final; // EpilogueOp
|
||||||
|
// Make sure the smem result is visible to all warps.
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
return final;
|
||||||
|
}
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "compat.h"
|
#include "../common/micros.h"
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,7 @@
|
||||||
#include "ATen/AccumulateType.h"
|
#include "ATen/AccumulateType.h"
|
||||||
#include "ATen/cuda/CUDAContext.h"
|
#include "ATen/cuda/CUDAContext.h"
|
||||||
#include "ATen/cuda/DeviceUtils.cuh"
|
#include "ATen/cuda/DeviceUtils.cuh"
|
||||||
#include "type_shim.h"
|
#include "../common/micros.h"
|
||||||
|
|
||||||
template <typename U>
|
template <typename U>
|
||||||
__device__ void cuWelfordOnlineSum(const U curr, U& mu, U& sigma2, U& count) {
|
__device__ void cuWelfordOnlineSum(const U curr, U& mu, U& sigma2, U& count) {
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
|
|
||||||
#include "multi_tensor_apply.cuh"
|
#include "multi_tensor_apply.cuh"
|
||||||
#include "type_shim.h"
|
#include "../common/micros.h"
|
||||||
|
|
||||||
#define BLOCK_SIZE 512
|
#define BLOCK_SIZE 512
|
||||||
#define ILP 4
|
#define ILP 4
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
#include "compat.h"
|
#include "../common/micros.h"
|
||||||
|
|
||||||
// #include <iostream>
|
// #include <iostream>
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,8 @@
|
||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
|
|
||||||
#include "multi_tensor_apply.cuh"
|
#include "multi_tensor_apply.cuh"
|
||||||
#include "type_shim.h"
|
#include "../common/micros.h"
|
||||||
|
#include "include/block_reduce.h"
|
||||||
|
|
||||||
#define BLOCK_SIZE 512
|
#define BLOCK_SIZE 512
|
||||||
#define ILP 4
|
#define ILP 4
|
||||||
|
|
|
@ -10,7 +10,7 @@
|
||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
|
|
||||||
#include "multi_tensor_apply.cuh"
|
#include "multi_tensor_apply.cuh"
|
||||||
#include "type_shim.h"
|
#include "../common/micros.h"
|
||||||
|
|
||||||
#define BLOCK_SIZE 512
|
#define BLOCK_SIZE 512
|
||||||
#define ILP 4
|
#define ILP 4
|
||||||
|
|
|
@ -10,7 +10,7 @@
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "multi_tensor_apply.cuh"
|
#include "multi_tensor_apply.cuh"
|
||||||
#include "type_shim.h"
|
#include "../common/micros.h"
|
||||||
|
|
||||||
#define BLOCK_SIZE 512
|
#define BLOCK_SIZE 512
|
||||||
#define ILP 4
|
#define ILP 4
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
#include "compat.h"
|
#include "../common/micros.h"
|
||||||
#include "multi_tensor_apply.cuh"
|
#include "multi_tensor_apply.cuh"
|
||||||
|
|
||||||
#define BLOCK_SIZE 512
|
#define BLOCK_SIZE 512
|
||||||
|
|
|
@ -10,7 +10,7 @@
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
#include "scaled_masked_softmax.h"
|
#include "scaled_masked_softmax.h"
|
||||||
#include "type_shim.h"
|
#include "../common/micros.h"
|
||||||
|
|
||||||
namespace multihead_attn {
|
namespace multihead_attn {
|
||||||
namespace fused_softmax {
|
namespace fused_softmax {
|
||||||
|
|
|
@ -10,7 +10,7 @@
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
#include "scaled_upper_triang_masked_softmax.h"
|
#include "scaled_upper_triang_masked_softmax.h"
|
||||||
#include "type_shim.h"
|
#include "../common/micros.h"
|
||||||
|
|
||||||
namespace multihead_attn {
|
namespace multihead_attn {
|
||||||
namespace fused_softmax {
|
namespace fused_softmax {
|
||||||
|
|
|
@ -0,0 +1,36 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
|
namespace colossalAI {
|
||||||
|
namespace cuda {
|
||||||
|
namespace utils {
|
||||||
|
|
||||||
|
GPULaunchConfig GPUGetGPULaunchConfig1D(int64_t numel, int vec_size);
|
||||||
|
|
||||||
|
// TODO(LiuYang): to be implemented
|
||||||
|
GPULaunchConfig GPUGetGPULaunchConfig2D(int64_t numel, int vec_size);
|
||||||
|
|
||||||
|
// TODO(LiuYang): to be implemented
|
||||||
|
GPULaunchConfig GPUGetGPULaunchConfig3D(int64_t numel, int vec_size);
|
||||||
|
|
||||||
|
class GPULaunchConfig {
|
||||||
|
public:
|
||||||
|
GPULaunchConfig(){};
|
||||||
|
GPULaunchConfig(const dim3& block, const dim3& grid)
|
||||||
|
: block_(block), grid_(grid) {}
|
||||||
|
friend GPULaunchConfig GPUGetGPULaunchConfig1D(int64_t numel, int vec_size);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void set_block(const dim3& dim) { block_ = dim; }
|
||||||
|
void set_grid(const dim3& dim) { grid_ = dim; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
dim3 block_(1, 1, 1);
|
||||||
|
dim3 grid_(1, 1, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace utils
|
||||||
|
} // namespace cuda
|
||||||
|
} // namespace colossalAI
|
|
@ -0,0 +1,12 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
|
#define CUDA_CHECK(func) \
|
||||||
|
{ \
|
||||||
|
auto status = func; \
|
||||||
|
if (status != cudaSuccess) { \
|
||||||
|
LOG(FATAL) << "CUDA Error : " << cudaGetErrorString(status); \
|
||||||
|
} \
|
||||||
|
}
|
|
@ -0,0 +1,45 @@
|
||||||
|
#include "nvgpu_dev_info.h"
|
||||||
|
|
||||||
|
#include <array>
|
||||||
|
|
||||||
|
namespace colossalAI {
|
||||||
|
namespace cuda {
|
||||||
|
namespace utils {
|
||||||
|
|
||||||
|
std::array<int, 3> NVGPUDevInfo::GetMaxGridDims() const {
|
||||||
|
std::array<int, 3> ret;
|
||||||
|
ret[0] = prop_->maxGridSize[0];
|
||||||
|
ret[1] = prop_->maxGridSize[1];
|
||||||
|
ret[2] = prop_->maxGridSize[2];
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::array<int, 3> NVGPUDevInfo::GetMaxBlockDims() const {
|
||||||
|
std::array<int, 3> ret;
|
||||||
|
ret[0] = prop_->maxThreadsDim[0];
|
||||||
|
ret[1] = prop_->maxThreadsDim[1];
|
||||||
|
ret[2] = prop_->maxThreadsDim[2];
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::array<int, 2> NVGPUDevInfo::GetCapability() const {
|
||||||
|
std::array<int, 2> ret;
|
||||||
|
ret[0] = prop_.major;
|
||||||
|
ret[1] = prop_.minor;
|
||||||
|
}
|
||||||
|
|
||||||
|
int NVGPUDevInfo::GetMultiProcessorCount() const {
|
||||||
|
return prop_->multiProcessorCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
int NVGPUDevInfo::GetMaxThreadsPerMultiProcessor() const {
|
||||||
|
return prop_->maxThreadsPerMultiProcessor;
|
||||||
|
}
|
||||||
|
|
||||||
|
int NVGPUDevInfo::GetMaxThreadsPerBlock() const {
|
||||||
|
return prop_->maxThreadsPerBlock;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace utils
|
||||||
|
} // namespace cuda
|
||||||
|
} // namespace colossalAI
|
|
@ -0,0 +1,37 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
|
#include <ostream>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "micros.h"
|
||||||
|
#include "target.h"
|
||||||
|
|
||||||
|
namespace colossalAI {
|
||||||
|
namespace cuda {
|
||||||
|
namespace utils {
|
||||||
|
|
||||||
|
class NVGPUDevInfo {
|
||||||
|
public:
|
||||||
|
explicit NVGPUDevInfo(int device_num) : device_num_(device_num) {
|
||||||
|
CUDA_CALL(cudaGetDeviceProperties(prop_, device));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::array<int, 3> GetMaxGridDims() const;
|
||||||
|
std::array<int, 3> GetMaxBlockDims() const;
|
||||||
|
std::array<int, 2> GetCapability() const;
|
||||||
|
int GetMultiProcessorCount() const;
|
||||||
|
int GetMaxThreadsPerMultiProcessor() const;
|
||||||
|
int GetMaxThreadsPerBlock() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
int device_num_;
|
||||||
|
cudaDeviceProp* prop_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace utils
|
||||||
|
} // namespace cuda
|
||||||
|
} // namespace colossalAI
|
Loading…
Reference in New Issue