|
|
|
@ -1,12 +1,13 @@
|
|
|
|
|
#include "block_reduce.h"
|
|
|
|
|
#include <cub/cub.cuh>
|
|
|
|
|
#include <cuda.h>
|
|
|
|
|
#include <cuda_fp16.h>
|
|
|
|
|
#include <torch/extension.h>
|
|
|
|
|
|
|
|
|
|
#include <cub/cub.cuh>
|
|
|
|
|
|
|
|
|
|
#include "block_reduce.h"
|
|
|
|
|
|
|
|
|
|
template <typename T, int block_size, int pack_size>
|
|
|
|
|
__device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) {
|
|
|
|
|
|
|
|
|
|
assert(cols % pack_size == 0);
|
|
|
|
|
const int bpack_size = block_size * pack_size;
|
|
|
|
|
|
|
|
|
@ -28,7 +29,6 @@ __device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) {
|
|
|
|
|
|
|
|
|
|
template <typename T, int block_size, int pack_size>
|
|
|
|
|
__device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) {
|
|
|
|
|
|
|
|
|
|
assert(cols % pack_size == 0);
|
|
|
|
|
const int bpack_size = block_size * pack_size;
|
|
|
|
|
|
|
|
|
@ -51,7 +51,6 @@ __device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) {
|
|
|
|
|
template <typename T, int block_size, int pack_size>
|
|
|
|
|
__device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2,
|
|
|
|
|
const int cols) {
|
|
|
|
|
|
|
|
|
|
assert(cols % pack_size == 0);
|
|
|
|
|
const int bpack_size = block_size * pack_size;
|
|
|
|
|
|
|
|
|
@ -75,7 +74,6 @@ __device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2,
|
|
|
|
|
template <typename T, int block_size, int pack_size>
|
|
|
|
|
__device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2,
|
|
|
|
|
const int cols) {
|
|
|
|
|
|
|
|
|
|
assert(cols % pack_size == 0);
|
|
|
|
|
const int bpack_size = block_size * pack_size;
|
|
|
|
|
|
|
|
|
@ -105,7 +103,6 @@ __device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2,
|
|
|
|
|
template <typename T, int block_size, int pack_size>
|
|
|
|
|
__device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight,
|
|
|
|
|
const int cols) {
|
|
|
|
|
|
|
|
|
|
assert(cols % pack_size == 0);
|
|
|
|
|
const int bpack_size = block_size * pack_size;
|
|
|
|
|
|
|
|
|
@ -134,7 +131,6 @@ __device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight,
|
|
|
|
|
template <typename T, int block_size, int pack_size>
|
|
|
|
|
__device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row,
|
|
|
|
|
T *weight_grad, const T weight, const int cols) {
|
|
|
|
|
|
|
|
|
|
assert(cols % pack_size == 0);
|
|
|
|
|
const int bpack_size = block_size * pack_size;
|
|
|
|
|
|
|
|
|
@ -164,15 +160,13 @@ __device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row,
|
|
|
|
|
|
|
|
|
|
blockReduce<ReduceType::kSum, 1>(&thread_sum);
|
|
|
|
|
|
|
|
|
|
if (threadIdx.x == 0)
|
|
|
|
|
*weight_grad = static_cast<T>(thread_sum);
|
|
|
|
|
if (threadIdx.x == 0) *weight_grad = static_cast<T>(thread_sum);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T, int block_size, int pack_size>
|
|
|
|
|
__device__ void moe_cb_two_fwd(T *src_row1, T *src_row2, T *dst_row,
|
|
|
|
|
const T weight1, const T weight2,
|
|
|
|
|
const int cols) {
|
|
|
|
|
|
|
|
|
|
assert(cols % pack_size == 0);
|
|
|
|
|
const int bpack_size = block_size * pack_size;
|
|
|
|
|
|
|
|
|
@ -204,7 +198,6 @@ __device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row,
|
|
|
|
|
T *tks_row1, T *tks_row2, T *weight_grad1,
|
|
|
|
|
T *weight_grad2, const T weight1,
|
|
|
|
|
const T weight2, const int cols) {
|
|
|
|
|
|
|
|
|
|
assert(cols % pack_size == 0);
|
|
|
|
|
const int bpack_size = block_size * pack_size;
|
|
|
|
|
|
|
|
|
@ -251,7 +244,6 @@ template <typename T, int block_size, int pack_size>
|
|
|
|
|
__device__ void moe_dpch_fwd_selector(T *src_row, T *dst_row1, T *dst_row2,
|
|
|
|
|
const int cols, const int indicator1,
|
|
|
|
|
const int indicator2) {
|
|
|
|
|
|
|
|
|
|
if (indicator1 != 0 && indicator2 != 0)
|
|
|
|
|
moe_dpch_two_fwd<T, block_size, pack_size>(src_row, dst_row1, dst_row2,
|
|
|
|
|
cols);
|
|
|
|
@ -267,7 +259,6 @@ template <typename T, int block_size, int pack_size>
|
|
|
|
|
__device__ void moe_dpch_bwd_selector(T *src_row, T *dst_row1, T *dst_row2,
|
|
|
|
|
const int cols, const int indicator1,
|
|
|
|
|
const int indicator2) {
|
|
|
|
|
|
|
|
|
|
if (indicator1 != 0 && indicator2 != 0)
|
|
|
|
|
moe_dpch_two_bwd<T, block_size, pack_size>(src_row, dst_row1, dst_row2,
|
|
|
|
|
cols);
|
|
|
|
@ -283,7 +274,6 @@ template <typename T, int block_size, int pack_size>
|
|
|
|
|
__global__ void moe_dpch_fwd_kernel(T *batch_tokens, T *expert_input,
|
|
|
|
|
int *mask1, int *mask2, int *dest1,
|
|
|
|
|
int *dest2, const int h) {
|
|
|
|
|
|
|
|
|
|
int row = blockIdx.x;
|
|
|
|
|
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
|
|
|
|
|
moe_dpch_fwd_selector<T, block_size, pack_size>(
|
|
|
|
@ -295,7 +285,6 @@ template <typename T, int block_size, int pack_size>
|
|
|
|
|
__global__ void moe_dpch_bwd_kernel(T *tokens_grad, T *expert_grad, int *mask1,
|
|
|
|
|
int *mask2, int *dest1, int *dest2,
|
|
|
|
|
const int h) {
|
|
|
|
|
|
|
|
|
|
int row = blockIdx.x;
|
|
|
|
|
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
|
|
|
|
|
moe_dpch_bwd_selector<T, block_size, pack_size>(
|
|
|
|
@ -310,7 +299,6 @@ __device__ void moe_cb_fwd_selector(T *src_row1, T *src_row2, T *dst_row,
|
|
|
|
|
const int cols, const T weight1,
|
|
|
|
|
const T weight2, const int indicator1,
|
|
|
|
|
const int indicator2) {
|
|
|
|
|
|
|
|
|
|
if (indicator1 != 0 && indicator2 != 0)
|
|
|
|
|
moe_cb_two_fwd<T, block_size, pack_size>(src_row1, src_row2, dst_row,
|
|
|
|
|
weight1, weight2, cols);
|
|
|
|
@ -328,7 +316,6 @@ __device__ void moe_cb_bwd_selector(T *src_row1, T *src_row2, T *dst_row,
|
|
|
|
|
T *wt_grad1, T *wt_grad2, const T weight1,
|
|
|
|
|
const T weight2, const int indicator1,
|
|
|
|
|
const int indicator2) {
|
|
|
|
|
|
|
|
|
|
if (indicator1 != 0 && indicator2 != 0)
|
|
|
|
|
moe_cb_two_bwd<T, block_size, pack_size>(src_row1, src_row2, dst_row,
|
|
|
|
|
tks_row1, tks_row2, wt_grad1,
|
|
|
|
@ -348,7 +335,6 @@ __global__ void moe_cb_fwd_kernel(T *expert_tokens, T *combine_tokens,
|
|
|
|
|
T *logits, int *mask1, int *mask2, int *dest1,
|
|
|
|
|
int *dest2, const int e, const int c,
|
|
|
|
|
const int h) {
|
|
|
|
|
|
|
|
|
|
int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c;
|
|
|
|
|
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
|
|
|
|
|
T *row_log = logits + (row * e);
|
|
|
|
@ -363,7 +349,6 @@ __global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks,
|
|
|
|
|
T *logits, T *logits_grad, int *mask1,
|
|
|
|
|
int *mask2, int *dest1, int *dest2,
|
|
|
|
|
const int e, const int c, const int h) {
|
|
|
|
|
|
|
|
|
|
int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c;
|
|
|
|
|
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
|
|
|
|
|
T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e);
|
|
|
|
@ -379,7 +364,6 @@ __global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks,
|
|
|
|
|
template <int block_size, int pack_size>
|
|
|
|
|
__global__ void cumsum_kernel(int *inputs, int *outputs, const int s,
|
|
|
|
|
const int e) {
|
|
|
|
|
|
|
|
|
|
assert(s % pack_size == 0);
|
|
|
|
|
constexpr int bpack_size = block_size * pack_size;
|
|
|
|
|
int tid = threadIdx.x, bid = blockIdx.x, tps = tid * pack_size, last_sum = -1;
|
|
|
|
@ -426,8 +410,7 @@ __global__ void cumsum_kernel(int *inputs, int *outputs, const int s,
|
|
|
|
|
}
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
if (tid == 0)
|
|
|
|
|
temp[0] = temp[block_size];
|
|
|
|
|
if (tid == 0) temp[0] = temp[block_size];
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
|
|
if (idx + tps < s) {
|
|
|
|
@ -453,7 +436,6 @@ template <typename T>
|
|
|
|
|
void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1,
|
|
|
|
|
int *mask2, int *dest1, int *dest2, const int s,
|
|
|
|
|
const int h) {
|
|
|
|
|
|
|
|
|
|
if (h < 256)
|
|
|
|
|
moe_dpch_fwd_kernel<T, 32, 4>
|
|
|
|
|
<<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
|
|
|
|
@ -474,7 +456,6 @@ void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1,
|
|
|
|
|
template <typename T>
|
|
|
|
|
void moe_dpch_bwd_launch(T *tokens_grad, T *expert_grad, int *mask1, int *mask2,
|
|
|
|
|
int *dest1, int *dest2, const int s, const int h) {
|
|
|
|
|
|
|
|
|
|
if (h < 256)
|
|
|
|
|
moe_dpch_bwd_kernel<T, 32, 4>
|
|
|
|
|
<<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
|
|
|
|
@ -496,7 +477,6 @@ template <typename T>
|
|
|
|
|
void moe_cb_fwd_launch(T *expert_tokens, T *combine_tokens, T *logits,
|
|
|
|
|
int *mask1, int *mask2, int *dest1, int *dest2,
|
|
|
|
|
const int s, const int e, const int c, const int h) {
|
|
|
|
|
|
|
|
|
|
if (h < 256)
|
|
|
|
|
moe_cb_fwd_kernel<T, 32, 4><<<s, 32>>>(expert_tokens, combine_tokens,
|
|
|
|
|
logits, mask1, mask2, dest1, dest2,
|
|
|
|
@ -524,12 +504,11 @@ void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits,
|
|
|
|
|
T *logits_grad, int *mask1, int *mask2, int *dest1,
|
|
|
|
|
int *dest2, const int s, const int e, const int c,
|
|
|
|
|
const int h) {
|
|
|
|
|
|
|
|
|
|
if (h < 256)
|
|
|
|
|
moe_cb_bwd_kernel<T, 32, 4><<<s, 32>>>(tokens_grad, expert_grad, tks,
|
|
|
|
|
logits, logits_grad, mask1, mask2,
|
|
|
|
|
dest1, dest2, e, c, h);
|
|
|
|
|
else // if (h < 512)
|
|
|
|
|
else // if (h < 512)
|
|
|
|
|
moe_cb_bwd_kernel<T, 64, 4><<<s, 64>>>(tokens_grad, expert_grad, tks,
|
|
|
|
|
logits, logits_grad, mask1, mask2,
|
|
|
|
|
dest1, dest2, e, c, h);
|
|
|
|
@ -544,7 +523,6 @@ void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void cumsum_launch(int *inputs, int *outputs, const int s, const int e) {
|
|
|
|
|
|
|
|
|
|
if (s <= 256)
|
|
|
|
|
cumsum_kernel<256, 1><<<e, 256>>>(inputs, outputs, s, e);
|
|
|
|
|
else if (s <= 512)
|
|
|
|
@ -559,27 +537,26 @@ void cumsum_launch(int *inputs, int *outputs, const int s, const int e) {
|
|
|
|
|
|
|
|
|
|
// API FUNCTIONS --------------------------------
|
|
|
|
|
|
|
|
|
|
#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \
|
|
|
|
|
switch (TYPE) { \
|
|
|
|
|
case at::ScalarType::Float: { \
|
|
|
|
|
using scalar_t = float; \
|
|
|
|
|
__VA_ARGS__; \
|
|
|
|
|
break; \
|
|
|
|
|
} \
|
|
|
|
|
case at::ScalarType::Half: { \
|
|
|
|
|
using scalar_t = at::Half; \
|
|
|
|
|
__VA_ARGS__; \
|
|
|
|
|
break; \
|
|
|
|
|
} \
|
|
|
|
|
default: \
|
|
|
|
|
AT_ERROR(#NAME, " not implemented yet for specific data type."); \
|
|
|
|
|
#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \
|
|
|
|
|
switch (TYPE) { \
|
|
|
|
|
case at::ScalarType::Float: { \
|
|
|
|
|
using scalar_t = float; \
|
|
|
|
|
__VA_ARGS__; \
|
|
|
|
|
break; \
|
|
|
|
|
} \
|
|
|
|
|
case at::ScalarType::Half: { \
|
|
|
|
|
using scalar_t = at::Half; \
|
|
|
|
|
__VA_ARGS__; \
|
|
|
|
|
break; \
|
|
|
|
|
} \
|
|
|
|
|
default: \
|
|
|
|
|
AT_ERROR(#NAME, " not implemented yet for specific data type."); \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h,
|
|
|
|
|
torch::Tensor batch_tokens,
|
|
|
|
|
torch::Tensor mask,
|
|
|
|
|
torch::Tensor dest_idx) {
|
|
|
|
|
|
|
|
|
|
assert(h % 16 == 0);
|
|
|
|
|
auto res = torch::zeros(
|
|
|
|
|
{ec, h},
|
|
|
|
@ -601,7 +578,6 @@ torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h,
|
|
|
|
|
torch::Tensor expert_grad,
|
|
|
|
|
torch::Tensor mask,
|
|
|
|
|
torch::Tensor dest_idx) {
|
|
|
|
|
|
|
|
|
|
assert(h % 16 == 0);
|
|
|
|
|
auto res = torch::zeros(
|
|
|
|
|
{s, h}, torch::dtype(expert_grad.dtype()).device(expert_grad.device()));
|
|
|
|
@ -622,7 +598,6 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
|
|
|
|
|
torch::Tensor expert_tokens,
|
|
|
|
|
torch::Tensor logits, torch::Tensor mask,
|
|
|
|
|
torch::Tensor dest_idx) {
|
|
|
|
|
|
|
|
|
|
assert(h % 16 == 0);
|
|
|
|
|
assert(expert_tokens.dtype() == logits.dtype());
|
|
|
|
|
|
|
|
|
@ -643,11 +618,10 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<torch::Tensor>
|
|
|
|
|
moe_combine_cuda_backward(int s, int e, int c, int h, torch::Tensor tokens_grad,
|
|
|
|
|
torch::Tensor expert_tokens, torch::Tensor logits,
|
|
|
|
|
torch::Tensor mask, torch::Tensor dest_idx) {
|
|
|
|
|
|
|
|
|
|
std::vector<torch::Tensor> moe_combine_cuda_backward(
|
|
|
|
|
int s, int e, int c, int h, torch::Tensor tokens_grad,
|
|
|
|
|
torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask,
|
|
|
|
|
torch::Tensor dest_idx) {
|
|
|
|
|
assert(h % 16 == 0);
|
|
|
|
|
assert(tokens_grad.dtype() == expert_tokens.dtype());
|
|
|
|
|
assert(expert_tokens.dtype() == logits.dtype());
|
|
|
|
@ -673,7 +647,6 @@ moe_combine_cuda_backward(int s, int e, int c, int h, torch::Tensor tokens_grad,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) {
|
|
|
|
|
|
|
|
|
|
assert(mask.dim() == 2);
|
|
|
|
|
assert(mask.dtype() == torch::kInt32);
|
|
|
|
|
|
|
|
|
|