mirror of https://github.com/hpcaitech/ColossalAI
[Inference] Add Reduce Utils (#5537)
* add reduce utils * add using to delele namespace prefixpull/5546/head
parent
04aca9e55b
commit
a2878e39f4
|
@ -9,16 +9,6 @@
|
|||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#ifndef TORCH_CHECK
|
||||
#define TORCH_CHECK AT_CHECK
|
||||
#endif
|
||||
|
||||
#ifdef VERSION_GE_1_3
|
||||
#define DATA_PTR data_ptr
|
||||
#else
|
||||
#define DATA_PTR data
|
||||
#endif
|
||||
|
||||
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::Half: { \
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <functional>
|
||||
|
||||
namespace colossalAI {
|
||||
namespace cuda {
|
||||
namespace funcs {
|
||||
|
||||
enum class BinaryOpType { kAdd = 0, kMinus, kMul, KDiv, kMax, KMin };
|
||||
|
||||
template <typename T, BinaryOpType Op>
|
||||
struct BinaryOpFunctor;
|
||||
|
||||
template <typename T>
|
||||
struct BinaryOpFunctor<T, BinaryOpType::kAdd>
|
||||
: public std::binary_function<T, T, T> {
|
||||
__host__ __device__ T operator()(T lhs, T rhs) { return lhs + rhs; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct BinaryOpFunctor<T, BinaryOpType::kMax>
|
||||
: public std::binary_function<T, T, T> {
|
||||
__host__ __device__ T operator()(T lhs, T rhs) { return max(lhs, rhs); }
|
||||
};
|
||||
|
||||
} // namespace funcs
|
||||
} // namespace cuda
|
||||
} // namespace colossalAI
|
|
@ -1,319 +1,100 @@
|
|||
/* Copyright 2021 The LightSeq Team
|
||||
Copyright Tencent/TurboTransformers
|
||||
This block_reduce_n is adapted from Tencent/TurboTransformers
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "../funcs/op_functor.h"
|
||||
|
||||
namespace colossalAI {
|
||||
namespace cuda {
|
||||
namespace utils {
|
||||
|
||||
const float kReduceFloatInfNeg = -100000000.f;
|
||||
const float kReduceFloatInfPos = 100000000.f;
|
||||
const int kWarpSize = 32;
|
||||
const unsigned int kWarpReduceMask = 0xffffffff;
|
||||
|
||||
enum class ReduceType { kMax = 0, kSum };
|
||||
const unsigned int WARP_REDUCE_MASK = 0xffffffff;
|
||||
const float REDUCE_FLOAT_INF_NEG = -100000000.f;
|
||||
const float REDUCE_FLOAT_INF_POS = 100000000.f;
|
||||
const unsigned int WARP_REDUCE_SIZE = 32;
|
||||
|
||||
template <typename T, ReduceType rtype>
|
||||
struct GetOpForReduceType;
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ T warpReduceSum(T val) {
|
||||
for (int mask = (WARP_REDUCE_SIZE >> 1); mask > 0; mask >>= 1)
|
||||
val += __shfl_xor_sync(WARP_REDUCE_MASK, val, mask, WARP_REDUCE_SIZE);
|
||||
return val;
|
||||
}
|
||||
struct GetOpForReduceType<T, ReduceType::kMax> {
|
||||
using Op = funcs::BinaryOpFunctor<T, funcs::BinaryOpType::kMax>;
|
||||
};
|
||||
|
||||
/* Calculate the sum of all elements in a block */
|
||||
template <typename T>
|
||||
__forceinline__ __device__ T blockReduceSum(T val) {
|
||||
static __shared__ T shared[32];
|
||||
int lane = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
struct GetOpForReduceType<T, ReduceType::kSum> {
|
||||
using Op = funcs::BinaryOpFunctor<T, funcs::BinaryOpType::kAdd>;
|
||||
};
|
||||
|
||||
val = warpReduceSum<T>(val);
|
||||
#define COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, DELTA, WIDTH, OP, LANES) \
|
||||
for (int offset = 0; offset < LANES; ++offset) { \
|
||||
*(VAL_PTR + offset) = \
|
||||
OP(*(VAL_PTR + offset), \
|
||||
__shfl_xor_sync(MASK, *(VAL_PTR + offset), DELTA, WIDTH)); \
|
||||
}
|
||||
|
||||
if (lane == 0) shared[wid] = val;
|
||||
__syncthreads();
|
||||
#define COLOSSAL_WARP_REDUCE_IMPL(MASK, VAL_PTR, OP, LANES) \
|
||||
COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 16, 32, OP, LANES) \
|
||||
COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 8, 32, OP, LANES) \
|
||||
COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 4, 32, OP, LANES) \
|
||||
COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 2, 32, OP, LANES) \
|
||||
COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, 1, 32, OP, LANES)
|
||||
|
||||
val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : (T)0.0f;
|
||||
val = warpReduceSum<T>(val);
|
||||
return val;
|
||||
#define COLOSSAL_BLOCK_REDUCE_IMPL(DTYPE, MASK, VAL_PTR, OP, LANES, \
|
||||
DEFAULT_VALUE, REDUCE_TYPE) \
|
||||
__shared__ T shm[LANES][32]; \
|
||||
int lane_id = threadIdx.x & 0x1f; \
|
||||
int warp_id = threadIdx.x >> 5; \
|
||||
\
|
||||
warp_reduce<DTYPE, REDUCE_TYPE, LANES>(VAL_PTR); \
|
||||
if (lane_id == 0) { \
|
||||
for (int offset = 0; offset < LANES; ++offset) { \
|
||||
shm[offset][warp_id] = *(VAL_PTR + offset); \
|
||||
} \
|
||||
} \
|
||||
__syncthreads(); \
|
||||
\
|
||||
for (int offset = 0; offset < LANES; ++offset) { \
|
||||
*(VAL_PTR + offset) = (threadIdx.x < (blockDim.x >> 5)) \
|
||||
? shm[offset][lane_id] \
|
||||
: static_cast<T>(DEFAULT_VALUE); \
|
||||
} \
|
||||
warp_reduce<DTYPE, REDUCE_TYPE, LANES>(VAL_PTR);
|
||||
|
||||
template <typename T, ReduceType rtype, int lanes>
|
||||
__forceinline__ __device__ void warp_reduce(T* pval) {
|
||||
typename GetOpForReduceType<T, rtype>::Op op;
|
||||
COLOSSAL_WARP_REDUCE_IMPL(kWarpReduceMask, pval, op, lanes);
|
||||
}
|
||||
|
||||
template <ReduceType Rtype, int Num>
|
||||
__inline__ __device__ void blockReduce(float *pval);
|
||||
|
||||
// use template to make code more concise
|
||||
template <ReduceType Rtype, int Num>
|
||||
__inline__ __device__ void warpReduce(float *pval);
|
||||
|
||||
// static
|
||||
template <>
|
||||
__inline__ __device__ void warpReduce<ReduceType::kMax, 1>(float *pval) {
|
||||
*pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 16, 32));
|
||||
*pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 8, 32));
|
||||
*pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 4, 32));
|
||||
*pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 2, 32));
|
||||
*pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 1, 32));
|
||||
template <typename T, ReduceType rtype>
|
||||
__forceinline__ __device__ constexpr T GetDefaultValueForBlockReduce() {
|
||||
if constexpr (rtype == ReduceType::kSum) {
|
||||
return static_cast<T>(0.0f);
|
||||
} else if constexpr (rtype == ReduceType::kMax) {
|
||||
return static_cast<T>(kReduceFloatInfNeg);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void warpReduce<ReduceType::kMax, 2>(float *pval) {
|
||||
float val0_tmp, val1_tmp;
|
||||
#define WarpReduceMaxOneStep(a, b) \
|
||||
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval), a, b); \
|
||||
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \
|
||||
*(pval) = max(val0_tmp, *(pval)); \
|
||||
*(pval + 1) = max(val1_tmp, *(pval + 1));
|
||||
|
||||
WarpReduceMaxOneStep(16, 32);
|
||||
WarpReduceMaxOneStep(8, 32);
|
||||
WarpReduceMaxOneStep(4, 32);
|
||||
WarpReduceMaxOneStep(2, 32);
|
||||
WarpReduceMaxOneStep(1, 32);
|
||||
#undef WarpReduceMaxOneStep
|
||||
template <typename T, ReduceType rtype, int lanes>
|
||||
__forceinline__ __device__ void block_reduce(T* pval) {
|
||||
constexpr T kDefaultValue = GetDefaultValueForBlockReduce<T, rtype>();
|
||||
typename GetOpForReduceType<T, rtype>::Op op;
|
||||
COLOSSAL_BLOCK_REDUCE_IMPL(T, kWarpReduceMask, pval, op, lanes, kDefaultValue,
|
||||
rtype);
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void warpReduce<ReduceType::kSum, 1>(float *pval) {
|
||||
*pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 16, 32);
|
||||
*pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 8, 32);
|
||||
*pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 4, 32);
|
||||
*pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 2, 32);
|
||||
*pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 1, 32);
|
||||
}
|
||||
|
||||
/*
|
||||
* Unorll for loop for warpreduce to
|
||||
* imporve instruction issue efficiency
|
||||
* ElemX means there are X numbers to be summed
|
||||
*/
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void warpReduce<ReduceType::kSum, 2>(float *pval) {
|
||||
float val0_tmp, val1_tmp;
|
||||
#define WarpReduceSumOneStep(a, b) \
|
||||
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \
|
||||
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \
|
||||
*(pval + 0) += val0_tmp; \
|
||||
*(pval + 1) += val1_tmp
|
||||
|
||||
WarpReduceSumOneStep(16, 32);
|
||||
WarpReduceSumOneStep(8, 32);
|
||||
WarpReduceSumOneStep(4, 32);
|
||||
WarpReduceSumOneStep(2, 32);
|
||||
WarpReduceSumOneStep(1, 32);
|
||||
|
||||
#undef WarpReduceSumOneStep
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void warpReduce<ReduceType::kSum, 4>(float *pval) {
|
||||
float val0_tmp, val1_tmp, val2_tmp, val3_tmp;
|
||||
#define WarpReduceSumOneStep(a, b) \
|
||||
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \
|
||||
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \
|
||||
val2_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 2), a, b); \
|
||||
val3_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 3), a, b); \
|
||||
*(pval + 0) += val0_tmp; \
|
||||
*(pval + 1) += val1_tmp; \
|
||||
*(pval + 2) += val2_tmp; \
|
||||
*(pval + 3) += val3_tmp
|
||||
|
||||
WarpReduceSumOneStep(16, 32);
|
||||
WarpReduceSumOneStep(8, 32);
|
||||
WarpReduceSumOneStep(4, 32);
|
||||
WarpReduceSumOneStep(2, 32);
|
||||
WarpReduceSumOneStep(1, 32);
|
||||
#undef WarpReduceSumOneStep
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void blockReduce<ReduceType::kSum, 1>(float *pval) {
|
||||
const int num = 1;
|
||||
static __shared__ float shared[num][32];
|
||||
int lane_id = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
warpReduce<ReduceType::kSum, num>(pval);
|
||||
|
||||
if (lane_id == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
shared[i][wid] = *(pval + i);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < (blockDim.x >> 5)) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = shared[i][lane_id];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = 0.f;
|
||||
}
|
||||
}
|
||||
warpReduce<ReduceType::kSum, num>(pval);
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void blockReduce<ReduceType::kSum, 2>(float *pval) {
|
||||
const int num = 2;
|
||||
static __shared__ float shared[num][32];
|
||||
int lane_id = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
warpReduce<ReduceType::kSum, num>(pval);
|
||||
|
||||
if (lane_id == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
shared[i][wid] = *(pval + i);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < (blockDim.x >> 5)) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = shared[i][lane_id];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = 0.f;
|
||||
}
|
||||
}
|
||||
warpReduce<ReduceType::kSum, num>(pval);
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void blockReduce<ReduceType::kSum, 4>(float *pval) {
|
||||
const int num = 4;
|
||||
static __shared__ float shared[num][32];
|
||||
int lane_id = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
warpReduce<ReduceType::kSum, num>(pval);
|
||||
|
||||
if (lane_id == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
shared[i][wid] = *(pval + i);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < (blockDim.x >> 5)) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = shared[i][lane_id];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = 0.f;
|
||||
}
|
||||
}
|
||||
warpReduce<ReduceType::kSum, num>(pval);
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void blockReduce<ReduceType::kMax, 1>(float *pval) {
|
||||
const int num = 1;
|
||||
static __shared__ float shared[num][32];
|
||||
int lane_id = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
warpReduce<ReduceType::kMax, num>(pval);
|
||||
|
||||
if (lane_id == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
shared[i][wid] = *(pval + i);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < (blockDim.x >> 5)) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = shared[i][lane_id];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = REDUCE_FLOAT_INF_NEG;
|
||||
}
|
||||
}
|
||||
warpReduce<ReduceType::kMax, num>(pval);
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void blockReduce<ReduceType::kMax, 2>(float *pval) {
|
||||
const int num = 1;
|
||||
static __shared__ float shared[num][32];
|
||||
int lane_id = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
warpReduce<ReduceType::kMax, num>(pval);
|
||||
|
||||
if (lane_id == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
shared[i][wid] = *(pval + i);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < (blockDim.x >> 5)) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = shared[i][lane_id];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = REDUCE_FLOAT_INF_NEG;
|
||||
}
|
||||
}
|
||||
warpReduce<ReduceType::kMax, num>(pval);
|
||||
}
|
||||
|
||||
template <>
|
||||
__inline__ __device__ void blockReduce<ReduceType::kMax, 4>(float *pval) {
|
||||
const int num = 1;
|
||||
static __shared__ float shared[num][32];
|
||||
int lane_id = threadIdx.x & 0x1f;
|
||||
int wid = threadIdx.x >> 5;
|
||||
|
||||
warpReduce<ReduceType::kMax, num>(pval);
|
||||
|
||||
if (lane_id == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
shared[i][wid] = *(pval + i);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < (blockDim.x >> 5)) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = shared[i][lane_id];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num; ++i) {
|
||||
*(pval + i) = REDUCE_FLOAT_INF_NEG;
|
||||
}
|
||||
}
|
||||
warpReduce<ReduceType::kMax, num>(pval);
|
||||
}
|
||||
#undef COLOSSAL_SHFL_FUNCTION
|
||||
#undef COLOSSAL_WARP_REDUCE_IMPL
|
||||
#undef COLOSSAL_BLOCK_REDUCE_IMPL
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T reduce_block_into_lanes(
|
||||
T *x, T val, int lanes = 1,
|
||||
T* x, T val, int lanes = 1,
|
||||
bool share_result = false) // lanes is intended to be <= 32.
|
||||
{
|
||||
int tid = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
|
@ -356,7 +137,7 @@ __device__ __forceinline__ T reduce_block_into_lanes(
|
|||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T reduce_block_into_lanes_max_op(
|
||||
T *x, T val, int lanes = 1,
|
||||
T* x, T val, int lanes = 1,
|
||||
bool share_result = false) // lanes is intended to be <= 32.
|
||||
{
|
||||
int tid = threadIdx.x + threadIdx.y * blockDim.x;
|
||||
|
@ -397,3 +178,7 @@ __device__ __forceinline__ T reduce_block_into_lanes_max_op(
|
|||
|
||||
return final;
|
||||
}
|
||||
|
||||
} // namespace utils
|
||||
} // namespace cuda
|
||||
} // namespace colossalAI
|
||||
|
|
|
@ -606,11 +606,11 @@ void cuda_layer_norm(at::Tensor* output, at::Tensor* mean, at::Tensor* invvar,
|
|||
using namespace at;
|
||||
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
|
||||
input->scalar_type(), output->scalar_type(), "cuda_layer_norm_kernel",
|
||||
HostApplyLayerNorm(output->DATA_PTR<scalar_t_out>(),
|
||||
mean->DATA_PTR<float>(), invvar->DATA_PTR<float>(),
|
||||
input->DATA_PTR<scalar_t_in>(), n1, n2, epsilon,
|
||||
gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
|
||||
beta != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL);)
|
||||
HostApplyLayerNorm(output->data_ptr<scalar_t_out>(),
|
||||
mean->data_ptr<float>(), invvar->data_ptr<float>(),
|
||||
input->data_ptr<scalar_t_in>(), n1, n2, epsilon,
|
||||
gamma != NULL ? gamma->data_ptr<scalar_t_out>() : NULL,
|
||||
beta != NULL ? beta->data_ptr<scalar_t_out>() : NULL);)
|
||||
}
|
||||
|
||||
template <typename T, typename U, typename V>
|
||||
|
@ -633,14 +633,14 @@ void HostLayerNormGradient(const V* dout, const U* mean, const U* invvar,
|
|||
{part_size, n2}, input->options().dtype(at::ScalarType::Float));
|
||||
at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
|
||||
cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
|
||||
dout, input->DATA_PTR<T>(), n1, n2, mean, invvar, U(epsilon),
|
||||
part_grad_gamma.DATA_PTR<U>(), part_grad_beta.DATA_PTR<U>());
|
||||
dout, input->data_ptr<T>(), n1, n2, mean, invvar, U(epsilon),
|
||||
part_grad_gamma.data_ptr<U>(), part_grad_beta.data_ptr<U>());
|
||||
|
||||
const dim3 threads3(32, 8, 1);
|
||||
const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1);
|
||||
const int nshared3 = threads3.x * threads3.y * sizeof(U);
|
||||
cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(
|
||||
part_grad_gamma.DATA_PTR<U>(), part_grad_beta.DATA_PTR<U>(), part_size,
|
||||
part_grad_gamma.data_ptr<U>(), part_grad_beta.data_ptr<U>(), part_size,
|
||||
n1, n2, grad_gamma, grad_beta);
|
||||
}
|
||||
|
||||
|
@ -651,7 +651,7 @@ void HostLayerNormGradient(const V* dout, const U* mean, const U* invvar,
|
|||
const dim3 threads1(32, 4, 1);
|
||||
int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0;
|
||||
cuComputeGradInput<<<blocks1, threads1, nshared, stream>>>(
|
||||
dout, input->DATA_PTR<T>(), n1, n2, mean, invvar, U(epsilon), gamma,
|
||||
dout, input->data_ptr<T>(), n1, n2, mean, invvar, U(epsilon), gamma,
|
||||
grad_input);
|
||||
}
|
||||
|
||||
|
@ -671,13 +671,13 @@ void cuda_layer_norm_gradient(at::Tensor* dout, at::Tensor* mean,
|
|||
input->scalar_type(), gamma->scalar_type(),
|
||||
"cuda_layer_norm_gradient_kernel",
|
||||
HostLayerNormGradient(
|
||||
dout->DATA_PTR<scalar_t_out>(), mean->DATA_PTR<float>(),
|
||||
invvar->DATA_PTR<float>(), input, n1, n2,
|
||||
dout->data_ptr<scalar_t_out>(), mean->data_ptr<float>(),
|
||||
invvar->data_ptr<float>(), input, n1, n2,
|
||||
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
|
||||
// if gamma Tensor is NULL on input.
|
||||
gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
|
||||
gamma != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL, epsilon,
|
||||
grad_input->DATA_PTR<scalar_t_in>(),
|
||||
gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL,
|
||||
gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL);)
|
||||
gamma != NULL ? gamma->data_ptr<scalar_t_out>() : NULL,
|
||||
gamma != NULL ? beta->data_ptr<scalar_t_out>() : NULL, epsilon,
|
||||
grad_input->data_ptr<scalar_t_in>(),
|
||||
gamma != NULL ? grad_gamma->data_ptr<scalar_t_out>() : NULL,
|
||||
gamma != NULL ? grad_beta->data_ptr<scalar_t_out>() : NULL);)
|
||||
}
|
||||
|
|
|
@ -6,6 +6,10 @@
|
|||
|
||||
#include "block_reduce.h"
|
||||
|
||||
|
||||
using colossalAI::cuda::utils::block_reduce;
|
||||
using colossalAI::cuda::utils::ReduceType;
|
||||
|
||||
template <typename T, int block_size, int pack_size>
|
||||
__device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) {
|
||||
assert(cols % pack_size == 0);
|
||||
|
@ -157,8 +161,7 @@ __device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row,
|
|||
|
||||
BlockStore(ts_store).Store(src_row + idx, grad);
|
||||
}
|
||||
|
||||
blockReduce<ReduceType::kSum, 1>(&thread_sum);
|
||||
block_reduce<float, ReduceType::kSum, 1>(&thread_sum);
|
||||
|
||||
if (threadIdx.x == 0) *weight_grad = static_cast<T>(thread_sum);
|
||||
}
|
||||
|
@ -230,7 +233,7 @@ __device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row,
|
|||
BlockStore(ts_store).Store(src_row2 + idx, sgrad2);
|
||||
}
|
||||
|
||||
blockReduce<ReduceType::kSum, 2>(thread_sum);
|
||||
block_reduce<float, ReduceType::kSum, 2>(thread_sum);
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
*weight_grad1 = static_cast<T>(thread_sum[0]);
|
||||
|
@ -566,10 +569,10 @@ torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h,
|
|||
DISPATCH_FLOAT_AND_HALF(
|
||||
batch_tokens.scalar_type(), "moe dispatch forward",
|
||||
moe_dpch_fwd_launch<scalar_t>(
|
||||
batch_tokens.data<scalar_t>(), res.data<scalar_t>(),
|
||||
mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(),
|
||||
dest_idx[0].data<int>(),
|
||||
k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), s, h));
|
||||
batch_tokens.data_ptr<scalar_t>(), res.data_ptr<scalar_t>(),
|
||||
mask[0].data_ptr<int>(), k == 1 ? nullptr : mask[1].data_ptr<int>(),
|
||||
dest_idx[0].data_ptr<int>(),
|
||||
k == 1 ? dest_idx[0].data_ptr<int>() : dest_idx[1].data_ptr<int>(), s, h));
|
||||
|
||||
return res;
|
||||
}
|
||||
|
@ -586,10 +589,10 @@ torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h,
|
|||
DISPATCH_FLOAT_AND_HALF(
|
||||
expert_grad.scalar_type(), "moe dispatch backward",
|
||||
moe_dpch_bwd_launch<scalar_t>(
|
||||
res.data<scalar_t>(), expert_grad.data<scalar_t>(),
|
||||
mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(),
|
||||
dest_idx[0].data<int>(),
|
||||
k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), s, h));
|
||||
res.data_ptr<scalar_t>(), expert_grad.data_ptr<scalar_t>(),
|
||||
mask[0].data_ptr<int>(), k == 1 ? nullptr : mask[1].data_ptr<int>(),
|
||||
dest_idx[0].data_ptr<int>(),
|
||||
k == 1 ? dest_idx[0].data_ptr<int>() : dest_idx[1].data_ptr<int>(), s, h));
|
||||
|
||||
return res;
|
||||
}
|
||||
|
@ -609,10 +612,10 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
|
|||
DISPATCH_FLOAT_AND_HALF(
|
||||
expert_tokens.scalar_type(), "moe combine forward",
|
||||
moe_cb_fwd_launch<scalar_t>(
|
||||
expert_tokens.data<scalar_t>(), res.data<scalar_t>(),
|
||||
logits.data<scalar_t>(), mask[0].data<int>(),
|
||||
k == 1 ? nullptr : mask[1].data<int>(), dest_idx[0].data<int>(),
|
||||
k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), s, e, c,
|
||||
expert_tokens.data_ptr<scalar_t>(), res.data_ptr<scalar_t>(),
|
||||
logits.data_ptr<scalar_t>(), mask[0].data_ptr<int>(),
|
||||
k == 1 ? nullptr : mask[1].data_ptr<int>(), dest_idx[0].data_ptr<int>(),
|
||||
k == 1 ? dest_idx[0].data_ptr<int>() : dest_idx[1].data_ptr<int>(), s, e, c,
|
||||
h));
|
||||
|
||||
return res;
|
||||
|
@ -636,11 +639,11 @@ std::vector<torch::Tensor> moe_combine_cuda_backward(
|
|||
DISPATCH_FLOAT_AND_HALF(
|
||||
tokens_grad.scalar_type(), "moe combine backward",
|
||||
moe_cb_bwd_launch<scalar_t>(
|
||||
tokens_grad.data<scalar_t>(), egrad.data<scalar_t>(),
|
||||
expert_tokens.data<scalar_t>(), logits.data<scalar_t>(),
|
||||
wgrad.data<scalar_t>(), mask[0].data<int>(),
|
||||
k == 1 ? nullptr : mask[1].data<int>(), dest_idx[0].data<int>(),
|
||||
k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), s, e, c,
|
||||
tokens_grad.data_ptr<scalar_t>(), egrad.data_ptr<scalar_t>(),
|
||||
expert_tokens.data_ptr<scalar_t>(), logits.data_ptr<scalar_t>(),
|
||||
wgrad.data_ptr<scalar_t>(), mask[0].data_ptr<int>(),
|
||||
k == 1 ? nullptr : mask[1].data_ptr<int>(), dest_idx[0].data_ptr<int>(),
|
||||
k == 1 ? dest_idx[0].data_ptr<int>() : dest_idx[1].data_ptr<int>(), s, e, c,
|
||||
h));
|
||||
|
||||
return {egrad, wgrad};
|
||||
|
@ -653,7 +656,7 @@ torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) {
|
|||
const int s = mask.size(0), e = mask.size(1);
|
||||
auto res =
|
||||
torch::empty({s, e}, torch::dtype(torch::kInt32).device(mask.device()));
|
||||
cumsum_launch(mask.data<int>(), res.data<int>(), s, e);
|
||||
cumsum_launch(mask.data_ptr<int>(), res.data_ptr<int>(), s, e);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
|
|
@ -104,7 +104,7 @@ void multi_tensor_apply(
|
|||
if (tensors_full || blocks_full || last_chunk) {
|
||||
// using accscalar_t = acc_type<scalar_t, true>;
|
||||
multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
|
||||
chunk_size, noop_flag.DATA_PTR<int>(), tl, callable, args...);
|
||||
chunk_size, noop_flag.data_ptr<int>(), tl, callable, args...);
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
|
|
|
@ -17,6 +17,10 @@
|
|||
#define BLOCK_SIZE 512
|
||||
#define ILP 4
|
||||
|
||||
using colossalAI::cuda::utils::block_reduce;
|
||||
using colossalAI::cuda::utils::reduce_block_into_lanes;
|
||||
using colossalAI::cuda::utils::reduce_block_into_lanes_max_op;
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ bool is_aligned(T *p) {
|
||||
return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
|
||||
|
@ -290,8 +294,8 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
|
|||
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
|
||||
multi_tensor_apply<1>(
|
||||
BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
|
||||
L2NormFunctor<scalar_t_0>(), output.DATA_PTR<float>(),
|
||||
per_tensor ? output_per_tensor.DATA_PTR<float>() : nullptr,
|
||||
L2NormFunctor<scalar_t_0>(), output.data_ptr<float>(),
|
||||
per_tensor ? output_per_tensor.data_ptr<float>() : nullptr,
|
||||
per_tensor, max_chunks_per_tensor);)
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
@ -304,10 +308,10 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
|
|||
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
cleanup<<<per_tensor ? ntensors : 1, 512, 0, stream>>>(
|
||||
output.DATA_PTR<float>(),
|
||||
per_tensor ? output_per_tensor.DATA_PTR<float>() : nullptr,
|
||||
ret.DATA_PTR<float>(),
|
||||
per_tensor ? ret_per_tensor.DATA_PTR<float>() : nullptr, per_tensor,
|
||||
output.data_ptr<float>(),
|
||||
per_tensor ? output_per_tensor.data_ptr<float>() : nullptr,
|
||||
ret.data_ptr<float>(),
|
||||
per_tensor ? ret_per_tensor.data_ptr<float>() : nullptr, per_tensor,
|
||||
max_chunks_per_tensor);
|
||||
|
||||
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
|
||||
|
@ -350,15 +354,15 @@ void multi_tensor_norm_out_cuda(
|
|||
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_maxnorm_cuda",
|
||||
multi_tensor_apply<1>(
|
||||
BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
|
||||
MaxNormFunctor<scalar_t_0>(), output.DATA_PTR<float>(),
|
||||
output_per_tensor.DATA_PTR<float>(), true, max_chunks_per_tensor);)
|
||||
MaxNormFunctor<scalar_t_0>(), output.data_ptr<float>(),
|
||||
output_per_tensor.data_ptr<float>(), true, max_chunks_per_tensor);)
|
||||
} else {
|
||||
DISPATCH_FLOAT_AND_HALF(
|
||||
tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
|
||||
multi_tensor_apply<1>(
|
||||
BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
|
||||
L2NormFunctor<scalar_t_0>(), output.DATA_PTR<float>(),
|
||||
output_per_tensor.DATA_PTR<float>(), true, max_chunks_per_tensor);)
|
||||
L2NormFunctor<scalar_t_0>(), output.data_ptr<float>(),
|
||||
output_per_tensor.data_ptr<float>(), true, max_chunks_per_tensor);)
|
||||
}
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
|
@ -375,8 +379,8 @@ void multi_tensor_norm_out_cuda(
|
|||
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
cleanup_v2<<<ntensors, 512, 0, stream>>>(
|
||||
output.DATA_PTR<float>(), output_per_tensor.DATA_PTR<float>(),
|
||||
ret.DATA_PTR<float>(), out.DATA_PTR<float>(), true, max_chunks_per_tensor,
|
||||
output.data_ptr<float>(), output_per_tensor.data_ptr<float>(),
|
||||
ret.data_ptr<float>(), out.data_ptr<float>(), true, max_chunks_per_tensor,
|
||||
norm_type, alpha, beta);
|
||||
|
||||
return;
|
||||
|
|
|
@ -333,7 +333,7 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
|
|||
beta3, // 1-beta1 or 1 depends on averaging mode
|
||||
bias_correction1, bias_correction2, epsilon,
|
||||
(adamMode_t)mode, weight_decay,
|
||||
global_grad_norm.DATA_PTR<float>(), max_grad_norm);)
|
||||
global_grad_norm.data_ptr<float>(), max_grad_norm);)
|
||||
|
||||
// Compute update norms
|
||||
auto update_norm_tuple =
|
||||
|
@ -346,8 +346,8 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
|
|||
tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
|
||||
multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, grad_param_list,
|
||||
LAMBStage2Functor<scalar_t_0>(),
|
||||
std::get<1>(param_norm_tuple).DATA_PTR<float>(),
|
||||
std::get<1>(update_norm_tuple).DATA_PTR<float>(),
|
||||
std::get<1>(param_norm_tuple).data_ptr<float>(),
|
||||
std::get<1>(update_norm_tuple).data_ptr<float>(),
|
||||
lr, weight_decay, use_nvlamb);)
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
|
|
@ -12,6 +12,9 @@
|
|||
#include "../common/micros.h"
|
||||
#include "utils/cuda_type_utils.h"
|
||||
|
||||
using colossalAI::cuda::utils::block_reduce;
|
||||
using colossalAI::cuda::utils::ReduceType;
|
||||
|
||||
#define DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(DATA_SIZE, TYPE, NAME, ...) \
|
||||
if (DATA_SIZE == 2) { \
|
||||
switch (TYPE) { \
|
||||
|
@ -77,7 +80,7 @@ __global__ void rms_layernorm_kernel(
|
|||
float v2 = cuda_cast<float>(x_local[cnt].y);
|
||||
variance += v1 * v1 + v2 * v2;
|
||||
}
|
||||
variance = blockReduceSum<float>(variance);
|
||||
block_reduce<float, ReduceType::kSum,1>(&variance);
|
||||
if (threadIdx.x == 0) {
|
||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||
}
|
||||
|
@ -111,7 +114,7 @@ __global__ void general_rms_layernorm_kernel(
|
|||
x_local[cnt] = (float) input[id];
|
||||
variance += x_local[cnt] * x_local[cnt];
|
||||
}
|
||||
variance = blockReduceSum<float>(variance);
|
||||
block_reduce<float, ReduceType::kSum,1>(&variance);
|
||||
if (threadIdx.x == 0) {
|
||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||
}
|
||||
|
@ -154,7 +157,7 @@ __global__ void fused_add_rms_layernorm_kernel(
|
|||
variance += v1 * v1 + v2 * v2;
|
||||
residual_ptr[id] = x_local[cnt];
|
||||
}
|
||||
variance = blockReduceSum<float>(variance);
|
||||
block_reduce<float, ReduceType::kSum,1>(&variance);
|
||||
if (threadIdx.x == 0) {
|
||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||
}
|
||||
|
@ -190,7 +193,7 @@ __global__ void general_fused_add_rms_layernorm_kernel(
|
|||
variance += x_local[cnt] * x_local[cnt];
|
||||
residual[id] = (scalar_t) x_local[cnt];
|
||||
}
|
||||
variance = blockReduceSum<float>(variance);
|
||||
block_reduce<float, ReduceType::kSum,1>(&variance);
|
||||
if (threadIdx.x == 0) {
|
||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue