[NFC] polish colossalai/kernel/cuda_native/csrc/moe_cuda.cpp code style (#942)

pull/997/head
wky 3 years ago committed by binmakeswell
parent c0f373db5d
commit 8ffdc38376

@ -15,25 +15,24 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
torch::Tensor logits, torch::Tensor mask, torch::Tensor logits, torch::Tensor mask,
torch::Tensor dest_idx); torch::Tensor dest_idx);
std::vector<torch::Tensor> std::vector<torch::Tensor> moe_combine_cuda_backward(
moe_combine_cuda_backward(int s, int e, int c, int h, torch::Tensor tokens_grad, int s, int e, int c, int h, torch::Tensor tokens_grad,
torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask,
torch::Tensor mask, torch::Tensor dest_idx); torch::Tensor dest_idx);
torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask); torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask);
#define CHECK_CUDA(x) \ #define CHECK_CUDA(x) \
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \ #define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \ #define CHECK_INPUT(x) \
CHECK_CUDA(x); \ CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x) CHECK_CONTIGUOUS(x)
torch::Tensor moe_dispatch_forward(int s, int ec, int h, torch::Tensor moe_dispatch_forward(int s, int ec, int h,
torch::Tensor batch_tokens, torch::Tensor batch_tokens,
torch::Tensor mask, torch::Tensor dest_idx) { torch::Tensor mask, torch::Tensor dest_idx) {
CHECK_INPUT(batch_tokens); CHECK_INPUT(batch_tokens);
CHECK_CUDA(mask); CHECK_CUDA(mask);
CHECK_CUDA(dest_idx); CHECK_CUDA(dest_idx);
@ -45,7 +44,6 @@ torch::Tensor moe_dispatch_backward(int s, int ec, int h,
torch::Tensor expert_grad, torch::Tensor expert_grad,
torch::Tensor mask, torch::Tensor mask,
torch::Tensor dest_idx) { torch::Tensor dest_idx) {
CHECK_INPUT(expert_grad); CHECK_INPUT(expert_grad);
CHECK_CUDA(mask); CHECK_CUDA(mask);
CHECK_CUDA(dest_idx); CHECK_CUDA(dest_idx);
@ -57,7 +55,6 @@ torch::Tensor moe_combine_forward(int s, int e, int c, int h,
torch::Tensor expert_tokens, torch::Tensor expert_tokens,
torch::Tensor logits, torch::Tensor mask, torch::Tensor logits, torch::Tensor mask,
torch::Tensor dest_idx) { torch::Tensor dest_idx) {
CHECK_INPUT(expert_tokens); CHECK_INPUT(expert_tokens);
CHECK_INPUT(logits); CHECK_INPUT(logits);
CHECK_CUDA(mask); CHECK_CUDA(mask);
@ -67,11 +64,12 @@ torch::Tensor moe_combine_forward(int s, int e, int c, int h,
dest_idx); dest_idx);
} }
std::vector<torch::Tensor> std::vector<torch::Tensor> moe_combine_backward(int s, int e, int c, int h,
moe_combine_backward(int s, int e, int c, int h, torch::Tensor tokens_grad, torch::Tensor tokens_grad,
torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor expert_tokens,
torch::Tensor mask, torch::Tensor dest_idx) { torch::Tensor logits,
torch::Tensor mask,
torch::Tensor dest_idx) {
CHECK_INPUT(tokens_grad); CHECK_INPUT(tokens_grad);
CHECK_INPUT(logits); CHECK_INPUT(logits);
CHECK_CUDA(mask); CHECK_CUDA(mask);

Loading…
Cancel
Save