Added TPExpert for special situation

pull/394/head
1SAA 2022-02-27 22:28:39 +08:00 committed by Frank Lee
parent 36b8477228
commit 82023779bb
7 changed files with 192 additions and 41 deletions

View File

@ -1,5 +1,8 @@
from .experts import Experts, FFNExperts from .experts import Experts, FFNExperts, TPExperts
from .layers import MoeLayer, Top1Router, Top2Router from .layers import MoeLayer, Top1Router, Top2Router
from .utils import NormalNoiseGenerator from .utils import NormalNoiseGenerator, build_ffn_experts
__all__ = ['Experts', 'FFNExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator'] __all__ = [
'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator',
'build_ffn_experts'
]

View File

@ -15,6 +15,55 @@ except ImportError:
print("If you want to activate cuda mode for MoE, please install with cuda_ext!") print("If you want to activate cuda mode for MoE, please install with cuda_ext!")
class AllGather(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, inputs: Tensor, parallel_mode: ParallelMode) -> Tensor:
if ctx is not None:
ctx.parallel_mode = parallel_mode
comm_size = gpc.get_world_size(parallel_mode)
if comm_size == 1:
return inputs.unsqueeze(0)
buffer_shape = (comm_size,) + inputs.shape
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
dist.all_gather(buffer_list, inputs, group=gpc.get_group(parallel_mode))
return outputs
@staticmethod
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
return ReduceScatter.forward(None, grad_outputs, ctx.parallel_mode), None
class ReduceScatter(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, inputs: Tensor, parallel_mode: ParallelMode) -> Tensor:
if ctx is not None:
ctx.parallel_mode = parallel_mode
comm_size = gpc.get_world_size(parallel_mode)
if comm_size == 1:
return inputs.squeeze(0)
if not inputs.is_contiguous():
inputs = inputs.contiguous()
output_shape = inputs.shape[1:]
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
dist.reduce_scatter(outputs, buffer_list, group=gpc.get_group(parallel_mode))
return outputs
@staticmethod
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
return AllGather.forward(None, grad_outputs, ctx.parallel_mode), None
class AllToAll(torch.autograd.Function): class AllToAll(torch.autograd.Function):
"""Dispatches input tensor [e, c, h] to all experts by all_to_all_single """Dispatches input tensor [e, c, h] to all experts by all_to_all_single
operation in torch.distributed. operation in torch.distributed.

View File

@ -7,7 +7,16 @@ from colossalai.context import ParallelMode, seed
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
class Experts(nn.Module): class MoeExperts(nn.Module):
def __init__(self, comm: str):
super().__init__()
assert comm in {"all_to_all", "all_gather"}, \
"This kind of communication has not been implemented yet.\n Please use Experts build function."
self.comm = comm
class Experts(MoeExperts):
"""A wrapper class to create experts. It will create E experts across the """A wrapper class to create experts. It will create E experts across the
moe model parallel group, where E is the number of experts. Every expert moe model parallel group, where E is the number of experts. Every expert
is a instence of the class, 'expert' in initialization parameters. is a instence of the class, 'expert' in initialization parameters.
@ -20,19 +29,22 @@ class Experts(nn.Module):
""" """
def __init__(self, expert, num_experts, **expert_args): def __init__(self, expert, num_experts, **expert_args):
super().__init__() super().__init__("all_to_all")
assert num_experts % moe_env.model_parallel_size == 0, \ assert num_experts % moe_env.model_parallel_size == 0, \
"The number of experts should be divied by moe model size" "The number of experts should be divied by moe model size"
num_local_experts = num_experts // moe_env.model_parallel_size num_local_experts = num_experts // moe_env.model_parallel_size
with seed(ParallelMode.MOE_MODEL): with seed(ParallelMode.MOE_MODEL):
self.experts = nn.ModuleList([expert(**expert_args) for _ in range(num_local_experts)]) self.experts = nn.ModuleList([expert(**expert_args) for _ in range(num_local_experts)])
self.num_local_experts = num_local_experts
for exp in self.experts: for exp in self.experts:
for param in exp.parameters(): for param in exp.parameters():
param.__setattr__('moe_param', True) param.__setattr__('moe_param', True)
self.num_local_experts = num_local_experts
def forward(self, inputs): def forward(self, inputs):
expert_input = torch.chunk(inputs, self.num_local_experts, dim=1) expert_input = torch.chunk(inputs, self.num_local_experts, dim=1)
expert_output = [] expert_output = []
@ -44,10 +56,10 @@ class Experts(nn.Module):
return output return output
class FFNExperts(nn.Module): class FFNExperts(MoeExperts):
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
super().__init__() super().__init__("all_to_all")
assert num_experts % moe_env.model_parallel_size == 0, \ assert num_experts % moe_env.model_parallel_size == 0, \
"The number of experts should be divied by moe model size" "The number of experts should be divied by moe model size"
@ -75,7 +87,7 @@ class FFNExperts(nn.Module):
for param in self.parameters(): for param in self.parameters():
param.__setattr__('moe_param', True) param.__setattr__('moe_param', True)
def forward(self, inputs): # x [g, el, c, h] def forward(self, inputs): # inputs [g, el, c, h]
el = inputs.size(1) el = inputs.size(1)
h = inputs.size(-1) h = inputs.size(-1)
@ -96,3 +108,58 @@ class FFNExperts(nn.Module):
outputs = outputs.reshape(inshape) outputs = outputs.reshape(inshape)
outputs = outputs.transpose(0, 1).contiguous() outputs = outputs.transpose(0, 1).contiguous()
return outputs return outputs
class TPExperts(MoeExperts):
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
super().__init__("all_gather")
assert d_ff % moe_env.model_parallel_size == 0, \
"d_ff should be divied by moe model size"
p_ff = d_ff // moe_env.model_parallel_size
self.w1 = nn.Parameter(torch.empty(num_experts, d_model, p_ff, device=get_current_device()))
self.b1 = nn.Parameter(torch.empty(num_experts, 1, p_ff, device=get_current_device()))
self.w2 = nn.Parameter(torch.empty(num_experts, p_ff, d_model, device=get_current_device()))
self.b2 = nn.Parameter(torch.empty(num_experts, 1, d_model, device=get_current_device()))
s1 = math.sqrt(0.1 / d_model)
s2 = math.sqrt(0.1 / d_ff)
with seed(ParallelMode.MOE_MODEL):
nn.init.trunc_normal_(self.w1, std=s1)
nn.init.trunc_normal_(self.b1, std=s1)
nn.init.trunc_normal_(self.w2, std=s2)
nn.init.trunc_normal_(self.b2, std=s2)
self.act = nn.GELU() if activation is None else activation
self.drop = nn.Dropout(p=drop_rate)
self.w1.__setattr__('moe_param', True)
self.w2.__setattr__('moe_param', True)
self.b1.__setattr__('moe_param', True)
def forward(self, inputs): # inputs [g, e, c, h]
e = inputs.size(1)
h = inputs.size(-1)
inputs = inputs.transpose(0, 1)
inshape = inputs.shape
inputs = inputs.reshape(e, -1, h)
out_ff = torch.baddbmm(self.b1, inputs, self.w1)
out_act = self.act(out_ff)
with seed(ParallelMode.TENSOR):
inter = self.drop(out_act)
out_model = torch.baddbmm(self.b2, inter, self.w2)
outputs = self.drop(out_model) # outputs [e, gc, h]
outputs = outputs.reshape(inshape)
outputs = outputs.transpose(0, 1).contiguous()
return outputs # outputs [g, e, c, h]

View File

@ -8,7 +8,8 @@ from colossalai.core import global_context as gpc
from colossalai.global_variables import moe_env from colossalai.global_variables import moe_env
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from ._operation import U_CUDA_MODE, AllToAll, MoeDispatch, MoeCombine, moe_cumsum from ._operation import U_CUDA_MODE, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum
from .experts import MoeExperts
from .utils import autocast_softmax from .utils import autocast_softmax
@ -198,7 +199,7 @@ class MoeLayer(nn.Module):
:type experts: nn.Module :type experts: nn.Module
""" """
def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: nn.Module): def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: MoeExperts):
super().__init__() super().__init__()
self.d_model = dim_model self.d_model = dim_model
self.num_experts = num_experts self.num_experts = num_experts
@ -207,8 +208,8 @@ class MoeLayer(nn.Module):
self.experts = experts self.experts = experts
self.cuda_mode = True if U_CUDA_MODE and moe_env.enable_cuda else False self.cuda_mode = True if U_CUDA_MODE and moe_env.enable_cuda else False
def expert_part(self, expert_input: torch.Tensor): def a2a_process(self, dispatch_data: torch.Tensor):
expert_input = AllToAll.apply(expert_input, ParallelMode.MOE_MODEL) expert_input = AllToAll.apply(dispatch_data, ParallelMode.MOE_MODEL)
input_shape = expert_input.shape input_shape = expert_input.shape
@ -221,24 +222,42 @@ class MoeLayer(nn.Module):
expert_output = AllToAll.apply(expert_output, ParallelMode.MOE_MODEL) expert_output = AllToAll.apply(expert_output, ParallelMode.MOE_MODEL)
return expert_output return expert_output
def tp_process(self, dispatch_data: torch.Tensor):
expert_in = AllGather.apply(dispatch_data, ParallelMode.MOE_MODEL)
expert_out = self.experts(expert_in)
expert_out = ReduceScatter.apply(expert_out, ParallelMode.MOE_MODEL)
return expert_out
def forward(self, inputs: torch.Tensor) -> torch.Tensor: def forward(self, inputs: torch.Tensor) -> torch.Tensor:
tokens = inputs.reshape(-1, self.d_model) tokens = inputs.reshape(-1, self.d_model)
gate_output = self.gate(tokens) gate_output = self.gate(tokens)
router_res = self.router(gate_output, self.cuda_mode) router_res = self.router(gate_output, self.cuda_mode)
if self.cuda_mode: if self.cuda_mode:
logits, mask, dest_idx, ec = router_res dispatch_data = MoeDispatch.apply(tokens, *router_res[1:])
expert_input = MoeDispatch.apply(tokens, mask, dest_idx, ec) dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model)
expert_output = self.expert_part(expert_input)
ret = MoeCombine.apply(expert_output, logits, mask, dest_idx, ec)
else: else:
combine_weights, sec_mask = router_res sec_mask_f = router_res[1].type_as(inputs)
sec_mask_f = sec_mask.type_as(inputs) dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
expert_input = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
expert_output = self.expert_part(expert_input) # dispatch_data [e, c, h]
if self.experts.comm == "all_to_all":
expert_output = self.a2a_process(dispatch_data)
elif self.experts.comm == "all_gather":
expert_output = self.tp_process(dispatch_data)
else:
raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts "
"build function.")
# expert_output [e, c, h]
if self.cuda_mode:
expert_output = expert_output.reshape(-1, self.d_model)
ans = MoeCombine.apply(expert_output, *router_res)
else:
combine_weights = router_res[0]
combine_weights = combine_weights.view(combine_weights.shape[0], -1) combine_weights = combine_weights.view(combine_weights.shape[0], -1)
expert_output = expert_output.view(-1, expert_output.shape[-1]) expert_output = expert_output.view(-1, expert_output.shape[-1])
ret = torch.matmul(combine_weights, expert_output) ans = torch.matmul(combine_weights, expert_output)
ret = ret.reshape(inputs.shape) ans = ans.reshape(inputs.shape)
return ret return ans

View File

@ -1,6 +1,8 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.global_variables import moe_env
from .experts import FFNExperts, TPExperts
class NormalNoiseGenerator: class NormalNoiseGenerator:
@ -14,10 +16,9 @@ class NormalNoiseGenerator:
""" """
def __init__(self, num_experts: int): def __init__(self, num_experts: int):
self.normal = torch.distributions.normal.Normal( self.normal = torch.distributions.normal.Normal(loc=torch.tensor(0.0, device=get_current_device()),
loc=torch.tensor(0.0, device=get_current_device()), scale=torch.tensor(1.0 / num_experts**2,
scale=torch.tensor(1.0 / num_experts ** 2, device=get_current_device()) device=get_current_device())).rsample
).rsample
def __call__(self, inputs: torch.Tensor): def __call__(self, inputs: torch.Tensor):
noisy = self.normal(inputs.shape) noisy = self.normal(inputs.shape)
@ -30,3 +31,13 @@ def autocast_softmax(inputs: torch.Tensor, dim: int):
sm_input = inputs.to(torch.float32) if fp16_flag else inputs sm_input = inputs.to(torch.float32) if fp16_flag else inputs
sm_output = F.softmax(sm_input, dim) sm_output = F.softmax(sm_input, dim)
return sm_output return sm_output
def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
moe_mp_size = moe_env.model_parallel_size
if num_experts % moe_mp_size == 0:
return FFNExperts(num_experts, d_model, d_ff, activation, drop_rate)
elif d_ff % moe_mp_size == 0:
return TPExperts(num_experts, d_model, d_ff, activation, drop_rate)
else:
raise NotImplementedError(f"Can not build {num_experts} experts in {moe_mp_size} GPUS.")

View File

@ -4,7 +4,7 @@ import torch.nn as nn
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.nn.layer import VanillaPatchEmbedding, VanillaClassifier, \ from colossalai.nn.layer import VanillaPatchEmbedding, VanillaClassifier, \
WrappedDropout as Dropout, WrappedDropPath as DropPath WrappedDropout as Dropout, WrappedDropPath as DropPath
from colossalai.nn.layer.moe import FFNExperts, MoeLayer, Top2Router, NormalNoiseGenerator from colossalai.nn.layer.moe import build_ffn_experts, MoeLayer, Top2Router, NormalNoiseGenerator
from .util import moe_sa_args, moe_mlp_args from .util import moe_sa_args, moe_mlp_args
from ..helper import TransformerLayer from ..helper import TransformerLayer
from colossalai.global_variables import moe_env from colossalai.global_variables import moe_env
@ -110,7 +110,7 @@ class Widenet(nn.Module):
noisy_func = NormalNoiseGenerator(num_experts) noisy_func = NormalNoiseGenerator(num_experts)
shared_router = Top2Router(capacity_factor, noisy_func=noisy_func) shared_router = Top2Router(capacity_factor, noisy_func=noisy_func)
shared_experts = FFNExperts(num_experts, d_model, d_ff, drop_rate=drop_rate) shared_experts = build_ffn_experts(num_experts, d_model, d_ff, drop_rate=drop_rate)
# stochastic depth decay rule # stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)] dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
@ -177,7 +177,7 @@ class ViTMoE(nn.Module):
ffn = VanillaFFN(**moe_mlp_args( ffn = VanillaFFN(**moe_mlp_args(
d_model=d_model, d_ff=d_ff, drop_rate=drop_rate)) if i % 2 == 0 else \ d_model=d_model, d_ff=d_ff, drop_rate=drop_rate)) if i % 2 == 0 else \
MoeLayer(dim_model=d_model, num_experts=num_experts, router=router, MoeLayer(dim_model=d_model, num_experts=num_experts, router=router,
experts=FFNExperts(num_experts, d_model, d_ff, drop_rate=drop_rate)) experts=build_ffn_experts(num_experts, d_model, d_ff, drop_rate=drop_rate))
layer = TransformerLayer(att=sa, layer = TransformerLayer(att=sa,
ffn=ffn, ffn=ffn,
norm1=nn.LayerNorm(d_model, eps=1e-6), norm1=nn.LayerNorm(d_model, eps=1e-6),

View File

@ -1,6 +1,4 @@
import os
from functools import partial from functools import partial
from pathlib import Path
import pytest import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -9,10 +7,10 @@ import colossalai
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.utils import free_port, get_current_device from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Top2Router, MoeLayer from colossalai.nn.layer.moe import Top2Router, MoeLayer, Experts
from colossalai.context.random import moe_set_seed
from colossalai.global_variables import moe_env from colossalai.global_variables import moe_env
BATCH_SIZE = 32 BATCH_SIZE = 32
NUM_EXPERTS = 4 NUM_EXPERTS = 4
CONFIG = dict(parallel=dict(moe=dict(size=4))) CONFIG = dict(parallel=dict(moe=dict(size=4)))
@ -24,17 +22,17 @@ def check_equal(A, B, atol=1e-06):
def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32): def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
moe_set_seed(42)
# torch.set_printoptions(precision=30) # torch.set_printoptions(precision=30)
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
local_rank = gpc.get_local_rank(ParallelMode.GLOBAL) local_rank = gpc.get_local_rank(ParallelMode.GLOBAL)
torch.manual_seed(rs + local_rank) torch.manual_seed(rs + local_rank)
moe_env.reset_loss() moe_env.reset_loss()
tokens = torch.randn(BATCH_SIZE, hidden_size, tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True)
dtype=data_type, device=get_current_device(), requires_grad=True)
# print(f"tokens:\n{tokens}") # print(f"tokens:\n{tokens}")
router = Top2Router(1) router = Top2Router(1)
layer = MoeLayer(hidden_size, NUM_EXPERTS, router, nn.Identity()) expert = Experts(nn.Identity, 4)
layer = MoeLayer(hidden_size, NUM_EXPERTS, router, expert)
if data_type == torch.float16: if data_type == torch.float16:
layer = layer.half() layer = layer.half()
layer.cuda_mode = False layer.cuda_mode = False
@ -88,8 +86,12 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16]) @pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
def test_moe_top2(rs, hidden_size, data_type): def test_moe_top2(rs, hidden_size, data_type):
world_size = 4 world_size = 4
run_func = partial(run_routing, world_size=world_size, port=free_port(), run_func = partial(run_routing,
rs=rs, hidden_size=hidden_size, data_type=data_type) world_size=world_size,
port=free_port(),
rs=rs,
hidden_size=hidden_size,
data_type=data_type)
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)