[NFC] polish colossalai/kernel/cuda_native/csrc/type_shim.h code style (#1260)

pull/1298/head
Sze-qq 2022-07-12 17:38:34 +08:00 committed by Frank Lee
parent f660152c73
commit f8b9aaef47
1 changed files with 216 additions and 260 deletions

View File

@ -1,76 +1,63 @@
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include "compat.h" #include "compat.h"
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ switch (TYPE) { \
switch(TYPE) \ case at::ScalarType::Half: { \
{ \ using scalar_t = at::Half; \
case at::ScalarType::Half: \ __VA_ARGS__; \
{ \ break; \
using scalar_t = at::Half; \ } \
__VA_ARGS__; \ case at::ScalarType::BFloat16: { \
break; \ using scalar_t = at::BFloat16; \
} \ __VA_ARGS__; \
case at::ScalarType::BFloat16: \ break; \
{ \ } \
using scalar_t = at::BFloat16; \ default: \
__VA_ARGS__; \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
break; \ }
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch(TYPEIN) \ switch (TYPEIN) { \
{ \ case at::ScalarType::Float: { \
case at::ScalarType::Float: \ using scalar_t_in = float; \
{ \ switch (TYPEOUT) { \
using scalar_t_in = float; \ case at::ScalarType::Float: { \
switch(TYPEOUT) \ using scalar_t_out = float; \
{ \ __VA_ARGS__; \
case at::ScalarType::Float: \ break; \
{ \ } \
using scalar_t_out = float; \ case at::ScalarType::Half: { \
__VA_ARGS__; \ using scalar_t_out = at::Half; \
break; \ __VA_ARGS__; \
} \ break; \
case at::ScalarType::Half: \ } \
{ \ case at::ScalarType::BFloat16: { \
using scalar_t_out = at::Half; \ using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} \ } \
case at::ScalarType::BFloat16: \ default: \
{ \ AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
using scalar_t_out = at::BFloat16; \ } \
__VA_ARGS__; \ break; \
break; \ } \
} \ case at::ScalarType::Half: { \
default: \ using scalar_t_in = at::Half; \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ using scalar_t_out = at::Half; \
} \ __VA_ARGS__; \
break; \ break; \
} \ } \
case at::ScalarType::Half: \ case at::ScalarType::BFloat16: { \
{ \ using scalar_t_in = at::BFloat16; \
using scalar_t_in = at::Half; \ using scalar_t_out = at::BFloat16; \
using scalar_t_out = at::Half; \ __VA_ARGS__; \
__VA_ARGS__; \ break; \
break; \ } \
} \ default: \
case at::ScalarType::BFloat16: \ AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
{ \ }
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
// Forward/backward compatiblity hack around // Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288 // https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
@ -81,222 +68,191 @@
// TypeShim(const at::Type& type) : payload(type) {} // TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78 // // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; }; // operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78 // // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; }; // //operator at::ScalarType(){ return payload.; };
// }; // };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ #define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) \ switch (TYPE) { \
{ \ case at::ScalarType::Float: { \
case at::ScalarType::Float: \ using scalar_t_##LEVEL = float; \
{ \ __VA_ARGS__; \
using scalar_t_##LEVEL = float; \ break; \
__VA_ARGS__; \ } \
break; \ case at::ScalarType::Half: { \
} \ using scalar_t_##LEVEL = at::Half; \
case at::ScalarType::Half: \ __VA_ARGS__; \
{ \ break; \
using scalar_t_##LEVEL = at::Half; \ } \
__VA_ARGS__; \ default: \
break; \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
} \ }
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \ #define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch (TYPE) \ switch (TYPE) { \
{ \ case at::ScalarType::Float: { \
case at::ScalarType::Float: \ using scalar_t_##LEVEL = float; \
{ \ __VA_ARGS__; \
using scalar_t_##LEVEL = float; \ break; \
__VA_ARGS__; \ } \
break; \ case at::ScalarType::Half: { \
} \ using scalar_t_##LEVEL = at::Half; \
case at::ScalarType::Half: \ __VA_ARGS__; \
{ \ break; \
using scalar_t_##LEVEL = at::Half; \ } \
__VA_ARGS__; \ case at::ScalarType::Byte: { \
break; \ using scalar_t_##LEVEL = uint8_t; \
} \ __VA_ARGS__; \
case at::ScalarType::Byte: \ break; \
{ \ } \
using scalar_t_##LEVEL = uint8_t; \ default: \
__VA_ARGS__; \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
break; \ }
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ #define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) \ switch (TYPE) { \
{ \ case at::ScalarType::Double: { \
case at::ScalarType::Double: \ using scalar_t_##LEVEL = double; \
{ \ __VA_ARGS__; \
using scalar_t_##LEVEL = double; \ break; \
__VA_ARGS__; \ } \
break; \ case at::ScalarType::Float: { \
} \ using scalar_t_##LEVEL = float; \
case at::ScalarType::Float: \ __VA_ARGS__; \
{ \ break; \
using scalar_t_##LEVEL = float; \ } \
__VA_ARGS__; \ case at::ScalarType::Half: { \
break; \ using scalar_t_##LEVEL = at::Half; \
} \ __VA_ARGS__; \
case at::ScalarType::Half: \ break; \
{ \ } \
using scalar_t_##LEVEL = at::Half; \ default: \
__VA_ARGS__; \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
break; \ }
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \ #define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) \ switch (TYPE) { \
{ \ case at::ScalarType::Double: { \
case at::ScalarType::Double: \ using scalar_t_##LEVEL = double; \
{ \ __VA_ARGS__; \
using scalar_t_##LEVEL = double; \ break; \
__VA_ARGS__; \ } \
break; \ case at::ScalarType::Float: { \
} \ using scalar_t_##LEVEL = float; \
case at::ScalarType::Float: \ __VA_ARGS__; \
{ \ break; \
using scalar_t_##LEVEL = float; \ } \
__VA_ARGS__; \ default: \
break; \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
} \ }
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_AND_HALF_FOR_G_P(GTYPE, PTYPE, LEVEL, NAME, ...) \ #define DISPATCH_FLOAT_AND_HALF_FOR_G_P(GTYPE, PTYPE, LEVEL, NAME, ...) \
if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Float) \ if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Float) { \
{ \ using g_scalar_t_##LEVEL = float; \
using g_scalar_t_##LEVEL = float; \ using p_scalar_t_##LEVEL = float; \
using p_scalar_t_##LEVEL = float; \ __VA_ARGS__; \
__VA_ARGS__; \ } else if (GTYPE == at::ScalarType::Float && \
} \ PTYPE == at::ScalarType::Half) { \
else if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Half) \ using g_scalar_t_##LEVEL = float; \
{ \ using p_scalar_t_##LEVEL = at::Half; \
using g_scalar_t_##LEVEL = float; \ __VA_ARGS__; \
using p_scalar_t_##LEVEL = at::Half; \ } else if (GTYPE == at::ScalarType::Half && \
__VA_ARGS__; \ PTYPE == at::ScalarType::Float) { \
} \ using g_scalar_t_##LEVEL = at::Half; \
else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Float) \ using p_scalar_t_##LEVEL = float; \
{ \ __VA_ARGS__; \
using g_scalar_t_##LEVEL = at::Half; \ } else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Half) { \
using p_scalar_t_##LEVEL = float; \ using g_scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \ using p_scalar_t_##LEVEL = at::Half; \
} \ __VA_ARGS__; \
else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Half) \ } else { \
{ \ AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \
using g_scalar_t_##LEVEL = at::Half; \ "'"); \
using p_scalar_t_##LEVEL = at::Half; \ }
__VA_ARGS__; \
} \
else \
{ \
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), "'"); \
} \
template <typename T> template <typename T>
__device__ __forceinline__ T reduce_block_into_lanes(T *x, __device__ __forceinline__ T reduce_block_into_lanes(
T val, T *x, T val, int lanes = 1,
int lanes = 1, bool share_result = false) // lanes is intended to be <= 32.
bool share_result = false) // lanes is intended to be <= 32.
{ {
int tid = threadIdx.x + threadIdx.y * blockDim.x; int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. int blockSize =
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = x[tid] + x[tid + i];
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64) if (blockSize >= 64)
{ final = x[tid] + x[tid + 32];
x[tid] = val; else
__syncthreads(); final = val;
} // __SYNCWARP();
#pragma unroll #pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) for (int i = 16; i >= lanes; i >>= 1)
{ final = final + __shfl_down_sync(0xffffffff, final, i);
if (tid < i) }
x[tid] = x[tid] + x[tid + i];
__syncthreads();
}
T final; if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
if (tid < 32) return final;
{
if (blockSize >= 64)
final = x[tid] + x[tid + 32];
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
final = final + __shfl_down_sync(0xffffffff, final, i);
}
if (share_result)
{
if (tid < lanes)
x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
} }
template <typename T> template <typename T>
__device__ __forceinline__ T reduce_block_into_lanes_max_op(T *x, __device__ __forceinline__ T reduce_block_into_lanes_max_op(
T val, T *x, T val, int lanes = 1,
int lanes = 1, bool share_result = false) // lanes is intended to be <= 32.
bool share_result = false) // lanes is intended to be <= 32.
{ {
int tid = threadIdx.x + threadIdx.y * blockDim.x; int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. int blockSize =
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64) if (blockSize >= 64)
{ final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
x[tid] = val; else
__syncthreads(); final = val;
} // __SYNCWARP();
#pragma unroll #pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) for (int i = 16; i >= lanes; i >>= 1)
{ final =
if (tid < i) fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i])); }
__syncthreads();
}
T final; if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
if (tid < 32) return final;
{
if (blockSize >= 64)
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
}
if (share_result)
{
if (tid < lanes)
x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
} }