Added TPExpert for special situation

pull/394/head
1SAA 2022-02-27 22:28:39 +08:00 committed by Frank Lee
parent 36b8477228
commit 82023779bb
7 changed files with 192 additions and 41 deletions

View File

@ -1,5 +1,8 @@
from .experts import Experts, FFNExperts
from .experts import Experts, FFNExperts, TPExperts
from .layers import MoeLayer, Top1Router, Top2Router
from .utils import NormalNoiseGenerator
from .utils import NormalNoiseGenerator, build_ffn_experts
__all__ = ['Experts', 'FFNExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator']
__all__ = [
'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator',
'build_ffn_experts'
]

View File

@ -15,6 +15,55 @@ except ImportError:
print("If you want to activate cuda mode for MoE, please install with cuda_ext!")
class AllGather(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, inputs: Tensor, parallel_mode: ParallelMode) -> Tensor:
if ctx is not None:
ctx.parallel_mode = parallel_mode
comm_size = gpc.get_world_size(parallel_mode)
if comm_size == 1:
return inputs.unsqueeze(0)
buffer_shape = (comm_size,) + inputs.shape
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
dist.all_gather(buffer_list, inputs, group=gpc.get_group(parallel_mode))
return outputs
@staticmethod
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
return ReduceScatter.forward(None, grad_outputs, ctx.parallel_mode), None
class ReduceScatter(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, inputs: Tensor, parallel_mode: ParallelMode) -> Tensor:
if ctx is not None:
ctx.parallel_mode = parallel_mode
comm_size = gpc.get_world_size(parallel_mode)
if comm_size == 1:
return inputs.squeeze(0)
if not inputs.is_contiguous():
inputs = inputs.contiguous()
output_shape = inputs.shape[1:]
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
dist.reduce_scatter(outputs, buffer_list, group=gpc.get_group(parallel_mode))
return outputs
@staticmethod
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
return AllGather.forward(None, grad_outputs, ctx.parallel_mode), None
class AllToAll(torch.autograd.Function):
"""Dispatches input tensor [e, c, h] to all experts by all_to_all_single
operation in torch.distributed.

View File

@ -7,7 +7,16 @@ from colossalai.context import ParallelMode, seed
from colossalai.utils import get_current_device
class Experts(nn.Module):
class MoeExperts(nn.Module):
def __init__(self, comm: str):
super().__init__()
assert comm in {"all_to_all", "all_gather"}, \
"This kind of communication has not been implemented yet.\n Please use Experts build function."
self.comm = comm
class Experts(MoeExperts):
"""A wrapper class to create experts. It will create E experts across the
moe model parallel group, where E is the number of experts. Every expert
is a instence of the class, 'expert' in initialization parameters.
@ -20,19 +29,22 @@ class Experts(nn.Module):
"""
def __init__(self, expert, num_experts, **expert_args):
super().__init__()
super().__init__("all_to_all")
assert num_experts % moe_env.model_parallel_size == 0, \
"The number of experts should be divied by moe model size"
num_local_experts = num_experts // moe_env.model_parallel_size
with seed(ParallelMode.MOE_MODEL):
self.experts = nn.ModuleList([expert(**expert_args) for _ in range(num_local_experts)])
self.num_local_experts = num_local_experts
for exp in self.experts:
for param in exp.parameters():
param.__setattr__('moe_param', True)
self.num_local_experts = num_local_experts
def forward(self, inputs):
expert_input = torch.chunk(inputs, self.num_local_experts, dim=1)
expert_output = []
@ -44,10 +56,10 @@ class Experts(nn.Module):
return output
class FFNExperts(nn.Module):
class FFNExperts(MoeExperts):
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
super().__init__()
super().__init__("all_to_all")
assert num_experts % moe_env.model_parallel_size == 0, \
"The number of experts should be divied by moe model size"
@ -75,7 +87,7 @@ class FFNExperts(nn.Module):
for param in self.parameters():
param.__setattr__('moe_param', True)
def forward(self, inputs): # x [g, el, c, h]
def forward(self, inputs): # inputs [g, el, c, h]
el = inputs.size(1)
h = inputs.size(-1)
@ -96,3 +108,58 @@ class FFNExperts(nn.Module):
outputs = outputs.reshape(inshape)
outputs = outputs.transpose(0, 1).contiguous()
return outputs
class TPExperts(MoeExperts):
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
super().__init__("all_gather")
assert d_ff % moe_env.model_parallel_size == 0, \
"d_ff should be divied by moe model size"
p_ff = d_ff // moe_env.model_parallel_size
self.w1 = nn.Parameter(torch.empty(num_experts, d_model, p_ff, device=get_current_device()))
self.b1 = nn.Parameter(torch.empty(num_experts, 1, p_ff, device=get_current_device()))
self.w2 = nn.Parameter(torch.empty(num_experts, p_ff, d_model, device=get_current_device()))
self.b2 = nn.Parameter(torch.empty(num_experts, 1, d_model, device=get_current_device()))
s1 = math.sqrt(0.1 / d_model)
s2 = math.sqrt(0.1 / d_ff)
with seed(ParallelMode.MOE_MODEL):
nn.init.trunc_normal_(self.w1, std=s1)
nn.init.trunc_normal_(self.b1, std=s1)
nn.init.trunc_normal_(self.w2, std=s2)
nn.init.trunc_normal_(self.b2, std=s2)
self.act = nn.GELU() if activation is None else activation
self.drop = nn.Dropout(p=drop_rate)
self.w1.__setattr__('moe_param', True)
self.w2.__setattr__('moe_param', True)
self.b1.__setattr__('moe_param', True)
def forward(self, inputs): # inputs [g, e, c, h]
e = inputs.size(1)
h = inputs.size(-1)
inputs = inputs.transpose(0, 1)
inshape = inputs.shape
inputs = inputs.reshape(e, -1, h)
out_ff = torch.baddbmm(self.b1, inputs, self.w1)
out_act = self.act(out_ff)
with seed(ParallelMode.TENSOR):
inter = self.drop(out_act)
out_model = torch.baddbmm(self.b2, inter, self.w2)
outputs = self.drop(out_model) # outputs [e, gc, h]
outputs = outputs.reshape(inshape)
outputs = outputs.transpose(0, 1).contiguous()
return outputs # outputs [g, e, c, h]

View File

@ -8,7 +8,8 @@ from colossalai.core import global_context as gpc
from colossalai.global_variables import moe_env
from colossalai.context import ParallelMode
from colossalai.utils import get_current_device
from ._operation import U_CUDA_MODE, AllToAll, MoeDispatch, MoeCombine, moe_cumsum
from ._operation import U_CUDA_MODE, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum
from .experts import MoeExperts
from .utils import autocast_softmax
@ -198,7 +199,7 @@ class MoeLayer(nn.Module):
:type experts: nn.Module
"""
def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: nn.Module):
def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: MoeExperts):
super().__init__()
self.d_model = dim_model
self.num_experts = num_experts
@ -207,8 +208,8 @@ class MoeLayer(nn.Module):
self.experts = experts
self.cuda_mode = True if U_CUDA_MODE and moe_env.enable_cuda else False
def expert_part(self, expert_input: torch.Tensor):
expert_input = AllToAll.apply(expert_input, ParallelMode.MOE_MODEL)
def a2a_process(self, dispatch_data: torch.Tensor):
expert_input = AllToAll.apply(dispatch_data, ParallelMode.MOE_MODEL)
input_shape = expert_input.shape
@ -221,24 +222,42 @@ class MoeLayer(nn.Module):
expert_output = AllToAll.apply(expert_output, ParallelMode.MOE_MODEL)
return expert_output
def tp_process(self, dispatch_data: torch.Tensor):
expert_in = AllGather.apply(dispatch_data, ParallelMode.MOE_MODEL)
expert_out = self.experts(expert_in)
expert_out = ReduceScatter.apply(expert_out, ParallelMode.MOE_MODEL)
return expert_out
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
tokens = inputs.reshape(-1, self.d_model)
gate_output = self.gate(tokens)
router_res = self.router(gate_output, self.cuda_mode)
if self.cuda_mode:
logits, mask, dest_idx, ec = router_res
expert_input = MoeDispatch.apply(tokens, mask, dest_idx, ec)
expert_output = self.expert_part(expert_input)
ret = MoeCombine.apply(expert_output, logits, mask, dest_idx, ec)
dispatch_data = MoeDispatch.apply(tokens, *router_res[1:])
dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model)
else:
combine_weights, sec_mask = router_res
sec_mask_f = sec_mask.type_as(inputs)
expert_input = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
expert_output = self.expert_part(expert_input)
sec_mask_f = router_res[1].type_as(inputs)
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
# dispatch_data [e, c, h]
if self.experts.comm == "all_to_all":
expert_output = self.a2a_process(dispatch_data)
elif self.experts.comm == "all_gather":
expert_output = self.tp_process(dispatch_data)
else:
raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts "
"build function.")
# expert_output [e, c, h]
if self.cuda_mode:
expert_output = expert_output.reshape(-1, self.d_model)
ans = MoeCombine.apply(expert_output, *router_res)
else:
combine_weights = router_res[0]
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
expert_output = expert_output.view(-1, expert_output.shape[-1])
ret = torch.matmul(combine_weights, expert_output)
ans = torch.matmul(combine_weights, expert_output)
ret = ret.reshape(inputs.shape)
return ret
ans = ans.reshape(inputs.shape)
return ans

View File

@ -1,6 +1,8 @@
import torch
import torch.nn.functional as F
from colossalai.utils import get_current_device
from colossalai.global_variables import moe_env
from .experts import FFNExperts, TPExperts
class NormalNoiseGenerator:
@ -14,10 +16,9 @@ class NormalNoiseGenerator:
"""
def __init__(self, num_experts: int):
self.normal = torch.distributions.normal.Normal(
loc=torch.tensor(0.0, device=get_current_device()),
scale=torch.tensor(1.0 / num_experts ** 2, device=get_current_device())
).rsample
self.normal = torch.distributions.normal.Normal(loc=torch.tensor(0.0, device=get_current_device()),
scale=torch.tensor(1.0 / num_experts**2,
device=get_current_device())).rsample
def __call__(self, inputs: torch.Tensor):
noisy = self.normal(inputs.shape)
@ -30,3 +31,13 @@ def autocast_softmax(inputs: torch.Tensor, dim: int):
sm_input = inputs.to(torch.float32) if fp16_flag else inputs
sm_output = F.softmax(sm_input, dim)
return sm_output
def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
moe_mp_size = moe_env.model_parallel_size
if num_experts % moe_mp_size == 0:
return FFNExperts(num_experts, d_model, d_ff, activation, drop_rate)
elif d_ff % moe_mp_size == 0:
return TPExperts(num_experts, d_model, d_ff, activation, drop_rate)
else:
raise NotImplementedError(f"Can not build {num_experts} experts in {moe_mp_size} GPUS.")

View File

@ -4,7 +4,7 @@ import torch.nn as nn
from colossalai.context import ParallelMode
from colossalai.nn.layer import VanillaPatchEmbedding, VanillaClassifier, \
WrappedDropout as Dropout, WrappedDropPath as DropPath
from colossalai.nn.layer.moe import FFNExperts, MoeLayer, Top2Router, NormalNoiseGenerator
from colossalai.nn.layer.moe import build_ffn_experts, MoeLayer, Top2Router, NormalNoiseGenerator
from .util import moe_sa_args, moe_mlp_args
from ..helper import TransformerLayer
from colossalai.global_variables import moe_env
@ -110,7 +110,7 @@ class Widenet(nn.Module):
noisy_func = NormalNoiseGenerator(num_experts)
shared_router = Top2Router(capacity_factor, noisy_func=noisy_func)
shared_experts = FFNExperts(num_experts, d_model, d_ff, drop_rate=drop_rate)
shared_experts = build_ffn_experts(num_experts, d_model, d_ff, drop_rate=drop_rate)
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
@ -177,7 +177,7 @@ class ViTMoE(nn.Module):
ffn = VanillaFFN(**moe_mlp_args(
d_model=d_model, d_ff=d_ff, drop_rate=drop_rate)) if i % 2 == 0 else \
MoeLayer(dim_model=d_model, num_experts=num_experts, router=router,
experts=FFNExperts(num_experts, d_model, d_ff, drop_rate=drop_rate))
experts=build_ffn_experts(num_experts, d_model, d_ff, drop_rate=drop_rate))
layer = TransformerLayer(att=sa,
ffn=ffn,
norm1=nn.LayerNorm(d_model, eps=1e-6),

View File

@ -1,6 +1,4 @@
import os
from functools import partial
from pathlib import Path
import pytest
import torch
import torch.nn as nn
@ -9,10 +7,10 @@ import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Top2Router, MoeLayer
from colossalai.nn.layer.moe import Top2Router, MoeLayer, Experts
from colossalai.context.random import moe_set_seed
from colossalai.global_variables import moe_env
BATCH_SIZE = 32
NUM_EXPERTS = 4
CONFIG = dict(parallel=dict(moe=dict(size=4)))
@ -24,17 +22,17 @@ def check_equal(A, B, atol=1e-06):
def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
moe_set_seed(42)
# torch.set_printoptions(precision=30)
torch.backends.cuda.matmul.allow_tf32 = False
local_rank = gpc.get_local_rank(ParallelMode.GLOBAL)
torch.manual_seed(rs + local_rank)
moe_env.reset_loss()
tokens = torch.randn(BATCH_SIZE, hidden_size,
dtype=data_type, device=get_current_device(), requires_grad=True)
tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True)
# print(f"tokens:\n{tokens}")
router = Top2Router(1)
layer = MoeLayer(hidden_size, NUM_EXPERTS, router, nn.Identity())
expert = Experts(nn.Identity, 4)
layer = MoeLayer(hidden_size, NUM_EXPERTS, router, expert)
if data_type == torch.float16:
layer = layer.half()
layer.cuda_mode = False
@ -88,8 +86,12 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
def test_moe_top2(rs, hidden_size, data_type):
world_size = 4
run_func = partial(run_routing, world_size=world_size, port=free_port(),
rs=rs, hidden_size=hidden_size, data_type=data_type)
run_func = partial(run_routing,
world_size=world_size,
port=free_port(),
rs=rs,
hidden_size=hidden_size,
data_type=data_type)
mp.spawn(run_func, nprocs=world_size)