mirror of https://github.com/hpcaitech/ColossalAI
Added TPExpert for special situation
parent
36b8477228
commit
82023779bb
|
@ -1,5 +1,8 @@
|
|||
from .experts import Experts, FFNExperts
|
||||
from .experts import Experts, FFNExperts, TPExperts
|
||||
from .layers import MoeLayer, Top1Router, Top2Router
|
||||
from .utils import NormalNoiseGenerator
|
||||
from .utils import NormalNoiseGenerator, build_ffn_experts
|
||||
|
||||
__all__ = ['Experts', 'FFNExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator']
|
||||
__all__ = [
|
||||
'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator',
|
||||
'build_ffn_experts'
|
||||
]
|
||||
|
|
|
@ -15,6 +15,55 @@ except ImportError:
|
|||
print("If you want to activate cuda mode for MoE, please install with cuda_ext!")
|
||||
|
||||
|
||||
class AllGather(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, inputs: Tensor, parallel_mode: ParallelMode) -> Tensor:
|
||||
|
||||
if ctx is not None:
|
||||
ctx.parallel_mode = parallel_mode
|
||||
|
||||
comm_size = gpc.get_world_size(parallel_mode)
|
||||
if comm_size == 1:
|
||||
return inputs.unsqueeze(0)
|
||||
|
||||
buffer_shape = (comm_size,) + inputs.shape
|
||||
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
|
||||
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
|
||||
dist.all_gather(buffer_list, inputs, group=gpc.get_group(parallel_mode))
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
||||
return ReduceScatter.forward(None, grad_outputs, ctx.parallel_mode), None
|
||||
|
||||
|
||||
class ReduceScatter(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, inputs: Tensor, parallel_mode: ParallelMode) -> Tensor:
|
||||
|
||||
if ctx is not None:
|
||||
ctx.parallel_mode = parallel_mode
|
||||
|
||||
comm_size = gpc.get_world_size(parallel_mode)
|
||||
if comm_size == 1:
|
||||
return inputs.squeeze(0)
|
||||
|
||||
if not inputs.is_contiguous():
|
||||
inputs = inputs.contiguous()
|
||||
|
||||
output_shape = inputs.shape[1:]
|
||||
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
|
||||
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
|
||||
dist.reduce_scatter(outputs, buffer_list, group=gpc.get_group(parallel_mode))
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
|
||||
return AllGather.forward(None, grad_outputs, ctx.parallel_mode), None
|
||||
|
||||
|
||||
class AllToAll(torch.autograd.Function):
|
||||
"""Dispatches input tensor [e, c, h] to all experts by all_to_all_single
|
||||
operation in torch.distributed.
|
||||
|
|
|
@ -7,7 +7,16 @@ from colossalai.context import ParallelMode, seed
|
|||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class Experts(nn.Module):
|
||||
class MoeExperts(nn.Module):
|
||||
|
||||
def __init__(self, comm: str):
|
||||
super().__init__()
|
||||
assert comm in {"all_to_all", "all_gather"}, \
|
||||
"This kind of communication has not been implemented yet.\n Please use Experts build function."
|
||||
self.comm = comm
|
||||
|
||||
|
||||
class Experts(MoeExperts):
|
||||
"""A wrapper class to create experts. It will create E experts across the
|
||||
moe model parallel group, where E is the number of experts. Every expert
|
||||
is a instence of the class, 'expert' in initialization parameters.
|
||||
|
@ -20,19 +29,22 @@ class Experts(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, expert, num_experts, **expert_args):
|
||||
super().__init__()
|
||||
super().__init__("all_to_all")
|
||||
|
||||
assert num_experts % moe_env.model_parallel_size == 0, \
|
||||
"The number of experts should be divied by moe model size"
|
||||
|
||||
num_local_experts = num_experts // moe_env.model_parallel_size
|
||||
|
||||
with seed(ParallelMode.MOE_MODEL):
|
||||
self.experts = nn.ModuleList([expert(**expert_args) for _ in range(num_local_experts)])
|
||||
self.num_local_experts = num_local_experts
|
||||
|
||||
for exp in self.experts:
|
||||
for param in exp.parameters():
|
||||
param.__setattr__('moe_param', True)
|
||||
|
||||
self.num_local_experts = num_local_experts
|
||||
|
||||
def forward(self, inputs):
|
||||
expert_input = torch.chunk(inputs, self.num_local_experts, dim=1)
|
||||
expert_output = []
|
||||
|
@ -44,10 +56,10 @@ class Experts(nn.Module):
|
|||
return output
|
||||
|
||||
|
||||
class FFNExperts(nn.Module):
|
||||
class FFNExperts(MoeExperts):
|
||||
|
||||
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
|
||||
super().__init__()
|
||||
super().__init__("all_to_all")
|
||||
|
||||
assert num_experts % moe_env.model_parallel_size == 0, \
|
||||
"The number of experts should be divied by moe model size"
|
||||
|
@ -75,7 +87,7 @@ class FFNExperts(nn.Module):
|
|||
for param in self.parameters():
|
||||
param.__setattr__('moe_param', True)
|
||||
|
||||
def forward(self, inputs): # x [g, el, c, h]
|
||||
def forward(self, inputs): # inputs [g, el, c, h]
|
||||
|
||||
el = inputs.size(1)
|
||||
h = inputs.size(-1)
|
||||
|
@ -96,3 +108,58 @@ class FFNExperts(nn.Module):
|
|||
outputs = outputs.reshape(inshape)
|
||||
outputs = outputs.transpose(0, 1).contiguous()
|
||||
return outputs
|
||||
|
||||
|
||||
class TPExperts(MoeExperts):
|
||||
|
||||
def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
|
||||
super().__init__("all_gather")
|
||||
|
||||
assert d_ff % moe_env.model_parallel_size == 0, \
|
||||
"d_ff should be divied by moe model size"
|
||||
|
||||
p_ff = d_ff // moe_env.model_parallel_size
|
||||
|
||||
self.w1 = nn.Parameter(torch.empty(num_experts, d_model, p_ff, device=get_current_device()))
|
||||
self.b1 = nn.Parameter(torch.empty(num_experts, 1, p_ff, device=get_current_device()))
|
||||
|
||||
self.w2 = nn.Parameter(torch.empty(num_experts, p_ff, d_model, device=get_current_device()))
|
||||
self.b2 = nn.Parameter(torch.empty(num_experts, 1, d_model, device=get_current_device()))
|
||||
|
||||
s1 = math.sqrt(0.1 / d_model)
|
||||
s2 = math.sqrt(0.1 / d_ff)
|
||||
|
||||
with seed(ParallelMode.MOE_MODEL):
|
||||
nn.init.trunc_normal_(self.w1, std=s1)
|
||||
nn.init.trunc_normal_(self.b1, std=s1)
|
||||
nn.init.trunc_normal_(self.w2, std=s2)
|
||||
|
||||
nn.init.trunc_normal_(self.b2, std=s2)
|
||||
|
||||
self.act = nn.GELU() if activation is None else activation
|
||||
self.drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
self.w1.__setattr__('moe_param', True)
|
||||
self.w2.__setattr__('moe_param', True)
|
||||
self.b1.__setattr__('moe_param', True)
|
||||
|
||||
def forward(self, inputs): # inputs [g, e, c, h]
|
||||
|
||||
e = inputs.size(1)
|
||||
h = inputs.size(-1)
|
||||
|
||||
inputs = inputs.transpose(0, 1)
|
||||
inshape = inputs.shape
|
||||
inputs = inputs.reshape(e, -1, h)
|
||||
|
||||
out_ff = torch.baddbmm(self.b1, inputs, self.w1)
|
||||
out_act = self.act(out_ff)
|
||||
with seed(ParallelMode.TENSOR):
|
||||
inter = self.drop(out_act)
|
||||
|
||||
out_model = torch.baddbmm(self.b2, inter, self.w2)
|
||||
outputs = self.drop(out_model) # outputs [e, gc, h]
|
||||
|
||||
outputs = outputs.reshape(inshape)
|
||||
outputs = outputs.transpose(0, 1).contiguous()
|
||||
return outputs # outputs [g, e, c, h]
|
||||
|
|
|
@ -8,7 +8,8 @@ from colossalai.core import global_context as gpc
|
|||
from colossalai.global_variables import moe_env
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.utils import get_current_device
|
||||
from ._operation import U_CUDA_MODE, AllToAll, MoeDispatch, MoeCombine, moe_cumsum
|
||||
from ._operation import U_CUDA_MODE, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum
|
||||
from .experts import MoeExperts
|
||||
from .utils import autocast_softmax
|
||||
|
||||
|
||||
|
@ -198,7 +199,7 @@ class MoeLayer(nn.Module):
|
|||
:type experts: nn.Module
|
||||
"""
|
||||
|
||||
def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: nn.Module):
|
||||
def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: MoeExperts):
|
||||
super().__init__()
|
||||
self.d_model = dim_model
|
||||
self.num_experts = num_experts
|
||||
|
@ -207,8 +208,8 @@ class MoeLayer(nn.Module):
|
|||
self.experts = experts
|
||||
self.cuda_mode = True if U_CUDA_MODE and moe_env.enable_cuda else False
|
||||
|
||||
def expert_part(self, expert_input: torch.Tensor):
|
||||
expert_input = AllToAll.apply(expert_input, ParallelMode.MOE_MODEL)
|
||||
def a2a_process(self, dispatch_data: torch.Tensor):
|
||||
expert_input = AllToAll.apply(dispatch_data, ParallelMode.MOE_MODEL)
|
||||
|
||||
input_shape = expert_input.shape
|
||||
|
||||
|
@ -221,24 +222,42 @@ class MoeLayer(nn.Module):
|
|||
expert_output = AllToAll.apply(expert_output, ParallelMode.MOE_MODEL)
|
||||
return expert_output
|
||||
|
||||
def tp_process(self, dispatch_data: torch.Tensor):
|
||||
expert_in = AllGather.apply(dispatch_data, ParallelMode.MOE_MODEL)
|
||||
expert_out = self.experts(expert_in)
|
||||
expert_out = ReduceScatter.apply(expert_out, ParallelMode.MOE_MODEL)
|
||||
return expert_out
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
tokens = inputs.reshape(-1, self.d_model)
|
||||
gate_output = self.gate(tokens)
|
||||
router_res = self.router(gate_output, self.cuda_mode)
|
||||
|
||||
if self.cuda_mode:
|
||||
logits, mask, dest_idx, ec = router_res
|
||||
expert_input = MoeDispatch.apply(tokens, mask, dest_idx, ec)
|
||||
expert_output = self.expert_part(expert_input)
|
||||
ret = MoeCombine.apply(expert_output, logits, mask, dest_idx, ec)
|
||||
dispatch_data = MoeDispatch.apply(tokens, *router_res[1:])
|
||||
dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model)
|
||||
else:
|
||||
combine_weights, sec_mask = router_res
|
||||
sec_mask_f = sec_mask.type_as(inputs)
|
||||
expert_input = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
|
||||
expert_output = self.expert_part(expert_input)
|
||||
sec_mask_f = router_res[1].type_as(inputs)
|
||||
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
|
||||
|
||||
# dispatch_data [e, c, h]
|
||||
if self.experts.comm == "all_to_all":
|
||||
expert_output = self.a2a_process(dispatch_data)
|
||||
elif self.experts.comm == "all_gather":
|
||||
expert_output = self.tp_process(dispatch_data)
|
||||
else:
|
||||
raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts "
|
||||
"build function.")
|
||||
# expert_output [e, c, h]
|
||||
|
||||
if self.cuda_mode:
|
||||
expert_output = expert_output.reshape(-1, self.d_model)
|
||||
ans = MoeCombine.apply(expert_output, *router_res)
|
||||
else:
|
||||
combine_weights = router_res[0]
|
||||
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
|
||||
expert_output = expert_output.view(-1, expert_output.shape[-1])
|
||||
ret = torch.matmul(combine_weights, expert_output)
|
||||
ans = torch.matmul(combine_weights, expert_output)
|
||||
|
||||
ret = ret.reshape(inputs.shape)
|
||||
return ret
|
||||
ans = ans.reshape(inputs.shape)
|
||||
return ans
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.global_variables import moe_env
|
||||
from .experts import FFNExperts, TPExperts
|
||||
|
||||
|
||||
class NormalNoiseGenerator:
|
||||
|
@ -14,10 +16,9 @@ class NormalNoiseGenerator:
|
|||
"""
|
||||
|
||||
def __init__(self, num_experts: int):
|
||||
self.normal = torch.distributions.normal.Normal(
|
||||
loc=torch.tensor(0.0, device=get_current_device()),
|
||||
scale=torch.tensor(1.0 / num_experts ** 2, device=get_current_device())
|
||||
).rsample
|
||||
self.normal = torch.distributions.normal.Normal(loc=torch.tensor(0.0, device=get_current_device()),
|
||||
scale=torch.tensor(1.0 / num_experts**2,
|
||||
device=get_current_device())).rsample
|
||||
|
||||
def __call__(self, inputs: torch.Tensor):
|
||||
noisy = self.normal(inputs.shape)
|
||||
|
@ -30,3 +31,13 @@ def autocast_softmax(inputs: torch.Tensor, dim: int):
|
|||
sm_input = inputs.to(torch.float32) if fp16_flag else inputs
|
||||
sm_output = F.softmax(sm_input, dim)
|
||||
return sm_output
|
||||
|
||||
|
||||
def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
|
||||
moe_mp_size = moe_env.model_parallel_size
|
||||
if num_experts % moe_mp_size == 0:
|
||||
return FFNExperts(num_experts, d_model, d_ff, activation, drop_rate)
|
||||
elif d_ff % moe_mp_size == 0:
|
||||
return TPExperts(num_experts, d_model, d_ff, activation, drop_rate)
|
||||
else:
|
||||
raise NotImplementedError(f"Can not build {num_experts} experts in {moe_mp_size} GPUS.")
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch.nn as nn
|
|||
from colossalai.context import ParallelMode
|
||||
from colossalai.nn.layer import VanillaPatchEmbedding, VanillaClassifier, \
|
||||
WrappedDropout as Dropout, WrappedDropPath as DropPath
|
||||
from colossalai.nn.layer.moe import FFNExperts, MoeLayer, Top2Router, NormalNoiseGenerator
|
||||
from colossalai.nn.layer.moe import build_ffn_experts, MoeLayer, Top2Router, NormalNoiseGenerator
|
||||
from .util import moe_sa_args, moe_mlp_args
|
||||
from ..helper import TransformerLayer
|
||||
from colossalai.global_variables import moe_env
|
||||
|
@ -110,7 +110,7 @@ class Widenet(nn.Module):
|
|||
|
||||
noisy_func = NormalNoiseGenerator(num_experts)
|
||||
shared_router = Top2Router(capacity_factor, noisy_func=noisy_func)
|
||||
shared_experts = FFNExperts(num_experts, d_model, d_ff, drop_rate=drop_rate)
|
||||
shared_experts = build_ffn_experts(num_experts, d_model, d_ff, drop_rate=drop_rate)
|
||||
|
||||
# stochastic depth decay rule
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
|
||||
|
@ -177,7 +177,7 @@ class ViTMoE(nn.Module):
|
|||
ffn = VanillaFFN(**moe_mlp_args(
|
||||
d_model=d_model, d_ff=d_ff, drop_rate=drop_rate)) if i % 2 == 0 else \
|
||||
MoeLayer(dim_model=d_model, num_experts=num_experts, router=router,
|
||||
experts=FFNExperts(num_experts, d_model, d_ff, drop_rate=drop_rate))
|
||||
experts=build_ffn_experts(num_experts, d_model, d_ff, drop_rate=drop_rate))
|
||||
layer = TransformerLayer(att=sa,
|
||||
ffn=ffn,
|
||||
norm1=nn.LayerNorm(d_model, eps=1e-6),
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -9,10 +7,10 @@ import colossalai
|
|||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.nn.layer.moe import Top2Router, MoeLayer
|
||||
from colossalai.nn.layer.moe import Top2Router, MoeLayer, Experts
|
||||
from colossalai.context.random import moe_set_seed
|
||||
from colossalai.global_variables import moe_env
|
||||
|
||||
|
||||
BATCH_SIZE = 32
|
||||
NUM_EXPERTS = 4
|
||||
CONFIG = dict(parallel=dict(moe=dict(size=4)))
|
||||
|
@ -24,17 +22,17 @@ def check_equal(A, B, atol=1e-06):
|
|||
|
||||
def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
moe_set_seed(42)
|
||||
# torch.set_printoptions(precision=30)
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
local_rank = gpc.get_local_rank(ParallelMode.GLOBAL)
|
||||
torch.manual_seed(rs + local_rank)
|
||||
moe_env.reset_loss()
|
||||
tokens = torch.randn(BATCH_SIZE, hidden_size,
|
||||
dtype=data_type, device=get_current_device(), requires_grad=True)
|
||||
tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True)
|
||||
# print(f"tokens:\n{tokens}")
|
||||
router = Top2Router(1)
|
||||
layer = MoeLayer(hidden_size, NUM_EXPERTS, router, nn.Identity())
|
||||
expert = Experts(nn.Identity, 4)
|
||||
layer = MoeLayer(hidden_size, NUM_EXPERTS, router, expert)
|
||||
if data_type == torch.float16:
|
||||
layer = layer.half()
|
||||
layer.cuda_mode = False
|
||||
|
@ -88,8 +86,12 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
|
|||
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
|
||||
def test_moe_top2(rs, hidden_size, data_type):
|
||||
world_size = 4
|
||||
run_func = partial(run_routing, world_size=world_size, port=free_port(),
|
||||
rs=rs, hidden_size=hidden_size, data_type=data_type)
|
||||
run_func = partial(run_routing,
|
||||
world_size=world_size,
|
||||
port=free_port(),
|
||||
rs=rs,
|
||||
hidden_size=hidden_size,
|
||||
data_type=data_type)
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue