fix format (#564)

pull/673/head
coder-chin 2022-03-31 15:00:50 +08:00 committed by binmakeswell
parent e014144c44
commit 5835631218
1 changed files with 4 additions and 3 deletions

View File

@ -120,7 +120,7 @@ __global__ void ker_attn_softmax(T *inp, const T *attn_mask, int from_len,
BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i],
to_len); to_len);
} }
} // blockIdx.x } // blockIdx.x
} }
template <typename T, int block_dim, int ele_per_thread> template <typename T, int block_dim, int ele_per_thread>
@ -198,7 +198,7 @@ __global__ void ker_attn_softmax_lt32(T *inp, const T *attn_mask, int from_len,
BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i],
to_len); to_len);
} }
} // blockIdx.x } // blockIdx.x
} }
/* /*
@ -304,7 +304,8 @@ __global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) {
cg::thread_block b = cg::this_thread_block(); cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b); cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i); for (int i = 1; i < WARP_SIZE; i <<= 1)
sum += g.shfl_xor(sum, i);
#pragma unroll #pragma unroll
for (int i = 0; i < ITERATIONS; ++i) { for (int i = 0; i < ITERATIONS; ++i) {