From 5bbefeb06ab2fda8cc66434dcad369dc55357f9a Mon Sep 17 00:00:00 2001 From: XYE <92607131+Itok2000u@users.noreply.github.com> Date: Fri, 13 May 2022 15:56:34 +0800 Subject: [PATCH] [NFC] polish moe_cuda_kernel.cu code style (#940) Co-authored-by: Xiao Ye --- .../cuda_native/csrc/moe_cuda_kernel.cu | 77 ++++++------------- 1 file changed, 25 insertions(+), 52 deletions(-) diff --git a/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu b/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu index ac7f8aba2..0454377a2 100644 --- a/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu +++ b/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu @@ -1,12 +1,13 @@ -#include "block_reduce.h" -#include #include #include #include +#include + +#include "block_reduce.h" + template __device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) { - assert(cols % pack_size == 0); const int bpack_size = block_size * pack_size; @@ -28,7 +29,6 @@ __device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) { template __device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) { - assert(cols % pack_size == 0); const int bpack_size = block_size * pack_size; @@ -51,7 +51,6 @@ __device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) { template __device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2, const int cols) { - assert(cols % pack_size == 0); const int bpack_size = block_size * pack_size; @@ -75,7 +74,6 @@ __device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2, template __device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2, const int cols) { - assert(cols % pack_size == 0); const int bpack_size = block_size * pack_size; @@ -105,7 +103,6 @@ __device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2, template __device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight, const int cols) { - assert(cols % pack_size == 0); const int bpack_size = block_size * pack_size; @@ -134,7 +131,6 @@ __device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight, template __device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row, T *weight_grad, const T weight, const int cols) { - assert(cols % pack_size == 0); const int bpack_size = block_size * pack_size; @@ -164,15 +160,13 @@ __device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row, blockReduce(&thread_sum); - if (threadIdx.x == 0) - *weight_grad = static_cast(thread_sum); + if (threadIdx.x == 0) *weight_grad = static_cast(thread_sum); } template __device__ void moe_cb_two_fwd(T *src_row1, T *src_row2, T *dst_row, const T weight1, const T weight2, const int cols) { - assert(cols % pack_size == 0); const int bpack_size = block_size * pack_size; @@ -204,7 +198,6 @@ __device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row, T *tks_row1, T *tks_row2, T *weight_grad1, T *weight_grad2, const T weight1, const T weight2, const int cols) { - assert(cols % pack_size == 0); const int bpack_size = block_size * pack_size; @@ -251,7 +244,6 @@ template __device__ void moe_dpch_fwd_selector(T *src_row, T *dst_row1, T *dst_row2, const int cols, const int indicator1, const int indicator2) { - if (indicator1 != 0 && indicator2 != 0) moe_dpch_two_fwd(src_row, dst_row1, dst_row2, cols); @@ -267,7 +259,6 @@ template __device__ void moe_dpch_bwd_selector(T *src_row, T *dst_row1, T *dst_row2, const int cols, const int indicator1, const int indicator2) { - if (indicator1 != 0 && indicator2 != 0) moe_dpch_two_bwd(src_row, dst_row1, dst_row2, cols); @@ -283,7 +274,6 @@ template __global__ void moe_dpch_fwd_kernel(T *batch_tokens, T *expert_input, int *mask1, int *mask2, int *dest1, int *dest2, const int h) { - int row = blockIdx.x; int indicator2 = mask2 == nullptr ? 0 : mask2[row]; moe_dpch_fwd_selector( @@ -295,7 +285,6 @@ template __global__ void moe_dpch_bwd_kernel(T *tokens_grad, T *expert_grad, int *mask1, int *mask2, int *dest1, int *dest2, const int h) { - int row = blockIdx.x; int indicator2 = mask2 == nullptr ? 0 : mask2[row]; moe_dpch_bwd_selector( @@ -310,7 +299,6 @@ __device__ void moe_cb_fwd_selector(T *src_row1, T *src_row2, T *dst_row, const int cols, const T weight1, const T weight2, const int indicator1, const int indicator2) { - if (indicator1 != 0 && indicator2 != 0) moe_cb_two_fwd(src_row1, src_row2, dst_row, weight1, weight2, cols); @@ -328,7 +316,6 @@ __device__ void moe_cb_bwd_selector(T *src_row1, T *src_row2, T *dst_row, T *wt_grad1, T *wt_grad2, const T weight1, const T weight2, const int indicator1, const int indicator2) { - if (indicator1 != 0 && indicator2 != 0) moe_cb_two_bwd(src_row1, src_row2, dst_row, tks_row1, tks_row2, wt_grad1, @@ -348,7 +335,6 @@ __global__ void moe_cb_fwd_kernel(T *expert_tokens, T *combine_tokens, T *logits, int *mask1, int *mask2, int *dest1, int *dest2, const int e, const int c, const int h) { - int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; int indicator2 = mask2 == nullptr ? 0 : mask2[row]; T *row_log = logits + (row * e); @@ -363,7 +349,6 @@ __global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks, T *logits, T *logits_grad, int *mask1, int *mask2, int *dest1, int *dest2, const int e, const int c, const int h) { - int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; int indicator2 = mask2 == nullptr ? 0 : mask2[row]; T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e); @@ -379,7 +364,6 @@ __global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks, template __global__ void cumsum_kernel(int *inputs, int *outputs, const int s, const int e) { - assert(s % pack_size == 0); constexpr int bpack_size = block_size * pack_size; int tid = threadIdx.x, bid = blockIdx.x, tps = tid * pack_size, last_sum = -1; @@ -426,8 +410,7 @@ __global__ void cumsum_kernel(int *inputs, int *outputs, const int s, } __syncthreads(); - if (tid == 0) - temp[0] = temp[block_size]; + if (tid == 0) temp[0] = temp[block_size]; __syncthreads(); if (idx + tps < s) { @@ -453,7 +436,6 @@ template void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1, int *mask2, int *dest1, int *dest2, const int s, const int h) { - if (h < 256) moe_dpch_fwd_kernel <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); @@ -474,7 +456,6 @@ void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1, template void moe_dpch_bwd_launch(T *tokens_grad, T *expert_grad, int *mask1, int *mask2, int *dest1, int *dest2, const int s, const int h) { - if (h < 256) moe_dpch_bwd_kernel <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); @@ -496,7 +477,6 @@ template void moe_cb_fwd_launch(T *expert_tokens, T *combine_tokens, T *logits, int *mask1, int *mask2, int *dest1, int *dest2, const int s, const int e, const int c, const int h) { - if (h < 256) moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, @@ -524,12 +504,11 @@ void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits, T *logits_grad, int *mask1, int *mask2, int *dest1, int *dest2, const int s, const int e, const int c, const int h) { - if (h < 256) moe_cb_bwd_kernel<<>>(tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h); - else // if (h < 512) + else // if (h < 512) moe_cb_bwd_kernel<<>>(tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h); @@ -544,7 +523,6 @@ void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits, } void cumsum_launch(int *inputs, int *outputs, const int s, const int e) { - if (s <= 256) cumsum_kernel<256, 1><<>>(inputs, outputs, s, e); else if (s <= 512) @@ -559,27 +537,26 @@ void cumsum_launch(int *inputs, int *outputs, const int s, const int e) { // API FUNCTIONS -------------------------------- -#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \ - switch (TYPE) { \ - case at::ScalarType::Float: { \ - using scalar_t = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: { \ - using scalar_t = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented yet for specific data type."); \ +#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Float: { \ + using scalar_t = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented yet for specific data type."); \ } torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h, torch::Tensor batch_tokens, torch::Tensor mask, torch::Tensor dest_idx) { - assert(h % 16 == 0); auto res = torch::zeros( {ec, h}, @@ -601,7 +578,6 @@ torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h, torch::Tensor expert_grad, torch::Tensor mask, torch::Tensor dest_idx) { - assert(h % 16 == 0); auto res = torch::zeros( {s, h}, torch::dtype(expert_grad.dtype()).device(expert_grad.device())); @@ -622,7 +598,6 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask, torch::Tensor dest_idx) { - assert(h % 16 == 0); assert(expert_tokens.dtype() == logits.dtype()); @@ -643,11 +618,10 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, return res; } -std::vector -moe_combine_cuda_backward(int s, int e, int c, int h, torch::Tensor tokens_grad, - torch::Tensor expert_tokens, torch::Tensor logits, - torch::Tensor mask, torch::Tensor dest_idx) { - +std::vector moe_combine_cuda_backward( + int s, int e, int c, int h, torch::Tensor tokens_grad, + torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx) { assert(h % 16 == 0); assert(tokens_grad.dtype() == expert_tokens.dtype()); assert(expert_tokens.dtype() == logits.dtype()); @@ -673,7 +647,6 @@ moe_combine_cuda_backward(int s, int e, int c, int h, torch::Tensor tokens_grad, } torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) { - assert(mask.dim() == 2); assert(mask.dtype() == torch::kInt32);