mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish colossalai/kernel/cuda_native/csrc/type_shim.h code style (#1260)
parent
f660152c73
commit
f8b9aaef47
|
@ -1,76 +1,63 @@
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
|
|
||||||
#include "compat.h"
|
#include "compat.h"
|
||||||
|
|
||||||
|
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
|
||||||
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
|
switch (TYPE) { \
|
||||||
switch(TYPE) \
|
case at::ScalarType::Half: { \
|
||||||
{ \
|
using scalar_t = at::Half; \
|
||||||
case at::ScalarType::Half: \
|
__VA_ARGS__; \
|
||||||
{ \
|
break; \
|
||||||
using scalar_t = at::Half; \
|
} \
|
||||||
__VA_ARGS__; \
|
case at::ScalarType::BFloat16: { \
|
||||||
break; \
|
using scalar_t = at::BFloat16; \
|
||||||
} \
|
__VA_ARGS__; \
|
||||||
case at::ScalarType::BFloat16: \
|
break; \
|
||||||
{ \
|
} \
|
||||||
using scalar_t = at::BFloat16; \
|
default: \
|
||||||
__VA_ARGS__; \
|
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||||
break; \
|
}
|
||||||
} \
|
|
||||||
default: \
|
|
||||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
|
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
|
||||||
switch(TYPEIN) \
|
switch (TYPEIN) { \
|
||||||
{ \
|
case at::ScalarType::Float: { \
|
||||||
case at::ScalarType::Float: \
|
using scalar_t_in = float; \
|
||||||
{ \
|
switch (TYPEOUT) { \
|
||||||
using scalar_t_in = float; \
|
case at::ScalarType::Float: { \
|
||||||
switch(TYPEOUT) \
|
using scalar_t_out = float; \
|
||||||
{ \
|
__VA_ARGS__; \
|
||||||
case at::ScalarType::Float: \
|
break; \
|
||||||
{ \
|
} \
|
||||||
using scalar_t_out = float; \
|
case at::ScalarType::Half: { \
|
||||||
__VA_ARGS__; \
|
using scalar_t_out = at::Half; \
|
||||||
break; \
|
__VA_ARGS__; \
|
||||||
} \
|
break; \
|
||||||
case at::ScalarType::Half: \
|
} \
|
||||||
{ \
|
case at::ScalarType::BFloat16: { \
|
||||||
using scalar_t_out = at::Half; \
|
using scalar_t_out = at::BFloat16; \
|
||||||
__VA_ARGS__; \
|
__VA_ARGS__; \
|
||||||
break; \
|
break; \
|
||||||
} \
|
} \
|
||||||
case at::ScalarType::BFloat16: \
|
default: \
|
||||||
{ \
|
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
|
||||||
using scalar_t_out = at::BFloat16; \
|
} \
|
||||||
__VA_ARGS__; \
|
break; \
|
||||||
break; \
|
} \
|
||||||
} \
|
case at::ScalarType::Half: { \
|
||||||
default: \
|
using scalar_t_in = at::Half; \
|
||||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
|
using scalar_t_out = at::Half; \
|
||||||
} \
|
__VA_ARGS__; \
|
||||||
break; \
|
break; \
|
||||||
} \
|
} \
|
||||||
case at::ScalarType::Half: \
|
case at::ScalarType::BFloat16: { \
|
||||||
{ \
|
using scalar_t_in = at::BFloat16; \
|
||||||
using scalar_t_in = at::Half; \
|
using scalar_t_out = at::BFloat16; \
|
||||||
using scalar_t_out = at::Half; \
|
__VA_ARGS__; \
|
||||||
__VA_ARGS__; \
|
break; \
|
||||||
break; \
|
} \
|
||||||
} \
|
default: \
|
||||||
case at::ScalarType::BFloat16: \
|
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
|
||||||
{ \
|
}
|
||||||
using scalar_t_in = at::BFloat16; \
|
|
||||||
using scalar_t_out = at::BFloat16; \
|
|
||||||
__VA_ARGS__; \
|
|
||||||
break; \
|
|
||||||
} \
|
|
||||||
default: \
|
|
||||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward/backward compatiblity hack around
|
// Forward/backward compatiblity hack around
|
||||||
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
|
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
|
||||||
|
@ -81,222 +68,191 @@
|
||||||
// TypeShim(const at::Type& type) : payload(type) {}
|
// TypeShim(const at::Type& type) : payload(type) {}
|
||||||
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
|
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
|
||||||
// operator const at::Type&(){ return payload; };
|
// operator const at::Type&(){ return payload; };
|
||||||
// // Enable dispatch switch statements to take *this directly for post-3aeb78
|
// // Enable dispatch switch statements to take *this directly for post-3aeb78
|
||||||
// //operator at::ScalarType(){ return payload.; };
|
// //operator at::ScalarType(){ return payload.; };
|
||||||
// };
|
// };
|
||||||
|
|
||||||
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
|
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
|
||||||
switch (TYPE) \
|
switch (TYPE) { \
|
||||||
{ \
|
case at::ScalarType::Float: { \
|
||||||
case at::ScalarType::Float: \
|
using scalar_t_##LEVEL = float; \
|
||||||
{ \
|
__VA_ARGS__; \
|
||||||
using scalar_t_##LEVEL = float; \
|
break; \
|
||||||
__VA_ARGS__; \
|
} \
|
||||||
break; \
|
case at::ScalarType::Half: { \
|
||||||
} \
|
using scalar_t_##LEVEL = at::Half; \
|
||||||
case at::ScalarType::Half: \
|
__VA_ARGS__; \
|
||||||
{ \
|
break; \
|
||||||
using scalar_t_##LEVEL = at::Half; \
|
} \
|
||||||
__VA_ARGS__; \
|
default: \
|
||||||
break; \
|
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||||
} \
|
}
|
||||||
default: \
|
|
||||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
|
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
|
||||||
switch (TYPE) \
|
switch (TYPE) { \
|
||||||
{ \
|
case at::ScalarType::Float: { \
|
||||||
case at::ScalarType::Float: \
|
using scalar_t_##LEVEL = float; \
|
||||||
{ \
|
__VA_ARGS__; \
|
||||||
using scalar_t_##LEVEL = float; \
|
break; \
|
||||||
__VA_ARGS__; \
|
} \
|
||||||
break; \
|
case at::ScalarType::Half: { \
|
||||||
} \
|
using scalar_t_##LEVEL = at::Half; \
|
||||||
case at::ScalarType::Half: \
|
__VA_ARGS__; \
|
||||||
{ \
|
break; \
|
||||||
using scalar_t_##LEVEL = at::Half; \
|
} \
|
||||||
__VA_ARGS__; \
|
case at::ScalarType::Byte: { \
|
||||||
break; \
|
using scalar_t_##LEVEL = uint8_t; \
|
||||||
} \
|
__VA_ARGS__; \
|
||||||
case at::ScalarType::Byte: \
|
break; \
|
||||||
{ \
|
} \
|
||||||
using scalar_t_##LEVEL = uint8_t; \
|
default: \
|
||||||
__VA_ARGS__; \
|
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||||
break; \
|
}
|
||||||
} \
|
|
||||||
default: \
|
|
||||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
|
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
|
||||||
switch (TYPE) \
|
switch (TYPE) { \
|
||||||
{ \
|
case at::ScalarType::Double: { \
|
||||||
case at::ScalarType::Double: \
|
using scalar_t_##LEVEL = double; \
|
||||||
{ \
|
__VA_ARGS__; \
|
||||||
using scalar_t_##LEVEL = double; \
|
break; \
|
||||||
__VA_ARGS__; \
|
} \
|
||||||
break; \
|
case at::ScalarType::Float: { \
|
||||||
} \
|
using scalar_t_##LEVEL = float; \
|
||||||
case at::ScalarType::Float: \
|
__VA_ARGS__; \
|
||||||
{ \
|
break; \
|
||||||
using scalar_t_##LEVEL = float; \
|
} \
|
||||||
__VA_ARGS__; \
|
case at::ScalarType::Half: { \
|
||||||
break; \
|
using scalar_t_##LEVEL = at::Half; \
|
||||||
} \
|
__VA_ARGS__; \
|
||||||
case at::ScalarType::Half: \
|
break; \
|
||||||
{ \
|
} \
|
||||||
using scalar_t_##LEVEL = at::Half; \
|
default: \
|
||||||
__VA_ARGS__; \
|
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||||
break; \
|
}
|
||||||
} \
|
|
||||||
default: \
|
|
||||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
|
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
|
||||||
switch (TYPE) \
|
switch (TYPE) { \
|
||||||
{ \
|
case at::ScalarType::Double: { \
|
||||||
case at::ScalarType::Double: \
|
using scalar_t_##LEVEL = double; \
|
||||||
{ \
|
__VA_ARGS__; \
|
||||||
using scalar_t_##LEVEL = double; \
|
break; \
|
||||||
__VA_ARGS__; \
|
} \
|
||||||
break; \
|
case at::ScalarType::Float: { \
|
||||||
} \
|
using scalar_t_##LEVEL = float; \
|
||||||
case at::ScalarType::Float: \
|
__VA_ARGS__; \
|
||||||
{ \
|
break; \
|
||||||
using scalar_t_##LEVEL = float; \
|
} \
|
||||||
__VA_ARGS__; \
|
default: \
|
||||||
break; \
|
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||||
} \
|
}
|
||||||
default: \
|
|
||||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define DISPATCH_FLOAT_AND_HALF_FOR_G_P(GTYPE, PTYPE, LEVEL, NAME, ...) \
|
#define DISPATCH_FLOAT_AND_HALF_FOR_G_P(GTYPE, PTYPE, LEVEL, NAME, ...) \
|
||||||
if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Float) \
|
if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Float) { \
|
||||||
{ \
|
using g_scalar_t_##LEVEL = float; \
|
||||||
using g_scalar_t_##LEVEL = float; \
|
using p_scalar_t_##LEVEL = float; \
|
||||||
using p_scalar_t_##LEVEL = float; \
|
__VA_ARGS__; \
|
||||||
__VA_ARGS__; \
|
} else if (GTYPE == at::ScalarType::Float && \
|
||||||
} \
|
PTYPE == at::ScalarType::Half) { \
|
||||||
else if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Half) \
|
using g_scalar_t_##LEVEL = float; \
|
||||||
{ \
|
using p_scalar_t_##LEVEL = at::Half; \
|
||||||
using g_scalar_t_##LEVEL = float; \
|
__VA_ARGS__; \
|
||||||
using p_scalar_t_##LEVEL = at::Half; \
|
} else if (GTYPE == at::ScalarType::Half && \
|
||||||
__VA_ARGS__; \
|
PTYPE == at::ScalarType::Float) { \
|
||||||
} \
|
using g_scalar_t_##LEVEL = at::Half; \
|
||||||
else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Float) \
|
using p_scalar_t_##LEVEL = float; \
|
||||||
{ \
|
__VA_ARGS__; \
|
||||||
using g_scalar_t_##LEVEL = at::Half; \
|
} else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Half) { \
|
||||||
using p_scalar_t_##LEVEL = float; \
|
using g_scalar_t_##LEVEL = at::Half; \
|
||||||
__VA_ARGS__; \
|
using p_scalar_t_##LEVEL = at::Half; \
|
||||||
} \
|
__VA_ARGS__; \
|
||||||
else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Half) \
|
} else { \
|
||||||
{ \
|
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \
|
||||||
using g_scalar_t_##LEVEL = at::Half; \
|
"'"); \
|
||||||
using p_scalar_t_##LEVEL = at::Half; \
|
}
|
||||||
__VA_ARGS__; \
|
|
||||||
} \
|
|
||||||
else \
|
|
||||||
{ \
|
|
||||||
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), "'"); \
|
|
||||||
} \
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ __forceinline__ T reduce_block_into_lanes(T *x,
|
__device__ __forceinline__ T reduce_block_into_lanes(
|
||||||
T val,
|
T *x, T val, int lanes = 1,
|
||||||
int lanes = 1,
|
bool share_result = false) // lanes is intended to be <= 32.
|
||||||
bool share_result = false) // lanes is intended to be <= 32.
|
|
||||||
{
|
{
|
||||||
int tid = threadIdx.x + threadIdx.y * blockDim.x;
|
int tid = threadIdx.x + threadIdx.y * blockDim.x;
|
||||||
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
|
int blockSize =
|
||||||
|
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
|
||||||
|
|
||||||
|
if (blockSize >= 64) {
|
||||||
|
x[tid] = val;
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
|
||||||
|
if (tid < i) x[tid] = x[tid] + x[tid + i];
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
T final;
|
||||||
|
|
||||||
|
if (tid < 32) {
|
||||||
if (blockSize >= 64)
|
if (blockSize >= 64)
|
||||||
{
|
final = x[tid] + x[tid + 32];
|
||||||
x[tid] = val;
|
else
|
||||||
__syncthreads();
|
final = val;
|
||||||
}
|
// __SYNCWARP();
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = (blockSize >> 1); i >= 64; i >>= 1)
|
for (int i = 16; i >= lanes; i >>= 1)
|
||||||
{
|
final = final + __shfl_down_sync(0xffffffff, final, i);
|
||||||
if (tid < i)
|
}
|
||||||
x[tid] = x[tid] + x[tid + i];
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
|
|
||||||
T final;
|
if (share_result) {
|
||||||
|
if (tid < lanes) x[tid] = final; // EpilogueOp
|
||||||
|
// Make sure the smem result is visible to all warps.
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
if (tid < 32)
|
return final;
|
||||||
{
|
|
||||||
if (blockSize >= 64)
|
|
||||||
final = x[tid] + x[tid + 32];
|
|
||||||
else
|
|
||||||
final = val;
|
|
||||||
// __SYNCWARP();
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 16; i >= lanes; i >>= 1)
|
|
||||||
final = final + __shfl_down_sync(0xffffffff, final, i);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (share_result)
|
|
||||||
{
|
|
||||||
if (tid < lanes)
|
|
||||||
x[tid] = final; // EpilogueOp
|
|
||||||
// Make sure the smem result is visible to all warps.
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
|
|
||||||
return final;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ __forceinline__ T reduce_block_into_lanes_max_op(T *x,
|
__device__ __forceinline__ T reduce_block_into_lanes_max_op(
|
||||||
T val,
|
T *x, T val, int lanes = 1,
|
||||||
int lanes = 1,
|
bool share_result = false) // lanes is intended to be <= 32.
|
||||||
bool share_result = false) // lanes is intended to be <= 32.
|
|
||||||
{
|
{
|
||||||
int tid = threadIdx.x + threadIdx.y * blockDim.x;
|
int tid = threadIdx.x + threadIdx.y * blockDim.x;
|
||||||
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
|
int blockSize =
|
||||||
|
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
|
||||||
|
|
||||||
|
if (blockSize >= 64) {
|
||||||
|
x[tid] = val;
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
|
||||||
|
if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
T final;
|
||||||
|
|
||||||
|
if (tid < 32) {
|
||||||
if (blockSize >= 64)
|
if (blockSize >= 64)
|
||||||
{
|
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
|
||||||
x[tid] = val;
|
else
|
||||||
__syncthreads();
|
final = val;
|
||||||
}
|
// __SYNCWARP();
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = (blockSize >> 1); i >= 64; i >>= 1)
|
for (int i = 16; i >= lanes; i >>= 1)
|
||||||
{
|
final =
|
||||||
if (tid < i)
|
fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
|
||||||
x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
|
}
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
|
|
||||||
T final;
|
if (share_result) {
|
||||||
|
if (tid < lanes) x[tid] = final; // EpilogueOp
|
||||||
|
// Make sure the smem result is visible to all warps.
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
if (tid < 32)
|
return final;
|
||||||
{
|
|
||||||
if (blockSize >= 64)
|
|
||||||
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
|
|
||||||
else
|
|
||||||
final = val;
|
|
||||||
// __SYNCWARP();
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 16; i >= lanes; i >>= 1)
|
|
||||||
final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (share_result)
|
|
||||||
{
|
|
||||||
if (tid < lanes)
|
|
||||||
x[tid] = final; // EpilogueOp
|
|
||||||
// Make sure the smem result is visible to all warps.
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
|
|
||||||
return final;
|
|
||||||
}
|
}
|
Loading…
Reference in New Issue