pull/673/head
coder-chin 3 years ago committed by binmakeswell
parent e014144c44
commit 5835631218

@ -304,7 +304,8 @@ __global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) {
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i);
for (int i = 1; i < WARP_SIZE; i <<= 1)
sum += g.shfl_xor(sum, i);
#pragma unroll
for (int i = 0; i < ITERATIONS; ++i) {

Loading…
Cancel
Save