[NFC] polish colossalai/kernel/cuda_native/csrc/type_shim.h code style (#1260)

pull/1298/head
Sze-qq 2022-07-12 17:38:34 +08:00 committed by Frank Lee
parent f660152c73
commit f8b9aaef47
1 changed files with 216 additions and 260 deletions

View File

@ -1,18 +1,15 @@
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include "compat.h" #include "compat.h"
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ #define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch(TYPE) \ switch (TYPE) { \
{ \ case at::ScalarType::Half: { \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \ using scalar_t = at::Half; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} \ } \
case at::ScalarType::BFloat16: \ case at::ScalarType::BFloat16: { \
{ \
using scalar_t = at::BFloat16; \ using scalar_t = at::BFloat16; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
@ -21,30 +18,22 @@
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
} }
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch(TYPEIN) \ switch (TYPEIN) { \
{ \ case at::ScalarType::Float: { \
case at::ScalarType::Float: \
{ \
using scalar_t_in = float; \ using scalar_t_in = float; \
switch(TYPEOUT) \ switch (TYPEOUT) { \
{ \ case at::ScalarType::Float: { \
case at::ScalarType::Float: \
{ \
using scalar_t_out = float; \ using scalar_t_out = float; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} \ } \
case at::ScalarType::Half: \ case at::ScalarType::Half: { \
{ \
using scalar_t_out = at::Half; \ using scalar_t_out = at::Half; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} \ } \
case at::ScalarType::BFloat16: \ case at::ScalarType::BFloat16: { \
{ \
using scalar_t_out = at::BFloat16; \ using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
@ -54,15 +43,13 @@
} \ } \
break; \ break; \
} \ } \
case at::ScalarType::Half: \ case at::ScalarType::Half: { \
{ \
using scalar_t_in = at::Half; \ using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \ using scalar_t_out = at::Half; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} \ } \
case at::ScalarType::BFloat16: \ case at::ScalarType::BFloat16: { \
{ \
using scalar_t_in = at::BFloat16; \ using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \ using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \ __VA_ARGS__; \
@ -86,16 +73,13 @@
// }; // };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ #define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) \ switch (TYPE) { \
{ \ case at::ScalarType::Float: { \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \ using scalar_t_##LEVEL = float; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} \ } \
case at::ScalarType::Half: \ case at::ScalarType::Half: { \
{ \
using scalar_t_##LEVEL = at::Half; \ using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
@ -105,22 +89,18 @@
} }
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \ #define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch (TYPE) \ switch (TYPE) { \
{ \ case at::ScalarType::Float: { \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \ using scalar_t_##LEVEL = float; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} \ } \
case at::ScalarType::Half: \ case at::ScalarType::Half: { \
{ \
using scalar_t_##LEVEL = at::Half; \ using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} \ } \
case at::ScalarType::Byte: \ case at::ScalarType::Byte: { \
{ \
using scalar_t_##LEVEL = uint8_t; \ using scalar_t_##LEVEL = uint8_t; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
@ -130,22 +110,18 @@
} }
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ #define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) \ switch (TYPE) { \
{ \ case at::ScalarType::Double: { \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \ using scalar_t_##LEVEL = double; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} \ } \
case at::ScalarType::Float: \ case at::ScalarType::Float: { \
{ \
using scalar_t_##LEVEL = float; \ using scalar_t_##LEVEL = float; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} \ } \
case at::ScalarType::Half: \ case at::ScalarType::Half: { \
{ \
using scalar_t_##LEVEL = at::Half; \ using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
@ -155,16 +131,13 @@
} }
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \ #define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) \ switch (TYPE) { \
{ \ case at::ScalarType::Double: { \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \ using scalar_t_##LEVEL = double; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} \ } \
case at::ScalarType::Float: \ case at::ScalarType::Float: { \
{ \
using scalar_t_##LEVEL = float; \ using scalar_t_##LEVEL = float; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
@ -174,62 +147,52 @@
} }
#define DISPATCH_FLOAT_AND_HALF_FOR_G_P(GTYPE, PTYPE, LEVEL, NAME, ...) \ #define DISPATCH_FLOAT_AND_HALF_FOR_G_P(GTYPE, PTYPE, LEVEL, NAME, ...) \
if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Float) \ if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Float) { \
{ \
using g_scalar_t_##LEVEL = float; \ using g_scalar_t_##LEVEL = float; \
using p_scalar_t_##LEVEL = float; \ using p_scalar_t_##LEVEL = float; \
__VA_ARGS__; \ __VA_ARGS__; \
} \ } else if (GTYPE == at::ScalarType::Float && \
else if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Half) \ PTYPE == at::ScalarType::Half) { \
{ \
using g_scalar_t_##LEVEL = float; \ using g_scalar_t_##LEVEL = float; \
using p_scalar_t_##LEVEL = at::Half; \ using p_scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \ __VA_ARGS__; \
} \ } else if (GTYPE == at::ScalarType::Half && \
else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Float) \ PTYPE == at::ScalarType::Float) { \
{ \
using g_scalar_t_##LEVEL = at::Half; \ using g_scalar_t_##LEVEL = at::Half; \
using p_scalar_t_##LEVEL = float; \ using p_scalar_t_##LEVEL = float; \
__VA_ARGS__; \ __VA_ARGS__; \
} \ } else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Half) { \
else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Half) \
{ \
using g_scalar_t_##LEVEL = at::Half; \ using g_scalar_t_##LEVEL = at::Half; \
using p_scalar_t_##LEVEL = at::Half; \ using p_scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \ __VA_ARGS__; \
} \ } else { \
else \ AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \
{ \ "'"); \
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), "'"); \ }
} \
template <typename T> template <typename T>
__device__ __forceinline__ T reduce_block_into_lanes(T *x, __device__ __forceinline__ T reduce_block_into_lanes(
T val, T *x, T val, int lanes = 1,
int lanes = 1,
bool share_result = false) // lanes is intended to be <= 32. bool share_result = false) // lanes is intended to be <= 32.
{ {
int tid = threadIdx.x + threadIdx.y * blockDim.x; int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. int blockSize =
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) if (blockSize >= 64) {
{
x[tid] = val; x[tid] = val;
__syncthreads(); __syncthreads();
} }
#pragma unroll #pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
{ if (tid < i) x[tid] = x[tid] + x[tid + i];
if (tid < i)
x[tid] = x[tid] + x[tid + i];
__syncthreads(); __syncthreads();
} }
T final; T final;
if (tid < 32) if (tid < 32) {
{
if (blockSize >= 64) if (blockSize >= 64)
final = x[tid] + x[tid + 32]; final = x[tid] + x[tid + 32];
else else
@ -241,10 +204,8 @@ __device__ __forceinline__ T reduce_block_into_lanes(T *x,
final = final + __shfl_down_sync(0xffffffff, final, i); final = final + __shfl_down_sync(0xffffffff, final, i);
} }
if (share_result) if (share_result) {
{ if (tid < lanes) x[tid] = final; // EpilogueOp
if (tid < lanes)
x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps. // Make sure the smem result is visible to all warps.
__syncthreads(); __syncthreads();
} }
@ -253,32 +214,28 @@ __device__ __forceinline__ T reduce_block_into_lanes(T *x,
} }
template <typename T> template <typename T>
__device__ __forceinline__ T reduce_block_into_lanes_max_op(T *x, __device__ __forceinline__ T reduce_block_into_lanes_max_op(
T val, T *x, T val, int lanes = 1,
int lanes = 1,
bool share_result = false) // lanes is intended to be <= 32. bool share_result = false) // lanes is intended to be <= 32.
{ {
int tid = threadIdx.x + threadIdx.y * blockDim.x; int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. int blockSize =
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) if (blockSize >= 64) {
{
x[tid] = val; x[tid] = val;
__syncthreads(); __syncthreads();
} }
#pragma unroll #pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
{ if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
if (tid < i)
x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
__syncthreads(); __syncthreads();
} }
T final; T final;
if (tid < 32) if (tid < 32) {
{
if (blockSize >= 64) if (blockSize >= 64)
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32])); final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
else else
@ -287,13 +244,12 @@ __device__ __forceinline__ T reduce_block_into_lanes_max_op(T *x,
#pragma unroll #pragma unroll
for (int i = 16; i >= lanes; i >>= 1) for (int i = 16; i >= lanes; i >>= 1)
final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i))); final =
fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
} }
if (share_result) if (share_result) {
{ if (tid < lanes) x[tid] = final; // EpilogueOp
if (tid < lanes)
x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps. // Make sure the smem result is visible to all warps.
__syncthreads(); __syncthreads();
} }