|
|
|
@ -9,7 +9,15 @@
|
|
|
|
|
|
|
|
|
|
#include <ATen/ATen.h> |
|
|
|
|
|
|
|
|
|
#include "compat.h" |
|
|
|
|
#ifndef TORCH_CHECK |
|
|
|
|
#define TORCH_CHECK AT_CHECK |
|
|
|
|
#endif |
|
|
|
|
|
|
|
|
|
#ifdef VERSION_GE_1_3 |
|
|
|
|
#define DATA_PTR data_ptr |
|
|
|
|
#else |
|
|
|
|
#define DATA_PTR data |
|
|
|
|
#endif |
|
|
|
|
|
|
|
|
|
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ |
|
|
|
|
switch (TYPE) { \
|
|
|
|
@ -214,90 +222,3 @@
|
|
|
|
|
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \
|
|
|
|
|
"'"); \
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
template <typename T> |
|
|
|
|
__device__ __forceinline__ T reduce_block_into_lanes( |
|
|
|
|
T *x, T val, int lanes = 1, |
|
|
|
|
bool share_result = false) // lanes is intended to be <= 32.
|
|
|
|
|
{ |
|
|
|
|
int tid = threadIdx.x + threadIdx.y * blockDim.x; |
|
|
|
|
int blockSize = |
|
|
|
|
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
|
|
|
|
|
|
|
|
|
|
if (blockSize >= 64) { |
|
|
|
|
x[tid] = val; |
|
|
|
|
__syncthreads(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
#pragma unroll |
|
|
|
|
for (int i = (blockSize >> 1); i >= 64; i >>= 1) { |
|
|
|
|
if (tid < i) x[tid] = x[tid] + x[tid + i]; |
|
|
|
|
__syncthreads(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
T final; |
|
|
|
|
|
|
|
|
|
if (tid < 32) { |
|
|
|
|
if (blockSize >= 64) |
|
|
|
|
final = x[tid] + x[tid + 32]; |
|
|
|
|
else |
|
|
|
|
final = val; |
|
|
|
|
// __SYNCWARP();
|
|
|
|
|
|
|
|
|
|
#pragma unroll |
|
|
|
|
for (int i = 16; i >= lanes; i >>= 1) |
|
|
|
|
final = final + __shfl_down_sync(0xffffffff, final, i); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if (share_result) { |
|
|
|
|
if (tid < lanes) x[tid] = final; // EpilogueOp
|
|
|
|
|
// Make sure the smem result is visible to all warps.
|
|
|
|
|
__syncthreads(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return final; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
template <typename T> |
|
|
|
|
__device__ __forceinline__ T reduce_block_into_lanes_max_op( |
|
|
|
|
T *x, T val, int lanes = 1, |
|
|
|
|
bool share_result = false) // lanes is intended to be <= 32.
|
|
|
|
|
{ |
|
|
|
|
int tid = threadIdx.x + threadIdx.y * blockDim.x; |
|
|
|
|
int blockSize = |
|
|
|
|
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
|
|
|
|
|
|
|
|
|
|
if (blockSize >= 64) { |
|
|
|
|
x[tid] = val; |
|
|
|
|
__syncthreads(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
#pragma unroll |
|
|
|
|
for (int i = (blockSize >> 1); i >= 64; i >>= 1) { |
|
|
|
|
if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i])); |
|
|
|
|
__syncthreads(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
T final; |
|
|
|
|
|
|
|
|
|
if (tid < 32) { |
|
|
|
|
if (blockSize >= 64) |
|
|
|
|
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32])); |
|
|
|
|
else |
|
|
|
|
final = val; |
|
|
|
|
// __SYNCWARP();
|
|
|
|
|
|
|
|
|
|
#pragma unroll |
|
|
|
|
for (int i = 16; i >= lanes; i >>= 1) |
|
|
|
|
final = |
|
|
|
|
fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i))); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if (share_result) { |
|
|
|
|
if (tid < lanes) x[tid] = final; // EpilogueOp
|
|
|
|
|
// Make sure the smem result is visible to all warps.
|
|
|
|
|
__syncthreads(); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return final; |
|
|
|
|
} |