Optimized MoE layer and fixed some bugs;

Decreased moe tests;

Added FFNExperts and ViTMoE model
pull/394/head
1SAA 2022-02-18 20:42:31 +08:00 committed by Frank Lee
parent 3dba070580
commit 219df6e685
15 changed files with 1552 additions and 203 deletions

View File

@ -9,6 +9,6 @@ repos:
hooks: hooks:
- id: flake8 - id: flake8
- repo: https://github.com/pre-commit/mirrors-clang-format - repo: https://github.com/pre-commit/mirrors-clang-format
rev: v13.0.0 rev: v13.0.1
hooks: hooks:
- id: clang-format - id: clang-format

View File

@ -56,6 +56,7 @@ class MoeEnv:
self.data_parallel_size = None self.data_parallel_size = None
self.model_parallel_size = None self.model_parallel_size = None
self.aux_loss = None self.aux_loss = None
self.enable_cuda = True
def setup(self, moe_model_size): def setup(self, moe_model_size):
from .core import global_context as gpc from .core import global_context as gpc
@ -71,6 +72,9 @@ class MoeEnv:
def is_initialized(self): def is_initialized(self):
return self.model_parallel_size is not None return self.model_parallel_size is not None
def set_cuda_false(self):
self.enable_cuda = False
def reset_loss(self): def reset_loss(self):
self.aux_loss = 0 self.aux_loss = 0

View File

@ -5,7 +5,7 @@
#include "ATen/ATen.h" #include "ATen/ATen.h"
#include "ATen/AccumulateType.h" #include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDAContext.h" #include "ATen/cuda/CUDAContext.h"
#include <THC/THCDeviceUtils.cuh> #include "ATen/cuda/DeviceUtils.cuh"
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>

View File

@ -0,0 +1,118 @@
#include <torch/extension.h>
torch::Tensor moe_dispatch_cuda_forward(
int s, int ec, int h,
torch::Tensor batch_tokens,
torch::Tensor mask,
torch::Tensor dest_idx);
torch::Tensor moe_dispatch_cuda_backward(
int s, int ec, int h,
torch::Tensor expert_grad,
torch::Tensor mask,
torch::Tensor dest_idx);
torch::Tensor moe_combine_cuda_forward(
int s, int e, int c, int h,
torch::Tensor expert_tokens,
torch::Tensor logits,
torch::Tensor mask,
torch::Tensor dest_idx);
std::vector<torch::Tensor> moe_combine_cuda_backward(
int s, int e, int c, int h,
torch::Tensor tokens_grad,
torch::Tensor expert_tokens,
torch::Tensor logits,
torch::Tensor mask,
torch::Tensor dest_idx);
torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask);
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
torch::Tensor moe_dispatch_forward(
int s, int ec, int h,
torch::Tensor batch_tokens,
torch::Tensor mask,
torch::Tensor dest_idx) {
CHECK_INPUT(batch_tokens);
CHECK_CUDA(mask);
CHECK_CUDA(dest_idx);
return moe_dispatch_cuda_forward(
s, ec, h,
batch_tokens, mask, dest_idx);
}
torch::Tensor moe_dispatch_backward(
int s, int ec, int h,
torch::Tensor expert_grad,
torch::Tensor mask,
torch::Tensor dest_idx) {
CHECK_INPUT(expert_grad);
CHECK_CUDA(mask);
CHECK_CUDA(dest_idx);
return moe_dispatch_cuda_backward(
s, ec, h,
expert_grad, mask, dest_idx);
}
torch::Tensor moe_combine_forward(
int s, int e, int c, int h,
torch::Tensor expert_tokens,
torch::Tensor logits,
torch::Tensor mask,
torch::Tensor dest_idx) {
CHECK_INPUT(expert_tokens);
CHECK_INPUT(logits);
CHECK_CUDA(mask);
CHECK_CUDA(dest_idx);
return moe_combine_cuda_forward(
s, e, c, h,
expert_tokens, logits, mask, dest_idx);
}
std::vector<torch::Tensor> moe_combine_backward(
int s, int e, int c, int h,
torch::Tensor tokens_grad,
torch::Tensor expert_tokens,
torch::Tensor logits,
torch::Tensor mask,
torch::Tensor dest_idx) {
CHECK_INPUT(tokens_grad);
CHECK_INPUT(logits);
CHECK_CUDA(mask);
CHECK_CUDA(dest_idx);
return moe_combine_cuda_backward(
s, e, c, h,
tokens_grad, expert_tokens, logits, mask, dest_idx);
}
torch::Tensor moe_cumsum(torch::Tensor mask) {
CHECK_INPUT(mask);
return cumsum_sub_one_in_dim0(mask);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("cumsum_sub_one", &moe_cumsum,
"Fast cumsum operation in dim0");
m.def("dispatch_forward", &moe_dispatch_forward,
"Forward operation in MoE dispatch function");
m.def("dispatch_backward", &moe_dispatch_backward,
"Backward operation in MoE dispatch function");
m.def("combine_forward", &moe_combine_forward,
"Combine operation in MoE combine function");
m.def("combine_backward", &moe_combine_backward,
"Combine operation in MoE combine function");
}

View File

@ -0,0 +1,702 @@
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cub/cub.cuh>
#include "block_reduce.h"
template<typename T, int block_size, int pack_size>
__device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size,
cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size,
cub::BLOCK_STORE_VECTORIZE> BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size; T pack[pack_size];
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(src_row + idx, pack);
BlockStore(ts_store).Store(dst_row + idx, pack);
}
}
template<typename T, int block_size, int pack_size>
__device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size,
cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size,
cub::BLOCK_STORE_VECTORIZE> BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size; T pack[pack_size];
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(dst_row + idx, pack);
BlockStore(ts_store).Store(src_row + idx, pack);
}
}
template<typename T, int block_size, int pack_size>
__device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2, const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size,
cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size,
cub::BLOCK_STORE_VECTORIZE> BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size; T pack[pack_size];
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(src_row + idx, pack);
BlockStore(ts_store).Store(dst_row1 + idx, pack);
BlockStore(ts_store).Store(dst_row2 + idx, pack);
}
}
template<typename T, int block_size, int pack_size>
__device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2, const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size,
cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size,
cub::BLOCK_STORE_VECTORIZE> BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size;
T pack1[pack_size], pack2[pack_size];
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(dst_row1 + idx, pack1);
BlockLoad(ts_load).Load(dst_row2 + idx, pack2);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
pack1[i] += pack2[i];
}
BlockStore(ts_store).Store(src_row + idx, pack1);
}
}
template<typename T, int block_size, int pack_size>
__device__ void moe_cb_one_fwd(
T *src_row, T *dst_row,
const T weight, const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size,
cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size,
cub::BLOCK_STORE_VECTORIZE> BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size; T pack[pack_size];
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(src_row + idx, pack);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
pack[i] *= weight;
}
BlockStore(ts_store).Store(dst_row + idx, pack);
}
}
template<typename T, int block_size, int pack_size>
__device__ void moe_cb_one_bwd(
T *src_row, T *dst_row, T *tks_row, T *weight_grad,
const T weight, const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size,
cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size,
cub::BLOCK_STORE_VECTORIZE> BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size;
T grad[pack_size], tokens[pack_size];
float thread_sum = 0;
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(dst_row + idx, grad);
BlockLoad(ts_load).Load(tks_row + idx, tokens);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
thread_sum += grad[i] * tokens[i];
grad[i] *= weight;
}
BlockStore(ts_store).Store(src_row + idx, grad);
}
blockReduce<ReduceType::kSum, 1>(&thread_sum);
if (threadIdx.x == 0)
*weight_grad = static_cast<T>(thread_sum);
}
template<typename T, int block_size, int pack_size>
__device__ void moe_cb_two_fwd(
T *src_row1, T *src_row2, T *dst_row,
const T weight1, const T weight2, const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size,
cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size,
cub::BLOCK_STORE_VECTORIZE> BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size;
T pack1[pack_size], pack2[pack_size];
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(src_row1 + idx, pack1);
BlockLoad(ts_load).Load(src_row2 + idx, pack2);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
pack1[i] = pack1[i] * weight1 + pack2[i] * weight2;
}
BlockStore(ts_store).Store(dst_row + idx, pack1);
}
}
template<typename T, int block_size, int pack_size>
__device__ void moe_cb_two_bwd(
T *src_row1, T *src_row2, T *dst_row,
T *tks_row1, T *tks_row2, T *weight_grad1, T *weight_grad2,
const T weight1, const T weight2, const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size,
cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size,
cub::BLOCK_STORE_VECTORIZE> BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size;
T grad[pack_size], tokens1[pack_size], tokens2[pack_size],
sgrad1[pack_size], sgrad2[pack_size];
float thread_sum[2] = {0, 0};
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(dst_row + idx, grad);
BlockLoad(ts_load).Load(tks_row1 + idx, tokens1);
BlockLoad(ts_load).Load(tks_row2 + idx, tokens2);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
thread_sum[0] += grad[i] * tokens1[i];
thread_sum[1] += grad[i] * tokens2[i];
sgrad1[i] = weight1 * grad[i];
sgrad2[i] = weight2 * grad[i];
}
BlockStore(ts_store).Store(src_row1 + idx, sgrad1);
BlockStore(ts_store).Store(src_row2 + idx, sgrad2);
}
blockReduce<ReduceType::kSum, 2>(thread_sum);
if (threadIdx.x == 0)
*weight_grad1 = static_cast<T>(thread_sum[0]);
else if (threadIdx.x == 1)
*weight_grad2 = static_cast<T>(thread_sum[1]);
}
// DISPATCH KERNELS --------------------------------
template<typename T, int block_size, int pack_size>
__device__ void moe_dpch_fwd_selector(
T *src_row, T *dst_row1, T *dst_row2, const int cols,
const int indicator1, const int indicator2) {
if (indicator1 != 0 && indicator2 != 0)
moe_dpch_two_fwd<T, block_size, pack_size>(
src_row, dst_row1, dst_row2, cols);
else if (indicator1 != 0)
moe_dpch_one_fwd<T, block_size, pack_size>(
src_row, dst_row1, cols);
else if (indicator2 != 0)
moe_dpch_one_fwd<T, block_size, pack_size>(
src_row, dst_row2, cols);
else
return;
}
template<typename T, int block_size, int pack_size>
__device__ void moe_dpch_bwd_selector(
T *src_row, T *dst_row1, T *dst_row2, const int cols,
const int indicator1, const int indicator2) {
if (indicator1 != 0 && indicator2 != 0)
moe_dpch_two_bwd<T, block_size, pack_size>(
src_row, dst_row1, dst_row2, cols);
else if (indicator1 != 0)
moe_dpch_one_bwd<T, block_size, pack_size>(
src_row, dst_row1, cols);
else if (indicator2 != 0)
moe_dpch_one_bwd<T, block_size, pack_size>(
src_row, dst_row2, cols);
else
return;
}
template<typename T, int block_size, int pack_size>
__global__ void moe_dpch_fwd_kernel(
T *batch_tokens, T *expert_input,
int *mask1, int *mask2,
int *dest1, int *dest2, const int h) {
int row = blockIdx.x;
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
moe_dpch_fwd_selector<T, block_size, pack_size>(
batch_tokens + (row * h),
expert_input + (dest1[row] * h), expert_input + (dest2[row] * h),
h, mask1[row], indicator2);
}
template<typename T, int block_size, int pack_size>
__global__ void moe_dpch_bwd_kernel(
T *tokens_grad, T *expert_grad,
int *mask1, int *mask2,
int *dest1, int *dest2, const int h) {
int row = blockIdx.x;
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
moe_dpch_bwd_selector<T, block_size, pack_size>(
tokens_grad + (row * h),
expert_grad + (dest1[row] * h), expert_grad + (dest2[row] * h),
h, mask1[row], indicator2);
}
// COMBINE KERNELS --------------------------------
template<typename T, int block_size, int pack_size>
__device__ void moe_cb_fwd_selector(
T *src_row1, T *src_row2, T *dst_row, const int cols,
const T weight1, const T weight2,
const int indicator1, const int indicator2) {
if (indicator1 != 0 && indicator2 != 0)
moe_cb_two_fwd<T, block_size, pack_size>(
src_row1, src_row2, dst_row, weight1, weight2, cols);
else if (indicator1 != 0)
moe_cb_one_fwd<T, block_size, pack_size>(
src_row1, dst_row, weight1, cols);
else if (indicator2 != 0)
moe_cb_one_fwd<T, block_size, pack_size>(
src_row2, dst_row, weight2, cols);
else
return;
}
template<typename T, int block_size, int pack_size>
__device__ void moe_cb_bwd_selector(
T *src_row1, T *src_row2, T *dst_row, const int cols,
T *tks_row1, T *tks_row2, T *wt_grad1, T *wt_grad2,
const T weight1, const T weight2,
const int indicator1, const int indicator2) {
if (indicator1 != 0 && indicator2 != 0)
moe_cb_two_bwd<T, block_size, pack_size>(
src_row1, src_row2, dst_row,
tks_row1, tks_row2, wt_grad1, wt_grad2,
weight1, weight2, cols);
else if (indicator1 != 0)
moe_cb_one_bwd<T, block_size, pack_size>(
src_row1, dst_row, tks_row1, wt_grad1, weight1, cols);
else if (indicator2 != 0)
moe_cb_one_bwd<T, block_size, pack_size>(
src_row2, dst_row, tks_row2, wt_grad2, weight2, cols);
else
return;
}
template<typename T, int block_size, int pack_size>
__global__ void moe_cb_fwd_kernel(
T *expert_tokens, T *combine_tokens, T *logits,
int *mask1, int *mask2,
int *dest1, int *dest2,
const int e, const int c, const int h) {
int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c;
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
T *row_log = logits + (row * e);
moe_cb_fwd_selector<T, block_size, pack_size>(
expert_tokens + (dest1[row] * h), expert_tokens + (dest2[row] * h),
combine_tokens + (row * h), h,
row_log[eid1], row_log[eid2],
mask1[row], indicator2);
}
template<typename T, int block_size, int pack_size>
__global__ void moe_cb_bwd_kernel(
T *tokens_grad, T *expert_grad, T *tks,
T *logits, T *logits_grad,
int *mask1, int *mask2,
int *dest1, int *dest2,
const int e, const int c, const int h) {
int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c;
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e);
moe_cb_bwd_selector<T, block_size, pack_size>(
expert_grad + (dest1[row] * h), expert_grad + (dest2[row] * h),
tokens_grad + (row * h), h,
tks + (dest1[row] * h), tks + (dest2[row] * h),
row_grad + eid1, row_grad + eid2,
row_log[eid1], row_log[eid2],
mask1[row], indicator2);
}
//CUMSUM KERNEL --------------------------------
template<int block_size, int pack_size>
__global__ void cumsum_kernel(
int *inputs, int *outputs,
const int s, const int e) {
assert(s % pack_size == 0);
constexpr int bpack_size = block_size * pack_size;
int tid = threadIdx.x, bid = blockIdx.x,
tps = tid * pack_size, last_sum = -1;
__shared__ int temp[block_size + 1]; int pack[pack_size];
for (int idx = 0; idx < s; idx += bpack_size) {
int offset = 1;
if (idx + tps < s) {
temp[tid] = inputs[tps * e + bid];
#pragma unroll
for (int i = 1; i < pack_size; ++i) {
pack[i] = inputs[(tps + i) * e + bid];
}
#pragma unroll
for (int i = 1; i < pack_size; ++i) {
temp[tid] += pack[i];
}
}
for (int i = block_size >> 1; i > 0; i >>= 1) {
__syncthreads();
if (tid < i) {
int j = offset * (2 * tid + 1) - 1;
temp[j + offset] += temp[j];
}
offset <<= 1;
}
if (tid == 0) {
temp[block_size] = temp[block_size - 1];
temp[block_size - 1] = 0;
}
for (int i = 1; i < block_size; i <<= 1) {
offset >>= 1;
__syncthreads();
if (tid < i) {
int j = offset * (2 * tid + 1) - 1,
k = j + offset, ts = temp[j];
temp[j] = temp[k];
temp[k] += ts;
}
}
__syncthreads();
if (tid == 0)
temp[0] = temp[block_size];
__syncthreads();
if (idx + tps < s) {
temp[tid + 1] += last_sum;
#pragma unroll
for (int i = pack_size - 1; i > 0; --i) {
outputs[(tps + i) * e + bid] = temp[tid + 1];
temp[tid + 1] -= pack[i];
}
outputs[tps * e + bid] = temp[tid + 1];
}
__syncthreads();
last_sum += temp[0];
inputs += bpack_size * e;
outputs += bpack_size * e;
}
}
//LAUNCH FUNCTIONS --------------------------------
template<typename T>
void moe_dpch_fwd_launch(
T *batch_tokens, T *expert_input,
int *mask1, int *mask2,
int *dest1, int *dest2,
const int s, const int h) {
if (h < 256)
moe_dpch_fwd_kernel<T, 32, 4><<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
else if (h < 512)
moe_dpch_fwd_kernel<T, 32, 8><<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
else if (h < 1024)
moe_dpch_fwd_kernel<T, 32, 16><<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
else if (h < 2048)
moe_dpch_fwd_kernel<T, 64, 16><<<s, 64>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
else
moe_dpch_fwd_kernel<T, 128, 16><<<s, 128>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
}
template<typename T>
void moe_dpch_bwd_launch(
T *tokens_grad, T *expert_grad,
int *mask1, int *mask2,
int *dest1, int *dest2,
const int s, const int h) {
if (h < 256)
moe_dpch_bwd_kernel<T, 32, 4><<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
else if (h < 512)
moe_dpch_bwd_kernel<T, 32, 8><<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
else if (h < 1024)
moe_dpch_bwd_kernel<T, 32, 16><<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
else if (h < 2048)
moe_dpch_bwd_kernel<T, 64, 16><<<s, 64>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
else
moe_dpch_bwd_kernel<T, 128, 16><<<s, 128>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
}
template<typename T>
void moe_cb_fwd_launch(
T *expert_tokens, T *combine_tokens, T *logits,
int *mask1, int *mask2,
int *dest1, int *dest2,
const int s, const int e, const int c, const int h) {
if (h < 256)
moe_cb_fwd_kernel<T, 32, 4><<<s, 32>>>
(expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, e, c, h);
else if (h < 512)
moe_cb_fwd_kernel<T, 32, 8><<<s, 32>>>
(expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, e, c, h);
else if (h < 1024)
moe_cb_fwd_kernel<T, 32, 16><<<s, 32>>>
(expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, e, c, h);
else if (h < 2048)
moe_cb_fwd_kernel<T, 64, 16><<<s, 64>>>
(expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, e, c, h);
else
moe_cb_fwd_kernel<T, 128, 16><<<s, 128>>>
(expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, e, c, h);
}
template<typename T>
void moe_cb_bwd_launch(
T *tokens_grad, T *expert_grad, T *tks,
T *logits, T *logits_grad,
int *mask1, int *mask2,
int *dest1, int *dest2,
const int s, const int e, const int c, const int h) {
if (h < 256)
moe_cb_bwd_kernel<T, 32, 4><<<s, 32>>>
(tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h);
else // if (h < 512)
moe_cb_bwd_kernel<T, 64, 4><<<s, 64>>>
(tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h);
// else if (h < 1024)
// moe_cb_bwd_kernel<T, 128, 4><<<s, 128>>>
// (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h);
// else
// moe_cb_bwd_kernel<T, 256, 4><<<s, 256>>>
// (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h);
}
void cumsum_launch(
int *inputs, int *outputs,
const int s, const int e) {
if (s <= 256)
cumsum_kernel<256, 1><<<e, 256>>>(inputs, outputs, s, e);
else if (s <= 512)
cumsum_kernel<512, 1><<<e, 512>>>(inputs, outputs, s, e);
else if (s <= 1024)
cumsum_kernel<1024, 1><<<e, 1024>>>(inputs, outputs, s, e);
else if (s <= 2048)
cumsum_kernel<1024, 2><<<e, 1024>>>(inputs, outputs, s, e);
else
cumsum_kernel<1024, 4><<<e, 1024>>>(inputs, outputs, s, e);
}
// API FUNCTIONS --------------------------------
#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \
switch (TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented yet for specific data type.");\
}
torch::Tensor moe_dispatch_cuda_forward(
int s, int ec, int h,
torch::Tensor batch_tokens,
torch::Tensor mask,
torch::Tensor dest_idx) {
assert(h % 16 == 0);
auto res = torch::zeros({ec, h},
torch::dtype(batch_tokens.dtype()).device(batch_tokens.device()));
auto k = mask.size(0);
DISPATCH_FLOAT_AND_HALF(
batch_tokens.scalar_type(), "moe dispatch forward",
moe_dpch_fwd_launch<scalar_t>(
batch_tokens.data<scalar_t>(), res.data<scalar_t>(),
mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(),
dest_idx[0].data<int>(), k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(),
s, h)
);
return res;
}
torch::Tensor moe_dispatch_cuda_backward(
int s, int ec, int h,
torch::Tensor expert_grad,
torch::Tensor mask,
torch::Tensor dest_idx) {
assert(h % 16 == 0);
auto res = torch::zeros({s, h},
torch::dtype(expert_grad.dtype()).device(expert_grad.device()));
auto k = mask.size(0);
DISPATCH_FLOAT_AND_HALF(
expert_grad.scalar_type(), "moe dispatch backward",
moe_dpch_bwd_launch<scalar_t>(
res.data<scalar_t>(), expert_grad.data<scalar_t>(),
mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(),
dest_idx[0].data<int>(), k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(),
s, h)
);
return res;
}
torch::Tensor moe_combine_cuda_forward(
int s, int e, int c, int h,
torch::Tensor expert_tokens,
torch::Tensor logits,
torch::Tensor mask,
torch::Tensor dest_idx) {
assert(h % 16 == 0);
assert(expert_tokens.dtype() == logits.dtype());
auto res = torch::zeros({s, h},
torch::dtype(expert_tokens.dtype()).device(expert_tokens.device()));
auto k = mask.size(0);
DISPATCH_FLOAT_AND_HALF(
expert_tokens.scalar_type(), "moe combine forward",
moe_cb_fwd_launch<scalar_t>(
expert_tokens.data<scalar_t>(), res.data<scalar_t>(), logits.data<scalar_t>(),
mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(),
dest_idx[0].data<int>(), k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(),
s, e, c, h)
);
return res;
}
std::vector<torch::Tensor> moe_combine_cuda_backward(
int s, int e, int c, int h,
torch::Tensor tokens_grad,
torch::Tensor expert_tokens,
torch::Tensor logits,
torch::Tensor mask,
torch::Tensor dest_idx) {
assert(h % 16 == 0);
assert(tokens_grad.dtype() == expert_tokens.dtype());
assert(expert_tokens.dtype() == logits.dtype());
auto egrad = torch::zeros({e * c, h},
torch::dtype(tokens_grad.dtype()).device(tokens_grad.device())),
wgrad = torch::zeros({s, e}, torch::dtype(logits.dtype()).device(logits.device()));
auto k = mask.size(0);
DISPATCH_FLOAT_AND_HALF(
tokens_grad.scalar_type(), "moe combine backward",
moe_cb_bwd_launch<scalar_t>(
tokens_grad.data<scalar_t>(), egrad.data<scalar_t>(), expert_tokens.data<scalar_t>(),
logits.data<scalar_t>(), wgrad.data<scalar_t>(),
mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(),
dest_idx[0].data<int>(), k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(),
s, e, c, h)
);
return {egrad, wgrad};
}
torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) {
assert(mask.dim() == 2);
assert(mask.dtype() == torch::kInt32);
const int s = mask.size(0), e = mask.size(1);
auto res = torch::empty({s, e}, torch::dtype(torch::kInt32).device(mask.device()));
cumsum_launch(mask.data<int>(), res.data<int>(), s, e);
return res;
}

View File

@ -1,8 +1,5 @@
from ._operation import AllToAll from .experts import Experts, FFNExperts
from .layers import Experts, MoeLayer, \ from .layers import MoeLayer, Top1Router, Top2Router
NormalNoiseGenerator, Top1Router, Top2Router from .utils import NormalNoiseGenerator
__all__ = [ __all__ = ['Experts', 'FFNExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator']
'AllToAll', 'Experts', 'Top1Router', 'Top2Router',
'MoeLayer', 'NormalNoiseGenerator'
]

View File

@ -6,16 +6,26 @@ from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from typing import Any, Tuple from typing import Any, Tuple
U_CUDA_MODE = False
try:
import colossal_moe_cuda
U_CUDA_MODE = True
except ImportError:
print("If you want to activate cuda mode for MoE, please install with cuda_ext!")
class AllToAll(torch.autograd.Function): class AllToAll(torch.autograd.Function):
"""Dispatches input tensor [e, c, h] to all experts by all_to_all_single """Dispatches input tensor [e, c, h] to all experts by all_to_all_single
operation in torch.distributed. operation in torch.distributed.
""" """
@staticmethod @staticmethod
def forward(ctx: Any, def forward(ctx: Any,
inputs: Tensor, inputs: Tensor,
parallel_mode: ParallelMode) -> Tensor: parallel_mode: ParallelMode) -> Tensor:
ctx.parallel_mode = parallel_mode if ctx is not None:
ctx.parallel_mode = parallel_mode
if not inputs.is_contiguous(): if not inputs.is_contiguous():
inputs = inputs.contiguous() inputs = inputs.contiguous()
@ -26,4 +36,79 @@ class AllToAll(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]: def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
return AllToAll.apply(*grad_outputs, ctx.parallel_mode), None return AllToAll.forward(None, *grad_outputs, ctx.parallel_mode), None
class MoeDispatch(torch.autograd.Function):
@staticmethod
def forward(ctx, tokens, mask, dest_idx, ec):
s = tokens.size(0)
h = tokens.size(1)
expert_input = colossal_moe_cuda.dispatch_forward(s, ec, h, tokens, mask, dest_idx)
ctx.save_for_backward(mask, dest_idx)
ctx.s = s
ctx.h = h
ctx.ec = ec
return expert_input
@staticmethod
def backward(ctx, output_grad):
mask, dest_idx = ctx.saved_tensors
d_tokens = colossal_moe_cuda.dispatch_backward(
ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
return d_tokens, None, None, None
class MoeCombine(torch.autograd.Function):
@staticmethod
def forward(ctx, expert_tokens, logits, mask, dest_idx, ec):
assert logits.dtype == torch.float32
s = logits.size(0)
e = logits.size(1)
c = ec // e
h = expert_tokens.size(-1)
fp16_flag = (expert_tokens.dtype == torch.float16)
cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens
ctokens = colossal_moe_cuda.combine_forward(s, e, c, h,
cb_input, logits,
mask, dest_idx)
output = ctokens.to(torch.float16) if fp16_flag else ctokens
ctx.save_for_backward(expert_tokens, logits, mask, dest_idx)
ctx.s = s
ctx.e = e
ctx.c = c
ctx.h = h
ctx.fp16_flag = fp16_flag
return output
@staticmethod
def backward(ctx, tokens_grad):
expert_tokens, logits, mask, dest_idx = ctx.saved_tensors
cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 \
else tokens_grad
cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens
d_expert, d_logits = colossal_moe_cuda.combine_backward(
ctx.s, ctx.e, ctx.c, ctx.h,
cb_grad, cb_input, logits, mask, dest_idx)
d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert
return d_expert, d_logits, None, None, None
def moe_cumsum(inputs: Tensor):
dim0 = inputs.size(0)
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
if flag and U_CUDA_MODE:
return colossal_moe_cuda.cumsum_sub_one(inputs)
else:
return torch.cumsum(inputs, dim=0) - 1

View File

@ -0,0 +1,96 @@
import math
import torch
import torch.nn as nn
from colossalai.global_variables import moe_env
from colossalai.context import ParallelMode, seed
from colossalai.utils import get_current_device
class Experts(nn.Module):
"""A wrapper class to create experts. It will create E experts across the
moe model parallel group, where E is the number of experts. Every expert
is a instence of the class, 'expert' in initialization parameters.
:param expert: The class of all experts
:param num_experts: The number of experts
:param expert_args: Args used to initialize experts
:type num_experts: int
"""
def __init__(self, expert, num_experts, **expert_args):
super().__init__()
assert num_experts % moe_env.model_parallel_size == 0, \
"The number of experts should be divied by moe model size"
num_local_experts = num_experts // moe_env.model_parallel_size
with seed(ParallelMode.MOE_MODEL):
self.experts = nn.ModuleList([expert(**expert_args) for _ in range(num_local_experts)])
self.num_local_experts = num_local_experts
for exp in self.experts:
for param in exp.parameters():
param.__setattr__('moe_param', True)
def forward(self, inputs):
expert_input = torch.chunk(inputs, self.num_local_experts, dim=1)
expert_output = []
for i in range(self.num_local_experts):
expert_output.append(self.experts[i](expert_input[i]))
output = torch.cat(expert_output, dim=1).contiguous()
return output
class FFNExperts(nn.Module):
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
super().__init__()
assert num_experts % moe_env.model_parallel_size == 0, \
"The number of experts should be divied by moe model size"
num_local_experts = num_experts // moe_env.model_parallel_size
self.w1 = nn.Parameter(torch.empty(num_local_experts, d_model, d_ff, device=get_current_device()))
self.b1 = nn.Parameter(torch.empty(num_local_experts, 1, d_ff, device=get_current_device()))
self.w2 = nn.Parameter(torch.empty(num_local_experts, d_ff, d_model, device=get_current_device()))
self.b2 = nn.Parameter(torch.empty(num_local_experts, 1, d_model, device=get_current_device()))
s1 = math.sqrt(0.1 / d_model)
s2 = math.sqrt(0.1 / d_ff)
nn.init.trunc_normal_(self.w1, std=s1)
nn.init.trunc_normal_(self.b1, std=s1)
nn.init.trunc_normal_(self.w2, std=s2)
nn.init.trunc_normal_(self.b2, std=s2)
self.act = nn.GELU() if activation is None else activation
self.drop = nn.Dropout(p=drop_rate)
for param in self.parameters():
param.__setattr__('moe_param', True)
def forward(self, inputs): # x [g, el, c, h]
el = inputs.size(1)
h = inputs.size(-1)
inputs = inputs.transpose(0, 1)
inshape = inputs.shape
inputs = inputs.reshape(el, -1, h)
out_ff = torch.baddbmm(self.b1, inputs, self.w1)
out_act = self.act(out_ff)
with seed(ParallelMode.TENSOR):
inter = self.drop(out_act)
out_model = torch.baddbmm(self.b2, inter, self.w2)
with seed(ParallelMode.TENSOR):
outputs = self.drop(out_model) # outputs [el, gc, h]
outputs = outputs.reshape(inshape)
outputs = outputs.transpose(0, 1).contiguous()
return outputs

View File

@ -3,70 +3,13 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.cuda.amp import autocast import torch.distributed as dist
from colossalai.core import global_context as gpc
from colossalai.global_variables import moe_env from colossalai.global_variables import moe_env
from colossalai.context import ParallelMode, seed from colossalai.context import ParallelMode
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from ._operation import AllToAll from ._operation import U_CUDA_MODE, AllToAll, MoeDispatch, MoeCombine, moe_cumsum
from .utils import autocast_softmax
class NormalNoiseGenerator:
"""Generates a random noisy mask for logtis tensor.
All noise is generated from a normal distribution (0, 1 / E^2), where
E = the number of experts.
:param num_experts: The number of experts
:type num_experts: int
"""
def __init__(self, num_experts: int):
self.normal = torch.distributions.normal.Normal(
loc=torch.tensor(0.0, device=get_current_device()),
scale=torch.tensor(1.0 / num_experts ** 2, device=get_current_device())
).rsample
def __call__(self, inputs: torch.Tensor):
noisy = self.normal(inputs.shape)
return inputs + noisy
class Experts(nn.Module):
"""A wrapper class to create experts. It will create E experts across the
moe model parallel group, where E is the number of experts. Every expert
is a instence of the class, 'expert' in initialization parameters.
:param expert: The class of all experts
:param num_experts: The number of experts
:param expert_args: Args used to initialize experts
:type num_experts: int
"""
def __init__(self, expert, num_experts, **expert_args):
super().__init__()
assert num_experts % moe_env.model_parallel_size == 0, \
"The number of experts should be divied by moe model size"
num_local_experts = num_experts // moe_env.model_parallel_size
with seed(ParallelMode.MOE_MODEL):
self.experts = nn.ModuleList([
expert(**expert_args) for _ in range(num_local_experts)])
self.num_local_experts = num_local_experts
for exp in self.experts:
for param in exp.parameters():
param.__setattr__('moe_param', 1)
def forward(self, inputs):
expert_input = torch.chunk(inputs, self.num_local_experts, dim=0)
expert_output = []
for i in range(self.num_local_experts):
expert_output.append(self.experts[i](expert_input[i]))
output = torch.cat(expert_output, dim=0)
return output
class Top1Router(nn.Module): class Top1Router(nn.Module):
@ -83,63 +26,79 @@ class Top1Router(nn.Module):
:type noisy_func: Callable, optional :type noisy_func: Callable, optional
""" """
def __init__(self, def __init__(self, capacity_factor: float, min_capacity: int = 0, select_policy: str = "first", noisy_func=None):
capacity_factor: float,
min_capacity: int,
noisy_func=None):
super().__init__() super().__init__()
self.capacity_factor = capacity_factor self.capacity_factor = capacity_factor
self.min_capacity = min_capacity self.min_capacity = min_capacity
self.select_policy = select_policy
self.noisy_func = noisy_func self.noisy_func = noisy_func
self.uniform = torch.distributions.uniform.Uniform(
low=torch.tensor(0.0, device=get_current_device()),
high=torch.tensor(1.0, device=get_current_device())).rsample
def get_capacity(self, logits_shape): assert select_policy in {"first", "random"}
capacity = math.ceil(self.capacity_factor * if select_policy == "random":
logits_shape[0] / logits_shape[1]) self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()),
if capacity < self.min_capacity: high=torch.tensor(1.0,
capacity = self.min_capacity device=get_current_device())).rsample
def get_capacity(
self,
logits_shape,
):
capacity = math.floor(self.capacity_factor * logits_shape[-2] / logits_shape[-1])
capacity += capacity % 2
capacity = max(capacity, self.min_capacity)
assert capacity > 0
return capacity return capacity
def forward(self, inputs): def forward(self, inputs: torch.Tensor, cuda_mode: bool = False):
if self.noisy_func is not None: if self.noisy_func is not None:
inputs_noisy = self.noisy_func(inputs) inputs_noisy = self.noisy_func(inputs)
else: else:
inputs_noisy = inputs inputs_noisy = inputs
logits = F.softmax(inputs, dim=1) logits = autocast_softmax(inputs, dim=-1)
num_experts = logits.size(-1)
num_experts = logits.shape[1]
capacity = self.get_capacity(logits.shape) capacity = self.get_capacity(logits.shape)
expert_idx = torch.argmax(inputs_noisy, dim=1) top1_idx = torch.argmax(inputs_noisy, dim=-1)
expert_mask = F.one_hot(expert_idx, num_classes=num_experts) mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
expert_mask_f = expert_mask.float()
exp_counts = torch.sum(expert_mask, dim=0).detach().to('cpu') if self.training:
me = torch.mean(logits, dim=0)
ce = torch.mean(mask.float(), dim=0)
l_aux = num_experts * torch.sum(me * ce)
moe_env.add_loss(l_aux)
else:
max_num = torch.max(torch.sum(mask, dim=0))
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MOE_MODEL))
capacity = max_num.item()
me = torch.mean(logits, dim=0) if not self.training:
ce = torch.mean(expert_mask_f, dim=0) ranks = moe_cumsum(mask)
l_aux = torch.sum(me * ce) * num_experts elif self.select_policy == "random":
moe_env.add_loss(l_aux) rand_mask = mask * self.uniform(mask.shape)
_, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0)
mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1)
ranks = moe_cumsum(mask)
elif self.select_policy == "first":
ranks = moe_cumsum(mask)
mask = mask * torch.lt(ranks, capacity)
else:
raise NotImplementedError("Not support such select policy yet.")
rand_mask = expert_mask * self.uniform(logits.shape) ranks = torch.sum(mask * ranks, dim=-1)
_, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0)
dispatch_mask = \ if cuda_mode:
expert_mask * torch.zeros_like(expert_mask).scatter_(0, dispatch_idx, 1) mask = torch.sum(mask, dim=-1)
mask = torch.stack([mask], dim=0).to(torch.int32)
locations = torch.cumsum(dispatch_mask, dim=0) - 1 dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32)
locations = torch.sum(dispatch_mask * locations, dim=1) return logits, mask, dest_idx, num_experts * capacity
locations = F.one_hot(locations, num_classes=capacity) else:
ranks = F.one_hot(ranks, num_classes=capacity)
logits = logits * dispatch_mask weight = mask * logits.type_as(inputs)
combine_weights = logits.unsqueeze(2) * locations.unsqueeze(1) combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
sec_mask = combine_weights.bool()
sec_mask = combine_weights.bool() return combine_weights, sec_mask
return combine_weights, sec_mask, exp_counts
class Top2Router(nn.Module): class Top2Router(nn.Module):
@ -159,53 +118,67 @@ class Top2Router(nn.Module):
self.noisy_func = noisy_func self.noisy_func = noisy_func
def get_capacity(self, logits_shape): def get_capacity(self, logits_shape):
capacity = math.ceil(2 * self.capacity_factor * capacity = math.floor(2 * self.capacity_factor * logits_shape[-2] / logits_shape[-1])
logits_shape[0] / logits_shape[1]) capacity += capacity % 2
assert capacity > 0
return capacity return capacity
def forward(self, inputs): def forward(self, inputs: torch.Tensor, cuda_mode: bool = False):
# inputs: [s, h]
if self.noisy_func is not None: if self.noisy_func is not None:
inputs = self.noisy_func(inputs) inputs = self.noisy_func(inputs)
logits = F.softmax(inputs, dim=-1) logits = autocast_softmax(inputs, dim=-1) # logits: [s, e]
num_experts = logits.size(-1) num_experts = logits.size(-1)
capacity = self.get_capacity(logits.shape) capacity = self.get_capacity(logits.shape)
_, expert_idx = torch.topk(logits, k=2, dim=-1, largest=True, sorted=True) top1_idx = torch.argmax(logits, dim=-1)
top1_idx = expert_idx[:, 0] mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
top2_idx = expert_idx[:, 1] logits_except1 = logits.masked_fill(mask1.bool(), float("-inf"))
top2_idx = torch.argmax(logits_except1, dim=-1)
mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
mask1 = F.one_hot(top1_idx, num_classes=num_experts) cmask = (mask1 + mask2) # loss: [s, e]
mask2 = F.one_hot(top2_idx, num_classes=num_experts) if self.training:
me = torch.mean(logits, dim=0)
ce = torch.mean(cmask.float(), dim=0)
l_aux = num_experts * torch.sum(me * ce) / 2.0
moe_env.add_loss(l_aux)
else:
max_num = torch.max(torch.sum(cmask, dim=0))
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MOE_MODEL))
capacity = max_num.item()
loss_mask = (mask1 + mask2) rank1 = moe_cumsum(mask1) # rank1: [s, e]
exp_counts = torch.sum(loss_mask, dim=0).detach().to('cpu') rank2 = moe_cumsum(mask2)
me = torch.mean(logits, dim=0) rank2 += torch.sum(mask1, dim=-2, keepdim=True)
ce = torch.mean(loss_mask.float(), dim=0)
l_aux = num_experts * torch.sum(me * ce) / 2.0
moe_env.add_loss(l_aux)
locations1 = torch.cumsum(mask1, dim=0) - 1 mask1 *= torch.lt(rank1, capacity)
locations2 = torch.cumsum(mask2, dim=0) - 1 mask2 *= torch.lt(rank2, capacity)
locations2 += torch.sum(mask1, dim=0, keepdim=True)
mask1 *= torch.lt(locations1, capacity) rank1 = torch.sum(mask1 * rank1, dim=-1)
mask2 *= torch.lt(locations2, capacity) rank2 = torch.sum(mask2 * rank2, dim=-1)
weight1 = mask1 * logits if cuda_mode:
weight2 = mask2 * logits mask1 = torch.sum(mask1, dim=-1)
mask2 = torch.sum(mask2, dim=-1)
locations1 = torch.sum(mask1 * locations1, dim=1) mask = torch.stack([mask1, mask2], dim=0).to(torch.int32)
locations2 = torch.sum(mask2 * locations2, dim=1) dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32)
locations1_sc = F.one_hot(locations1, num_classes=capacity)
locations2_sc = F.one_hot(locations2, num_classes=capacity)
combine_weights1 = weight1.unsqueeze(2) * locations1_sc.unsqueeze(1) return logits, mask, dest_idx, num_experts * capacity
combine_weights2 = weight2.unsqueeze(2) * locations2_sc.unsqueeze(1) else:
combine_weights = combine_weights1 + combine_weights2 weight1 = mask1 * logits.type_as(inputs)
sec_mask = combine_weights.bool() weight2 = mask2 * logits.type_as(inputs)
rank1_sc = F.one_hot(rank1, num_classes=capacity)
rank2_sc = F.one_hot(rank2, num_classes=capacity)
return combine_weights, sec_mask, exp_counts cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1)
cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1)
cb_weight = cb_weight1 + cb_weight2
sec_mask = cb_weight.bool()
return cb_weight, sec_mask
class MoeLayer(nn.Module): class MoeLayer(nn.Module):
@ -225,52 +198,47 @@ class MoeLayer(nn.Module):
:type experts: nn.Module :type experts: nn.Module
""" """
def __init__(self, def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: nn.Module):
dim_model: int,
num_experts: int,
router: nn.Module,
experts: nn.Module):
super().__init__() super().__init__()
self.d_model = dim_model self.d_model = dim_model
self.num_experts = num_experts self.num_experts = num_experts
self.gate = nn.Linear(dim_model, num_experts, device=get_current_device()) self.gate = nn.Linear(dim_model, num_experts, bias=False, device=get_current_device())
self.router = router self.router = router
self.experts = experts self.experts = experts
self.cuda_mode = True if U_CUDA_MODE and moe_env.enable_cuda else False
def _router_part(self, tokens: torch.Tensor): def expert_part(self, expert_input: torch.Tensor):
gate_output = self.gate(tokens) expert_input = AllToAll.apply(expert_input, ParallelMode.MOE_MODEL)
return self.router(gate_output)
def router_part(self, tokens: torch.Tensor): input_shape = expert_input.shape
autocast_context = torch.is_autocast_enabled()
if not autocast_context: expert_input = expert_input.reshape(moe_env.model_parallel_size,
return self._router_part(tokens) self.num_experts // moe_env.model_parallel_size, -1, self.d_model)
else:
with autocast(enabled=False): expert_output = self.experts(expert_input)
if tokens.dtype == torch.float16: expert_output = expert_output.reshape(input_shape)
input_tokens = tokens.float()
else: expert_output = AllToAll.apply(expert_output, ParallelMode.MOE_MODEL)
input_tokens = tokens return expert_output
return self._router_part(input_tokens)
def forward(self, inputs: torch.Tensor) -> torch.Tensor: def forward(self, inputs: torch.Tensor) -> torch.Tensor:
tokens = inputs.reshape(-1, self.d_model) tokens = inputs.reshape(-1, self.d_model)
gate_output = self.gate(tokens)
router_res = self.router(gate_output, self.cuda_mode)
combine_weights, sec_mask, exp_counts = self.router_part(tokens) if self.cuda_mode:
logits, mask, dest_idx, ec = router_res
expert_input = MoeDispatch.apply(tokens, mask, dest_idx, ec)
expert_output = self.expert_part(expert_input)
ret = MoeCombine.apply(expert_output, logits, mask, dest_idx, ec)
else:
combine_weights, sec_mask = router_res
sec_mask_f = sec_mask.type_as(inputs)
expert_input = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
expert_output = self.expert_part(expert_input)
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
expert_output = expert_output.view(-1, expert_output.shape[-1])
ret = torch.matmul(combine_weights, expert_output)
sec_mask_f = sec_mask.type_as(inputs)
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
dispatch_data = AllToAll.apply(dispatch_data, ParallelMode.MOE_MODEL)
expert_output = self.experts(dispatch_data)
expert_output = AllToAll.apply(expert_output, ParallelMode.MOE_MODEL)
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
expert_output = expert_output.view(-1, expert_output.shape[-1])
ret = torch.matmul(combine_weights, expert_output)
ret = ret.reshape(inputs.shape) ret = ret.reshape(inputs.shape)
return ret return ret

View File

@ -0,0 +1,32 @@
import torch
import torch.nn.functional as F
from colossalai.utils import get_current_device
class NormalNoiseGenerator:
"""Generates a random noisy mask for logtis tensor.
All noise is generated from a normal distribution (0, 1 / E^2), where
E = the number of experts.
:param num_experts: The number of experts
:type num_experts: int
"""
def __init__(self, num_experts: int):
self.normal = torch.distributions.normal.Normal(
loc=torch.tensor(0.0, device=get_current_device()),
scale=torch.tensor(1.0 / num_experts ** 2, device=get_current_device())
).rsample
def __call__(self, inputs: torch.Tensor):
noisy = self.normal(inputs.shape)
return inputs + noisy
def autocast_softmax(inputs: torch.Tensor, dim: int):
assert inputs.dtype in {torch.float16, torch.float32}
fp16_flag = (inputs.dtype == torch.float16)
sm_input = inputs.to(torch.float32) if fp16_flag else inputs
sm_output = F.softmax(sm_input, dim)
return sm_output

View File

@ -4,7 +4,7 @@ import torch.nn as nn
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.nn.layer import VanillaPatchEmbedding, VanillaClassifier, \ from colossalai.nn.layer import VanillaPatchEmbedding, VanillaClassifier, \
WrappedDropout as Dropout, WrappedDropPath as DropPath WrappedDropout as Dropout, WrappedDropPath as DropPath
from colossalai.nn.layer.moe import Experts, MoeLayer, Top2Router, NormalNoiseGenerator from colossalai.nn.layer.moe import FFNExperts, MoeLayer, Top2Router, NormalNoiseGenerator
from .util import moe_sa_args, moe_mlp_args from .util import moe_sa_args, moe_mlp_args
from ..helper import TransformerLayer from ..helper import TransformerLayer
from colossalai.global_variables import moe_env from colossalai.global_variables import moe_env
@ -81,6 +81,7 @@ class VanillaFFN(nn.Module):
class Widenet(nn.Module): class Widenet(nn.Module):
def __init__(self, def __init__(self,
num_experts: int, num_experts: int,
capacity_factor: float, capacity_factor: float,
@ -98,43 +99,33 @@ class Widenet(nn.Module):
drop_path: float = 0.): drop_path: float = 0.):
super().__init__() super().__init__()
embedding = VanillaPatchEmbedding( embedding = VanillaPatchEmbedding(img_size=img_size,
img_size=img_size, patch_size=patch_size,
patch_size=patch_size, in_chans=in_chans,
in_chans=in_chans, embed_size=d_model)
embed_size=d_model)
embed_dropout = Dropout(p=drop_rate, mode=ParallelMode.TENSOR) embed_dropout = Dropout(p=drop_rate, mode=ParallelMode.TENSOR)
shared_sa = VanillaSelfAttention(**moe_sa_args( shared_sa = VanillaSelfAttention(**moe_sa_args(
d_model=d_model, n_heads=num_heads, d_kv=d_kv, d_model=d_model, n_heads=num_heads, d_kv=d_kv, attention_drop=attention_drop, drop_rate=drop_rate))
attention_drop=attention_drop, drop_rate=drop_rate))
noisy_func = NormalNoiseGenerator(num_experts) noisy_func = NormalNoiseGenerator(num_experts)
shared_router = Top2Router(capacity_factor, noisy_func=noisy_func) shared_router = Top2Router(capacity_factor, noisy_func=noisy_func)
shared_experts = Experts(expert=VanillaFFN, shared_experts = FFNExperts(num_experts, d_model, d_ff, drop_rate=drop_rate)
num_experts=num_experts,
**moe_mlp_args(
d_model=d_model,
d_ff=d_ff,
drop_rate=drop_rate
))
# stochastic depth decay rule # stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)] dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
blocks = [ blocks = [
TransformerLayer( TransformerLayer(att=shared_sa,
att=shared_sa, ffn=MoeLayer(dim_model=d_model,
ffn=MoeLayer(dim_model=d_model, num_experts=num_experts, num_experts=num_experts,
router=shared_router, experts=shared_experts), router=shared_router,
norm1=nn.LayerNorm(d_model, eps=1e-6), experts=shared_experts),
norm2=nn.LayerNorm(d_model, eps=1e-6), norm1=nn.LayerNorm(d_model, eps=1e-6),
droppath=DropPath(p=dpr[i], mode=ParallelMode.TENSOR) norm2=nn.LayerNorm(d_model, eps=1e-6),
) droppath=DropPath(p=dpr[i], mode=ParallelMode.TENSOR)) for i in range(depth)
for i in range(depth)
] ]
norm = nn.LayerNorm(d_model, eps=1e-6) norm = nn.LayerNorm(d_model, eps=1e-6)
self.linear = VanillaClassifier(in_features=d_model, self.linear = VanillaClassifier(in_features=d_model, num_classes=num_classes)
num_classes=num_classes)
nn.init.zeros_(self.linear.weight) nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias) nn.init.zeros_(self.linear.bias)
self.widenet = nn.Sequential(embedding, embed_dropout, *blocks, norm) self.widenet = nn.Sequential(embedding, embed_dropout, *blocks, norm)
@ -145,3 +136,64 @@ class Widenet(nn.Module):
x = torch.mean(x, dim=1) x = torch.mean(x, dim=1)
x = self.linear(x) x = self.linear(x)
return x return x
class ViTMoE(nn.Module):
def __init__(self,
num_experts: int,
capacity_factor: float,
img_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
num_classes: int = 1000,
depth: int = 12,
d_model: int = 768,
num_heads: int = 12,
d_kv: int = 64,
d_ff: int = 3072,
attention_drop: float = 0.,
drop_rate: float = 0.1,
drop_path: float = 0.):
super().__init__()
embedding = VanillaPatchEmbedding(img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_size=d_model)
embed_dropout = Dropout(p=drop_rate, mode=ParallelMode.TENSOR)
noisy_func = NormalNoiseGenerator(num_experts)
router = Top2Router(capacity_factor, noisy_func=noisy_func)
assert depth % 2 == 0
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
blocks = []
for i in range(depth):
sa = VanillaSelfAttention(**moe_sa_args(
d_model=d_model, n_heads=num_heads, d_kv=d_kv, attention_drop=attention_drop, drop_rate=drop_rate))
ffn = VanillaFFN(**moe_mlp_args(
d_model=d_model, d_ff=d_ff, drop_rate=drop_rate)) if i % 2 == 0 else \
MoeLayer(dim_model=d_model, num_experts=num_experts, router=router,
experts=FFNExperts(num_experts, d_model, d_ff, drop_rate=drop_rate))
layer = TransformerLayer(att=sa,
ffn=ffn,
norm1=nn.LayerNorm(d_model, eps=1e-6),
norm2=nn.LayerNorm(d_model, eps=1e-6),
droppath=DropPath(p=dpr[i], mode=ParallelMode.TENSOR))
blocks.append(layer)
norm = nn.LayerNorm(d_model, eps=1e-6)
self.linear = VanillaClassifier(in_features=d_model, num_classes=num_classes)
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
self.vitmoe = nn.Sequential(embedding, embed_dropout, *blocks, norm)
def forward(self, x):
moe_env.reset_loss()
x = self.vitmoe(x)
x = torch.mean(x, dim=1)
x = self.linear(x)
return x

View File

@ -162,6 +162,10 @@ if build_cuda_ext:
['scaled_masked_softmax.cpp', 'scaled_masked_softmax_cuda.cu'], ['scaled_masked_softmax.cpp', 'scaled_masked_softmax_cuda.cu'],
extra_cuda_flags + cc_flag)) extra_cuda_flags + cc_flag))
ext_modules.append(cuda_ext_helper('colossal_moe_cuda',
['moe_cuda.cpp', 'moe_cuda_kernel.cu'],
extra_cuda_flags + cc_flag))
extra_cuda_flags = ['-maxrregcount=50'] extra_cuda_flags = ['-maxrregcount=50']
ext_modules.append(cuda_ext_helper('colossal_layer_norm_cuda', ext_modules.append(cuda_ext_helper('colossal_layer_norm_cuda',

View File

@ -0,0 +1,97 @@
import os
from functools import partial
from pathlib import Path
import pytest
import torch
import torch.nn as nn
import torch.multiprocessing as mp
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Top2Router, MoeLayer
from colossalai.global_variables import moe_env
BATCH_SIZE = 32
NUM_EXPERTS = 4
CONFIG = dict(parallel=dict(moe=dict(size=4)))
def check_equal(A, B, atol=1e-06):
assert torch.allclose(A, B, rtol=0, atol=atol) is True
def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# torch.set_printoptions(precision=30)
torch.backends.cuda.matmul.allow_tf32 = False
local_rank = gpc.get_local_rank(ParallelMode.GLOBAL)
torch.manual_seed(rs + local_rank)
moe_env.reset_loss()
tokens = torch.randn(BATCH_SIZE, hidden_size,
dtype=data_type, device=get_current_device(), requires_grad=True)
# print(f"tokens:\n{tokens}")
router = Top2Router(1)
layer = MoeLayer(hidden_size, NUM_EXPERTS, router, nn.Identity())
if data_type == torch.float16:
layer = layer.half()
layer.cuda_mode = False
old_out = layer(tokens)
# print(f"old output:\n{old_out}")
ech = old_out.shape
grad = torch.randn(ech, device=get_current_device())
old_out.backward(grad)
o_tk_grad = tokens.grad.data.clone()
o_gt_grad = layer.gate.weight.grad.data.clone()
tokens.grad.zero_()
layer.gate.weight.grad.zero_()
layer.cuda_mode = True
new_out = layer(tokens)
# print(torch.max(torch.abs(old_out - new_out)))
if data_type == torch.float32:
check_equal(old_out, new_out)
else:
check_equal(old_out, new_out, 1e-2)
# print(f"forward functions passed")
# print(f"new output:\n{new_out}")
new_out.backward(grad)
n_tk_grad = tokens.grad.data.clone()
n_gt_grad = layer.gate.weight.grad.data.clone()
# print(torch.max(torch.abs(o_tk_grad - n_tk_grad)))
if data_type == torch.float32:
check_equal(o_tk_grad, n_tk_grad)
else:
check_equal(o_tk_grad, o_tk_grad, 1e-2)
# print(f"tokens gradient passed")
# print(torch.max(torch.abs(o_gt_grad - n_gt_grad)))
if data_type == torch.float32:
check_equal(o_gt_grad, n_gt_grad, 5e-05)
else:
check_equal(o_gt_grad, n_gt_grad, 2e-01)
# print(f"linear weight gradient passed")
@pytest.mark.dist
@pytest.mark.parametrize("rs", [131])
@pytest.mark.parametrize("hidden_size", [32, 144])
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
def test_moe_top2(rs, hidden_size, data_type):
world_size = 4
run_func = partial(run_routing, world_size=world_size, port=free_port(),
rs=rs, hidden_size=hidden_size, data_type=data_type)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_moe_top2(2, 256, torch.float16)

View File

@ -0,0 +1,97 @@
from functools import partial
import pytest
import torch
import torch.nn as nn
import torch.multiprocessing as mp
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Top1Router, MoeLayer
from colossalai.global_variables import moe_env
BATCH_SIZE = 32
NUM_EXPERTS = 4
CONFIG = dict(parallel=dict(moe=dict(size=4)))
def check_equal(A, B, atol=1e-06):
assert torch.allclose(A, B, rtol=0, atol=atol) is True
def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# torch.set_printoptions(precision=30)
torch.backends.cuda.matmul.allow_tf32 = False
local_rank = gpc.get_local_rank(ParallelMode.GLOBAL)
torch.manual_seed(rs + local_rank)
moe_env.reset_loss()
tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True)
# print(f"tokens:\n{tokens}")
router = Top1Router(1)
layer = MoeLayer(hidden_size, NUM_EXPERTS, router, nn.Identity())
if data_type == torch.float16:
layer = layer.half()
layer.cuda_mode = False
old_out = layer(tokens)
# print(f"old output:\n{old_out}")
ech = old_out.shape
grad = torch.randn(ech, device=get_current_device())
old_out.backward(grad)
o_tk_grad = tokens.grad.data.clone()
o_gt_grad = layer.gate.weight.grad.data.clone()
tokens.grad.zero_()
layer.gate.weight.grad.zero_()
layer.cuda_mode = True
new_out = layer(tokens)
# print(torch.max(torch.abs(old_out - new_out)))
if data_type == torch.float32:
check_equal(old_out, new_out)
else:
check_equal(old_out, new_out, 1e-2)
# print(f"forward functions passed")
# print(f"new output:\n{new_out}")
new_out.backward(grad)
n_tk_grad = tokens.grad.data.clone()
n_gt_grad = layer.gate.weight.grad.data.clone()
# print(torch.max(torch.abs(o_tk_grad - n_tk_grad)))
if data_type == torch.float32:
check_equal(o_tk_grad, n_tk_grad)
else:
check_equal(o_tk_grad, o_tk_grad, 1e-2)
# print(f"tokens gradient passed")
# print(torch.max(torch.abs(o_gt_grad - n_gt_grad)))
if data_type == torch.float32:
check_equal(o_gt_grad, n_gt_grad, 5e-05)
else:
check_equal(o_gt_grad, n_gt_grad, 2e-01)
# print(f"linear weight gradient passed")
@pytest.mark.skip(reason="Should be activated for detailed tests")
@pytest.mark.parametrize("rs", [2, 42, 60])
@pytest.mark.parametrize("hidden_size", [128, 256, 512, 768, 1024, 2048])
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
def test_moe_top2(rs, hidden_size, data_type):
world_size = 4
run_func = partial(run_routing,
world_size=world_size,
port=free_port(),
rs=rs,
hidden_size=hidden_size,
data_type=data_type)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_moe_top2(60, 512, torch.float16)

View File

@ -0,0 +1,97 @@
from functools import partial
import pytest
import torch
import torch.nn as nn
import torch.multiprocessing as mp
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Top2Router, MoeLayer
from colossalai.global_variables import moe_env
BATCH_SIZE = 32
NUM_EXPERTS = 4
CONFIG = dict(parallel=dict(moe=dict(size=4)))
def check_equal(A, B, atol=1e-06):
assert torch.allclose(A, B, rtol=0, atol=atol) is True
def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# torch.set_printoptions(precision=30)
torch.backends.cuda.matmul.allow_tf32 = False
local_rank = gpc.get_local_rank(ParallelMode.GLOBAL)
torch.manual_seed(rs + local_rank)
moe_env.reset_loss()
tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True)
# print(f"tokens:\n{tokens}")
router = Top2Router(1)
layer = MoeLayer(hidden_size, NUM_EXPERTS, router, nn.Identity())
if data_type == torch.float16:
layer = layer.half()
layer.cuda_mode = False
old_out = layer(tokens)
# print(f"old output:\n{old_out}")
ech = old_out.shape
grad = torch.randn(ech, device=get_current_device())
old_out.backward(grad)
o_tk_grad = tokens.grad.data.clone()
o_gt_grad = layer.gate.weight.grad.data.clone()
tokens.grad.zero_()
layer.gate.weight.grad.zero_()
layer.cuda_mode = True
new_out = layer(tokens)
# print(torch.max(torch.abs(old_out - new_out)))
if data_type == torch.float32:
check_equal(old_out, new_out)
else:
check_equal(old_out, new_out, 1e-2)
# print(f"forward functions passed")
# print(f"new output:\n{new_out}")
new_out.backward(grad)
n_tk_grad = tokens.grad.data.clone()
n_gt_grad = layer.gate.weight.grad.data.clone()
# print(torch.max(torch.abs(o_tk_grad - n_tk_grad)))
if data_type == torch.float32:
check_equal(o_tk_grad, n_tk_grad)
else:
check_equal(o_tk_grad, o_tk_grad, 1e-2)
# print(f"tokens gradient passed")
# print(torch.max(torch.abs(o_gt_grad - n_gt_grad)))
if data_type == torch.float32:
check_equal(o_gt_grad, n_gt_grad, 5e-05)
else:
check_equal(o_gt_grad, n_gt_grad, 2e-01)
# print(f"linear weight gradient passed")
@pytest.mark.skip(reason="Should be activated for detailed tests")
@pytest.mark.parametrize("rs", [2, 42, 60])
@pytest.mark.parametrize("hidden_size", [128, 256, 512, 768, 1024, 2048])
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
def test_moe_top2(rs, hidden_size, data_type):
world_size = 4
run_func = partial(run_routing,
world_size=world_size,
port=free_port(),
rs=rs,
hidden_size=hidden_size,
data_type=data_type)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_moe_top2(2, 256, torch.float16)