[MOE] support PR-MOE (#488)

pull/492/head
HELSON 2022-03-22 16:48:22 +08:00 committed by GitHub
parent a9ecb4b244
commit c9023d4078
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 133 additions and 18 deletions

View File

@ -1,8 +1,8 @@
from .experts import Experts, FFNExperts, TPExperts
from .layers import MoeLayer, Top1Router, Top2Router
from .layers import MoeLayer, Top1Router, Top2Router, MoeModule
from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts
__all__ = [
'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator',
'UniformNoiseGenerator', 'build_ffn_experts'
'UniformNoiseGenerator', 'build_ffn_experts', 'MoeModule'
]

View File

@ -7,9 +7,9 @@ import torch.distributed as dist
from colossalai.core import MOE_CONTEXT
from colossalai.utils import get_current_device
from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum
from .experts import MoeExperts
from .utils import ForceFP32Parameter
from typing import Callable, Optional
from .experts import MoeExperts, Experts
from .utils import ForceFP32Parameter, UniformNoiseGenerator, NormalNoiseGenerator
from typing import Callable, Optional, Type
from torch.distributed import ProcessGroup
@ -315,3 +315,100 @@ class MoeLayer(nn.Module):
ans = ans.reshape(inputs.shape)
return ans
class MoeModule(nn.Module):
"""A class for users to create MoE modules in their models.
Args:
dim_model (int): Hidden dimension of training model
num_experts (int): The number experts
top_k (int, optional): The number of experts for dispatchment of each token
capacity_factor_train (float, optional): Capacity factor in routing during training
capacity_factor_eval (float, optional): Capacity factor in routing during evaluation
min_capacity (int, optional): The minimum number of the capacity of each expert
noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'.
'Jitter' can be found in Switch Transformer paper (https://arxiv.org/abs/2101.03961).
'Gaussian' can be found in ViT-MoE paper (https://arxiv.org/abs/2106.05974).
drop_tks (bool, optional): Whether drops tokens in evaluation
use_residual (bool, optional): Makes this MoE layer a Residual MoE.
More information can be found in Microsoft paper (https://arxiv.org/abs/2201.05596).
residual_instance (nn.Module, optional): The instance of residual module in Resiual MoE
expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer
expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given
expert_args (optional): The args of expert when no instance is given
"""
def __init__(self,
dim_model: int,
num_experts: int,
top_k: int = 1,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_policy: Optional[str] = None,
drop_tks: bool = True,
use_residual: bool = False,
residual_instance: Optional[nn.Module] = None,
expert_instance: Optional[MoeExperts] = None,
expert_cls: Optional[Type[nn.Module]] = None,
**expert_args):
super().__init__()
noisy_func = None
if noisy_policy is not None:
if noisy_policy == 'Jitter':
noisy_func = UniformNoiseGenerator()
elif noisy_policy == 'Gaussian':
noisy_func = NormalNoiseGenerator(num_experts)
else:
raise NotImplementedError("Unsupported input noisy policy")
if top_k == 1:
moe_router_cls = Top1Router
elif top_k == 2:
moe_router_cls = Top2Router
else:
raise NotImplementedError("top_k > 2 is not supported yet")
self.moe_router = moe_router_cls(capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks)
self.use_residual = use_residual
if use_residual:
if residual_instance is not None:
self.residual_module = residual_instance
else:
assert expert_cls is not None, \
"Expert class can't be None when residual instance is not given"
self.residual_module = expert_cls(**expert_args)
self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device())
if expert_instance is not None:
self.experts = expert_instance
else:
assert expert_cls is not None, \
"Expert class can't be None when experts instance is not given"
self.experts = Experts(expert_cls, num_experts, **expert_args)
self.moe_layer = MoeLayer(dim_model=dim_model,
num_experts=num_experts,
router=self.moe_router,
experts=self.experts)
def forward(self, inputs: torch.Tensor):
moe_output = self.moe_layer(inputs)
if self.use_residual:
residual_output = self.residual_module(inputs)
combine_coef = self.residual_combine(inputs)
combine_coef = F.softmax(combine_coef, dim=-1)
output = moe_output * combine_coef[..., 0:1] + residual_output * combine_coef[..., 1:]
else:
output = moe_output
return output

View File

@ -4,11 +4,12 @@ import torch.nn as nn
from colossalai.context import ParallelMode
from colossalai.nn.layer import VanillaPatchEmbedding, VanillaClassifier, \
WrappedDropout as Dropout, WrappedDropPath as DropPath
from colossalai.nn.layer.moe import build_ffn_experts, MoeLayer, Top2Router, NormalNoiseGenerator
from colossalai.nn.layer.moe import build_ffn_experts, MoeLayer, Top2Router, NormalNoiseGenerator, MoeModule
from .util import moe_sa_args, moe_mlp_args
from ..helper import TransformerLayer
from colossalai.core import MOE_CONTEXT
from colossalai.utils import get_current_device
from typing import List
class VanillaSelfAttention(nn.Module):
@ -146,7 +147,8 @@ class Widenet(nn.Module):
class ViTMoE(nn.Module):
def __init__(self,
num_experts: int,
num_experts: int or List[int],
use_residual: bool = False,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
drop_tks: bool = True,
@ -164,29 +166,45 @@ class ViTMoE(nn.Module):
drop_path: float = 0.):
super().__init__()
assert depth % 2 == 0, "The number of layers should be even right now"
if isinstance(num_experts, list):
assert len(num_experts) == depth // 2, \
"The length of num_experts should equal to the number of MOE layers"
num_experts_list = num_experts
else:
num_experts_list = [num_experts] * (depth // 2)
embedding = VanillaPatchEmbedding(img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_size=d_model)
embed_dropout = Dropout(p=drop_rate, mode=ParallelMode.TENSOR)
noisy_func = NormalNoiseGenerator(num_experts)
router = Top2Router(capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
noisy_func=noisy_func,
drop_tks=drop_tks)
assert depth % 2 == 0
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
blocks = []
for i in range(depth):
sa = VanillaSelfAttention(**moe_sa_args(
d_model=d_model, n_heads=num_heads, d_kv=d_kv, attention_drop=attention_drop, drop_rate=drop_rate))
ffn = VanillaFFN(**moe_mlp_args(
d_model=d_model, d_ff=d_ff, drop_rate=drop_rate)) if i % 2 == 0 else \
MoeLayer(dim_model=d_model, num_experts=num_experts, router=router,
experts=build_ffn_experts(num_experts, d_model, d_ff, drop_rate=drop_rate))
if i % 2 == 0:
ffn = VanillaFFN(**moe_mlp_args(d_model=d_model, d_ff=d_ff, drop_rate=drop_rate))
else:
num_experts = num_experts_list[i // 2]
experts = build_ffn_experts(num_experts, d_model, d_ff, drop_rate=drop_rate)
ffn = MoeModule(dim_model=d_model,
num_experts=num_experts,
top_k=1 if use_residual else 2,
capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
noisy_policy='Jitter' if use_residual else 'Gaussian',
drop_tks=drop_tks,
use_residual=use_residual,
expert_instance=experts,
expert_cls=VanillaFFN,
**moe_mlp_args(d_model=d_model, d_ff=d_ff, drop_rate=drop_rate))
layer = TransformerLayer(att=sa,
ffn=ffn,
norm1=nn.LayerNorm(d_model, eps=1e-6),