diff --git a/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp b/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp index 6aaa15b4e..61c8a7250 100644 --- a/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp +++ b/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp @@ -15,25 +15,24 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, torch::Tensor logits, torch::Tensor mask, torch::Tensor dest_idx); -std::vector -moe_combine_cuda_backward(int s, int e, int c, int h, torch::Tensor tokens_grad, - torch::Tensor expert_tokens, torch::Tensor logits, - torch::Tensor mask, torch::Tensor dest_idx); +std::vector moe_combine_cuda_backward( + int s, int e, int c, int h, torch::Tensor tokens_grad, + torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx); torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask); -#define CHECK_CUDA(x) \ +#define CHECK_CUDA(x) \ TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) \ +#define CHECK_CONTIGUOUS(x) \ TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) torch::Tensor moe_dispatch_forward(int s, int ec, int h, torch::Tensor batch_tokens, torch::Tensor mask, torch::Tensor dest_idx) { - CHECK_INPUT(batch_tokens); CHECK_CUDA(mask); CHECK_CUDA(dest_idx); @@ -45,7 +44,6 @@ torch::Tensor moe_dispatch_backward(int s, int ec, int h, torch::Tensor expert_grad, torch::Tensor mask, torch::Tensor dest_idx) { - CHECK_INPUT(expert_grad); CHECK_CUDA(mask); CHECK_CUDA(dest_idx); @@ -57,7 +55,6 @@ torch::Tensor moe_combine_forward(int s, int e, int c, int h, torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask, torch::Tensor dest_idx) { - CHECK_INPUT(expert_tokens); CHECK_INPUT(logits); CHECK_CUDA(mask); @@ -67,11 +64,12 @@ torch::Tensor moe_combine_forward(int s, int e, int c, int h, dest_idx); } -std::vector -moe_combine_backward(int s, int e, int c, int h, torch::Tensor tokens_grad, - torch::Tensor expert_tokens, torch::Tensor logits, - torch::Tensor mask, torch::Tensor dest_idx) { - +std::vector moe_combine_backward(int s, int e, int c, int h, + torch::Tensor tokens_grad, + torch::Tensor expert_tokens, + torch::Tensor logits, + torch::Tensor mask, + torch::Tensor dest_idx) { CHECK_INPUT(tokens_grad); CHECK_INPUT(logits); CHECK_CUDA(mask);