|
|
@ -15,10 +15,10 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
|
|
|
|
torch::Tensor logits, torch::Tensor mask,
|
|
|
|
torch::Tensor logits, torch::Tensor mask,
|
|
|
|
torch::Tensor dest_idx);
|
|
|
|
torch::Tensor dest_idx);
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<torch::Tensor>
|
|
|
|
std::vector<torch::Tensor> moe_combine_cuda_backward(
|
|
|
|
moe_combine_cuda_backward(int s, int e, int c, int h, torch::Tensor tokens_grad,
|
|
|
|
int s, int e, int c, int h, torch::Tensor tokens_grad,
|
|
|
|
torch::Tensor expert_tokens, torch::Tensor logits,
|
|
|
|
torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask,
|
|
|
|
torch::Tensor mask, torch::Tensor dest_idx);
|
|
|
|
torch::Tensor dest_idx);
|
|
|
|
|
|
|
|
|
|
|
|
torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask);
|
|
|
|
torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask);
|
|
|
|
|
|
|
|
|
|
|
@ -33,7 +33,6 @@ torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask);
|
|
|
|
torch::Tensor moe_dispatch_forward(int s, int ec, int h,
|
|
|
|
torch::Tensor moe_dispatch_forward(int s, int ec, int h,
|
|
|
|
torch::Tensor batch_tokens,
|
|
|
|
torch::Tensor batch_tokens,
|
|
|
|
torch::Tensor mask, torch::Tensor dest_idx) {
|
|
|
|
torch::Tensor mask, torch::Tensor dest_idx) {
|
|
|
|
|
|
|
|
|
|
|
|
CHECK_INPUT(batch_tokens);
|
|
|
|
CHECK_INPUT(batch_tokens);
|
|
|
|
CHECK_CUDA(mask);
|
|
|
|
CHECK_CUDA(mask);
|
|
|
|
CHECK_CUDA(dest_idx);
|
|
|
|
CHECK_CUDA(dest_idx);
|
|
|
@ -45,7 +44,6 @@ torch::Tensor moe_dispatch_backward(int s, int ec, int h,
|
|
|
|
torch::Tensor expert_grad,
|
|
|
|
torch::Tensor expert_grad,
|
|
|
|
torch::Tensor mask,
|
|
|
|
torch::Tensor mask,
|
|
|
|
torch::Tensor dest_idx) {
|
|
|
|
torch::Tensor dest_idx) {
|
|
|
|
|
|
|
|
|
|
|
|
CHECK_INPUT(expert_grad);
|
|
|
|
CHECK_INPUT(expert_grad);
|
|
|
|
CHECK_CUDA(mask);
|
|
|
|
CHECK_CUDA(mask);
|
|
|
|
CHECK_CUDA(dest_idx);
|
|
|
|
CHECK_CUDA(dest_idx);
|
|
|
@ -57,7 +55,6 @@ torch::Tensor moe_combine_forward(int s, int e, int c, int h,
|
|
|
|
torch::Tensor expert_tokens,
|
|
|
|
torch::Tensor expert_tokens,
|
|
|
|
torch::Tensor logits, torch::Tensor mask,
|
|
|
|
torch::Tensor logits, torch::Tensor mask,
|
|
|
|
torch::Tensor dest_idx) {
|
|
|
|
torch::Tensor dest_idx) {
|
|
|
|
|
|
|
|
|
|
|
|
CHECK_INPUT(expert_tokens);
|
|
|
|
CHECK_INPUT(expert_tokens);
|
|
|
|
CHECK_INPUT(logits);
|
|
|
|
CHECK_INPUT(logits);
|
|
|
|
CHECK_CUDA(mask);
|
|
|
|
CHECK_CUDA(mask);
|
|
|
@ -67,11 +64,12 @@ torch::Tensor moe_combine_forward(int s, int e, int c, int h,
|
|
|
|
dest_idx);
|
|
|
|
dest_idx);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<torch::Tensor>
|
|
|
|
std::vector<torch::Tensor> moe_combine_backward(int s, int e, int c, int h,
|
|
|
|
moe_combine_backward(int s, int e, int c, int h, torch::Tensor tokens_grad,
|
|
|
|
torch::Tensor tokens_grad,
|
|
|
|
torch::Tensor expert_tokens, torch::Tensor logits,
|
|
|
|
torch::Tensor expert_tokens,
|
|
|
|
torch::Tensor mask, torch::Tensor dest_idx) {
|
|
|
|
torch::Tensor logits,
|
|
|
|
|
|
|
|
torch::Tensor mask,
|
|
|
|
|
|
|
|
torch::Tensor dest_idx) {
|
|
|
|
CHECK_INPUT(tokens_grad);
|
|
|
|
CHECK_INPUT(tokens_grad);
|
|
|
|
CHECK_INPUT(logits);
|
|
|
|
CHECK_INPUT(logits);
|
|
|
|
CHECK_CUDA(mask);
|
|
|
|
CHECK_CUDA(mask);
|
|
|
|