Optimized MoE layer and fixed some bugs;

Decreased moe tests;

Added FFNExperts and ViTMoE model
pull/394/head
1SAA 2022-02-18 20:42:31 +08:00 committed by Frank Lee
parent 3dba070580
commit 219df6e685
15 changed files with 1552 additions and 203 deletions

View File

@ -9,6 +9,6 @@ repos:
hooks:
- id: flake8
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v13.0.0
rev: v13.0.1
hooks:
- id: clang-format

View File

@ -56,6 +56,7 @@ class MoeEnv:
self.data_parallel_size = None
self.model_parallel_size = None
self.aux_loss = None
self.enable_cuda = True
def setup(self, moe_model_size):
from .core import global_context as gpc
@ -71,6 +72,9 @@ class MoeEnv:
def is_initialized(self):
return self.model_parallel_size is not None
def set_cuda_false(self):
self.enable_cuda = False
def reset_loss(self):
self.aux_loss = 0

View File

@ -5,7 +5,7 @@
#include "ATen/ATen.h"
#include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDAContext.h"
#include <THC/THCDeviceUtils.cuh>
#include "ATen/cuda/DeviceUtils.cuh"
#include <cuda.h>
#include <cuda_runtime.h>

View File

@ -0,0 +1,118 @@
#include <torch/extension.h>
torch::Tensor moe_dispatch_cuda_forward(
int s, int ec, int h,
torch::Tensor batch_tokens,
torch::Tensor mask,
torch::Tensor dest_idx);
torch::Tensor moe_dispatch_cuda_backward(
int s, int ec, int h,
torch::Tensor expert_grad,
torch::Tensor mask,
torch::Tensor dest_idx);
torch::Tensor moe_combine_cuda_forward(
int s, int e, int c, int h,
torch::Tensor expert_tokens,
torch::Tensor logits,
torch::Tensor mask,
torch::Tensor dest_idx);
std::vector<torch::Tensor> moe_combine_cuda_backward(
int s, int e, int c, int h,
torch::Tensor tokens_grad,
torch::Tensor expert_tokens,
torch::Tensor logits,
torch::Tensor mask,
torch::Tensor dest_idx);
torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask);
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
torch::Tensor moe_dispatch_forward(
int s, int ec, int h,
torch::Tensor batch_tokens,
torch::Tensor mask,
torch::Tensor dest_idx) {
CHECK_INPUT(batch_tokens);
CHECK_CUDA(mask);
CHECK_CUDA(dest_idx);
return moe_dispatch_cuda_forward(
s, ec, h,
batch_tokens, mask, dest_idx);
}
torch::Tensor moe_dispatch_backward(
int s, int ec, int h,
torch::Tensor expert_grad,
torch::Tensor mask,
torch::Tensor dest_idx) {
CHECK_INPUT(expert_grad);
CHECK_CUDA(mask);
CHECK_CUDA(dest_idx);
return moe_dispatch_cuda_backward(
s, ec, h,
expert_grad, mask, dest_idx);
}
torch::Tensor moe_combine_forward(
int s, int e, int c, int h,
torch::Tensor expert_tokens,
torch::Tensor logits,
torch::Tensor mask,
torch::Tensor dest_idx) {
CHECK_INPUT(expert_tokens);
CHECK_INPUT(logits);
CHECK_CUDA(mask);
CHECK_CUDA(dest_idx);
return moe_combine_cuda_forward(
s, e, c, h,
expert_tokens, logits, mask, dest_idx);
}
std::vector<torch::Tensor> moe_combine_backward(
int s, int e, int c, int h,
torch::Tensor tokens_grad,
torch::Tensor expert_tokens,
torch::Tensor logits,
torch::Tensor mask,
torch::Tensor dest_idx) {
CHECK_INPUT(tokens_grad);
CHECK_INPUT(logits);
CHECK_CUDA(mask);
CHECK_CUDA(dest_idx);
return moe_combine_cuda_backward(
s, e, c, h,
tokens_grad, expert_tokens, logits, mask, dest_idx);
}
torch::Tensor moe_cumsum(torch::Tensor mask) {
CHECK_INPUT(mask);
return cumsum_sub_one_in_dim0(mask);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("cumsum_sub_one", &moe_cumsum,
"Fast cumsum operation in dim0");
m.def("dispatch_forward", &moe_dispatch_forward,
"Forward operation in MoE dispatch function");
m.def("dispatch_backward", &moe_dispatch_backward,
"Backward operation in MoE dispatch function");
m.def("combine_forward", &moe_combine_forward,
"Combine operation in MoE combine function");
m.def("combine_backward", &moe_combine_backward,
"Combine operation in MoE combine function");
}

View File

@ -0,0 +1,702 @@
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cub/cub.cuh>
#include "block_reduce.h"
template<typename T, int block_size, int pack_size>
__device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size,
cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size,
cub::BLOCK_STORE_VECTORIZE> BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size; T pack[pack_size];
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(src_row + idx, pack);
BlockStore(ts_store).Store(dst_row + idx, pack);
}
}
template<typename T, int block_size, int pack_size>
__device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size,
cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size,
cub::BLOCK_STORE_VECTORIZE> BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size; T pack[pack_size];
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(dst_row + idx, pack);
BlockStore(ts_store).Store(src_row + idx, pack);
}
}
template<typename T, int block_size, int pack_size>
__device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2, const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size,
cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size,
cub::BLOCK_STORE_VECTORIZE> BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size; T pack[pack_size];
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(src_row + idx, pack);
BlockStore(ts_store).Store(dst_row1 + idx, pack);
BlockStore(ts_store).Store(dst_row2 + idx, pack);
}
}
template<typename T, int block_size, int pack_size>
__device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2, const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size,
cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size,
cub::BLOCK_STORE_VECTORIZE> BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size;
T pack1[pack_size], pack2[pack_size];
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(dst_row1 + idx, pack1);
BlockLoad(ts_load).Load(dst_row2 + idx, pack2);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
pack1[i] += pack2[i];
}
BlockStore(ts_store).Store(src_row + idx, pack1);
}
}
template<typename T, int block_size, int pack_size>
__device__ void moe_cb_one_fwd(
T *src_row, T *dst_row,
const T weight, const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size,
cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size,
cub::BLOCK_STORE_VECTORIZE> BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size; T pack[pack_size];
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(src_row + idx, pack);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
pack[i] *= weight;
}
BlockStore(ts_store).Store(dst_row + idx, pack);
}
}
template<typename T, int block_size, int pack_size>
__device__ void moe_cb_one_bwd(
T *src_row, T *dst_row, T *tks_row, T *weight_grad,
const T weight, const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size,
cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size,
cub::BLOCK_STORE_VECTORIZE> BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size;
T grad[pack_size], tokens[pack_size];
float thread_sum = 0;
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(dst_row + idx, grad);
BlockLoad(ts_load).Load(tks_row + idx, tokens);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
thread_sum += grad[i] * tokens[i];
grad[i] *= weight;
}
BlockStore(ts_store).Store(src_row + idx, grad);
}
blockReduce<ReduceType::kSum, 1>(&thread_sum);
if (threadIdx.x == 0)
*weight_grad = static_cast<T>(thread_sum);
}
template<typename T, int block_size, int pack_size>
__device__ void moe_cb_two_fwd(
T *src_row1, T *src_row2, T *dst_row,
const T weight1, const T weight2, const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size,
cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size,
cub::BLOCK_STORE_VECTORIZE> BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size;
T pack1[pack_size], pack2[pack_size];
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(src_row1 + idx, pack1);
BlockLoad(ts_load).Load(src_row2 + idx, pack2);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
pack1[i] = pack1[i] * weight1 + pack2[i] * weight2;
}
BlockStore(ts_store).Store(dst_row + idx, pack1);
}
}
template<typename T, int block_size, int pack_size>
__device__ void moe_cb_two_bwd(
T *src_row1, T *src_row2, T *dst_row,
T *tks_row1, T *tks_row2, T *weight_grad1, T *weight_grad2,
const T weight1, const T weight2, const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size,
cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size,
cub::BLOCK_STORE_VECTORIZE> BlockStore;
__shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size;
T grad[pack_size], tokens1[pack_size], tokens2[pack_size],
sgrad1[pack_size], sgrad2[pack_size];
float thread_sum[2] = {0, 0};
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(dst_row + idx, grad);
BlockLoad(ts_load).Load(tks_row1 + idx, tokens1);
BlockLoad(ts_load).Load(tks_row2 + idx, tokens2);
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
thread_sum[0] += grad[i] * tokens1[i];
thread_sum[1] += grad[i] * tokens2[i];
sgrad1[i] = weight1 * grad[i];
sgrad2[i] = weight2 * grad[i];
}
BlockStore(ts_store).Store(src_row1 + idx, sgrad1);
BlockStore(ts_store).Store(src_row2 + idx, sgrad2);
}
blockReduce<ReduceType::kSum, 2>(thread_sum);
if (threadIdx.x == 0)
*weight_grad1 = static_cast<T>(thread_sum[0]);
else if (threadIdx.x == 1)
*weight_grad2 = static_cast<T>(thread_sum[1]);
}
// DISPATCH KERNELS --------------------------------
template<typename T, int block_size, int pack_size>
__device__ void moe_dpch_fwd_selector(
T *src_row, T *dst_row1, T *dst_row2, const int cols,
const int indicator1, const int indicator2) {
if (indicator1 != 0 && indicator2 != 0)
moe_dpch_two_fwd<T, block_size, pack_size>(
src_row, dst_row1, dst_row2, cols);
else if (indicator1 != 0)
moe_dpch_one_fwd<T, block_size, pack_size>(
src_row, dst_row1, cols);
else if (indicator2 != 0)
moe_dpch_one_fwd<T, block_size, pack_size>(
src_row, dst_row2, cols);
else
return;
}
template<typename T, int block_size, int pack_size>
__device__ void moe_dpch_bwd_selector(
T *src_row, T *dst_row1, T *dst_row2, const int cols,
const int indicator1, const int indicator2) {
if (indicator1 != 0 && indicator2 != 0)
moe_dpch_two_bwd<T, block_size, pack_size>(
src_row, dst_row1, dst_row2, cols);
else if (indicator1 != 0)
moe_dpch_one_bwd<T, block_size, pack_size>(
src_row, dst_row1, cols);
else if (indicator2 != 0)
moe_dpch_one_bwd<T, block_size, pack_size>(
src_row, dst_row2, cols);
else
return;
}
template<typename T, int block_size, int pack_size>
__global__ void moe_dpch_fwd_kernel(
T *batch_tokens, T *expert_input,
int *mask1, int *mask2,
int *dest1, int *dest2, const int h) {
int row = blockIdx.x;
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
moe_dpch_fwd_selector<T, block_size, pack_size>(
batch_tokens + (row * h),
expert_input + (dest1[row] * h), expert_input + (dest2[row] * h),
h, mask1[row], indicator2);
}
template<typename T, int block_size, int pack_size>
__global__ void moe_dpch_bwd_kernel(
T *tokens_grad, T *expert_grad,
int *mask1, int *mask2,
int *dest1, int *dest2, const int h) {
int row = blockIdx.x;
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
moe_dpch_bwd_selector<T, block_size, pack_size>(
tokens_grad + (row * h),
expert_grad + (dest1[row] * h), expert_grad + (dest2[row] * h),
h, mask1[row], indicator2);
}
// COMBINE KERNELS --------------------------------
template<typename T, int block_size, int pack_size>
__device__ void moe_cb_fwd_selector(
T *src_row1, T *src_row2, T *dst_row, const int cols,
const T weight1, const T weight2,
const int indicator1, const int indicator2) {
if (indicator1 != 0 && indicator2 != 0)
moe_cb_two_fwd<T, block_size, pack_size>(
src_row1, src_row2, dst_row, weight1, weight2, cols);
else if (indicator1 != 0)
moe_cb_one_fwd<T, block_size, pack_size>(
src_row1, dst_row, weight1, cols);
else if (indicator2 != 0)
moe_cb_one_fwd<T, block_size, pack_size>(
src_row2, dst_row, weight2, cols);
else
return;
}
template<typename T, int block_size, int pack_size>
__device__ void moe_cb_bwd_selector(
T *src_row1, T *src_row2, T *dst_row, const int cols,
T *tks_row1, T *tks_row2, T *wt_grad1, T *wt_grad2,
const T weight1, const T weight2,
const int indicator1, const int indicator2) {
if (indicator1 != 0 && indicator2 != 0)
moe_cb_two_bwd<T, block_size, pack_size>(
src_row1, src_row2, dst_row,
tks_row1, tks_row2, wt_grad1, wt_grad2,
weight1, weight2, cols);
else if (indicator1 != 0)
moe_cb_one_bwd<T, block_size, pack_size>(
src_row1, dst_row, tks_row1, wt_grad1, weight1, cols);
else if (indicator2 != 0)
moe_cb_one_bwd<T, block_size, pack_size>(
src_row2, dst_row, tks_row2, wt_grad2, weight2, cols);
else
return;
}
template<typename T, int block_size, int pack_size>
__global__ void moe_cb_fwd_kernel(
T *expert_tokens, T *combine_tokens, T *logits,
int *mask1, int *mask2,
int *dest1, int *dest2,
const int e, const int c, const int h) {
int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c;
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
T *row_log = logits + (row * e);
moe_cb_fwd_selector<T, block_size, pack_size>(
expert_tokens + (dest1[row] * h), expert_tokens + (dest2[row] * h),
combine_tokens + (row * h), h,
row_log[eid1], row_log[eid2],
mask1[row], indicator2);
}
template<typename T, int block_size, int pack_size>
__global__ void moe_cb_bwd_kernel(
T *tokens_grad, T *expert_grad, T *tks,
T *logits, T *logits_grad,
int *mask1, int *mask2,
int *dest1, int *dest2,
const int e, const int c, const int h) {
int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c;
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e);
moe_cb_bwd_selector<T, block_size, pack_size>(
expert_grad + (dest1[row] * h), expert_grad + (dest2[row] * h),
tokens_grad + (row * h), h,
tks + (dest1[row] * h), tks + (dest2[row] * h),
row_grad + eid1, row_grad + eid2,
row_log[eid1], row_log[eid2],
mask1[row], indicator2);
}
//CUMSUM KERNEL --------------------------------
template<int block_size, int pack_size>
__global__ void cumsum_kernel(
int *inputs, int *outputs,
const int s, const int e) {
assert(s % pack_size == 0);
constexpr int bpack_size = block_size * pack_size;
int tid = threadIdx.x, bid = blockIdx.x,
tps = tid * pack_size, last_sum = -1;
__shared__ int temp[block_size + 1]; int pack[pack_size];
for (int idx = 0; idx < s; idx += bpack_size) {
int offset = 1;
if (idx + tps < s) {
temp[tid] = inputs[tps * e + bid];
#pragma unroll
for (int i = 1; i < pack_size; ++i) {
pack[i] = inputs[(tps + i) * e + bid];
}
#pragma unroll
for (int i = 1; i < pack_size; ++i) {
temp[tid] += pack[i];
}
}
for (int i = block_size >> 1; i > 0; i >>= 1) {
__syncthreads();
if (tid < i) {
int j = offset * (2 * tid + 1) - 1;
temp[j + offset] += temp[j];
}
offset <<= 1;
}
if (tid == 0) {
temp[block_size] = temp[block_size - 1];
temp[block_size - 1] = 0;
}
for (int i = 1; i < block_size; i <<= 1) {
offset >>= 1;
__syncthreads();
if (tid < i) {
int j = offset * (2 * tid + 1) - 1,
k = j + offset, ts = temp[j];
temp[j] = temp[k];
temp[k] += ts;
}
}
__syncthreads();
if (tid == 0)
temp[0] = temp[block_size];
__syncthreads();
if (idx + tps < s) {
temp[tid + 1] += last_sum;
#pragma unroll
for (int i = pack_size - 1; i > 0; --i) {
outputs[(tps + i) * e + bid] = temp[tid + 1];
temp[tid + 1] -= pack[i];
}
outputs[tps * e + bid] = temp[tid + 1];
}
__syncthreads();
last_sum += temp[0];
inputs += bpack_size * e;
outputs += bpack_size * e;
}
}
//LAUNCH FUNCTIONS --------------------------------
template<typename T>
void moe_dpch_fwd_launch(
T *batch_tokens, T *expert_input,
int *mask1, int *mask2,
int *dest1, int *dest2,
const int s, const int h) {
if (h < 256)
moe_dpch_fwd_kernel<T, 32, 4><<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
else if (h < 512)
moe_dpch_fwd_kernel<T, 32, 8><<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
else if (h < 1024)
moe_dpch_fwd_kernel<T, 32, 16><<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
else if (h < 2048)
moe_dpch_fwd_kernel<T, 64, 16><<<s, 64>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
else
moe_dpch_fwd_kernel<T, 128, 16><<<s, 128>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
}
template<typename T>
void moe_dpch_bwd_launch(
T *tokens_grad, T *expert_grad,
int *mask1, int *mask2,
int *dest1, int *dest2,
const int s, const int h) {
if (h < 256)
moe_dpch_bwd_kernel<T, 32, 4><<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
else if (h < 512)
moe_dpch_bwd_kernel<T, 32, 8><<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
else if (h < 1024)
moe_dpch_bwd_kernel<T, 32, 16><<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
else if (h < 2048)
moe_dpch_bwd_kernel<T, 64, 16><<<s, 64>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
else
moe_dpch_bwd_kernel<T, 128, 16><<<s, 128>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
}
template<typename T>
void moe_cb_fwd_launch(
T *expert_tokens, T *combine_tokens, T *logits,
int *mask1, int *mask2,
int *dest1, int *dest2,
const int s, const int e, const int c, const int h) {
if (h < 256)
moe_cb_fwd_kernel<T, 32, 4><<<s, 32>>>
(expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, e, c, h);
else if (h < 512)
moe_cb_fwd_kernel<T, 32, 8><<<s, 32>>>
(expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, e, c, h);
else if (h < 1024)
moe_cb_fwd_kernel<T, 32, 16><<<s, 32>>>
(expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, e, c, h);
else if (h < 2048)
moe_cb_fwd_kernel<T, 64, 16><<<s, 64>>>
(expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, e, c, h);
else
moe_cb_fwd_kernel<T, 128, 16><<<s, 128>>>
(expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, e, c, h);
}
template<typename T>
void moe_cb_bwd_launch(
T *tokens_grad, T *expert_grad, T *tks,
T *logits, T *logits_grad,
int *mask1, int *mask2,
int *dest1, int *dest2,
const int s, const int e, const int c, const int h) {
if (h < 256)
moe_cb_bwd_kernel<T, 32, 4><<<s, 32>>>
(tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h);
else // if (h < 512)
moe_cb_bwd_kernel<T, 64, 4><<<s, 64>>>
(tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h);
// else if (h < 1024)
// moe_cb_bwd_kernel<T, 128, 4><<<s, 128>>>
// (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h);
// else
// moe_cb_bwd_kernel<T, 256, 4><<<s, 256>>>
// (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h);
}
void cumsum_launch(
int *inputs, int *outputs,
const int s, const int e) {
if (s <= 256)
cumsum_kernel<256, 1><<<e, 256>>>(inputs, outputs, s, e);
else if (s <= 512)
cumsum_kernel<512, 1><<<e, 512>>>(inputs, outputs, s, e);
else if (s <= 1024)
cumsum_kernel<1024, 1><<<e, 1024>>>(inputs, outputs, s, e);
else if (s <= 2048)
cumsum_kernel<1024, 2><<<e, 1024>>>(inputs, outputs, s, e);
else
cumsum_kernel<1024, 4><<<e, 1024>>>(inputs, outputs, s, e);
}
// API FUNCTIONS --------------------------------
#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \
switch (TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented yet for specific data type.");\
}
torch::Tensor moe_dispatch_cuda_forward(
int s, int ec, int h,
torch::Tensor batch_tokens,
torch::Tensor mask,
torch::Tensor dest_idx) {
assert(h % 16 == 0);
auto res = torch::zeros({ec, h},
torch::dtype(batch_tokens.dtype()).device(batch_tokens.device()));
auto k = mask.size(0);
DISPATCH_FLOAT_AND_HALF(
batch_tokens.scalar_type(), "moe dispatch forward",
moe_dpch_fwd_launch<scalar_t>(
batch_tokens.data<scalar_t>(), res.data<scalar_t>(),
mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(),
dest_idx[0].data<int>(), k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(),
s, h)
);
return res;
}
torch::Tensor moe_dispatch_cuda_backward(
int s, int ec, int h,
torch::Tensor expert_grad,
torch::Tensor mask,
torch::Tensor dest_idx) {
assert(h % 16 == 0);
auto res = torch::zeros({s, h},
torch::dtype(expert_grad.dtype()).device(expert_grad.device()));
auto k = mask.size(0);
DISPATCH_FLOAT_AND_HALF(
expert_grad.scalar_type(), "moe dispatch backward",
moe_dpch_bwd_launch<scalar_t>(
res.data<scalar_t>(), expert_grad.data<scalar_t>(),
mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(),
dest_idx[0].data<int>(), k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(),
s, h)
);
return res;
}
torch::Tensor moe_combine_cuda_forward(
int s, int e, int c, int h,
torch::Tensor expert_tokens,
torch::Tensor logits,
torch::Tensor mask,
torch::Tensor dest_idx) {
assert(h % 16 == 0);
assert(expert_tokens.dtype() == logits.dtype());
auto res = torch::zeros({s, h},
torch::dtype(expert_tokens.dtype()).device(expert_tokens.device()));
auto k = mask.size(0);
DISPATCH_FLOAT_AND_HALF(
expert_tokens.scalar_type(), "moe combine forward",
moe_cb_fwd_launch<scalar_t>(
expert_tokens.data<scalar_t>(), res.data<scalar_t>(), logits.data<scalar_t>(),
mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(),
dest_idx[0].data<int>(), k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(),
s, e, c, h)
);
return res;
}
std::vector<torch::Tensor> moe_combine_cuda_backward(
int s, int e, int c, int h,
torch::Tensor tokens_grad,
torch::Tensor expert_tokens,
torch::Tensor logits,
torch::Tensor mask,
torch::Tensor dest_idx) {
assert(h % 16 == 0);
assert(tokens_grad.dtype() == expert_tokens.dtype());
assert(expert_tokens.dtype() == logits.dtype());
auto egrad = torch::zeros({e * c, h},
torch::dtype(tokens_grad.dtype()).device(tokens_grad.device())),
wgrad = torch::zeros({s, e}, torch::dtype(logits.dtype()).device(logits.device()));
auto k = mask.size(0);
DISPATCH_FLOAT_AND_HALF(
tokens_grad.scalar_type(), "moe combine backward",
moe_cb_bwd_launch<scalar_t>(
tokens_grad.data<scalar_t>(), egrad.data<scalar_t>(), expert_tokens.data<scalar_t>(),
logits.data<scalar_t>(), wgrad.data<scalar_t>(),
mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(),
dest_idx[0].data<int>(), k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(),
s, e, c, h)
);
return {egrad, wgrad};
}
torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) {
assert(mask.dim() == 2);
assert(mask.dtype() == torch::kInt32);
const int s = mask.size(0), e = mask.size(1);
auto res = torch::empty({s, e}, torch::dtype(torch::kInt32).device(mask.device()));
cumsum_launch(mask.data<int>(), res.data<int>(), s, e);
return res;
}

View File

@ -1,8 +1,5 @@
from ._operation import AllToAll
from .layers import Experts, MoeLayer, \
NormalNoiseGenerator, Top1Router, Top2Router
from .experts import Experts, FFNExperts
from .layers import MoeLayer, Top1Router, Top2Router
from .utils import NormalNoiseGenerator
__all__ = [
'AllToAll', 'Experts', 'Top1Router', 'Top2Router',
'MoeLayer', 'NormalNoiseGenerator'
]
__all__ = ['Experts', 'FFNExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator']

View File

@ -6,16 +6,26 @@ from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from typing import Any, Tuple
U_CUDA_MODE = False
try:
import colossal_moe_cuda
U_CUDA_MODE = True
except ImportError:
print("If you want to activate cuda mode for MoE, please install with cuda_ext!")
class AllToAll(torch.autograd.Function):
"""Dispatches input tensor [e, c, h] to all experts by all_to_all_single
operation in torch.distributed.
"""
@staticmethod
def forward(ctx: Any,
inputs: Tensor,
parallel_mode: ParallelMode) -> Tensor:
ctx.parallel_mode = parallel_mode
if ctx is not None:
ctx.parallel_mode = parallel_mode
if not inputs.is_contiguous():
inputs = inputs.contiguous()
@ -26,4 +36,79 @@ class AllToAll(torch.autograd.Function):
@staticmethod
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
return AllToAll.apply(*grad_outputs, ctx.parallel_mode), None
return AllToAll.forward(None, *grad_outputs, ctx.parallel_mode), None
class MoeDispatch(torch.autograd.Function):
@staticmethod
def forward(ctx, tokens, mask, dest_idx, ec):
s = tokens.size(0)
h = tokens.size(1)
expert_input = colossal_moe_cuda.dispatch_forward(s, ec, h, tokens, mask, dest_idx)
ctx.save_for_backward(mask, dest_idx)
ctx.s = s
ctx.h = h
ctx.ec = ec
return expert_input
@staticmethod
def backward(ctx, output_grad):
mask, dest_idx = ctx.saved_tensors
d_tokens = colossal_moe_cuda.dispatch_backward(
ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
return d_tokens, None, None, None
class MoeCombine(torch.autograd.Function):
@staticmethod
def forward(ctx, expert_tokens, logits, mask, dest_idx, ec):
assert logits.dtype == torch.float32
s = logits.size(0)
e = logits.size(1)
c = ec // e
h = expert_tokens.size(-1)
fp16_flag = (expert_tokens.dtype == torch.float16)
cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens
ctokens = colossal_moe_cuda.combine_forward(s, e, c, h,
cb_input, logits,
mask, dest_idx)
output = ctokens.to(torch.float16) if fp16_flag else ctokens
ctx.save_for_backward(expert_tokens, logits, mask, dest_idx)
ctx.s = s
ctx.e = e
ctx.c = c
ctx.h = h
ctx.fp16_flag = fp16_flag
return output
@staticmethod
def backward(ctx, tokens_grad):
expert_tokens, logits, mask, dest_idx = ctx.saved_tensors
cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 \
else tokens_grad
cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens
d_expert, d_logits = colossal_moe_cuda.combine_backward(
ctx.s, ctx.e, ctx.c, ctx.h,
cb_grad, cb_input, logits, mask, dest_idx)
d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert
return d_expert, d_logits, None, None, None
def moe_cumsum(inputs: Tensor):
dim0 = inputs.size(0)
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
if flag and U_CUDA_MODE:
return colossal_moe_cuda.cumsum_sub_one(inputs)
else:
return torch.cumsum(inputs, dim=0) - 1

View File

@ -0,0 +1,96 @@
import math
import torch
import torch.nn as nn
from colossalai.global_variables import moe_env
from colossalai.context import ParallelMode, seed
from colossalai.utils import get_current_device
class Experts(nn.Module):
"""A wrapper class to create experts. It will create E experts across the
moe model parallel group, where E is the number of experts. Every expert
is a instence of the class, 'expert' in initialization parameters.
:param expert: The class of all experts
:param num_experts: The number of experts
:param expert_args: Args used to initialize experts
:type num_experts: int
"""
def __init__(self, expert, num_experts, **expert_args):
super().__init__()
assert num_experts % moe_env.model_parallel_size == 0, \
"The number of experts should be divied by moe model size"
num_local_experts = num_experts // moe_env.model_parallel_size
with seed(ParallelMode.MOE_MODEL):
self.experts = nn.ModuleList([expert(**expert_args) for _ in range(num_local_experts)])
self.num_local_experts = num_local_experts
for exp in self.experts:
for param in exp.parameters():
param.__setattr__('moe_param', True)
def forward(self, inputs):
expert_input = torch.chunk(inputs, self.num_local_experts, dim=1)
expert_output = []
for i in range(self.num_local_experts):
expert_output.append(self.experts[i](expert_input[i]))
output = torch.cat(expert_output, dim=1).contiguous()
return output
class FFNExperts(nn.Module):
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
super().__init__()
assert num_experts % moe_env.model_parallel_size == 0, \
"The number of experts should be divied by moe model size"
num_local_experts = num_experts // moe_env.model_parallel_size
self.w1 = nn.Parameter(torch.empty(num_local_experts, d_model, d_ff, device=get_current_device()))
self.b1 = nn.Parameter(torch.empty(num_local_experts, 1, d_ff, device=get_current_device()))
self.w2 = nn.Parameter(torch.empty(num_local_experts, d_ff, d_model, device=get_current_device()))
self.b2 = nn.Parameter(torch.empty(num_local_experts, 1, d_model, device=get_current_device()))
s1 = math.sqrt(0.1 / d_model)
s2 = math.sqrt(0.1 / d_ff)
nn.init.trunc_normal_(self.w1, std=s1)
nn.init.trunc_normal_(self.b1, std=s1)
nn.init.trunc_normal_(self.w2, std=s2)
nn.init.trunc_normal_(self.b2, std=s2)
self.act = nn.GELU() if activation is None else activation
self.drop = nn.Dropout(p=drop_rate)
for param in self.parameters():
param.__setattr__('moe_param', True)
def forward(self, inputs): # x [g, el, c, h]
el = inputs.size(1)
h = inputs.size(-1)
inputs = inputs.transpose(0, 1)
inshape = inputs.shape
inputs = inputs.reshape(el, -1, h)
out_ff = torch.baddbmm(self.b1, inputs, self.w1)
out_act = self.act(out_ff)
with seed(ParallelMode.TENSOR):
inter = self.drop(out_act)
out_model = torch.baddbmm(self.b2, inter, self.w2)
with seed(ParallelMode.TENSOR):
outputs = self.drop(out_model) # outputs [el, gc, h]
outputs = outputs.reshape(inshape)
outputs = outputs.transpose(0, 1).contiguous()
return outputs

View File

@ -3,70 +3,13 @@ import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast
import torch.distributed as dist
from colossalai.core import global_context as gpc
from colossalai.global_variables import moe_env
from colossalai.context import ParallelMode, seed
from colossalai.context import ParallelMode
from colossalai.utils import get_current_device
from ._operation import AllToAll
class NormalNoiseGenerator:
"""Generates a random noisy mask for logtis tensor.
All noise is generated from a normal distribution (0, 1 / E^2), where
E = the number of experts.
:param num_experts: The number of experts
:type num_experts: int
"""
def __init__(self, num_experts: int):
self.normal = torch.distributions.normal.Normal(
loc=torch.tensor(0.0, device=get_current_device()),
scale=torch.tensor(1.0 / num_experts ** 2, device=get_current_device())
).rsample
def __call__(self, inputs: torch.Tensor):
noisy = self.normal(inputs.shape)
return inputs + noisy
class Experts(nn.Module):
"""A wrapper class to create experts. It will create E experts across the
moe model parallel group, where E is the number of experts. Every expert
is a instence of the class, 'expert' in initialization parameters.
:param expert: The class of all experts
:param num_experts: The number of experts
:param expert_args: Args used to initialize experts
:type num_experts: int
"""
def __init__(self, expert, num_experts, **expert_args):
super().__init__()
assert num_experts % moe_env.model_parallel_size == 0, \
"The number of experts should be divied by moe model size"
num_local_experts = num_experts // moe_env.model_parallel_size
with seed(ParallelMode.MOE_MODEL):
self.experts = nn.ModuleList([
expert(**expert_args) for _ in range(num_local_experts)])
self.num_local_experts = num_local_experts
for exp in self.experts:
for param in exp.parameters():
param.__setattr__('moe_param', 1)
def forward(self, inputs):
expert_input = torch.chunk(inputs, self.num_local_experts, dim=0)
expert_output = []
for i in range(self.num_local_experts):
expert_output.append(self.experts[i](expert_input[i]))
output = torch.cat(expert_output, dim=0)
return output
from ._operation import U_CUDA_MODE, AllToAll, MoeDispatch, MoeCombine, moe_cumsum
from .utils import autocast_softmax
class Top1Router(nn.Module):
@ -83,63 +26,79 @@ class Top1Router(nn.Module):
:type noisy_func: Callable, optional
"""
def __init__(self,
capacity_factor: float,
min_capacity: int,
noisy_func=None):
def __init__(self, capacity_factor: float, min_capacity: int = 0, select_policy: str = "first", noisy_func=None):
super().__init__()
self.capacity_factor = capacity_factor
self.min_capacity = min_capacity
self.select_policy = select_policy
self.noisy_func = noisy_func
self.uniform = torch.distributions.uniform.Uniform(
low=torch.tensor(0.0, device=get_current_device()),
high=torch.tensor(1.0, device=get_current_device())).rsample
def get_capacity(self, logits_shape):
capacity = math.ceil(self.capacity_factor *
logits_shape[0] / logits_shape[1])
if capacity < self.min_capacity:
capacity = self.min_capacity
assert select_policy in {"first", "random"}
if select_policy == "random":
self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()),
high=torch.tensor(1.0,
device=get_current_device())).rsample
def get_capacity(
self,
logits_shape,
):
capacity = math.floor(self.capacity_factor * logits_shape[-2] / logits_shape[-1])
capacity += capacity % 2
capacity = max(capacity, self.min_capacity)
assert capacity > 0
return capacity
def forward(self, inputs):
def forward(self, inputs: torch.Tensor, cuda_mode: bool = False):
if self.noisy_func is not None:
inputs_noisy = self.noisy_func(inputs)
else:
inputs_noisy = inputs
logits = F.softmax(inputs, dim=1)
num_experts = logits.shape[1]
logits = autocast_softmax(inputs, dim=-1)
num_experts = logits.size(-1)
capacity = self.get_capacity(logits.shape)
expert_idx = torch.argmax(inputs_noisy, dim=1)
expert_mask = F.one_hot(expert_idx, num_classes=num_experts)
expert_mask_f = expert_mask.float()
top1_idx = torch.argmax(inputs_noisy, dim=-1)
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
exp_counts = torch.sum(expert_mask, dim=0).detach().to('cpu')
if self.training:
me = torch.mean(logits, dim=0)
ce = torch.mean(mask.float(), dim=0)
l_aux = num_experts * torch.sum(me * ce)
moe_env.add_loss(l_aux)
else:
max_num = torch.max(torch.sum(mask, dim=0))
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MOE_MODEL))
capacity = max_num.item()
me = torch.mean(logits, dim=0)
ce = torch.mean(expert_mask_f, dim=0)
l_aux = torch.sum(me * ce) * num_experts
moe_env.add_loss(l_aux)
if not self.training:
ranks = moe_cumsum(mask)
elif self.select_policy == "random":
rand_mask = mask * self.uniform(mask.shape)
_, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0)
mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1)
ranks = moe_cumsum(mask)
elif self.select_policy == "first":
ranks = moe_cumsum(mask)
mask = mask * torch.lt(ranks, capacity)
else:
raise NotImplementedError("Not support such select policy yet.")
rand_mask = expert_mask * self.uniform(logits.shape)
_, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0)
ranks = torch.sum(mask * ranks, dim=-1)
dispatch_mask = \
expert_mask * torch.zeros_like(expert_mask).scatter_(0, dispatch_idx, 1)
locations = torch.cumsum(dispatch_mask, dim=0) - 1
locations = torch.sum(dispatch_mask * locations, dim=1)
locations = F.one_hot(locations, num_classes=capacity)
logits = logits * dispatch_mask
combine_weights = logits.unsqueeze(2) * locations.unsqueeze(1)
sec_mask = combine_weights.bool()
return combine_weights, sec_mask, exp_counts
if cuda_mode:
mask = torch.sum(mask, dim=-1)
mask = torch.stack([mask], dim=0).to(torch.int32)
dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32)
return logits, mask, dest_idx, num_experts * capacity
else:
ranks = F.one_hot(ranks, num_classes=capacity)
weight = mask * logits.type_as(inputs)
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
sec_mask = combine_weights.bool()
return combine_weights, sec_mask
class Top2Router(nn.Module):
@ -159,53 +118,67 @@ class Top2Router(nn.Module):
self.noisy_func = noisy_func
def get_capacity(self, logits_shape):
capacity = math.ceil(2 * self.capacity_factor *
logits_shape[0] / logits_shape[1])
capacity = math.floor(2 * self.capacity_factor * logits_shape[-2] / logits_shape[-1])
capacity += capacity % 2
assert capacity > 0
return capacity
def forward(self, inputs):
def forward(self, inputs: torch.Tensor, cuda_mode: bool = False):
# inputs: [s, h]
if self.noisy_func is not None:
inputs = self.noisy_func(inputs)
logits = F.softmax(inputs, dim=-1)
logits = autocast_softmax(inputs, dim=-1) # logits: [s, e]
num_experts = logits.size(-1)
capacity = self.get_capacity(logits.shape)
_, expert_idx = torch.topk(logits, k=2, dim=-1, largest=True, sorted=True)
top1_idx = expert_idx[:, 0]
top2_idx = expert_idx[:, 1]
top1_idx = torch.argmax(logits, dim=-1)
mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
logits_except1 = logits.masked_fill(mask1.bool(), float("-inf"))
top2_idx = torch.argmax(logits_except1, dim=-1)
mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
mask1 = F.one_hot(top1_idx, num_classes=num_experts)
mask2 = F.one_hot(top2_idx, num_classes=num_experts)
cmask = (mask1 + mask2) # loss: [s, e]
if self.training:
me = torch.mean(logits, dim=0)
ce = torch.mean(cmask.float(), dim=0)
l_aux = num_experts * torch.sum(me * ce) / 2.0
moe_env.add_loss(l_aux)
else:
max_num = torch.max(torch.sum(cmask, dim=0))
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MOE_MODEL))
capacity = max_num.item()
loss_mask = (mask1 + mask2)
exp_counts = torch.sum(loss_mask, dim=0).detach().to('cpu')
me = torch.mean(logits, dim=0)
ce = torch.mean(loss_mask.float(), dim=0)
l_aux = num_experts * torch.sum(me * ce) / 2.0
moe_env.add_loss(l_aux)
rank1 = moe_cumsum(mask1) # rank1: [s, e]
rank2 = moe_cumsum(mask2)
rank2 += torch.sum(mask1, dim=-2, keepdim=True)
locations1 = torch.cumsum(mask1, dim=0) - 1
locations2 = torch.cumsum(mask2, dim=0) - 1
locations2 += torch.sum(mask1, dim=0, keepdim=True)
mask1 *= torch.lt(rank1, capacity)
mask2 *= torch.lt(rank2, capacity)
mask1 *= torch.lt(locations1, capacity)
mask2 *= torch.lt(locations2, capacity)
rank1 = torch.sum(mask1 * rank1, dim=-1)
rank2 = torch.sum(mask2 * rank2, dim=-1)
weight1 = mask1 * logits
weight2 = mask2 * logits
if cuda_mode:
mask1 = torch.sum(mask1, dim=-1)
mask2 = torch.sum(mask2, dim=-1)
locations1 = torch.sum(mask1 * locations1, dim=1)
locations2 = torch.sum(mask2 * locations2, dim=1)
locations1_sc = F.one_hot(locations1, num_classes=capacity)
locations2_sc = F.one_hot(locations2, num_classes=capacity)
mask = torch.stack([mask1, mask2], dim=0).to(torch.int32)
dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32)
combine_weights1 = weight1.unsqueeze(2) * locations1_sc.unsqueeze(1)
combine_weights2 = weight2.unsqueeze(2) * locations2_sc.unsqueeze(1)
combine_weights = combine_weights1 + combine_weights2
sec_mask = combine_weights.bool()
return logits, mask, dest_idx, num_experts * capacity
else:
weight1 = mask1 * logits.type_as(inputs)
weight2 = mask2 * logits.type_as(inputs)
rank1_sc = F.one_hot(rank1, num_classes=capacity)
rank2_sc = F.one_hot(rank2, num_classes=capacity)
return combine_weights, sec_mask, exp_counts
cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1)
cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1)
cb_weight = cb_weight1 + cb_weight2
sec_mask = cb_weight.bool()
return cb_weight, sec_mask
class MoeLayer(nn.Module):
@ -225,52 +198,47 @@ class MoeLayer(nn.Module):
:type experts: nn.Module
"""
def __init__(self,
dim_model: int,
num_experts: int,
router: nn.Module,
experts: nn.Module):
def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: nn.Module):
super().__init__()
self.d_model = dim_model
self.num_experts = num_experts
self.gate = nn.Linear(dim_model, num_experts, device=get_current_device())
self.gate = nn.Linear(dim_model, num_experts, bias=False, device=get_current_device())
self.router = router
self.experts = experts
self.cuda_mode = True if U_CUDA_MODE and moe_env.enable_cuda else False
def _router_part(self, tokens: torch.Tensor):
gate_output = self.gate(tokens)
return self.router(gate_output)
def expert_part(self, expert_input: torch.Tensor):
expert_input = AllToAll.apply(expert_input, ParallelMode.MOE_MODEL)
def router_part(self, tokens: torch.Tensor):
autocast_context = torch.is_autocast_enabled()
if not autocast_context:
return self._router_part(tokens)
else:
with autocast(enabled=False):
if tokens.dtype == torch.float16:
input_tokens = tokens.float()
else:
input_tokens = tokens
return self._router_part(input_tokens)
input_shape = expert_input.shape
expert_input = expert_input.reshape(moe_env.model_parallel_size,
self.num_experts // moe_env.model_parallel_size, -1, self.d_model)
expert_output = self.experts(expert_input)
expert_output = expert_output.reshape(input_shape)
expert_output = AllToAll.apply(expert_output, ParallelMode.MOE_MODEL)
return expert_output
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
tokens = inputs.reshape(-1, self.d_model)
gate_output = self.gate(tokens)
router_res = self.router(gate_output, self.cuda_mode)
combine_weights, sec_mask, exp_counts = self.router_part(tokens)
if self.cuda_mode:
logits, mask, dest_idx, ec = router_res
expert_input = MoeDispatch.apply(tokens, mask, dest_idx, ec)
expert_output = self.expert_part(expert_input)
ret = MoeCombine.apply(expert_output, logits, mask, dest_idx, ec)
else:
combine_weights, sec_mask = router_res
sec_mask_f = sec_mask.type_as(inputs)
expert_input = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
expert_output = self.expert_part(expert_input)
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
expert_output = expert_output.view(-1, expert_output.shape[-1])
ret = torch.matmul(combine_weights, expert_output)
sec_mask_f = sec_mask.type_as(inputs)
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
dispatch_data = AllToAll.apply(dispatch_data, ParallelMode.MOE_MODEL)
expert_output = self.experts(dispatch_data)
expert_output = AllToAll.apply(expert_output, ParallelMode.MOE_MODEL)
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
expert_output = expert_output.view(-1, expert_output.shape[-1])
ret = torch.matmul(combine_weights, expert_output)
ret = ret.reshape(inputs.shape)
return ret

View File

@ -0,0 +1,32 @@
import torch
import torch.nn.functional as F
from colossalai.utils import get_current_device
class NormalNoiseGenerator:
"""Generates a random noisy mask for logtis tensor.
All noise is generated from a normal distribution (0, 1 / E^2), where
E = the number of experts.
:param num_experts: The number of experts
:type num_experts: int
"""
def __init__(self, num_experts: int):
self.normal = torch.distributions.normal.Normal(
loc=torch.tensor(0.0, device=get_current_device()),
scale=torch.tensor(1.0 / num_experts ** 2, device=get_current_device())
).rsample
def __call__(self, inputs: torch.Tensor):
noisy = self.normal(inputs.shape)
return inputs + noisy
def autocast_softmax(inputs: torch.Tensor, dim: int):
assert inputs.dtype in {torch.float16, torch.float32}
fp16_flag = (inputs.dtype == torch.float16)
sm_input = inputs.to(torch.float32) if fp16_flag else inputs
sm_output = F.softmax(sm_input, dim)
return sm_output

View File

@ -4,7 +4,7 @@ import torch.nn as nn
from colossalai.context import ParallelMode
from colossalai.nn.layer import VanillaPatchEmbedding, VanillaClassifier, \
WrappedDropout as Dropout, WrappedDropPath as DropPath
from colossalai.nn.layer.moe import Experts, MoeLayer, Top2Router, NormalNoiseGenerator
from colossalai.nn.layer.moe import FFNExperts, MoeLayer, Top2Router, NormalNoiseGenerator
from .util import moe_sa_args, moe_mlp_args
from ..helper import TransformerLayer
from colossalai.global_variables import moe_env
@ -81,6 +81,7 @@ class VanillaFFN(nn.Module):
class Widenet(nn.Module):
def __init__(self,
num_experts: int,
capacity_factor: float,
@ -98,43 +99,33 @@ class Widenet(nn.Module):
drop_path: float = 0.):
super().__init__()
embedding = VanillaPatchEmbedding(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_size=d_model)
embedding = VanillaPatchEmbedding(img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_size=d_model)
embed_dropout = Dropout(p=drop_rate, mode=ParallelMode.TENSOR)
shared_sa = VanillaSelfAttention(**moe_sa_args(
d_model=d_model, n_heads=num_heads, d_kv=d_kv,
attention_drop=attention_drop, drop_rate=drop_rate))
d_model=d_model, n_heads=num_heads, d_kv=d_kv, attention_drop=attention_drop, drop_rate=drop_rate))
noisy_func = NormalNoiseGenerator(num_experts)
shared_router = Top2Router(capacity_factor, noisy_func=noisy_func)
shared_experts = Experts(expert=VanillaFFN,
num_experts=num_experts,
**moe_mlp_args(
d_model=d_model,
d_ff=d_ff,
drop_rate=drop_rate
))
shared_experts = FFNExperts(num_experts, d_model, d_ff, drop_rate=drop_rate)
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
blocks = [
TransformerLayer(
att=shared_sa,
ffn=MoeLayer(dim_model=d_model, num_experts=num_experts,
router=shared_router, experts=shared_experts),
norm1=nn.LayerNorm(d_model, eps=1e-6),
norm2=nn.LayerNorm(d_model, eps=1e-6),
droppath=DropPath(p=dpr[i], mode=ParallelMode.TENSOR)
)
for i in range(depth)
TransformerLayer(att=shared_sa,
ffn=MoeLayer(dim_model=d_model,
num_experts=num_experts,
router=shared_router,
experts=shared_experts),
norm1=nn.LayerNorm(d_model, eps=1e-6),
norm2=nn.LayerNorm(d_model, eps=1e-6),
droppath=DropPath(p=dpr[i], mode=ParallelMode.TENSOR)) for i in range(depth)
]
norm = nn.LayerNorm(d_model, eps=1e-6)
self.linear = VanillaClassifier(in_features=d_model,
num_classes=num_classes)
self.linear = VanillaClassifier(in_features=d_model, num_classes=num_classes)
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
self.widenet = nn.Sequential(embedding, embed_dropout, *blocks, norm)
@ -145,3 +136,64 @@ class Widenet(nn.Module):
x = torch.mean(x, dim=1)
x = self.linear(x)
return x
class ViTMoE(nn.Module):
def __init__(self,
num_experts: int,
capacity_factor: float,
img_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
num_classes: int = 1000,
depth: int = 12,
d_model: int = 768,
num_heads: int = 12,
d_kv: int = 64,
d_ff: int = 3072,
attention_drop: float = 0.,
drop_rate: float = 0.1,
drop_path: float = 0.):
super().__init__()
embedding = VanillaPatchEmbedding(img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_size=d_model)
embed_dropout = Dropout(p=drop_rate, mode=ParallelMode.TENSOR)
noisy_func = NormalNoiseGenerator(num_experts)
router = Top2Router(capacity_factor, noisy_func=noisy_func)
assert depth % 2 == 0
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
blocks = []
for i in range(depth):
sa = VanillaSelfAttention(**moe_sa_args(
d_model=d_model, n_heads=num_heads, d_kv=d_kv, attention_drop=attention_drop, drop_rate=drop_rate))
ffn = VanillaFFN(**moe_mlp_args(
d_model=d_model, d_ff=d_ff, drop_rate=drop_rate)) if i % 2 == 0 else \
MoeLayer(dim_model=d_model, num_experts=num_experts, router=router,
experts=FFNExperts(num_experts, d_model, d_ff, drop_rate=drop_rate))
layer = TransformerLayer(att=sa,
ffn=ffn,
norm1=nn.LayerNorm(d_model, eps=1e-6),
norm2=nn.LayerNorm(d_model, eps=1e-6),
droppath=DropPath(p=dpr[i], mode=ParallelMode.TENSOR))
blocks.append(layer)
norm = nn.LayerNorm(d_model, eps=1e-6)
self.linear = VanillaClassifier(in_features=d_model, num_classes=num_classes)
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
self.vitmoe = nn.Sequential(embedding, embed_dropout, *blocks, norm)
def forward(self, x):
moe_env.reset_loss()
x = self.vitmoe(x)
x = torch.mean(x, dim=1)
x = self.linear(x)
return x

View File

@ -162,6 +162,10 @@ if build_cuda_ext:
['scaled_masked_softmax.cpp', 'scaled_masked_softmax_cuda.cu'],
extra_cuda_flags + cc_flag))
ext_modules.append(cuda_ext_helper('colossal_moe_cuda',
['moe_cuda.cpp', 'moe_cuda_kernel.cu'],
extra_cuda_flags + cc_flag))
extra_cuda_flags = ['-maxrregcount=50']
ext_modules.append(cuda_ext_helper('colossal_layer_norm_cuda',

View File

@ -0,0 +1,97 @@
import os
from functools import partial
from pathlib import Path
import pytest
import torch
import torch.nn as nn
import torch.multiprocessing as mp
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Top2Router, MoeLayer
from colossalai.global_variables import moe_env
BATCH_SIZE = 32
NUM_EXPERTS = 4
CONFIG = dict(parallel=dict(moe=dict(size=4)))
def check_equal(A, B, atol=1e-06):
assert torch.allclose(A, B, rtol=0, atol=atol) is True
def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# torch.set_printoptions(precision=30)
torch.backends.cuda.matmul.allow_tf32 = False
local_rank = gpc.get_local_rank(ParallelMode.GLOBAL)
torch.manual_seed(rs + local_rank)
moe_env.reset_loss()
tokens = torch.randn(BATCH_SIZE, hidden_size,
dtype=data_type, device=get_current_device(), requires_grad=True)
# print(f"tokens:\n{tokens}")
router = Top2Router(1)
layer = MoeLayer(hidden_size, NUM_EXPERTS, router, nn.Identity())
if data_type == torch.float16:
layer = layer.half()
layer.cuda_mode = False
old_out = layer(tokens)
# print(f"old output:\n{old_out}")
ech = old_out.shape
grad = torch.randn(ech, device=get_current_device())
old_out.backward(grad)
o_tk_grad = tokens.grad.data.clone()
o_gt_grad = layer.gate.weight.grad.data.clone()
tokens.grad.zero_()
layer.gate.weight.grad.zero_()
layer.cuda_mode = True
new_out = layer(tokens)
# print(torch.max(torch.abs(old_out - new_out)))
if data_type == torch.float32:
check_equal(old_out, new_out)
else:
check_equal(old_out, new_out, 1e-2)
# print(f"forward functions passed")
# print(f"new output:\n{new_out}")
new_out.backward(grad)
n_tk_grad = tokens.grad.data.clone()
n_gt_grad = layer.gate.weight.grad.data.clone()
# print(torch.max(torch.abs(o_tk_grad - n_tk_grad)))
if data_type == torch.float32:
check_equal(o_tk_grad, n_tk_grad)
else:
check_equal(o_tk_grad, o_tk_grad, 1e-2)
# print(f"tokens gradient passed")
# print(torch.max(torch.abs(o_gt_grad - n_gt_grad)))
if data_type == torch.float32:
check_equal(o_gt_grad, n_gt_grad, 5e-05)
else:
check_equal(o_gt_grad, n_gt_grad, 2e-01)
# print(f"linear weight gradient passed")
@pytest.mark.dist
@pytest.mark.parametrize("rs", [131])
@pytest.mark.parametrize("hidden_size", [32, 144])
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
def test_moe_top2(rs, hidden_size, data_type):
world_size = 4
run_func = partial(run_routing, world_size=world_size, port=free_port(),
rs=rs, hidden_size=hidden_size, data_type=data_type)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_moe_top2(2, 256, torch.float16)

View File

@ -0,0 +1,97 @@
from functools import partial
import pytest
import torch
import torch.nn as nn
import torch.multiprocessing as mp
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Top1Router, MoeLayer
from colossalai.global_variables import moe_env
BATCH_SIZE = 32
NUM_EXPERTS = 4
CONFIG = dict(parallel=dict(moe=dict(size=4)))
def check_equal(A, B, atol=1e-06):
assert torch.allclose(A, B, rtol=0, atol=atol) is True
def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# torch.set_printoptions(precision=30)
torch.backends.cuda.matmul.allow_tf32 = False
local_rank = gpc.get_local_rank(ParallelMode.GLOBAL)
torch.manual_seed(rs + local_rank)
moe_env.reset_loss()
tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True)
# print(f"tokens:\n{tokens}")
router = Top1Router(1)
layer = MoeLayer(hidden_size, NUM_EXPERTS, router, nn.Identity())
if data_type == torch.float16:
layer = layer.half()
layer.cuda_mode = False
old_out = layer(tokens)
# print(f"old output:\n{old_out}")
ech = old_out.shape
grad = torch.randn(ech, device=get_current_device())
old_out.backward(grad)
o_tk_grad = tokens.grad.data.clone()
o_gt_grad = layer.gate.weight.grad.data.clone()
tokens.grad.zero_()
layer.gate.weight.grad.zero_()
layer.cuda_mode = True
new_out = layer(tokens)
# print(torch.max(torch.abs(old_out - new_out)))
if data_type == torch.float32:
check_equal(old_out, new_out)
else:
check_equal(old_out, new_out, 1e-2)
# print(f"forward functions passed")
# print(f"new output:\n{new_out}")
new_out.backward(grad)
n_tk_grad = tokens.grad.data.clone()
n_gt_grad = layer.gate.weight.grad.data.clone()
# print(torch.max(torch.abs(o_tk_grad - n_tk_grad)))
if data_type == torch.float32:
check_equal(o_tk_grad, n_tk_grad)
else:
check_equal(o_tk_grad, o_tk_grad, 1e-2)
# print(f"tokens gradient passed")
# print(torch.max(torch.abs(o_gt_grad - n_gt_grad)))
if data_type == torch.float32:
check_equal(o_gt_grad, n_gt_grad, 5e-05)
else:
check_equal(o_gt_grad, n_gt_grad, 2e-01)
# print(f"linear weight gradient passed")
@pytest.mark.skip(reason="Should be activated for detailed tests")
@pytest.mark.parametrize("rs", [2, 42, 60])
@pytest.mark.parametrize("hidden_size", [128, 256, 512, 768, 1024, 2048])
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
def test_moe_top2(rs, hidden_size, data_type):
world_size = 4
run_func = partial(run_routing,
world_size=world_size,
port=free_port(),
rs=rs,
hidden_size=hidden_size,
data_type=data_type)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_moe_top2(60, 512, torch.float16)

View File

@ -0,0 +1,97 @@
from functools import partial
import pytest
import torch
import torch.nn as nn
import torch.multiprocessing as mp
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Top2Router, MoeLayer
from colossalai.global_variables import moe_env
BATCH_SIZE = 32
NUM_EXPERTS = 4
CONFIG = dict(parallel=dict(moe=dict(size=4)))
def check_equal(A, B, atol=1e-06):
assert torch.allclose(A, B, rtol=0, atol=atol) is True
def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# torch.set_printoptions(precision=30)
torch.backends.cuda.matmul.allow_tf32 = False
local_rank = gpc.get_local_rank(ParallelMode.GLOBAL)
torch.manual_seed(rs + local_rank)
moe_env.reset_loss()
tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True)
# print(f"tokens:\n{tokens}")
router = Top2Router(1)
layer = MoeLayer(hidden_size, NUM_EXPERTS, router, nn.Identity())
if data_type == torch.float16:
layer = layer.half()
layer.cuda_mode = False
old_out = layer(tokens)
# print(f"old output:\n{old_out}")
ech = old_out.shape
grad = torch.randn(ech, device=get_current_device())
old_out.backward(grad)
o_tk_grad = tokens.grad.data.clone()
o_gt_grad = layer.gate.weight.grad.data.clone()
tokens.grad.zero_()
layer.gate.weight.grad.zero_()
layer.cuda_mode = True
new_out = layer(tokens)
# print(torch.max(torch.abs(old_out - new_out)))
if data_type == torch.float32:
check_equal(old_out, new_out)
else:
check_equal(old_out, new_out, 1e-2)
# print(f"forward functions passed")
# print(f"new output:\n{new_out}")
new_out.backward(grad)
n_tk_grad = tokens.grad.data.clone()
n_gt_grad = layer.gate.weight.grad.data.clone()
# print(torch.max(torch.abs(o_tk_grad - n_tk_grad)))
if data_type == torch.float32:
check_equal(o_tk_grad, n_tk_grad)
else:
check_equal(o_tk_grad, o_tk_grad, 1e-2)
# print(f"tokens gradient passed")
# print(torch.max(torch.abs(o_gt_grad - n_gt_grad)))
if data_type == torch.float32:
check_equal(o_gt_grad, n_gt_grad, 5e-05)
else:
check_equal(o_gt_grad, n_gt_grad, 2e-01)
# print(f"linear weight gradient passed")
@pytest.mark.skip(reason="Should be activated for detailed tests")
@pytest.mark.parametrize("rs", [2, 42, 60])
@pytest.mark.parametrize("hidden_size", [128, 256, 512, 768, 1024, 2048])
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
def test_moe_top2(rs, hidden_size, data_type):
world_size = 4
run_func = partial(run_routing,
world_size=world_size,
port=free_port(),
rs=rs,
hidden_size=hidden_size,
data_type=data_type)
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_moe_top2(2, 256, torch.float16)