mirror of https://github.com/hpcaitech/ColossalAI
Optimized MoE layer and fixed some bugs;
Decreased moe tests; Added FFNExperts and ViTMoE modelpull/394/head
parent
3dba070580
commit
219df6e685
|
@ -9,6 +9,6 @@ repos:
|
|||
hooks:
|
||||
- id: flake8
|
||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
rev: v13.0.0
|
||||
rev: v13.0.1
|
||||
hooks:
|
||||
- id: clang-format
|
||||
|
|
|
@ -56,6 +56,7 @@ class MoeEnv:
|
|||
self.data_parallel_size = None
|
||||
self.model_parallel_size = None
|
||||
self.aux_loss = None
|
||||
self.enable_cuda = True
|
||||
|
||||
def setup(self, moe_model_size):
|
||||
from .core import global_context as gpc
|
||||
|
@ -71,6 +72,9 @@ class MoeEnv:
|
|||
def is_initialized(self):
|
||||
return self.model_parallel_size is not None
|
||||
|
||||
def set_cuda_false(self):
|
||||
self.enable_cuda = False
|
||||
|
||||
def reset_loss(self):
|
||||
self.aux_loss = 0
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
#include "ATen/ATen.h"
|
||||
#include "ATen/AccumulateType.h"
|
||||
#include "ATen/cuda/CUDAContext.h"
|
||||
#include <THC/THCDeviceUtils.cuh>
|
||||
#include "ATen/cuda/DeviceUtils.cuh"
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
|
|
@ -0,0 +1,118 @@
|
|||
#include <torch/extension.h>
|
||||
|
||||
|
||||
torch::Tensor moe_dispatch_cuda_forward(
|
||||
int s, int ec, int h,
|
||||
torch::Tensor batch_tokens,
|
||||
torch::Tensor mask,
|
||||
torch::Tensor dest_idx);
|
||||
|
||||
torch::Tensor moe_dispatch_cuda_backward(
|
||||
int s, int ec, int h,
|
||||
torch::Tensor expert_grad,
|
||||
torch::Tensor mask,
|
||||
torch::Tensor dest_idx);
|
||||
|
||||
torch::Tensor moe_combine_cuda_forward(
|
||||
int s, int e, int c, int h,
|
||||
torch::Tensor expert_tokens,
|
||||
torch::Tensor logits,
|
||||
torch::Tensor mask,
|
||||
torch::Tensor dest_idx);
|
||||
|
||||
std::vector<torch::Tensor> moe_combine_cuda_backward(
|
||||
int s, int e, int c, int h,
|
||||
torch::Tensor tokens_grad,
|
||||
torch::Tensor expert_tokens,
|
||||
torch::Tensor logits,
|
||||
torch::Tensor mask,
|
||||
torch::Tensor dest_idx);
|
||||
|
||||
torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask);
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
torch::Tensor moe_dispatch_forward(
|
||||
int s, int ec, int h,
|
||||
torch::Tensor batch_tokens,
|
||||
torch::Tensor mask,
|
||||
torch::Tensor dest_idx) {
|
||||
|
||||
CHECK_INPUT(batch_tokens);
|
||||
CHECK_CUDA(mask);
|
||||
CHECK_CUDA(dest_idx);
|
||||
|
||||
return moe_dispatch_cuda_forward(
|
||||
s, ec, h,
|
||||
batch_tokens, mask, dest_idx);
|
||||
}
|
||||
|
||||
torch::Tensor moe_dispatch_backward(
|
||||
int s, int ec, int h,
|
||||
torch::Tensor expert_grad,
|
||||
torch::Tensor mask,
|
||||
torch::Tensor dest_idx) {
|
||||
|
||||
CHECK_INPUT(expert_grad);
|
||||
CHECK_CUDA(mask);
|
||||
CHECK_CUDA(dest_idx);
|
||||
|
||||
return moe_dispatch_cuda_backward(
|
||||
s, ec, h,
|
||||
expert_grad, mask, dest_idx);
|
||||
}
|
||||
|
||||
torch::Tensor moe_combine_forward(
|
||||
int s, int e, int c, int h,
|
||||
torch::Tensor expert_tokens,
|
||||
torch::Tensor logits,
|
||||
torch::Tensor mask,
|
||||
torch::Tensor dest_idx) {
|
||||
|
||||
CHECK_INPUT(expert_tokens);
|
||||
CHECK_INPUT(logits);
|
||||
CHECK_CUDA(mask);
|
||||
CHECK_CUDA(dest_idx);
|
||||
|
||||
return moe_combine_cuda_forward(
|
||||
s, e, c, h,
|
||||
expert_tokens, logits, mask, dest_idx);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> moe_combine_backward(
|
||||
int s, int e, int c, int h,
|
||||
torch::Tensor tokens_grad,
|
||||
torch::Tensor expert_tokens,
|
||||
torch::Tensor logits,
|
||||
torch::Tensor mask,
|
||||
torch::Tensor dest_idx) {
|
||||
|
||||
CHECK_INPUT(tokens_grad);
|
||||
CHECK_INPUT(logits);
|
||||
CHECK_CUDA(mask);
|
||||
CHECK_CUDA(dest_idx);
|
||||
|
||||
return moe_combine_cuda_backward(
|
||||
s, e, c, h,
|
||||
tokens_grad, expert_tokens, logits, mask, dest_idx);
|
||||
}
|
||||
|
||||
torch::Tensor moe_cumsum(torch::Tensor mask) {
|
||||
CHECK_INPUT(mask);
|
||||
return cumsum_sub_one_in_dim0(mask);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("cumsum_sub_one", &moe_cumsum,
|
||||
"Fast cumsum operation in dim0");
|
||||
m.def("dispatch_forward", &moe_dispatch_forward,
|
||||
"Forward operation in MoE dispatch function");
|
||||
m.def("dispatch_backward", &moe_dispatch_backward,
|
||||
"Backward operation in MoE dispatch function");
|
||||
m.def("combine_forward", &moe_combine_forward,
|
||||
"Combine operation in MoE combine function");
|
||||
m.def("combine_backward", &moe_combine_backward,
|
||||
"Combine operation in MoE combine function");
|
||||
}
|
|
@ -0,0 +1,702 @@
|
|||
#include <torch/extension.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cub/cub.cuh>
|
||||
#include "block_reduce.h"
|
||||
|
||||
template<typename T, int block_size, int pack_size>
|
||||
__device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) {
|
||||
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
typedef cub::BlockLoad<T, block_size, pack_size,
|
||||
cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
|
||||
typedef cub::BlockStore<T, block_size, pack_size,
|
||||
cub::BLOCK_STORE_VECTORIZE> BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
int tps = threadIdx.x * pack_size; T pack[pack_size];
|
||||
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
|
||||
BlockLoad(ts_load).Load(src_row + idx, pack);
|
||||
BlockStore(ts_store).Store(dst_row + idx, pack);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, int block_size, int pack_size>
|
||||
__device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) {
|
||||
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
typedef cub::BlockLoad<T, block_size, pack_size,
|
||||
cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
|
||||
typedef cub::BlockStore<T, block_size, pack_size,
|
||||
cub::BLOCK_STORE_VECTORIZE> BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
int tps = threadIdx.x * pack_size; T pack[pack_size];
|
||||
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
|
||||
BlockLoad(ts_load).Load(dst_row + idx, pack);
|
||||
BlockStore(ts_store).Store(src_row + idx, pack);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, int block_size, int pack_size>
|
||||
__device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2, const int cols) {
|
||||
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
typedef cub::BlockLoad<T, block_size, pack_size,
|
||||
cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
|
||||
typedef cub::BlockStore<T, block_size, pack_size,
|
||||
cub::BLOCK_STORE_VECTORIZE> BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
int tps = threadIdx.x * pack_size; T pack[pack_size];
|
||||
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
|
||||
BlockLoad(ts_load).Load(src_row + idx, pack);
|
||||
BlockStore(ts_store).Store(dst_row1 + idx, pack);
|
||||
BlockStore(ts_store).Store(dst_row2 + idx, pack);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, int block_size, int pack_size>
|
||||
__device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2, const int cols) {
|
||||
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
typedef cub::BlockLoad<T, block_size, pack_size,
|
||||
cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
|
||||
typedef cub::BlockStore<T, block_size, pack_size,
|
||||
cub::BLOCK_STORE_VECTORIZE> BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
int tps = threadIdx.x * pack_size;
|
||||
T pack1[pack_size], pack2[pack_size];
|
||||
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
|
||||
BlockLoad(ts_load).Load(dst_row1 + idx, pack1);
|
||||
BlockLoad(ts_load).Load(dst_row2 + idx, pack2);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < pack_size; ++i) {
|
||||
pack1[i] += pack2[i];
|
||||
}
|
||||
|
||||
BlockStore(ts_store).Store(src_row + idx, pack1);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, int block_size, int pack_size>
|
||||
__device__ void moe_cb_one_fwd(
|
||||
T *src_row, T *dst_row,
|
||||
const T weight, const int cols) {
|
||||
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
typedef cub::BlockLoad<T, block_size, pack_size,
|
||||
cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
|
||||
typedef cub::BlockStore<T, block_size, pack_size,
|
||||
cub::BLOCK_STORE_VECTORIZE> BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
int tps = threadIdx.x * pack_size; T pack[pack_size];
|
||||
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
|
||||
BlockLoad(ts_load).Load(src_row + idx, pack);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < pack_size; ++i) {
|
||||
pack[i] *= weight;
|
||||
}
|
||||
|
||||
BlockStore(ts_store).Store(dst_row + idx, pack);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, int block_size, int pack_size>
|
||||
__device__ void moe_cb_one_bwd(
|
||||
T *src_row, T *dst_row, T *tks_row, T *weight_grad,
|
||||
const T weight, const int cols) {
|
||||
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
typedef cub::BlockLoad<T, block_size, pack_size,
|
||||
cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
|
||||
typedef cub::BlockStore<T, block_size, pack_size,
|
||||
cub::BLOCK_STORE_VECTORIZE> BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
int tps = threadIdx.x * pack_size;
|
||||
T grad[pack_size], tokens[pack_size];
|
||||
float thread_sum = 0;
|
||||
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
|
||||
BlockLoad(ts_load).Load(dst_row + idx, grad);
|
||||
BlockLoad(ts_load).Load(tks_row + idx, tokens);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < pack_size; ++i) {
|
||||
thread_sum += grad[i] * tokens[i];
|
||||
grad[i] *= weight;
|
||||
}
|
||||
|
||||
BlockStore(ts_store).Store(src_row + idx, grad);
|
||||
}
|
||||
|
||||
blockReduce<ReduceType::kSum, 1>(&thread_sum);
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
*weight_grad = static_cast<T>(thread_sum);
|
||||
}
|
||||
|
||||
template<typename T, int block_size, int pack_size>
|
||||
__device__ void moe_cb_two_fwd(
|
||||
T *src_row1, T *src_row2, T *dst_row,
|
||||
const T weight1, const T weight2, const int cols) {
|
||||
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
typedef cub::BlockLoad<T, block_size, pack_size,
|
||||
cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
|
||||
typedef cub::BlockStore<T, block_size, pack_size,
|
||||
cub::BLOCK_STORE_VECTORIZE> BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
int tps = threadIdx.x * pack_size;
|
||||
T pack1[pack_size], pack2[pack_size];
|
||||
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
|
||||
BlockLoad(ts_load).Load(src_row1 + idx, pack1);
|
||||
BlockLoad(ts_load).Load(src_row2 + idx, pack2);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < pack_size; ++i) {
|
||||
pack1[i] = pack1[i] * weight1 + pack2[i] * weight2;
|
||||
}
|
||||
|
||||
BlockStore(ts_store).Store(dst_row + idx, pack1);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, int block_size, int pack_size>
|
||||
__device__ void moe_cb_two_bwd(
|
||||
T *src_row1, T *src_row2, T *dst_row,
|
||||
T *tks_row1, T *tks_row2, T *weight_grad1, T *weight_grad2,
|
||||
const T weight1, const T weight2, const int cols) {
|
||||
|
||||
assert(cols % pack_size == 0);
|
||||
const int bpack_size = block_size * pack_size;
|
||||
|
||||
typedef cub::BlockLoad<T, block_size, pack_size,
|
||||
cub::BLOCK_LOAD_VECTORIZE> BlockLoad;
|
||||
__shared__ typename BlockLoad::TempStorage ts_load;
|
||||
|
||||
typedef cub::BlockStore<T, block_size, pack_size,
|
||||
cub::BLOCK_STORE_VECTORIZE> BlockStore;
|
||||
__shared__ typename BlockStore::TempStorage ts_store;
|
||||
|
||||
int tps = threadIdx.x * pack_size;
|
||||
T grad[pack_size], tokens1[pack_size], tokens2[pack_size],
|
||||
sgrad1[pack_size], sgrad2[pack_size];
|
||||
float thread_sum[2] = {0, 0};
|
||||
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
|
||||
BlockLoad(ts_load).Load(dst_row + idx, grad);
|
||||
BlockLoad(ts_load).Load(tks_row1 + idx, tokens1);
|
||||
BlockLoad(ts_load).Load(tks_row2 + idx, tokens2);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < pack_size; ++i) {
|
||||
thread_sum[0] += grad[i] * tokens1[i];
|
||||
thread_sum[1] += grad[i] * tokens2[i];
|
||||
sgrad1[i] = weight1 * grad[i];
|
||||
sgrad2[i] = weight2 * grad[i];
|
||||
}
|
||||
|
||||
BlockStore(ts_store).Store(src_row1 + idx, sgrad1);
|
||||
BlockStore(ts_store).Store(src_row2 + idx, sgrad2);
|
||||
}
|
||||
|
||||
blockReduce<ReduceType::kSum, 2>(thread_sum);
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
*weight_grad1 = static_cast<T>(thread_sum[0]);
|
||||
else if (threadIdx.x == 1)
|
||||
*weight_grad2 = static_cast<T>(thread_sum[1]);
|
||||
|
||||
}
|
||||
|
||||
// DISPATCH KERNELS --------------------------------
|
||||
|
||||
template<typename T, int block_size, int pack_size>
|
||||
__device__ void moe_dpch_fwd_selector(
|
||||
T *src_row, T *dst_row1, T *dst_row2, const int cols,
|
||||
const int indicator1, const int indicator2) {
|
||||
|
||||
if (indicator1 != 0 && indicator2 != 0)
|
||||
moe_dpch_two_fwd<T, block_size, pack_size>(
|
||||
src_row, dst_row1, dst_row2, cols);
|
||||
else if (indicator1 != 0)
|
||||
moe_dpch_one_fwd<T, block_size, pack_size>(
|
||||
src_row, dst_row1, cols);
|
||||
else if (indicator2 != 0)
|
||||
moe_dpch_one_fwd<T, block_size, pack_size>(
|
||||
src_row, dst_row2, cols);
|
||||
else
|
||||
return;
|
||||
}
|
||||
|
||||
template<typename T, int block_size, int pack_size>
|
||||
__device__ void moe_dpch_bwd_selector(
|
||||
T *src_row, T *dst_row1, T *dst_row2, const int cols,
|
||||
const int indicator1, const int indicator2) {
|
||||
|
||||
if (indicator1 != 0 && indicator2 != 0)
|
||||
moe_dpch_two_bwd<T, block_size, pack_size>(
|
||||
src_row, dst_row1, dst_row2, cols);
|
||||
else if (indicator1 != 0)
|
||||
moe_dpch_one_bwd<T, block_size, pack_size>(
|
||||
src_row, dst_row1, cols);
|
||||
else if (indicator2 != 0)
|
||||
moe_dpch_one_bwd<T, block_size, pack_size>(
|
||||
src_row, dst_row2, cols);
|
||||
else
|
||||
return;
|
||||
}
|
||||
|
||||
template<typename T, int block_size, int pack_size>
|
||||
__global__ void moe_dpch_fwd_kernel(
|
||||
T *batch_tokens, T *expert_input,
|
||||
int *mask1, int *mask2,
|
||||
int *dest1, int *dest2, const int h) {
|
||||
|
||||
int row = blockIdx.x;
|
||||
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
|
||||
moe_dpch_fwd_selector<T, block_size, pack_size>(
|
||||
batch_tokens + (row * h),
|
||||
expert_input + (dest1[row] * h), expert_input + (dest2[row] * h),
|
||||
h, mask1[row], indicator2);
|
||||
}
|
||||
|
||||
template<typename T, int block_size, int pack_size>
|
||||
__global__ void moe_dpch_bwd_kernel(
|
||||
T *tokens_grad, T *expert_grad,
|
||||
int *mask1, int *mask2,
|
||||
int *dest1, int *dest2, const int h) {
|
||||
|
||||
int row = blockIdx.x;
|
||||
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
|
||||
moe_dpch_bwd_selector<T, block_size, pack_size>(
|
||||
tokens_grad + (row * h),
|
||||
expert_grad + (dest1[row] * h), expert_grad + (dest2[row] * h),
|
||||
h, mask1[row], indicator2);
|
||||
}
|
||||
|
||||
// COMBINE KERNELS --------------------------------
|
||||
|
||||
template<typename T, int block_size, int pack_size>
|
||||
__device__ void moe_cb_fwd_selector(
|
||||
T *src_row1, T *src_row2, T *dst_row, const int cols,
|
||||
const T weight1, const T weight2,
|
||||
const int indicator1, const int indicator2) {
|
||||
|
||||
if (indicator1 != 0 && indicator2 != 0)
|
||||
moe_cb_two_fwd<T, block_size, pack_size>(
|
||||
src_row1, src_row2, dst_row, weight1, weight2, cols);
|
||||
else if (indicator1 != 0)
|
||||
moe_cb_one_fwd<T, block_size, pack_size>(
|
||||
src_row1, dst_row, weight1, cols);
|
||||
else if (indicator2 != 0)
|
||||
moe_cb_one_fwd<T, block_size, pack_size>(
|
||||
src_row2, dst_row, weight2, cols);
|
||||
else
|
||||
return;
|
||||
}
|
||||
|
||||
template<typename T, int block_size, int pack_size>
|
||||
__device__ void moe_cb_bwd_selector(
|
||||
T *src_row1, T *src_row2, T *dst_row, const int cols,
|
||||
T *tks_row1, T *tks_row2, T *wt_grad1, T *wt_grad2,
|
||||
const T weight1, const T weight2,
|
||||
const int indicator1, const int indicator2) {
|
||||
|
||||
if (indicator1 != 0 && indicator2 != 0)
|
||||
moe_cb_two_bwd<T, block_size, pack_size>(
|
||||
src_row1, src_row2, dst_row,
|
||||
tks_row1, tks_row2, wt_grad1, wt_grad2,
|
||||
weight1, weight2, cols);
|
||||
else if (indicator1 != 0)
|
||||
moe_cb_one_bwd<T, block_size, pack_size>(
|
||||
src_row1, dst_row, tks_row1, wt_grad1, weight1, cols);
|
||||
else if (indicator2 != 0)
|
||||
moe_cb_one_bwd<T, block_size, pack_size>(
|
||||
src_row2, dst_row, tks_row2, wt_grad2, weight2, cols);
|
||||
else
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
template<typename T, int block_size, int pack_size>
|
||||
__global__ void moe_cb_fwd_kernel(
|
||||
T *expert_tokens, T *combine_tokens, T *logits,
|
||||
int *mask1, int *mask2,
|
||||
int *dest1, int *dest2,
|
||||
const int e, const int c, const int h) {
|
||||
|
||||
int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c;
|
||||
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
|
||||
T *row_log = logits + (row * e);
|
||||
moe_cb_fwd_selector<T, block_size, pack_size>(
|
||||
expert_tokens + (dest1[row] * h), expert_tokens + (dest2[row] * h),
|
||||
combine_tokens + (row * h), h,
|
||||
row_log[eid1], row_log[eid2],
|
||||
mask1[row], indicator2);
|
||||
}
|
||||
|
||||
template<typename T, int block_size, int pack_size>
|
||||
__global__ void moe_cb_bwd_kernel(
|
||||
T *tokens_grad, T *expert_grad, T *tks,
|
||||
T *logits, T *logits_grad,
|
||||
int *mask1, int *mask2,
|
||||
int *dest1, int *dest2,
|
||||
const int e, const int c, const int h) {
|
||||
|
||||
int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c;
|
||||
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
|
||||
T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e);
|
||||
moe_cb_bwd_selector<T, block_size, pack_size>(
|
||||
expert_grad + (dest1[row] * h), expert_grad + (dest2[row] * h),
|
||||
tokens_grad + (row * h), h,
|
||||
tks + (dest1[row] * h), tks + (dest2[row] * h),
|
||||
row_grad + eid1, row_grad + eid2,
|
||||
row_log[eid1], row_log[eid2],
|
||||
mask1[row], indicator2);
|
||||
}
|
||||
|
||||
//CUMSUM KERNEL --------------------------------
|
||||
|
||||
template<int block_size, int pack_size>
|
||||
__global__ void cumsum_kernel(
|
||||
int *inputs, int *outputs,
|
||||
const int s, const int e) {
|
||||
|
||||
assert(s % pack_size == 0);
|
||||
constexpr int bpack_size = block_size * pack_size;
|
||||
int tid = threadIdx.x, bid = blockIdx.x,
|
||||
tps = tid * pack_size, last_sum = -1;
|
||||
__shared__ int temp[block_size + 1]; int pack[pack_size];
|
||||
|
||||
for (int idx = 0; idx < s; idx += bpack_size) {
|
||||
int offset = 1;
|
||||
|
||||
if (idx + tps < s) {
|
||||
temp[tid] = inputs[tps * e + bid];
|
||||
#pragma unroll
|
||||
for (int i = 1; i < pack_size; ++i) {
|
||||
pack[i] = inputs[(tps + i) * e + bid];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 1; i < pack_size; ++i) {
|
||||
temp[tid] += pack[i];
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = block_size >> 1; i > 0; i >>= 1) {
|
||||
__syncthreads();
|
||||
if (tid < i) {
|
||||
int j = offset * (2 * tid + 1) - 1;
|
||||
temp[j + offset] += temp[j];
|
||||
}
|
||||
offset <<= 1;
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
temp[block_size] = temp[block_size - 1];
|
||||
temp[block_size - 1] = 0;
|
||||
}
|
||||
|
||||
for (int i = 1; i < block_size; i <<= 1) {
|
||||
offset >>= 1;
|
||||
__syncthreads();
|
||||
if (tid < i) {
|
||||
int j = offset * (2 * tid + 1) - 1,
|
||||
k = j + offset, ts = temp[j];
|
||||
temp[j] = temp[k];
|
||||
temp[k] += ts;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (tid == 0)
|
||||
temp[0] = temp[block_size];
|
||||
__syncthreads();
|
||||
|
||||
if (idx + tps < s) {
|
||||
temp[tid + 1] += last_sum;
|
||||
#pragma unroll
|
||||
for (int i = pack_size - 1; i > 0; --i) {
|
||||
outputs[(tps + i) * e + bid] = temp[tid + 1];
|
||||
temp[tid + 1] -= pack[i];
|
||||
}
|
||||
outputs[tps * e + bid] = temp[tid + 1];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
last_sum += temp[0];
|
||||
inputs += bpack_size * e;
|
||||
outputs += bpack_size * e;
|
||||
}
|
||||
}
|
||||
|
||||
//LAUNCH FUNCTIONS --------------------------------
|
||||
|
||||
template<typename T>
|
||||
void moe_dpch_fwd_launch(
|
||||
T *batch_tokens, T *expert_input,
|
||||
int *mask1, int *mask2,
|
||||
int *dest1, int *dest2,
|
||||
const int s, const int h) {
|
||||
|
||||
if (h < 256)
|
||||
moe_dpch_fwd_kernel<T, 32, 4><<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
|
||||
else if (h < 512)
|
||||
moe_dpch_fwd_kernel<T, 32, 8><<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
|
||||
else if (h < 1024)
|
||||
moe_dpch_fwd_kernel<T, 32, 16><<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
|
||||
else if (h < 2048)
|
||||
moe_dpch_fwd_kernel<T, 64, 16><<<s, 64>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
|
||||
else
|
||||
moe_dpch_fwd_kernel<T, 128, 16><<<s, 128>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void moe_dpch_bwd_launch(
|
||||
T *tokens_grad, T *expert_grad,
|
||||
int *mask1, int *mask2,
|
||||
int *dest1, int *dest2,
|
||||
const int s, const int h) {
|
||||
|
||||
if (h < 256)
|
||||
moe_dpch_bwd_kernel<T, 32, 4><<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
|
||||
else if (h < 512)
|
||||
moe_dpch_bwd_kernel<T, 32, 8><<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
|
||||
else if (h < 1024)
|
||||
moe_dpch_bwd_kernel<T, 32, 16><<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
|
||||
else if (h < 2048)
|
||||
moe_dpch_bwd_kernel<T, 64, 16><<<s, 64>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
|
||||
else
|
||||
moe_dpch_bwd_kernel<T, 128, 16><<<s, 128>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void moe_cb_fwd_launch(
|
||||
T *expert_tokens, T *combine_tokens, T *logits,
|
||||
int *mask1, int *mask2,
|
||||
int *dest1, int *dest2,
|
||||
const int s, const int e, const int c, const int h) {
|
||||
|
||||
if (h < 256)
|
||||
moe_cb_fwd_kernel<T, 32, 4><<<s, 32>>>
|
||||
(expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, e, c, h);
|
||||
else if (h < 512)
|
||||
moe_cb_fwd_kernel<T, 32, 8><<<s, 32>>>
|
||||
(expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, e, c, h);
|
||||
else if (h < 1024)
|
||||
moe_cb_fwd_kernel<T, 32, 16><<<s, 32>>>
|
||||
(expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, e, c, h);
|
||||
else if (h < 2048)
|
||||
moe_cb_fwd_kernel<T, 64, 16><<<s, 64>>>
|
||||
(expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, e, c, h);
|
||||
else
|
||||
moe_cb_fwd_kernel<T, 128, 16><<<s, 128>>>
|
||||
(expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, e, c, h);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void moe_cb_bwd_launch(
|
||||
T *tokens_grad, T *expert_grad, T *tks,
|
||||
T *logits, T *logits_grad,
|
||||
int *mask1, int *mask2,
|
||||
int *dest1, int *dest2,
|
||||
const int s, const int e, const int c, const int h) {
|
||||
|
||||
if (h < 256)
|
||||
moe_cb_bwd_kernel<T, 32, 4><<<s, 32>>>
|
||||
(tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h);
|
||||
else // if (h < 512)
|
||||
moe_cb_bwd_kernel<T, 64, 4><<<s, 64>>>
|
||||
(tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h);
|
||||
// else if (h < 1024)
|
||||
// moe_cb_bwd_kernel<T, 128, 4><<<s, 128>>>
|
||||
// (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h);
|
||||
// else
|
||||
// moe_cb_bwd_kernel<T, 256, 4><<<s, 256>>>
|
||||
// (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h);
|
||||
}
|
||||
|
||||
void cumsum_launch(
|
||||
int *inputs, int *outputs,
|
||||
const int s, const int e) {
|
||||
|
||||
if (s <= 256)
|
||||
cumsum_kernel<256, 1><<<e, 256>>>(inputs, outputs, s, e);
|
||||
else if (s <= 512)
|
||||
cumsum_kernel<512, 1><<<e, 512>>>(inputs, outputs, s, e);
|
||||
else if (s <= 1024)
|
||||
cumsum_kernel<1024, 1><<<e, 1024>>>(inputs, outputs, s, e);
|
||||
else if (s <= 2048)
|
||||
cumsum_kernel<1024, 2><<<e, 1024>>>(inputs, outputs, s, e);
|
||||
else
|
||||
cumsum_kernel<1024, 4><<<e, 1024>>>(inputs, outputs, s, e);
|
||||
}
|
||||
|
||||
// API FUNCTIONS --------------------------------
|
||||
|
||||
#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \
|
||||
switch (TYPE) \
|
||||
{ \
|
||||
case at::ScalarType::Float: \
|
||||
{ \
|
||||
using scalar_t = float; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::Half: \
|
||||
{ \
|
||||
using scalar_t = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented yet for specific data type.");\
|
||||
}
|
||||
|
||||
torch::Tensor moe_dispatch_cuda_forward(
|
||||
int s, int ec, int h,
|
||||
torch::Tensor batch_tokens,
|
||||
torch::Tensor mask,
|
||||
torch::Tensor dest_idx) {
|
||||
|
||||
assert(h % 16 == 0);
|
||||
auto res = torch::zeros({ec, h},
|
||||
torch::dtype(batch_tokens.dtype()).device(batch_tokens.device()));
|
||||
auto k = mask.size(0);
|
||||
|
||||
DISPATCH_FLOAT_AND_HALF(
|
||||
batch_tokens.scalar_type(), "moe dispatch forward",
|
||||
moe_dpch_fwd_launch<scalar_t>(
|
||||
batch_tokens.data<scalar_t>(), res.data<scalar_t>(),
|
||||
mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(),
|
||||
dest_idx[0].data<int>(), k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(),
|
||||
s, h)
|
||||
);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
torch::Tensor moe_dispatch_cuda_backward(
|
||||
int s, int ec, int h,
|
||||
torch::Tensor expert_grad,
|
||||
torch::Tensor mask,
|
||||
torch::Tensor dest_idx) {
|
||||
|
||||
assert(h % 16 == 0);
|
||||
auto res = torch::zeros({s, h},
|
||||
torch::dtype(expert_grad.dtype()).device(expert_grad.device()));
|
||||
auto k = mask.size(0);
|
||||
|
||||
DISPATCH_FLOAT_AND_HALF(
|
||||
expert_grad.scalar_type(), "moe dispatch backward",
|
||||
moe_dpch_bwd_launch<scalar_t>(
|
||||
res.data<scalar_t>(), expert_grad.data<scalar_t>(),
|
||||
mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(),
|
||||
dest_idx[0].data<int>(), k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(),
|
||||
s, h)
|
||||
);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
torch::Tensor moe_combine_cuda_forward(
|
||||
int s, int e, int c, int h,
|
||||
torch::Tensor expert_tokens,
|
||||
torch::Tensor logits,
|
||||
torch::Tensor mask,
|
||||
torch::Tensor dest_idx) {
|
||||
|
||||
assert(h % 16 == 0);
|
||||
assert(expert_tokens.dtype() == logits.dtype());
|
||||
|
||||
auto res = torch::zeros({s, h},
|
||||
torch::dtype(expert_tokens.dtype()).device(expert_tokens.device()));
|
||||
auto k = mask.size(0);
|
||||
|
||||
DISPATCH_FLOAT_AND_HALF(
|
||||
expert_tokens.scalar_type(), "moe combine forward",
|
||||
moe_cb_fwd_launch<scalar_t>(
|
||||
expert_tokens.data<scalar_t>(), res.data<scalar_t>(), logits.data<scalar_t>(),
|
||||
mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(),
|
||||
dest_idx[0].data<int>(), k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(),
|
||||
s, e, c, h)
|
||||
);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> moe_combine_cuda_backward(
|
||||
int s, int e, int c, int h,
|
||||
torch::Tensor tokens_grad,
|
||||
torch::Tensor expert_tokens,
|
||||
torch::Tensor logits,
|
||||
torch::Tensor mask,
|
||||
torch::Tensor dest_idx) {
|
||||
|
||||
assert(h % 16 == 0);
|
||||
assert(tokens_grad.dtype() == expert_tokens.dtype());
|
||||
assert(expert_tokens.dtype() == logits.dtype());
|
||||
|
||||
auto egrad = torch::zeros({e * c, h},
|
||||
torch::dtype(tokens_grad.dtype()).device(tokens_grad.device())),
|
||||
wgrad = torch::zeros({s, e}, torch::dtype(logits.dtype()).device(logits.device()));
|
||||
auto k = mask.size(0);
|
||||
|
||||
DISPATCH_FLOAT_AND_HALF(
|
||||
tokens_grad.scalar_type(), "moe combine backward",
|
||||
moe_cb_bwd_launch<scalar_t>(
|
||||
tokens_grad.data<scalar_t>(), egrad.data<scalar_t>(), expert_tokens.data<scalar_t>(),
|
||||
logits.data<scalar_t>(), wgrad.data<scalar_t>(),
|
||||
mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(),
|
||||
dest_idx[0].data<int>(), k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(),
|
||||
s, e, c, h)
|
||||
);
|
||||
|
||||
return {egrad, wgrad};
|
||||
}
|
||||
|
||||
torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) {
|
||||
|
||||
assert(mask.dim() == 2);
|
||||
assert(mask.dtype() == torch::kInt32);
|
||||
|
||||
const int s = mask.size(0), e = mask.size(1);
|
||||
auto res = torch::empty({s, e}, torch::dtype(torch::kInt32).device(mask.device()));
|
||||
cumsum_launch(mask.data<int>(), res.data<int>(), s, e);
|
||||
|
||||
return res;
|
||||
}
|
|
@ -1,8 +1,5 @@
|
|||
from ._operation import AllToAll
|
||||
from .layers import Experts, MoeLayer, \
|
||||
NormalNoiseGenerator, Top1Router, Top2Router
|
||||
from .experts import Experts, FFNExperts
|
||||
from .layers import MoeLayer, Top1Router, Top2Router
|
||||
from .utils import NormalNoiseGenerator
|
||||
|
||||
__all__ = [
|
||||
'AllToAll', 'Experts', 'Top1Router', 'Top2Router',
|
||||
'MoeLayer', 'NormalNoiseGenerator'
|
||||
]
|
||||
__all__ = ['Experts', 'FFNExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator']
|
||||
|
|
|
@ -6,16 +6,26 @@ from colossalai.context import ParallelMode
|
|||
from colossalai.core import global_context as gpc
|
||||
from typing import Any, Tuple
|
||||
|
||||
U_CUDA_MODE = False
|
||||
try:
|
||||
import colossal_moe_cuda
|
||||
|
||||
U_CUDA_MODE = True
|
||||
except ImportError:
|
||||
print("If you want to activate cuda mode for MoE, please install with cuda_ext!")
|
||||
|
||||
|
||||
class AllToAll(torch.autograd.Function):
|
||||
"""Dispatches input tensor [e, c, h] to all experts by all_to_all_single
|
||||
operation in torch.distributed.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any,
|
||||
inputs: Tensor,
|
||||
parallel_mode: ParallelMode) -> Tensor:
|
||||
ctx.parallel_mode = parallel_mode
|
||||
if ctx is not None:
|
||||
ctx.parallel_mode = parallel_mode
|
||||
if not inputs.is_contiguous():
|
||||
inputs = inputs.contiguous()
|
||||
|
||||
|
@ -26,4 +36,79 @@ class AllToAll(torch.autograd.Function):
|
|||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
||||
return AllToAll.apply(*grad_outputs, ctx.parallel_mode), None
|
||||
return AllToAll.forward(None, *grad_outputs, ctx.parallel_mode), None
|
||||
|
||||
|
||||
class MoeDispatch(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, tokens, mask, dest_idx, ec):
|
||||
s = tokens.size(0)
|
||||
h = tokens.size(1)
|
||||
|
||||
expert_input = colossal_moe_cuda.dispatch_forward(s, ec, h, tokens, mask, dest_idx)
|
||||
|
||||
ctx.save_for_backward(mask, dest_idx)
|
||||
ctx.s = s
|
||||
ctx.h = h
|
||||
ctx.ec = ec
|
||||
|
||||
return expert_input
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, output_grad):
|
||||
mask, dest_idx = ctx.saved_tensors
|
||||
d_tokens = colossal_moe_cuda.dispatch_backward(
|
||||
ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
|
||||
return d_tokens, None, None, None
|
||||
|
||||
|
||||
class MoeCombine(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, expert_tokens, logits, mask, dest_idx, ec):
|
||||
assert logits.dtype == torch.float32
|
||||
|
||||
s = logits.size(0)
|
||||
e = logits.size(1)
|
||||
c = ec // e
|
||||
h = expert_tokens.size(-1)
|
||||
|
||||
fp16_flag = (expert_tokens.dtype == torch.float16)
|
||||
cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens
|
||||
ctokens = colossal_moe_cuda.combine_forward(s, e, c, h,
|
||||
cb_input, logits,
|
||||
mask, dest_idx)
|
||||
output = ctokens.to(torch.float16) if fp16_flag else ctokens
|
||||
|
||||
ctx.save_for_backward(expert_tokens, logits, mask, dest_idx)
|
||||
ctx.s = s
|
||||
ctx.e = e
|
||||
ctx.c = c
|
||||
ctx.h = h
|
||||
ctx.fp16_flag = fp16_flag
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, tokens_grad):
|
||||
expert_tokens, logits, mask, dest_idx = ctx.saved_tensors
|
||||
|
||||
cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 \
|
||||
else tokens_grad
|
||||
cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens
|
||||
d_expert, d_logits = colossal_moe_cuda.combine_backward(
|
||||
ctx.s, ctx.e, ctx.c, ctx.h,
|
||||
cb_grad, cb_input, logits, mask, dest_idx)
|
||||
d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert
|
||||
|
||||
return d_expert, d_logits, None, None, None
|
||||
|
||||
|
||||
def moe_cumsum(inputs: Tensor):
|
||||
dim0 = inputs.size(0)
|
||||
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
|
||||
if flag and U_CUDA_MODE:
|
||||
return colossal_moe_cuda.cumsum_sub_one(inputs)
|
||||
else:
|
||||
return torch.cumsum(inputs, dim=0) - 1
|
||||
|
|
|
@ -0,0 +1,96 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.global_variables import moe_env
|
||||
from colossalai.context import ParallelMode, seed
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class Experts(nn.Module):
|
||||
"""A wrapper class to create experts. It will create E experts across the
|
||||
moe model parallel group, where E is the number of experts. Every expert
|
||||
is a instence of the class, 'expert' in initialization parameters.
|
||||
|
||||
:param expert: The class of all experts
|
||||
:param num_experts: The number of experts
|
||||
:param expert_args: Args used to initialize experts
|
||||
|
||||
:type num_experts: int
|
||||
"""
|
||||
|
||||
def __init__(self, expert, num_experts, **expert_args):
|
||||
super().__init__()
|
||||
|
||||
assert num_experts % moe_env.model_parallel_size == 0, \
|
||||
"The number of experts should be divied by moe model size"
|
||||
|
||||
num_local_experts = num_experts // moe_env.model_parallel_size
|
||||
with seed(ParallelMode.MOE_MODEL):
|
||||
self.experts = nn.ModuleList([expert(**expert_args) for _ in range(num_local_experts)])
|
||||
self.num_local_experts = num_local_experts
|
||||
for exp in self.experts:
|
||||
for param in exp.parameters():
|
||||
param.__setattr__('moe_param', True)
|
||||
|
||||
def forward(self, inputs):
|
||||
expert_input = torch.chunk(inputs, self.num_local_experts, dim=1)
|
||||
expert_output = []
|
||||
|
||||
for i in range(self.num_local_experts):
|
||||
expert_output.append(self.experts[i](expert_input[i]))
|
||||
|
||||
output = torch.cat(expert_output, dim=1).contiguous()
|
||||
return output
|
||||
|
||||
|
||||
class FFNExperts(nn.Module):
|
||||
|
||||
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
|
||||
super().__init__()
|
||||
|
||||
assert num_experts % moe_env.model_parallel_size == 0, \
|
||||
"The number of experts should be divied by moe model size"
|
||||
|
||||
num_local_experts = num_experts // moe_env.model_parallel_size
|
||||
|
||||
self.w1 = nn.Parameter(torch.empty(num_local_experts, d_model, d_ff, device=get_current_device()))
|
||||
self.b1 = nn.Parameter(torch.empty(num_local_experts, 1, d_ff, device=get_current_device()))
|
||||
|
||||
self.w2 = nn.Parameter(torch.empty(num_local_experts, d_ff, d_model, device=get_current_device()))
|
||||
self.b2 = nn.Parameter(torch.empty(num_local_experts, 1, d_model, device=get_current_device()))
|
||||
|
||||
s1 = math.sqrt(0.1 / d_model)
|
||||
s2 = math.sqrt(0.1 / d_ff)
|
||||
nn.init.trunc_normal_(self.w1, std=s1)
|
||||
nn.init.trunc_normal_(self.b1, std=s1)
|
||||
nn.init.trunc_normal_(self.w2, std=s2)
|
||||
nn.init.trunc_normal_(self.b2, std=s2)
|
||||
|
||||
self.act = nn.GELU() if activation is None else activation
|
||||
self.drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
for param in self.parameters():
|
||||
param.__setattr__('moe_param', True)
|
||||
|
||||
def forward(self, inputs): # x [g, el, c, h]
|
||||
|
||||
el = inputs.size(1)
|
||||
h = inputs.size(-1)
|
||||
|
||||
inputs = inputs.transpose(0, 1)
|
||||
inshape = inputs.shape
|
||||
inputs = inputs.reshape(el, -1, h)
|
||||
|
||||
out_ff = torch.baddbmm(self.b1, inputs, self.w1)
|
||||
out_act = self.act(out_ff)
|
||||
with seed(ParallelMode.TENSOR):
|
||||
inter = self.drop(out_act)
|
||||
|
||||
out_model = torch.baddbmm(self.b2, inter, self.w2)
|
||||
with seed(ParallelMode.TENSOR):
|
||||
outputs = self.drop(out_model) # outputs [el, gc, h]
|
||||
|
||||
outputs = outputs.reshape(inshape)
|
||||
outputs = outputs.transpose(0, 1).contiguous()
|
||||
return outputs
|
|
@ -3,70 +3,13 @@ import math
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.cuda.amp import autocast
|
||||
import torch.distributed as dist
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.global_variables import moe_env
|
||||
from colossalai.context import ParallelMode, seed
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.utils import get_current_device
|
||||
from ._operation import AllToAll
|
||||
|
||||
|
||||
class NormalNoiseGenerator:
|
||||
"""Generates a random noisy mask for logtis tensor.
|
||||
|
||||
All noise is generated from a normal distribution (0, 1 / E^2), where
|
||||
E = the number of experts.
|
||||
|
||||
:param num_experts: The number of experts
|
||||
:type num_experts: int
|
||||
"""
|
||||
|
||||
def __init__(self, num_experts: int):
|
||||
self.normal = torch.distributions.normal.Normal(
|
||||
loc=torch.tensor(0.0, device=get_current_device()),
|
||||
scale=torch.tensor(1.0 / num_experts ** 2, device=get_current_device())
|
||||
).rsample
|
||||
|
||||
def __call__(self, inputs: torch.Tensor):
|
||||
noisy = self.normal(inputs.shape)
|
||||
return inputs + noisy
|
||||
|
||||
|
||||
class Experts(nn.Module):
|
||||
"""A wrapper class to create experts. It will create E experts across the
|
||||
moe model parallel group, where E is the number of experts. Every expert
|
||||
is a instence of the class, 'expert' in initialization parameters.
|
||||
|
||||
:param expert: The class of all experts
|
||||
:param num_experts: The number of experts
|
||||
:param expert_args: Args used to initialize experts
|
||||
|
||||
:type num_experts: int
|
||||
"""
|
||||
|
||||
def __init__(self, expert, num_experts, **expert_args):
|
||||
super().__init__()
|
||||
|
||||
assert num_experts % moe_env.model_parallel_size == 0, \
|
||||
"The number of experts should be divied by moe model size"
|
||||
|
||||
num_local_experts = num_experts // moe_env.model_parallel_size
|
||||
with seed(ParallelMode.MOE_MODEL):
|
||||
self.experts = nn.ModuleList([
|
||||
expert(**expert_args) for _ in range(num_local_experts)])
|
||||
self.num_local_experts = num_local_experts
|
||||
for exp in self.experts:
|
||||
for param in exp.parameters():
|
||||
param.__setattr__('moe_param', 1)
|
||||
|
||||
def forward(self, inputs):
|
||||
expert_input = torch.chunk(inputs, self.num_local_experts, dim=0)
|
||||
expert_output = []
|
||||
|
||||
for i in range(self.num_local_experts):
|
||||
expert_output.append(self.experts[i](expert_input[i]))
|
||||
|
||||
output = torch.cat(expert_output, dim=0)
|
||||
return output
|
||||
from ._operation import U_CUDA_MODE, AllToAll, MoeDispatch, MoeCombine, moe_cumsum
|
||||
from .utils import autocast_softmax
|
||||
|
||||
|
||||
class Top1Router(nn.Module):
|
||||
|
@ -83,63 +26,79 @@ class Top1Router(nn.Module):
|
|||
:type noisy_func: Callable, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
capacity_factor: float,
|
||||
min_capacity: int,
|
||||
noisy_func=None):
|
||||
def __init__(self, capacity_factor: float, min_capacity: int = 0, select_policy: str = "first", noisy_func=None):
|
||||
super().__init__()
|
||||
self.capacity_factor = capacity_factor
|
||||
self.min_capacity = min_capacity
|
||||
self.select_policy = select_policy
|
||||
self.noisy_func = noisy_func
|
||||
self.uniform = torch.distributions.uniform.Uniform(
|
||||
low=torch.tensor(0.0, device=get_current_device()),
|
||||
high=torch.tensor(1.0, device=get_current_device())).rsample
|
||||
|
||||
def get_capacity(self, logits_shape):
|
||||
capacity = math.ceil(self.capacity_factor *
|
||||
logits_shape[0] / logits_shape[1])
|
||||
if capacity < self.min_capacity:
|
||||
capacity = self.min_capacity
|
||||
assert select_policy in {"first", "random"}
|
||||
if select_policy == "random":
|
||||
self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()),
|
||||
high=torch.tensor(1.0,
|
||||
device=get_current_device())).rsample
|
||||
|
||||
def get_capacity(
|
||||
self,
|
||||
logits_shape,
|
||||
):
|
||||
capacity = math.floor(self.capacity_factor * logits_shape[-2] / logits_shape[-1])
|
||||
capacity += capacity % 2
|
||||
capacity = max(capacity, self.min_capacity)
|
||||
assert capacity > 0
|
||||
return capacity
|
||||
|
||||
def forward(self, inputs):
|
||||
def forward(self, inputs: torch.Tensor, cuda_mode: bool = False):
|
||||
|
||||
if self.noisy_func is not None:
|
||||
inputs_noisy = self.noisy_func(inputs)
|
||||
else:
|
||||
inputs_noisy = inputs
|
||||
|
||||
logits = F.softmax(inputs, dim=1)
|
||||
|
||||
num_experts = logits.shape[1]
|
||||
logits = autocast_softmax(inputs, dim=-1)
|
||||
num_experts = logits.size(-1)
|
||||
capacity = self.get_capacity(logits.shape)
|
||||
|
||||
expert_idx = torch.argmax(inputs_noisy, dim=1)
|
||||
expert_mask = F.one_hot(expert_idx, num_classes=num_experts)
|
||||
expert_mask_f = expert_mask.float()
|
||||
top1_idx = torch.argmax(inputs_noisy, dim=-1)
|
||||
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
|
||||
|
||||
exp_counts = torch.sum(expert_mask, dim=0).detach().to('cpu')
|
||||
if self.training:
|
||||
me = torch.mean(logits, dim=0)
|
||||
ce = torch.mean(mask.float(), dim=0)
|
||||
l_aux = num_experts * torch.sum(me * ce)
|
||||
moe_env.add_loss(l_aux)
|
||||
else:
|
||||
max_num = torch.max(torch.sum(mask, dim=0))
|
||||
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MOE_MODEL))
|
||||
capacity = max_num.item()
|
||||
|
||||
me = torch.mean(logits, dim=0)
|
||||
ce = torch.mean(expert_mask_f, dim=0)
|
||||
l_aux = torch.sum(me * ce) * num_experts
|
||||
moe_env.add_loss(l_aux)
|
||||
if not self.training:
|
||||
ranks = moe_cumsum(mask)
|
||||
elif self.select_policy == "random":
|
||||
rand_mask = mask * self.uniform(mask.shape)
|
||||
_, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0)
|
||||
mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1)
|
||||
ranks = moe_cumsum(mask)
|
||||
elif self.select_policy == "first":
|
||||
ranks = moe_cumsum(mask)
|
||||
mask = mask * torch.lt(ranks, capacity)
|
||||
else:
|
||||
raise NotImplementedError("Not support such select policy yet.")
|
||||
|
||||
rand_mask = expert_mask * self.uniform(logits.shape)
|
||||
_, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0)
|
||||
ranks = torch.sum(mask * ranks, dim=-1)
|
||||
|
||||
dispatch_mask = \
|
||||
expert_mask * torch.zeros_like(expert_mask).scatter_(0, dispatch_idx, 1)
|
||||
|
||||
locations = torch.cumsum(dispatch_mask, dim=0) - 1
|
||||
locations = torch.sum(dispatch_mask * locations, dim=1)
|
||||
locations = F.one_hot(locations, num_classes=capacity)
|
||||
|
||||
logits = logits * dispatch_mask
|
||||
combine_weights = logits.unsqueeze(2) * locations.unsqueeze(1)
|
||||
|
||||
sec_mask = combine_weights.bool()
|
||||
return combine_weights, sec_mask, exp_counts
|
||||
if cuda_mode:
|
||||
mask = torch.sum(mask, dim=-1)
|
||||
mask = torch.stack([mask], dim=0).to(torch.int32)
|
||||
dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32)
|
||||
return logits, mask, dest_idx, num_experts * capacity
|
||||
else:
|
||||
ranks = F.one_hot(ranks, num_classes=capacity)
|
||||
weight = mask * logits.type_as(inputs)
|
||||
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
|
||||
sec_mask = combine_weights.bool()
|
||||
return combine_weights, sec_mask
|
||||
|
||||
|
||||
class Top2Router(nn.Module):
|
||||
|
@ -159,53 +118,67 @@ class Top2Router(nn.Module):
|
|||
self.noisy_func = noisy_func
|
||||
|
||||
def get_capacity(self, logits_shape):
|
||||
capacity = math.ceil(2 * self.capacity_factor *
|
||||
logits_shape[0] / logits_shape[1])
|
||||
capacity = math.floor(2 * self.capacity_factor * logits_shape[-2] / logits_shape[-1])
|
||||
capacity += capacity % 2
|
||||
assert capacity > 0
|
||||
return capacity
|
||||
|
||||
def forward(self, inputs):
|
||||
def forward(self, inputs: torch.Tensor, cuda_mode: bool = False):
|
||||
# inputs: [s, h]
|
||||
if self.noisy_func is not None:
|
||||
inputs = self.noisy_func(inputs)
|
||||
|
||||
logits = F.softmax(inputs, dim=-1)
|
||||
logits = autocast_softmax(inputs, dim=-1) # logits: [s, e]
|
||||
num_experts = logits.size(-1)
|
||||
capacity = self.get_capacity(logits.shape)
|
||||
|
||||
_, expert_idx = torch.topk(logits, k=2, dim=-1, largest=True, sorted=True)
|
||||
top1_idx = expert_idx[:, 0]
|
||||
top2_idx = expert_idx[:, 1]
|
||||
top1_idx = torch.argmax(logits, dim=-1)
|
||||
mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
|
||||
logits_except1 = logits.masked_fill(mask1.bool(), float("-inf"))
|
||||
top2_idx = torch.argmax(logits_except1, dim=-1)
|
||||
mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
|
||||
|
||||
mask1 = F.one_hot(top1_idx, num_classes=num_experts)
|
||||
mask2 = F.one_hot(top2_idx, num_classes=num_experts)
|
||||
cmask = (mask1 + mask2) # loss: [s, e]
|
||||
if self.training:
|
||||
me = torch.mean(logits, dim=0)
|
||||
ce = torch.mean(cmask.float(), dim=0)
|
||||
l_aux = num_experts * torch.sum(me * ce) / 2.0
|
||||
moe_env.add_loss(l_aux)
|
||||
else:
|
||||
max_num = torch.max(torch.sum(cmask, dim=0))
|
||||
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MOE_MODEL))
|
||||
capacity = max_num.item()
|
||||
|
||||
loss_mask = (mask1 + mask2)
|
||||
exp_counts = torch.sum(loss_mask, dim=0).detach().to('cpu')
|
||||
me = torch.mean(logits, dim=0)
|
||||
ce = torch.mean(loss_mask.float(), dim=0)
|
||||
l_aux = num_experts * torch.sum(me * ce) / 2.0
|
||||
moe_env.add_loss(l_aux)
|
||||
rank1 = moe_cumsum(mask1) # rank1: [s, e]
|
||||
rank2 = moe_cumsum(mask2)
|
||||
rank2 += torch.sum(mask1, dim=-2, keepdim=True)
|
||||
|
||||
locations1 = torch.cumsum(mask1, dim=0) - 1
|
||||
locations2 = torch.cumsum(mask2, dim=0) - 1
|
||||
locations2 += torch.sum(mask1, dim=0, keepdim=True)
|
||||
mask1 *= torch.lt(rank1, capacity)
|
||||
mask2 *= torch.lt(rank2, capacity)
|
||||
|
||||
mask1 *= torch.lt(locations1, capacity)
|
||||
mask2 *= torch.lt(locations2, capacity)
|
||||
rank1 = torch.sum(mask1 * rank1, dim=-1)
|
||||
rank2 = torch.sum(mask2 * rank2, dim=-1)
|
||||
|
||||
weight1 = mask1 * logits
|
||||
weight2 = mask2 * logits
|
||||
if cuda_mode:
|
||||
mask1 = torch.sum(mask1, dim=-1)
|
||||
mask2 = torch.sum(mask2, dim=-1)
|
||||
|
||||
locations1 = torch.sum(mask1 * locations1, dim=1)
|
||||
locations2 = torch.sum(mask2 * locations2, dim=1)
|
||||
locations1_sc = F.one_hot(locations1, num_classes=capacity)
|
||||
locations2_sc = F.one_hot(locations2, num_classes=capacity)
|
||||
mask = torch.stack([mask1, mask2], dim=0).to(torch.int32)
|
||||
dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32)
|
||||
|
||||
combine_weights1 = weight1.unsqueeze(2) * locations1_sc.unsqueeze(1)
|
||||
combine_weights2 = weight2.unsqueeze(2) * locations2_sc.unsqueeze(1)
|
||||
combine_weights = combine_weights1 + combine_weights2
|
||||
sec_mask = combine_weights.bool()
|
||||
return logits, mask, dest_idx, num_experts * capacity
|
||||
else:
|
||||
weight1 = mask1 * logits.type_as(inputs)
|
||||
weight2 = mask2 * logits.type_as(inputs)
|
||||
rank1_sc = F.one_hot(rank1, num_classes=capacity)
|
||||
rank2_sc = F.one_hot(rank2, num_classes=capacity)
|
||||
|
||||
return combine_weights, sec_mask, exp_counts
|
||||
cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1)
|
||||
cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1)
|
||||
cb_weight = cb_weight1 + cb_weight2
|
||||
sec_mask = cb_weight.bool()
|
||||
|
||||
return cb_weight, sec_mask
|
||||
|
||||
|
||||
class MoeLayer(nn.Module):
|
||||
|
@ -225,52 +198,47 @@ class MoeLayer(nn.Module):
|
|||
:type experts: nn.Module
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dim_model: int,
|
||||
num_experts: int,
|
||||
router: nn.Module,
|
||||
experts: nn.Module):
|
||||
def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: nn.Module):
|
||||
super().__init__()
|
||||
self.d_model = dim_model
|
||||
self.num_experts = num_experts
|
||||
self.gate = nn.Linear(dim_model, num_experts, device=get_current_device())
|
||||
self.gate = nn.Linear(dim_model, num_experts, bias=False, device=get_current_device())
|
||||
self.router = router
|
||||
self.experts = experts
|
||||
self.cuda_mode = True if U_CUDA_MODE and moe_env.enable_cuda else False
|
||||
|
||||
def _router_part(self, tokens: torch.Tensor):
|
||||
gate_output = self.gate(tokens)
|
||||
return self.router(gate_output)
|
||||
def expert_part(self, expert_input: torch.Tensor):
|
||||
expert_input = AllToAll.apply(expert_input, ParallelMode.MOE_MODEL)
|
||||
|
||||
def router_part(self, tokens: torch.Tensor):
|
||||
autocast_context = torch.is_autocast_enabled()
|
||||
if not autocast_context:
|
||||
return self._router_part(tokens)
|
||||
else:
|
||||
with autocast(enabled=False):
|
||||
if tokens.dtype == torch.float16:
|
||||
input_tokens = tokens.float()
|
||||
else:
|
||||
input_tokens = tokens
|
||||
return self._router_part(input_tokens)
|
||||
input_shape = expert_input.shape
|
||||
|
||||
expert_input = expert_input.reshape(moe_env.model_parallel_size,
|
||||
self.num_experts // moe_env.model_parallel_size, -1, self.d_model)
|
||||
|
||||
expert_output = self.experts(expert_input)
|
||||
expert_output = expert_output.reshape(input_shape)
|
||||
|
||||
expert_output = AllToAll.apply(expert_output, ParallelMode.MOE_MODEL)
|
||||
return expert_output
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
tokens = inputs.reshape(-1, self.d_model)
|
||||
gate_output = self.gate(tokens)
|
||||
router_res = self.router(gate_output, self.cuda_mode)
|
||||
|
||||
combine_weights, sec_mask, exp_counts = self.router_part(tokens)
|
||||
if self.cuda_mode:
|
||||
logits, mask, dest_idx, ec = router_res
|
||||
expert_input = MoeDispatch.apply(tokens, mask, dest_idx, ec)
|
||||
expert_output = self.expert_part(expert_input)
|
||||
ret = MoeCombine.apply(expert_output, logits, mask, dest_idx, ec)
|
||||
else:
|
||||
combine_weights, sec_mask = router_res
|
||||
sec_mask_f = sec_mask.type_as(inputs)
|
||||
expert_input = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
|
||||
expert_output = self.expert_part(expert_input)
|
||||
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
|
||||
expert_output = expert_output.view(-1, expert_output.shape[-1])
|
||||
ret = torch.matmul(combine_weights, expert_output)
|
||||
|
||||
sec_mask_f = sec_mask.type_as(inputs)
|
||||
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
|
||||
|
||||
dispatch_data = AllToAll.apply(dispatch_data, ParallelMode.MOE_MODEL)
|
||||
|
||||
expert_output = self.experts(dispatch_data)
|
||||
|
||||
expert_output = AllToAll.apply(expert_output, ParallelMode.MOE_MODEL)
|
||||
|
||||
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
|
||||
expert_output = expert_output.view(-1, expert_output.shape[-1])
|
||||
|
||||
ret = torch.matmul(combine_weights, expert_output)
|
||||
ret = ret.reshape(inputs.shape)
|
||||
|
||||
return ret
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class NormalNoiseGenerator:
|
||||
"""Generates a random noisy mask for logtis tensor.
|
||||
|
||||
All noise is generated from a normal distribution (0, 1 / E^2), where
|
||||
E = the number of experts.
|
||||
|
||||
:param num_experts: The number of experts
|
||||
:type num_experts: int
|
||||
"""
|
||||
|
||||
def __init__(self, num_experts: int):
|
||||
self.normal = torch.distributions.normal.Normal(
|
||||
loc=torch.tensor(0.0, device=get_current_device()),
|
||||
scale=torch.tensor(1.0 / num_experts ** 2, device=get_current_device())
|
||||
).rsample
|
||||
|
||||
def __call__(self, inputs: torch.Tensor):
|
||||
noisy = self.normal(inputs.shape)
|
||||
return inputs + noisy
|
||||
|
||||
|
||||
def autocast_softmax(inputs: torch.Tensor, dim: int):
|
||||
assert inputs.dtype in {torch.float16, torch.float32}
|
||||
fp16_flag = (inputs.dtype == torch.float16)
|
||||
sm_input = inputs.to(torch.float32) if fp16_flag else inputs
|
||||
sm_output = F.softmax(sm_input, dim)
|
||||
return sm_output
|
|
@ -4,7 +4,7 @@ import torch.nn as nn
|
|||
from colossalai.context import ParallelMode
|
||||
from colossalai.nn.layer import VanillaPatchEmbedding, VanillaClassifier, \
|
||||
WrappedDropout as Dropout, WrappedDropPath as DropPath
|
||||
from colossalai.nn.layer.moe import Experts, MoeLayer, Top2Router, NormalNoiseGenerator
|
||||
from colossalai.nn.layer.moe import FFNExperts, MoeLayer, Top2Router, NormalNoiseGenerator
|
||||
from .util import moe_sa_args, moe_mlp_args
|
||||
from ..helper import TransformerLayer
|
||||
from colossalai.global_variables import moe_env
|
||||
|
@ -81,6 +81,7 @@ class VanillaFFN(nn.Module):
|
|||
|
||||
|
||||
class Widenet(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
num_experts: int,
|
||||
capacity_factor: float,
|
||||
|
@ -98,43 +99,33 @@ class Widenet(nn.Module):
|
|||
drop_path: float = 0.):
|
||||
super().__init__()
|
||||
|
||||
embedding = VanillaPatchEmbedding(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_size=d_model)
|
||||
embedding = VanillaPatchEmbedding(img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_size=d_model)
|
||||
embed_dropout = Dropout(p=drop_rate, mode=ParallelMode.TENSOR)
|
||||
|
||||
shared_sa = VanillaSelfAttention(**moe_sa_args(
|
||||
d_model=d_model, n_heads=num_heads, d_kv=d_kv,
|
||||
attention_drop=attention_drop, drop_rate=drop_rate))
|
||||
d_model=d_model, n_heads=num_heads, d_kv=d_kv, attention_drop=attention_drop, drop_rate=drop_rate))
|
||||
|
||||
noisy_func = NormalNoiseGenerator(num_experts)
|
||||
shared_router = Top2Router(capacity_factor, noisy_func=noisy_func)
|
||||
shared_experts = Experts(expert=VanillaFFN,
|
||||
num_experts=num_experts,
|
||||
**moe_mlp_args(
|
||||
d_model=d_model,
|
||||
d_ff=d_ff,
|
||||
drop_rate=drop_rate
|
||||
))
|
||||
shared_experts = FFNExperts(num_experts, d_model, d_ff, drop_rate=drop_rate)
|
||||
|
||||
# stochastic depth decay rule
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
|
||||
blocks = [
|
||||
TransformerLayer(
|
||||
att=shared_sa,
|
||||
ffn=MoeLayer(dim_model=d_model, num_experts=num_experts,
|
||||
router=shared_router, experts=shared_experts),
|
||||
norm1=nn.LayerNorm(d_model, eps=1e-6),
|
||||
norm2=nn.LayerNorm(d_model, eps=1e-6),
|
||||
droppath=DropPath(p=dpr[i], mode=ParallelMode.TENSOR)
|
||||
)
|
||||
for i in range(depth)
|
||||
TransformerLayer(att=shared_sa,
|
||||
ffn=MoeLayer(dim_model=d_model,
|
||||
num_experts=num_experts,
|
||||
router=shared_router,
|
||||
experts=shared_experts),
|
||||
norm1=nn.LayerNorm(d_model, eps=1e-6),
|
||||
norm2=nn.LayerNorm(d_model, eps=1e-6),
|
||||
droppath=DropPath(p=dpr[i], mode=ParallelMode.TENSOR)) for i in range(depth)
|
||||
]
|
||||
norm = nn.LayerNorm(d_model, eps=1e-6)
|
||||
self.linear = VanillaClassifier(in_features=d_model,
|
||||
num_classes=num_classes)
|
||||
self.linear = VanillaClassifier(in_features=d_model, num_classes=num_classes)
|
||||
nn.init.zeros_(self.linear.weight)
|
||||
nn.init.zeros_(self.linear.bias)
|
||||
self.widenet = nn.Sequential(embedding, embed_dropout, *blocks, norm)
|
||||
|
@ -145,3 +136,64 @@ class Widenet(nn.Module):
|
|||
x = torch.mean(x, dim=1)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class ViTMoE(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
num_experts: int,
|
||||
capacity_factor: float,
|
||||
img_size: int = 224,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
num_classes: int = 1000,
|
||||
depth: int = 12,
|
||||
d_model: int = 768,
|
||||
num_heads: int = 12,
|
||||
d_kv: int = 64,
|
||||
d_ff: int = 3072,
|
||||
attention_drop: float = 0.,
|
||||
drop_rate: float = 0.1,
|
||||
drop_path: float = 0.):
|
||||
super().__init__()
|
||||
|
||||
embedding = VanillaPatchEmbedding(img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_size=d_model)
|
||||
embed_dropout = Dropout(p=drop_rate, mode=ParallelMode.TENSOR)
|
||||
|
||||
noisy_func = NormalNoiseGenerator(num_experts)
|
||||
router = Top2Router(capacity_factor, noisy_func=noisy_func)
|
||||
|
||||
assert depth % 2 == 0
|
||||
|
||||
# stochastic depth decay rule
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
|
||||
blocks = []
|
||||
for i in range(depth):
|
||||
sa = VanillaSelfAttention(**moe_sa_args(
|
||||
d_model=d_model, n_heads=num_heads, d_kv=d_kv, attention_drop=attention_drop, drop_rate=drop_rate))
|
||||
ffn = VanillaFFN(**moe_mlp_args(
|
||||
d_model=d_model, d_ff=d_ff, drop_rate=drop_rate)) if i % 2 == 0 else \
|
||||
MoeLayer(dim_model=d_model, num_experts=num_experts, router=router,
|
||||
experts=FFNExperts(num_experts, d_model, d_ff, drop_rate=drop_rate))
|
||||
layer = TransformerLayer(att=sa,
|
||||
ffn=ffn,
|
||||
norm1=nn.LayerNorm(d_model, eps=1e-6),
|
||||
norm2=nn.LayerNorm(d_model, eps=1e-6),
|
||||
droppath=DropPath(p=dpr[i], mode=ParallelMode.TENSOR))
|
||||
blocks.append(layer)
|
||||
|
||||
norm = nn.LayerNorm(d_model, eps=1e-6)
|
||||
self.linear = VanillaClassifier(in_features=d_model, num_classes=num_classes)
|
||||
nn.init.zeros_(self.linear.weight)
|
||||
nn.init.zeros_(self.linear.bias)
|
||||
self.vitmoe = nn.Sequential(embedding, embed_dropout, *blocks, norm)
|
||||
|
||||
def forward(self, x):
|
||||
moe_env.reset_loss()
|
||||
x = self.vitmoe(x)
|
||||
x = torch.mean(x, dim=1)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
|
4
setup.py
4
setup.py
|
@ -162,6 +162,10 @@ if build_cuda_ext:
|
|||
['scaled_masked_softmax.cpp', 'scaled_masked_softmax_cuda.cu'],
|
||||
extra_cuda_flags + cc_flag))
|
||||
|
||||
ext_modules.append(cuda_ext_helper('colossal_moe_cuda',
|
||||
['moe_cuda.cpp', 'moe_cuda_kernel.cu'],
|
||||
extra_cuda_flags + cc_flag))
|
||||
|
||||
extra_cuda_flags = ['-maxrregcount=50']
|
||||
|
||||
ext_modules.append(cuda_ext_helper('colossal_layer_norm_cuda',
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.multiprocessing as mp
|
||||
import colossalai
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.nn.layer.moe import Top2Router, MoeLayer
|
||||
from colossalai.global_variables import moe_env
|
||||
|
||||
|
||||
BATCH_SIZE = 32
|
||||
NUM_EXPERTS = 4
|
||||
CONFIG = dict(parallel=dict(moe=dict(size=4)))
|
||||
|
||||
|
||||
def check_equal(A, B, atol=1e-06):
|
||||
assert torch.allclose(A, B, rtol=0, atol=atol) is True
|
||||
|
||||
|
||||
def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
# torch.set_printoptions(precision=30)
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
local_rank = gpc.get_local_rank(ParallelMode.GLOBAL)
|
||||
torch.manual_seed(rs + local_rank)
|
||||
moe_env.reset_loss()
|
||||
tokens = torch.randn(BATCH_SIZE, hidden_size,
|
||||
dtype=data_type, device=get_current_device(), requires_grad=True)
|
||||
# print(f"tokens:\n{tokens}")
|
||||
router = Top2Router(1)
|
||||
layer = MoeLayer(hidden_size, NUM_EXPERTS, router, nn.Identity())
|
||||
if data_type == torch.float16:
|
||||
layer = layer.half()
|
||||
layer.cuda_mode = False
|
||||
|
||||
old_out = layer(tokens)
|
||||
# print(f"old output:\n{old_out}")
|
||||
|
||||
ech = old_out.shape
|
||||
grad = torch.randn(ech, device=get_current_device())
|
||||
old_out.backward(grad)
|
||||
|
||||
o_tk_grad = tokens.grad.data.clone()
|
||||
o_gt_grad = layer.gate.weight.grad.data.clone()
|
||||
|
||||
tokens.grad.zero_()
|
||||
layer.gate.weight.grad.zero_()
|
||||
|
||||
layer.cuda_mode = True
|
||||
new_out = layer(tokens)
|
||||
|
||||
# print(torch.max(torch.abs(old_out - new_out)))
|
||||
if data_type == torch.float32:
|
||||
check_equal(old_out, new_out)
|
||||
else:
|
||||
check_equal(old_out, new_out, 1e-2)
|
||||
# print(f"forward functions passed")
|
||||
|
||||
# print(f"new output:\n{new_out}")
|
||||
new_out.backward(grad)
|
||||
n_tk_grad = tokens.grad.data.clone()
|
||||
n_gt_grad = layer.gate.weight.grad.data.clone()
|
||||
|
||||
# print(torch.max(torch.abs(o_tk_grad - n_tk_grad)))
|
||||
if data_type == torch.float32:
|
||||
check_equal(o_tk_grad, n_tk_grad)
|
||||
else:
|
||||
check_equal(o_tk_grad, o_tk_grad, 1e-2)
|
||||
# print(f"tokens gradient passed")
|
||||
|
||||
# print(torch.max(torch.abs(o_gt_grad - n_gt_grad)))
|
||||
if data_type == torch.float32:
|
||||
check_equal(o_gt_grad, n_gt_grad, 5e-05)
|
||||
else:
|
||||
check_equal(o_gt_grad, n_gt_grad, 2e-01)
|
||||
# print(f"linear weight gradient passed")
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("rs", [131])
|
||||
@pytest.mark.parametrize("hidden_size", [32, 144])
|
||||
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
|
||||
def test_moe_top2(rs, hidden_size, data_type):
|
||||
world_size = 4
|
||||
run_func = partial(run_routing, world_size=world_size, port=free_port(),
|
||||
rs=rs, hidden_size=hidden_size, data_type=data_type)
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_moe_top2(2, 256, torch.float16)
|
|
@ -0,0 +1,97 @@
|
|||
from functools import partial
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.multiprocessing as mp
|
||||
import colossalai
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.nn.layer.moe import Top1Router, MoeLayer
|
||||
from colossalai.global_variables import moe_env
|
||||
|
||||
BATCH_SIZE = 32
|
||||
NUM_EXPERTS = 4
|
||||
CONFIG = dict(parallel=dict(moe=dict(size=4)))
|
||||
|
||||
|
||||
def check_equal(A, B, atol=1e-06):
|
||||
assert torch.allclose(A, B, rtol=0, atol=atol) is True
|
||||
|
||||
|
||||
def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
# torch.set_printoptions(precision=30)
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
local_rank = gpc.get_local_rank(ParallelMode.GLOBAL)
|
||||
torch.manual_seed(rs + local_rank)
|
||||
moe_env.reset_loss()
|
||||
tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True)
|
||||
# print(f"tokens:\n{tokens}")
|
||||
router = Top1Router(1)
|
||||
layer = MoeLayer(hidden_size, NUM_EXPERTS, router, nn.Identity())
|
||||
if data_type == torch.float16:
|
||||
layer = layer.half()
|
||||
layer.cuda_mode = False
|
||||
|
||||
old_out = layer(tokens)
|
||||
# print(f"old output:\n{old_out}")
|
||||
|
||||
ech = old_out.shape
|
||||
grad = torch.randn(ech, device=get_current_device())
|
||||
old_out.backward(grad)
|
||||
|
||||
o_tk_grad = tokens.grad.data.clone()
|
||||
o_gt_grad = layer.gate.weight.grad.data.clone()
|
||||
|
||||
tokens.grad.zero_()
|
||||
layer.gate.weight.grad.zero_()
|
||||
|
||||
layer.cuda_mode = True
|
||||
new_out = layer(tokens)
|
||||
|
||||
# print(torch.max(torch.abs(old_out - new_out)))
|
||||
if data_type == torch.float32:
|
||||
check_equal(old_out, new_out)
|
||||
else:
|
||||
check_equal(old_out, new_out, 1e-2)
|
||||
# print(f"forward functions passed")
|
||||
|
||||
# print(f"new output:\n{new_out}")
|
||||
new_out.backward(grad)
|
||||
n_tk_grad = tokens.grad.data.clone()
|
||||
n_gt_grad = layer.gate.weight.grad.data.clone()
|
||||
|
||||
# print(torch.max(torch.abs(o_tk_grad - n_tk_grad)))
|
||||
if data_type == torch.float32:
|
||||
check_equal(o_tk_grad, n_tk_grad)
|
||||
else:
|
||||
check_equal(o_tk_grad, o_tk_grad, 1e-2)
|
||||
# print(f"tokens gradient passed")
|
||||
|
||||
# print(torch.max(torch.abs(o_gt_grad - n_gt_grad)))
|
||||
if data_type == torch.float32:
|
||||
check_equal(o_gt_grad, n_gt_grad, 5e-05)
|
||||
else:
|
||||
check_equal(o_gt_grad, n_gt_grad, 2e-01)
|
||||
# print(f"linear weight gradient passed")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Should be activated for detailed tests")
|
||||
@pytest.mark.parametrize("rs", [2, 42, 60])
|
||||
@pytest.mark.parametrize("hidden_size", [128, 256, 512, 768, 1024, 2048])
|
||||
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
|
||||
def test_moe_top2(rs, hidden_size, data_type):
|
||||
world_size = 4
|
||||
run_func = partial(run_routing,
|
||||
world_size=world_size,
|
||||
port=free_port(),
|
||||
rs=rs,
|
||||
hidden_size=hidden_size,
|
||||
data_type=data_type)
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_moe_top2(60, 512, torch.float16)
|
|
@ -0,0 +1,97 @@
|
|||
from functools import partial
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.multiprocessing as mp
|
||||
import colossalai
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.nn.layer.moe import Top2Router, MoeLayer
|
||||
from colossalai.global_variables import moe_env
|
||||
|
||||
BATCH_SIZE = 32
|
||||
NUM_EXPERTS = 4
|
||||
CONFIG = dict(parallel=dict(moe=dict(size=4)))
|
||||
|
||||
|
||||
def check_equal(A, B, atol=1e-06):
|
||||
assert torch.allclose(A, B, rtol=0, atol=atol) is True
|
||||
|
||||
|
||||
def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
# torch.set_printoptions(precision=30)
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
local_rank = gpc.get_local_rank(ParallelMode.GLOBAL)
|
||||
torch.manual_seed(rs + local_rank)
|
||||
moe_env.reset_loss()
|
||||
tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True)
|
||||
# print(f"tokens:\n{tokens}")
|
||||
router = Top2Router(1)
|
||||
layer = MoeLayer(hidden_size, NUM_EXPERTS, router, nn.Identity())
|
||||
if data_type == torch.float16:
|
||||
layer = layer.half()
|
||||
layer.cuda_mode = False
|
||||
|
||||
old_out = layer(tokens)
|
||||
# print(f"old output:\n{old_out}")
|
||||
|
||||
ech = old_out.shape
|
||||
grad = torch.randn(ech, device=get_current_device())
|
||||
old_out.backward(grad)
|
||||
|
||||
o_tk_grad = tokens.grad.data.clone()
|
||||
o_gt_grad = layer.gate.weight.grad.data.clone()
|
||||
|
||||
tokens.grad.zero_()
|
||||
layer.gate.weight.grad.zero_()
|
||||
|
||||
layer.cuda_mode = True
|
||||
new_out = layer(tokens)
|
||||
|
||||
# print(torch.max(torch.abs(old_out - new_out)))
|
||||
if data_type == torch.float32:
|
||||
check_equal(old_out, new_out)
|
||||
else:
|
||||
check_equal(old_out, new_out, 1e-2)
|
||||
# print(f"forward functions passed")
|
||||
|
||||
# print(f"new output:\n{new_out}")
|
||||
new_out.backward(grad)
|
||||
n_tk_grad = tokens.grad.data.clone()
|
||||
n_gt_grad = layer.gate.weight.grad.data.clone()
|
||||
|
||||
# print(torch.max(torch.abs(o_tk_grad - n_tk_grad)))
|
||||
if data_type == torch.float32:
|
||||
check_equal(o_tk_grad, n_tk_grad)
|
||||
else:
|
||||
check_equal(o_tk_grad, o_tk_grad, 1e-2)
|
||||
# print(f"tokens gradient passed")
|
||||
|
||||
# print(torch.max(torch.abs(o_gt_grad - n_gt_grad)))
|
||||
if data_type == torch.float32:
|
||||
check_equal(o_gt_grad, n_gt_grad, 5e-05)
|
||||
else:
|
||||
check_equal(o_gt_grad, n_gt_grad, 2e-01)
|
||||
# print(f"linear weight gradient passed")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Should be activated for detailed tests")
|
||||
@pytest.mark.parametrize("rs", [2, 42, 60])
|
||||
@pytest.mark.parametrize("hidden_size", [128, 256, 512, 768, 1024, 2048])
|
||||
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
|
||||
def test_moe_top2(rs, hidden_size, data_type):
|
||||
world_size = 4
|
||||
run_func = partial(run_routing,
|
||||
world_size=world_size,
|
||||
port=free_port(),
|
||||
rs=rs,
|
||||
hidden_size=hidden_size,
|
||||
data_type=data_type)
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_moe_top2(2, 256, torch.float16)
|
Loading…
Reference in New Issue