ColossalAI/colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h

313 lines
8.4 KiB
C
Raw Normal View History

2021-12-21 04:19:52 +00:00
/* Copyright 2021 The LightSeq Team
Copyright Tencent/TurboTransformers
This block_reduce_n is adapted from Tencent/TurboTransformers
*/
#pragma once
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
enum class ReduceType { kMax = 0, kSum };
const unsigned int WARP_REDUCE_MASK = 0xffffffff;
const float REDUCE_FLOAT_INF_NEG = -100000000.f;
const float REDUCE_FLOAT_INF_POS = 100000000.f;
const unsigned int WARP_REDUCE_SIZE = 32;
[NFC] Hotfix/format (#984) * [NFC] Polish colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu code style. (#937) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h code style (#939) * [NFC] polish colossalai/kernel/cuda_native/csrc/cpu_adam.cpp code style (#936) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h code style (#938) * [NFC] polish moe_cuda_kernel.cu code style (#940) Co-authored-by: Xiao Ye <xiaoye2@illinois.edu> * [NFC] polish pre-commit run --files colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu code style (#943) * [NFC] polish colossalai/kernel/cuda_native/csrc/moe_cuda.cpp code style (#942) * [NFC] polish colossalai/kernel/cuda_native/csrc/cpu_adam.h code style (#945) * [NFC] polish colossalai/kernel/jit/bias_gelu.py code style (#946) Co-authored-by: jnbai <897086360@qq.com> * [NFC] polish colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu code style (#949) Co-authored-by: Jiatong <jiatong.han@u.nus.edu> * [NFC] polish colossalai/builder/pipeline.py code style (#951) * [NFC] polish colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp code style (#952) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu code style (#953) Co-authored-by: 何晓昕 <cautious@hexiaoxins-MacBook-Pro.local> * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu code style (#954) * [NFC] polish colossalai/kernel/cuda_native/scaled_softmax.py code style (#955) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/context.h code style (#956) Co-authored-by: RichardoLuo <14049555596@qq.com> * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h code style (#957) * [NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu code style (#958) * [NFC] polish colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h code style (#962) * [NFC] polish colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp code style (#959) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu code style (#963) Co-authored-by: “Arsmart123 <202476410arsmart@gmail.com> * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h code style (#964) * [NFC] polish __init__.py code style (#965) * [NFC] polish colossalai/nn/layer/parallel_3d/layers.py code style (#966) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h (#968) code style * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h code style (#970) * [NFC] polish colossalai/nn/layer/parallel_2p5d/layers.py code style (#972) * [NFC] polish colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp code style (#973) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu code style (#974) * [NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu code style (#977) * [NFC] polish colossalai/nn/layer/parallel_2d/layers.py code style (#976) * [NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu code style (#978) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu code style (#979) * [NFC] polish colossalai/kernel/cuda_native/layer_norm.py code style (#980) * [NFC] polish colossalai/nn/layer/utils/common.py code style (#983) Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com> Co-authored-by: yuxuan-lou <83441848+yuxuan-lou@users.noreply.github.com> Co-authored-by: Geng Zhang <34452939+zxgx@users.noreply.github.com> Co-authored-by: Maruyama_Aya <38985202+MaruyamaAya@users.noreply.github.com> Co-authored-by: XYE <92607131+Itok2000u@users.noreply.github.com> Co-authored-by: Xiao Ye <xiaoye2@illinois.edu> Co-authored-by: HaoyuQin <79465534+coder-chin@users.noreply.github.com> Co-authored-by: wky <64853922+wangkuangyi@users.noreply.github.com> Co-authored-by: bajiaoyu517 <59548007+bajiaoyu517@users.noreply.github.com> Co-authored-by: luoling-LC <105470086+luoling-LC@users.noreply.github.com> Co-authored-by: jnbai <897086360@qq.com> Co-authored-by: JT.Han <59948448+JThh@users.noreply.github.com> Co-authored-by: Jiatong <jiatong.han@u.nus.edu> Co-authored-by: xyupeng <99191637+xyupeng@users.noreply.github.com> Co-authored-by: Sze-qq <68757353+Sze-qq@users.noreply.github.com> Co-authored-by: Cautiousss <48676630+Cautiousss@users.noreply.github.com> Co-authored-by: 何晓昕 <cautious@hexiaoxins-MacBook-Pro.local> Co-authored-by: Luxios22 <67457897+Luxios22@users.noreply.github.com> Co-authored-by: Wangbo Zhao(黑色枷锁) <56866854+wangbo-zhao@users.noreply.github.com> Co-authored-by: RichardoLuo <50363844+RichardoLuo@users.noreply.github.com> Co-authored-by: RichardoLuo <14049555596@qq.com> Co-authored-by: doubleHU <98150031+huxin711@users.noreply.github.com> Co-authored-by: runluo <68489000+run-qiao@users.noreply.github.com> Co-authored-by: MaxT <854721132@qq.com> Co-authored-by: superhao1995 <804673818@qq.com> Co-authored-by: ziyu huang <huang0ziyu@gmail.com> Co-authored-by: “Arsmart123 <202476410arsmart@gmail.com> Co-authored-by: Yuer867 <62204893+Yuer867@users.noreply.github.com> Co-authored-by: lucasliunju <lucasliunju@gmail.com> Co-authored-by: LuGY <74758262+Gy-Lu@users.noreply.github.com> Co-authored-by: ExtremeViscent <zhangyiqi55732@sina.com> Co-authored-by: Xu Kai <xukai16@foxmail.com> Co-authored-by: Zirui Zhu <zhuzr21@gmail.com> Co-authored-by: Ofey Chan <ofey206@gmail.com> Co-authored-by: DouJS <dujiangsu@163.com> Co-authored-by: Jie Zhu <chore.08-protist@icloud.com> Co-authored-by: shenggan <csg19971016@gmail.com> Co-authored-by: Kai Wang (Victor Kai) <37533040+kaiwang960112@users.noreply.github.com> Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com> Co-authored-by: Ziheng Qin <37519855+henryqin1997@users.noreply.github.com>
2022-05-17 01:54:49 +00:00
template <typename T>
__forceinline__ __device__ T warpReduceSum(T val) {
2021-12-21 04:19:52 +00:00
for (int mask = (WARP_REDUCE_SIZE >> 1); mask > 0; mask >>= 1)
val += __shfl_xor_sync(WARP_REDUCE_MASK, val, mask, WARP_REDUCE_SIZE);
return val;
}
/* Calculate the sum of all elements in a block */
[NFC] Hotfix/format (#984) * [NFC] Polish colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu code style. (#937) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h code style (#939) * [NFC] polish colossalai/kernel/cuda_native/csrc/cpu_adam.cpp code style (#936) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h code style (#938) * [NFC] polish moe_cuda_kernel.cu code style (#940) Co-authored-by: Xiao Ye <xiaoye2@illinois.edu> * [NFC] polish pre-commit run --files colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu code style (#943) * [NFC] polish colossalai/kernel/cuda_native/csrc/moe_cuda.cpp code style (#942) * [NFC] polish colossalai/kernel/cuda_native/csrc/cpu_adam.h code style (#945) * [NFC] polish colossalai/kernel/jit/bias_gelu.py code style (#946) Co-authored-by: jnbai <897086360@qq.com> * [NFC] polish colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu code style (#949) Co-authored-by: Jiatong <jiatong.han@u.nus.edu> * [NFC] polish colossalai/builder/pipeline.py code style (#951) * [NFC] polish colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp code style (#952) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu code style (#953) Co-authored-by: 何晓昕 <cautious@hexiaoxins-MacBook-Pro.local> * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu code style (#954) * [NFC] polish colossalai/kernel/cuda_native/scaled_softmax.py code style (#955) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/context.h code style (#956) Co-authored-by: RichardoLuo <14049555596@qq.com> * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h code style (#957) * [NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu code style (#958) * [NFC] polish colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h code style (#962) * [NFC] polish colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp code style (#959) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu code style (#963) Co-authored-by: “Arsmart123 <202476410arsmart@gmail.com> * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h code style (#964) * [NFC] polish __init__.py code style (#965) * [NFC] polish colossalai/nn/layer/parallel_3d/layers.py code style (#966) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h (#968) code style * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h code style (#970) * [NFC] polish colossalai/nn/layer/parallel_2p5d/layers.py code style (#972) * [NFC] polish colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp code style (#973) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu code style (#974) * [NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu code style (#977) * [NFC] polish colossalai/nn/layer/parallel_2d/layers.py code style (#976) * [NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu code style (#978) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu code style (#979) * [NFC] polish colossalai/kernel/cuda_native/layer_norm.py code style (#980) * [NFC] polish colossalai/nn/layer/utils/common.py code style (#983) Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com> Co-authored-by: yuxuan-lou <83441848+yuxuan-lou@users.noreply.github.com> Co-authored-by: Geng Zhang <34452939+zxgx@users.noreply.github.com> Co-authored-by: Maruyama_Aya <38985202+MaruyamaAya@users.noreply.github.com> Co-authored-by: XYE <92607131+Itok2000u@users.noreply.github.com> Co-authored-by: Xiao Ye <xiaoye2@illinois.edu> Co-authored-by: HaoyuQin <79465534+coder-chin@users.noreply.github.com> Co-authored-by: wky <64853922+wangkuangyi@users.noreply.github.com> Co-authored-by: bajiaoyu517 <59548007+bajiaoyu517@users.noreply.github.com> Co-authored-by: luoling-LC <105470086+luoling-LC@users.noreply.github.com> Co-authored-by: jnbai <897086360@qq.com> Co-authored-by: JT.Han <59948448+JThh@users.noreply.github.com> Co-authored-by: Jiatong <jiatong.han@u.nus.edu> Co-authored-by: xyupeng <99191637+xyupeng@users.noreply.github.com> Co-authored-by: Sze-qq <68757353+Sze-qq@users.noreply.github.com> Co-authored-by: Cautiousss <48676630+Cautiousss@users.noreply.github.com> Co-authored-by: 何晓昕 <cautious@hexiaoxins-MacBook-Pro.local> Co-authored-by: Luxios22 <67457897+Luxios22@users.noreply.github.com> Co-authored-by: Wangbo Zhao(黑色枷锁) <56866854+wangbo-zhao@users.noreply.github.com> Co-authored-by: RichardoLuo <50363844+RichardoLuo@users.noreply.github.com> Co-authored-by: RichardoLuo <14049555596@qq.com> Co-authored-by: doubleHU <98150031+huxin711@users.noreply.github.com> Co-authored-by: runluo <68489000+run-qiao@users.noreply.github.com> Co-authored-by: MaxT <854721132@qq.com> Co-authored-by: superhao1995 <804673818@qq.com> Co-authored-by: ziyu huang <huang0ziyu@gmail.com> Co-authored-by: “Arsmart123 <202476410arsmart@gmail.com> Co-authored-by: Yuer867 <62204893+Yuer867@users.noreply.github.com> Co-authored-by: lucasliunju <lucasliunju@gmail.com> Co-authored-by: LuGY <74758262+Gy-Lu@users.noreply.github.com> Co-authored-by: ExtremeViscent <zhangyiqi55732@sina.com> Co-authored-by: Xu Kai <xukai16@foxmail.com> Co-authored-by: Zirui Zhu <zhuzr21@gmail.com> Co-authored-by: Ofey Chan <ofey206@gmail.com> Co-authored-by: DouJS <dujiangsu@163.com> Co-authored-by: Jie Zhu <chore.08-protist@icloud.com> Co-authored-by: shenggan <csg19971016@gmail.com> Co-authored-by: Kai Wang (Victor Kai) <37533040+kaiwang960112@users.noreply.github.com> Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com> Co-authored-by: Ziheng Qin <37519855+henryqin1997@users.noreply.github.com>
2022-05-17 01:54:49 +00:00
template <typename T>
__forceinline__ __device__ T blockReduceSum(T val) {
2021-12-21 04:19:52 +00:00
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val);
[NFC] Hotfix/format (#984) * [NFC] Polish colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu code style. (#937) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h code style (#939) * [NFC] polish colossalai/kernel/cuda_native/csrc/cpu_adam.cpp code style (#936) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h code style (#938) * [NFC] polish moe_cuda_kernel.cu code style (#940) Co-authored-by: Xiao Ye <xiaoye2@illinois.edu> * [NFC] polish pre-commit run --files colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu code style (#943) * [NFC] polish colossalai/kernel/cuda_native/csrc/moe_cuda.cpp code style (#942) * [NFC] polish colossalai/kernel/cuda_native/csrc/cpu_adam.h code style (#945) * [NFC] polish colossalai/kernel/jit/bias_gelu.py code style (#946) Co-authored-by: jnbai <897086360@qq.com> * [NFC] polish colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu code style (#949) Co-authored-by: Jiatong <jiatong.han@u.nus.edu> * [NFC] polish colossalai/builder/pipeline.py code style (#951) * [NFC] polish colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp code style (#952) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu code style (#953) Co-authored-by: 何晓昕 <cautious@hexiaoxins-MacBook-Pro.local> * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu code style (#954) * [NFC] polish colossalai/kernel/cuda_native/scaled_softmax.py code style (#955) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/context.h code style (#956) Co-authored-by: RichardoLuo <14049555596@qq.com> * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h code style (#957) * [NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu code style (#958) * [NFC] polish colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h code style (#962) * [NFC] polish colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp code style (#959) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu code style (#963) Co-authored-by: “Arsmart123 <202476410arsmart@gmail.com> * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h code style (#964) * [NFC] polish __init__.py code style (#965) * [NFC] polish colossalai/nn/layer/parallel_3d/layers.py code style (#966) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h (#968) code style * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h code style (#970) * [NFC] polish colossalai/nn/layer/parallel_2p5d/layers.py code style (#972) * [NFC] polish colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp code style (#973) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu code style (#974) * [NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu code style (#977) * [NFC] polish colossalai/nn/layer/parallel_2d/layers.py code style (#976) * [NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu code style (#978) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu code style (#979) * [NFC] polish colossalai/kernel/cuda_native/layer_norm.py code style (#980) * [NFC] polish colossalai/nn/layer/utils/common.py code style (#983) Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com> Co-authored-by: yuxuan-lou <83441848+yuxuan-lou@users.noreply.github.com> Co-authored-by: Geng Zhang <34452939+zxgx@users.noreply.github.com> Co-authored-by: Maruyama_Aya <38985202+MaruyamaAya@users.noreply.github.com> Co-authored-by: XYE <92607131+Itok2000u@users.noreply.github.com> Co-authored-by: Xiao Ye <xiaoye2@illinois.edu> Co-authored-by: HaoyuQin <79465534+coder-chin@users.noreply.github.com> Co-authored-by: wky <64853922+wangkuangyi@users.noreply.github.com> Co-authored-by: bajiaoyu517 <59548007+bajiaoyu517@users.noreply.github.com> Co-authored-by: luoling-LC <105470086+luoling-LC@users.noreply.github.com> Co-authored-by: jnbai <897086360@qq.com> Co-authored-by: JT.Han <59948448+JThh@users.noreply.github.com> Co-authored-by: Jiatong <jiatong.han@u.nus.edu> Co-authored-by: xyupeng <99191637+xyupeng@users.noreply.github.com> Co-authored-by: Sze-qq <68757353+Sze-qq@users.noreply.github.com> Co-authored-by: Cautiousss <48676630+Cautiousss@users.noreply.github.com> Co-authored-by: 何晓昕 <cautious@hexiaoxins-MacBook-Pro.local> Co-authored-by: Luxios22 <67457897+Luxios22@users.noreply.github.com> Co-authored-by: Wangbo Zhao(黑色枷锁) <56866854+wangbo-zhao@users.noreply.github.com> Co-authored-by: RichardoLuo <50363844+RichardoLuo@users.noreply.github.com> Co-authored-by: RichardoLuo <14049555596@qq.com> Co-authored-by: doubleHU <98150031+huxin711@users.noreply.github.com> Co-authored-by: runluo <68489000+run-qiao@users.noreply.github.com> Co-authored-by: MaxT <854721132@qq.com> Co-authored-by: superhao1995 <804673818@qq.com> Co-authored-by: ziyu huang <huang0ziyu@gmail.com> Co-authored-by: “Arsmart123 <202476410arsmart@gmail.com> Co-authored-by: Yuer867 <62204893+Yuer867@users.noreply.github.com> Co-authored-by: lucasliunju <lucasliunju@gmail.com> Co-authored-by: LuGY <74758262+Gy-Lu@users.noreply.github.com> Co-authored-by: ExtremeViscent <zhangyiqi55732@sina.com> Co-authored-by: Xu Kai <xukai16@foxmail.com> Co-authored-by: Zirui Zhu <zhuzr21@gmail.com> Co-authored-by: Ofey Chan <ofey206@gmail.com> Co-authored-by: DouJS <dujiangsu@163.com> Co-authored-by: Jie Zhu <chore.08-protist@icloud.com> Co-authored-by: shenggan <csg19971016@gmail.com> Co-authored-by: Kai Wang (Victor Kai) <37533040+kaiwang960112@users.noreply.github.com> Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com> Co-authored-by: Ziheng Qin <37519855+henryqin1997@users.noreply.github.com>
2022-05-17 01:54:49 +00:00
if (lane == 0) shared[wid] = val;
2021-12-21 04:19:52 +00:00
__syncthreads();
val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : (T)0.0f;
val = warpReduceSum<T>(val);
return val;
}
template <ReduceType Rtype, int Num>
__inline__ __device__ void blockReduce(float *pval);
// use template to make code more concise
template <ReduceType Rtype, int Num>
__inline__ __device__ void warpReduce(float *pval);
// static
template <>
__inline__ __device__ void warpReduce<ReduceType::kMax, 1>(float *pval) {
*pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 16, 32));
*pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 8, 32));
*pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 4, 32));
*pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 2, 32));
*pval = max(*pval, __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 1, 32));
}
template <>
__inline__ __device__ void warpReduce<ReduceType::kMax, 2>(float *pval) {
float val0_tmp, val1_tmp;
[NFC] Hotfix/format (#984) * [NFC] Polish colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu code style. (#937) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h code style (#939) * [NFC] polish colossalai/kernel/cuda_native/csrc/cpu_adam.cpp code style (#936) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h code style (#938) * [NFC] polish moe_cuda_kernel.cu code style (#940) Co-authored-by: Xiao Ye <xiaoye2@illinois.edu> * [NFC] polish pre-commit run --files colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu code style (#943) * [NFC] polish colossalai/kernel/cuda_native/csrc/moe_cuda.cpp code style (#942) * [NFC] polish colossalai/kernel/cuda_native/csrc/cpu_adam.h code style (#945) * [NFC] polish colossalai/kernel/jit/bias_gelu.py code style (#946) Co-authored-by: jnbai <897086360@qq.com> * [NFC] polish colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu code style (#949) Co-authored-by: Jiatong <jiatong.han@u.nus.edu> * [NFC] polish colossalai/builder/pipeline.py code style (#951) * [NFC] polish colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp code style (#952) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu code style (#953) Co-authored-by: 何晓昕 <cautious@hexiaoxins-MacBook-Pro.local> * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu code style (#954) * [NFC] polish colossalai/kernel/cuda_native/scaled_softmax.py code style (#955) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/context.h code style (#956) Co-authored-by: RichardoLuo <14049555596@qq.com> * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h code style (#957) * [NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu code style (#958) * [NFC] polish colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h code style (#962) * [NFC] polish colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp code style (#959) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu code style (#963) Co-authored-by: “Arsmart123 <202476410arsmart@gmail.com> * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h code style (#964) * [NFC] polish __init__.py code style (#965) * [NFC] polish colossalai/nn/layer/parallel_3d/layers.py code style (#966) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h (#968) code style * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h code style (#970) * [NFC] polish colossalai/nn/layer/parallel_2p5d/layers.py code style (#972) * [NFC] polish colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp code style (#973) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu code style (#974) * [NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu code style (#977) * [NFC] polish colossalai/nn/layer/parallel_2d/layers.py code style (#976) * [NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu code style (#978) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu code style (#979) * [NFC] polish colossalai/kernel/cuda_native/layer_norm.py code style (#980) * [NFC] polish colossalai/nn/layer/utils/common.py code style (#983) Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com> Co-authored-by: yuxuan-lou <83441848+yuxuan-lou@users.noreply.github.com> Co-authored-by: Geng Zhang <34452939+zxgx@users.noreply.github.com> Co-authored-by: Maruyama_Aya <38985202+MaruyamaAya@users.noreply.github.com> Co-authored-by: XYE <92607131+Itok2000u@users.noreply.github.com> Co-authored-by: Xiao Ye <xiaoye2@illinois.edu> Co-authored-by: HaoyuQin <79465534+coder-chin@users.noreply.github.com> Co-authored-by: wky <64853922+wangkuangyi@users.noreply.github.com> Co-authored-by: bajiaoyu517 <59548007+bajiaoyu517@users.noreply.github.com> Co-authored-by: luoling-LC <105470086+luoling-LC@users.noreply.github.com> Co-authored-by: jnbai <897086360@qq.com> Co-authored-by: JT.Han <59948448+JThh@users.noreply.github.com> Co-authored-by: Jiatong <jiatong.han@u.nus.edu> Co-authored-by: xyupeng <99191637+xyupeng@users.noreply.github.com> Co-authored-by: Sze-qq <68757353+Sze-qq@users.noreply.github.com> Co-authored-by: Cautiousss <48676630+Cautiousss@users.noreply.github.com> Co-authored-by: 何晓昕 <cautious@hexiaoxins-MacBook-Pro.local> Co-authored-by: Luxios22 <67457897+Luxios22@users.noreply.github.com> Co-authored-by: Wangbo Zhao(黑色枷锁) <56866854+wangbo-zhao@users.noreply.github.com> Co-authored-by: RichardoLuo <50363844+RichardoLuo@users.noreply.github.com> Co-authored-by: RichardoLuo <14049555596@qq.com> Co-authored-by: doubleHU <98150031+huxin711@users.noreply.github.com> Co-authored-by: runluo <68489000+run-qiao@users.noreply.github.com> Co-authored-by: MaxT <854721132@qq.com> Co-authored-by: superhao1995 <804673818@qq.com> Co-authored-by: ziyu huang <huang0ziyu@gmail.com> Co-authored-by: “Arsmart123 <202476410arsmart@gmail.com> Co-authored-by: Yuer867 <62204893+Yuer867@users.noreply.github.com> Co-authored-by: lucasliunju <lucasliunju@gmail.com> Co-authored-by: LuGY <74758262+Gy-Lu@users.noreply.github.com> Co-authored-by: ExtremeViscent <zhangyiqi55732@sina.com> Co-authored-by: Xu Kai <xukai16@foxmail.com> Co-authored-by: Zirui Zhu <zhuzr21@gmail.com> Co-authored-by: Ofey Chan <ofey206@gmail.com> Co-authored-by: DouJS <dujiangsu@163.com> Co-authored-by: Jie Zhu <chore.08-protist@icloud.com> Co-authored-by: shenggan <csg19971016@gmail.com> Co-authored-by: Kai Wang (Victor Kai) <37533040+kaiwang960112@users.noreply.github.com> Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com> Co-authored-by: Ziheng Qin <37519855+henryqin1997@users.noreply.github.com>
2022-05-17 01:54:49 +00:00
#define WarpReduceMaxOneStep(a, b) \
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval), a, b); \
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \
*(pval) = max(val0_tmp, *(pval)); \
2021-12-21 04:19:52 +00:00
*(pval + 1) = max(val1_tmp, *(pval + 1));
WarpReduceMaxOneStep(16, 32);
WarpReduceMaxOneStep(8, 32);
WarpReduceMaxOneStep(4, 32);
WarpReduceMaxOneStep(2, 32);
WarpReduceMaxOneStep(1, 32);
#undef WarpReduceMaxOneStep
}
template <>
__inline__ __device__ void warpReduce<ReduceType::kSum, 1>(float *pval) {
*pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 16, 32);
*pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 8, 32);
*pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 4, 32);
*pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 2, 32);
*pval += __shfl_xor_sync(WARP_REDUCE_MASK, *pval, 1, 32);
}
/*
* Unorll for loop for warpreduce to
* imporve instruction issue efficiency
* ElemX means there are X numbers to be summed
*/
template <>
__inline__ __device__ void warpReduce<ReduceType::kSum, 2>(float *pval) {
float val0_tmp, val1_tmp;
[NFC] Hotfix/format (#984) * [NFC] Polish colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu code style. (#937) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h code style (#939) * [NFC] polish colossalai/kernel/cuda_native/csrc/cpu_adam.cpp code style (#936) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h code style (#938) * [NFC] polish moe_cuda_kernel.cu code style (#940) Co-authored-by: Xiao Ye <xiaoye2@illinois.edu> * [NFC] polish pre-commit run --files colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu code style (#943) * [NFC] polish colossalai/kernel/cuda_native/csrc/moe_cuda.cpp code style (#942) * [NFC] polish colossalai/kernel/cuda_native/csrc/cpu_adam.h code style (#945) * [NFC] polish colossalai/kernel/jit/bias_gelu.py code style (#946) Co-authored-by: jnbai <897086360@qq.com> * [NFC] polish colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu code style (#949) Co-authored-by: Jiatong <jiatong.han@u.nus.edu> * [NFC] polish colossalai/builder/pipeline.py code style (#951) * [NFC] polish colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp code style (#952) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu code style (#953) Co-authored-by: 何晓昕 <cautious@hexiaoxins-MacBook-Pro.local> * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu code style (#954) * [NFC] polish colossalai/kernel/cuda_native/scaled_softmax.py code style (#955) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/context.h code style (#956) Co-authored-by: RichardoLuo <14049555596@qq.com> * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h code style (#957) * [NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu code style (#958) * [NFC] polish colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h code style (#962) * [NFC] polish colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp code style (#959) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu code style (#963) Co-authored-by: “Arsmart123 <202476410arsmart@gmail.com> * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h code style (#964) * [NFC] polish __init__.py code style (#965) * [NFC] polish colossalai/nn/layer/parallel_3d/layers.py code style (#966) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h (#968) code style * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h code style (#970) * [NFC] polish colossalai/nn/layer/parallel_2p5d/layers.py code style (#972) * [NFC] polish colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp code style (#973) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu code style (#974) * [NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu code style (#977) * [NFC] polish colossalai/nn/layer/parallel_2d/layers.py code style (#976) * [NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu code style (#978) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu code style (#979) * [NFC] polish colossalai/kernel/cuda_native/layer_norm.py code style (#980) * [NFC] polish colossalai/nn/layer/utils/common.py code style (#983) Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com> Co-authored-by: yuxuan-lou <83441848+yuxuan-lou@users.noreply.github.com> Co-authored-by: Geng Zhang <34452939+zxgx@users.noreply.github.com> Co-authored-by: Maruyama_Aya <38985202+MaruyamaAya@users.noreply.github.com> Co-authored-by: XYE <92607131+Itok2000u@users.noreply.github.com> Co-authored-by: Xiao Ye <xiaoye2@illinois.edu> Co-authored-by: HaoyuQin <79465534+coder-chin@users.noreply.github.com> Co-authored-by: wky <64853922+wangkuangyi@users.noreply.github.com> Co-authored-by: bajiaoyu517 <59548007+bajiaoyu517@users.noreply.github.com> Co-authored-by: luoling-LC <105470086+luoling-LC@users.noreply.github.com> Co-authored-by: jnbai <897086360@qq.com> Co-authored-by: JT.Han <59948448+JThh@users.noreply.github.com> Co-authored-by: Jiatong <jiatong.han@u.nus.edu> Co-authored-by: xyupeng <99191637+xyupeng@users.noreply.github.com> Co-authored-by: Sze-qq <68757353+Sze-qq@users.noreply.github.com> Co-authored-by: Cautiousss <48676630+Cautiousss@users.noreply.github.com> Co-authored-by: 何晓昕 <cautious@hexiaoxins-MacBook-Pro.local> Co-authored-by: Luxios22 <67457897+Luxios22@users.noreply.github.com> Co-authored-by: Wangbo Zhao(黑色枷锁) <56866854+wangbo-zhao@users.noreply.github.com> Co-authored-by: RichardoLuo <50363844+RichardoLuo@users.noreply.github.com> Co-authored-by: RichardoLuo <14049555596@qq.com> Co-authored-by: doubleHU <98150031+huxin711@users.noreply.github.com> Co-authored-by: runluo <68489000+run-qiao@users.noreply.github.com> Co-authored-by: MaxT <854721132@qq.com> Co-authored-by: superhao1995 <804673818@qq.com> Co-authored-by: ziyu huang <huang0ziyu@gmail.com> Co-authored-by: “Arsmart123 <202476410arsmart@gmail.com> Co-authored-by: Yuer867 <62204893+Yuer867@users.noreply.github.com> Co-authored-by: lucasliunju <lucasliunju@gmail.com> Co-authored-by: LuGY <74758262+Gy-Lu@users.noreply.github.com> Co-authored-by: ExtremeViscent <zhangyiqi55732@sina.com> Co-authored-by: Xu Kai <xukai16@foxmail.com> Co-authored-by: Zirui Zhu <zhuzr21@gmail.com> Co-authored-by: Ofey Chan <ofey206@gmail.com> Co-authored-by: DouJS <dujiangsu@163.com> Co-authored-by: Jie Zhu <chore.08-protist@icloud.com> Co-authored-by: shenggan <csg19971016@gmail.com> Co-authored-by: Kai Wang (Victor Kai) <37533040+kaiwang960112@users.noreply.github.com> Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com> Co-authored-by: Ziheng Qin <37519855+henryqin1997@users.noreply.github.com>
2022-05-17 01:54:49 +00:00
#define WarpReduceSumOneStep(a, b) \
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \
*(pval + 0) += val0_tmp; \
2021-12-21 04:19:52 +00:00
*(pval + 1) += val1_tmp
WarpReduceSumOneStep(16, 32);
WarpReduceSumOneStep(8, 32);
WarpReduceSumOneStep(4, 32);
WarpReduceSumOneStep(2, 32);
WarpReduceSumOneStep(1, 32);
#undef WarpReduceSumOneStep
}
template <>
__inline__ __device__ void warpReduce<ReduceType::kSum, 4>(float *pval) {
float val0_tmp, val1_tmp, val2_tmp, val3_tmp;
[NFC] Hotfix/format (#984) * [NFC] Polish colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu code style. (#937) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h code style (#939) * [NFC] polish colossalai/kernel/cuda_native/csrc/cpu_adam.cpp code style (#936) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/block_reduce.h code style (#938) * [NFC] polish moe_cuda_kernel.cu code style (#940) Co-authored-by: Xiao Ye <xiaoye2@illinois.edu> * [NFC] polish pre-commit run --files colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax_cuda.cu code style (#943) * [NFC] polish colossalai/kernel/cuda_native/csrc/moe_cuda.cpp code style (#942) * [NFC] polish colossalai/kernel/cuda_native/csrc/cpu_adam.h code style (#945) * [NFC] polish colossalai/kernel/jit/bias_gelu.py code style (#946) Co-authored-by: jnbai <897086360@qq.com> * [NFC] polish colossalai/kernel/cuda_native/csrc/scaled_masked_softmax_cuda.cu code style (#949) Co-authored-by: Jiatong <jiatong.han@u.nus.edu> * [NFC] polish colossalai/builder/pipeline.py code style (#951) * [NFC] polish colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp code style (#952) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/cross_entropy.cu code style (#953) Co-authored-by: 何晓昕 <cautious@hexiaoxins-MacBook-Pro.local> * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu code style (#954) * [NFC] polish colossalai/kernel/cuda_native/scaled_softmax.py code style (#955) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/context.h code style (#956) Co-authored-by: RichardoLuo <14049555596@qq.com> * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/cross_entropy_layer.h code style (#957) * [NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu code style (#958) * [NFC] polish colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h code style (#962) * [NFC] polish colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.cpp code style (#959) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu code style (#963) Co-authored-by: “Arsmart123 <202476410arsmart@gmail.com> * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h code style (#964) * [NFC] polish __init__.py code style (#965) * [NFC] polish colossalai/nn/layer/parallel_3d/layers.py code style (#966) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h (#968) code style * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h code style (#970) * [NFC] polish colossalai/nn/layer/parallel_2p5d/layers.py code style (#972) * [NFC] polish colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp code style (#973) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu code style (#974) * [NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu code style (#977) * [NFC] polish colossalai/nn/layer/parallel_2d/layers.py code style (#976) * [NFC] polish colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu code style (#978) * [NFC] polish colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu code style (#979) * [NFC] polish colossalai/kernel/cuda_native/layer_norm.py code style (#980) * [NFC] polish colossalai/nn/layer/utils/common.py code style (#983) Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com> Co-authored-by: yuxuan-lou <83441848+yuxuan-lou@users.noreply.github.com> Co-authored-by: Geng Zhang <34452939+zxgx@users.noreply.github.com> Co-authored-by: Maruyama_Aya <38985202+MaruyamaAya@users.noreply.github.com> Co-authored-by: XYE <92607131+Itok2000u@users.noreply.github.com> Co-authored-by: Xiao Ye <xiaoye2@illinois.edu> Co-authored-by: HaoyuQin <79465534+coder-chin@users.noreply.github.com> Co-authored-by: wky <64853922+wangkuangyi@users.noreply.github.com> Co-authored-by: bajiaoyu517 <59548007+bajiaoyu517@users.noreply.github.com> Co-authored-by: luoling-LC <105470086+luoling-LC@users.noreply.github.com> Co-authored-by: jnbai <897086360@qq.com> Co-authored-by: JT.Han <59948448+JThh@users.noreply.github.com> Co-authored-by: Jiatong <jiatong.han@u.nus.edu> Co-authored-by: xyupeng <99191637+xyupeng@users.noreply.github.com> Co-authored-by: Sze-qq <68757353+Sze-qq@users.noreply.github.com> Co-authored-by: Cautiousss <48676630+Cautiousss@users.noreply.github.com> Co-authored-by: 何晓昕 <cautious@hexiaoxins-MacBook-Pro.local> Co-authored-by: Luxios22 <67457897+Luxios22@users.noreply.github.com> Co-authored-by: Wangbo Zhao(黑色枷锁) <56866854+wangbo-zhao@users.noreply.github.com> Co-authored-by: RichardoLuo <50363844+RichardoLuo@users.noreply.github.com> Co-authored-by: RichardoLuo <14049555596@qq.com> Co-authored-by: doubleHU <98150031+huxin711@users.noreply.github.com> Co-authored-by: runluo <68489000+run-qiao@users.noreply.github.com> Co-authored-by: MaxT <854721132@qq.com> Co-authored-by: superhao1995 <804673818@qq.com> Co-authored-by: ziyu huang <huang0ziyu@gmail.com> Co-authored-by: “Arsmart123 <202476410arsmart@gmail.com> Co-authored-by: Yuer867 <62204893+Yuer867@users.noreply.github.com> Co-authored-by: lucasliunju <lucasliunju@gmail.com> Co-authored-by: LuGY <74758262+Gy-Lu@users.noreply.github.com> Co-authored-by: ExtremeViscent <zhangyiqi55732@sina.com> Co-authored-by: Xu Kai <xukai16@foxmail.com> Co-authored-by: Zirui Zhu <zhuzr21@gmail.com> Co-authored-by: Ofey Chan <ofey206@gmail.com> Co-authored-by: DouJS <dujiangsu@163.com> Co-authored-by: Jie Zhu <chore.08-protist@icloud.com> Co-authored-by: shenggan <csg19971016@gmail.com> Co-authored-by: Kai Wang (Victor Kai) <37533040+kaiwang960112@users.noreply.github.com> Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com> Co-authored-by: Ziheng Qin <37519855+henryqin1997@users.noreply.github.com>
2022-05-17 01:54:49 +00:00
#define WarpReduceSumOneStep(a, b) \
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \
val2_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 2), a, b); \
val3_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 3), a, b); \
*(pval + 0) += val0_tmp; \
*(pval + 1) += val1_tmp; \
*(pval + 2) += val2_tmp; \
2021-12-21 04:19:52 +00:00
*(pval + 3) += val3_tmp
WarpReduceSumOneStep(16, 32);
WarpReduceSumOneStep(8, 32);
WarpReduceSumOneStep(4, 32);
WarpReduceSumOneStep(2, 32);
WarpReduceSumOneStep(1, 32);
#undef WarpReduceSumOneStep
}
template <>
__inline__ __device__ void blockReduce<ReduceType::kSum, 1>(float *pval) {
const int num = 1;
static __shared__ float shared[num][32];
int lane_id = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
warpReduce<ReduceType::kSum, num>(pval);
if (lane_id == 0) {
#pragma unroll
for (int i = 0; i < num; ++i) {
shared[i][wid] = *(pval + i);
}
}
__syncthreads();
if (threadIdx.x < (blockDim.x >> 5)) {
#pragma unroll
for (int i = 0; i < num; ++i) {
*(pval + i) = shared[i][lane_id];
}
} else {
#pragma unroll
for (int i = 0; i < num; ++i) {
*(pval + i) = 0.f;
}
}
warpReduce<ReduceType::kSum, num>(pval);
}
template <>
__inline__ __device__ void blockReduce<ReduceType::kSum, 2>(float *pval) {
const int num = 2;
static __shared__ float shared[num][32];
int lane_id = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
warpReduce<ReduceType::kSum, num>(pval);
if (lane_id == 0) {
#pragma unroll
for (int i = 0; i < num; ++i) {
shared[i][wid] = *(pval + i);
}
}
__syncthreads();
if (threadIdx.x < (blockDim.x >> 5)) {
#pragma unroll
for (int i = 0; i < num; ++i) {
*(pval + i) = shared[i][lane_id];
}
} else {
#pragma unroll
for (int i = 0; i < num; ++i) {
*(pval + i) = 0.f;
}
}
warpReduce<ReduceType::kSum, num>(pval);
}
template <>
__inline__ __device__ void blockReduce<ReduceType::kSum, 4>(float *pval) {
const int num = 4;
static __shared__ float shared[num][32];
int lane_id = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
warpReduce<ReduceType::kSum, num>(pval);
if (lane_id == 0) {
#pragma unroll
for (int i = 0; i < num; ++i) {
shared[i][wid] = *(pval + i);
}
}
__syncthreads();
if (threadIdx.x < (blockDim.x >> 5)) {
#pragma unroll
for (int i = 0; i < num; ++i) {
*(pval + i) = shared[i][lane_id];
}
} else {
#pragma unroll
for (int i = 0; i < num; ++i) {
*(pval + i) = 0.f;
}
}
warpReduce<ReduceType::kSum, num>(pval);
}
template <>
__inline__ __device__ void blockReduce<ReduceType::kMax, 1>(float *pval) {
const int num = 1;
static __shared__ float shared[num][32];
int lane_id = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
warpReduce<ReduceType::kMax, num>(pval);
if (lane_id == 0) {
#pragma unroll
for (int i = 0; i < num; ++i) {
shared[i][wid] = *(pval + i);
}
}
__syncthreads();
if (threadIdx.x < (blockDim.x >> 5)) {
#pragma unroll
for (int i = 0; i < num; ++i) {
*(pval + i) = shared[i][lane_id];
}
} else {
#pragma unroll
for (int i = 0; i < num; ++i) {
*(pval + i) = REDUCE_FLOAT_INF_NEG;
}
}
warpReduce<ReduceType::kMax, num>(pval);
}
template <>
__inline__ __device__ void blockReduce<ReduceType::kMax, 2>(float *pval) {
const int num = 1;
static __shared__ float shared[num][32];
int lane_id = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
warpReduce<ReduceType::kMax, num>(pval);
if (lane_id == 0) {
#pragma unroll
for (int i = 0; i < num; ++i) {
shared[i][wid] = *(pval + i);
}
}
__syncthreads();
if (threadIdx.x < (blockDim.x >> 5)) {
#pragma unroll
for (int i = 0; i < num; ++i) {
*(pval + i) = shared[i][lane_id];
}
} else {
#pragma unroll
for (int i = 0; i < num; ++i) {
*(pval + i) = REDUCE_FLOAT_INF_NEG;
}
}
warpReduce<ReduceType::kMax, num>(pval);
}
template <>
__inline__ __device__ void blockReduce<ReduceType::kMax, 4>(float *pval) {
const int num = 1;
static __shared__ float shared[num][32];
int lane_id = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
warpReduce<ReduceType::kMax, num>(pval);
if (lane_id == 0) {
#pragma unroll
for (int i = 0; i < num; ++i) {
shared[i][wid] = *(pval + i);
}
}
__syncthreads();
if (threadIdx.x < (blockDim.x >> 5)) {
#pragma unroll
for (int i = 0; i < num; ++i) {
*(pval + i) = shared[i][lane_id];
}
} else {
#pragma unroll
for (int i = 0; i < num; ++i) {
*(pval + i) = REDUCE_FLOAT_INF_NEG;
}
}
warpReduce<ReduceType::kMax, num>(pval);
}