diff --git a/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp b/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp index 063fbc664..6aaa15b4e 100644 --- a/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp +++ b/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp @@ -1,118 +1,99 @@ #include +torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h, + torch::Tensor batch_tokens, + torch::Tensor mask, + torch::Tensor dest_idx); -torch::Tensor moe_dispatch_cuda_forward( - int s, int ec, int h, - torch::Tensor batch_tokens, - torch::Tensor mask, - torch::Tensor dest_idx); +torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h, + torch::Tensor expert_grad, + torch::Tensor mask, + torch::Tensor dest_idx); -torch::Tensor moe_dispatch_cuda_backward( - int s, int ec, int h, - torch::Tensor expert_grad, - torch::Tensor mask, - torch::Tensor dest_idx); +torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, + torch::Tensor expert_tokens, + torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx); -torch::Tensor moe_combine_cuda_forward( - int s, int e, int c, int h, - torch::Tensor expert_tokens, - torch::Tensor logits, - torch::Tensor mask, - torch::Tensor dest_idx); - -std::vector moe_combine_cuda_backward( - int s, int e, int c, int h, - torch::Tensor tokens_grad, - torch::Tensor expert_tokens, - torch::Tensor logits, - torch::Tensor mask, - torch::Tensor dest_idx); +std::vector +moe_combine_cuda_backward(int s, int e, int c, int h, torch::Tensor tokens_grad, + torch::Tensor expert_tokens, torch::Tensor logits, + torch::Tensor mask, torch::Tensor dest_idx); torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask); -#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) +#define CHECK_CUDA(x) \ + TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) -torch::Tensor moe_dispatch_forward( - int s, int ec, int h, - torch::Tensor batch_tokens, - torch::Tensor mask, - torch::Tensor dest_idx) { +torch::Tensor moe_dispatch_forward(int s, int ec, int h, + torch::Tensor batch_tokens, + torch::Tensor mask, torch::Tensor dest_idx) { - CHECK_INPUT(batch_tokens); - CHECK_CUDA(mask); - CHECK_CUDA(dest_idx); + CHECK_INPUT(batch_tokens); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); - return moe_dispatch_cuda_forward( - s, ec, h, - batch_tokens, mask, dest_idx); + return moe_dispatch_cuda_forward(s, ec, h, batch_tokens, mask, dest_idx); } -torch::Tensor moe_dispatch_backward( - int s, int ec, int h, - torch::Tensor expert_grad, - torch::Tensor mask, - torch::Tensor dest_idx) { +torch::Tensor moe_dispatch_backward(int s, int ec, int h, + torch::Tensor expert_grad, + torch::Tensor mask, + torch::Tensor dest_idx) { - CHECK_INPUT(expert_grad); - CHECK_CUDA(mask); - CHECK_CUDA(dest_idx); + CHECK_INPUT(expert_grad); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); - return moe_dispatch_cuda_backward( - s, ec, h, - expert_grad, mask, dest_idx); + return moe_dispatch_cuda_backward(s, ec, h, expert_grad, mask, dest_idx); } -torch::Tensor moe_combine_forward( - int s, int e, int c, int h, - torch::Tensor expert_tokens, - torch::Tensor logits, - torch::Tensor mask, - torch::Tensor dest_idx) { +torch::Tensor moe_combine_forward(int s, int e, int c, int h, + torch::Tensor expert_tokens, + torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx) { - CHECK_INPUT(expert_tokens); - CHECK_INPUT(logits); - CHECK_CUDA(mask); - CHECK_CUDA(dest_idx); + CHECK_INPUT(expert_tokens); + CHECK_INPUT(logits); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); - return moe_combine_cuda_forward( - s, e, c, h, - expert_tokens, logits, mask, dest_idx); + return moe_combine_cuda_forward(s, e, c, h, expert_tokens, logits, mask, + dest_idx); } -std::vector moe_combine_backward( - int s, int e, int c, int h, - torch::Tensor tokens_grad, - torch::Tensor expert_tokens, - torch::Tensor logits, - torch::Tensor mask, - torch::Tensor dest_idx) { +std::vector +moe_combine_backward(int s, int e, int c, int h, torch::Tensor tokens_grad, + torch::Tensor expert_tokens, torch::Tensor logits, + torch::Tensor mask, torch::Tensor dest_idx) { - CHECK_INPUT(tokens_grad); - CHECK_INPUT(logits); - CHECK_CUDA(mask); - CHECK_CUDA(dest_idx); + CHECK_INPUT(tokens_grad); + CHECK_INPUT(logits); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); - return moe_combine_cuda_backward( - s, e, c, h, - tokens_grad, expert_tokens, logits, mask, dest_idx); + return moe_combine_cuda_backward(s, e, c, h, tokens_grad, expert_tokens, + logits, mask, dest_idx); } torch::Tensor moe_cumsum(torch::Tensor mask) { - CHECK_INPUT(mask); - return cumsum_sub_one_in_dim0(mask); + CHECK_INPUT(mask); + return cumsum_sub_one_in_dim0(mask); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("cumsum_sub_one", &moe_cumsum, - "Fast cumsum operation in dim0"); - m.def("dispatch_forward", &moe_dispatch_forward, + m.def("cumsum_sub_one", &moe_cumsum, "Fast cumsum operation in dim0"); + m.def("dispatch_forward", &moe_dispatch_forward, "Forward operation in MoE dispatch function"); - m.def("dispatch_backward", &moe_dispatch_backward, + m.def("dispatch_backward", &moe_dispatch_backward, "Backward operation in MoE dispatch function"); - m.def("combine_forward", &moe_combine_forward, + m.def("combine_forward", &moe_combine_forward, "Combine operation in MoE combine function"); - m.def("combine_backward", &moe_combine_backward, + m.def("combine_backward", &moe_combine_backward, "Combine operation in MoE combine function"); }