Browse Source

[MOE] support PR-MOE (#488)

pull/492/head
HELSON 3 years ago committed by GitHub
parent
commit
c9023d4078
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 4
      colossalai/nn/layer/moe/__init__.py
  2. 103
      colossalai/nn/layer/moe/layers.py
  3. 44
      model_zoo/moe/models.py

4
colossalai/nn/layer/moe/__init__.py

@ -1,8 +1,8 @@
from .experts import Experts, FFNExperts, TPExperts from .experts import Experts, FFNExperts, TPExperts
from .layers import MoeLayer, Top1Router, Top2Router from .layers import MoeLayer, Top1Router, Top2Router, MoeModule
from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts
__all__ = [ __all__ = [
'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator', 'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator',
'UniformNoiseGenerator', 'build_ffn_experts' 'UniformNoiseGenerator', 'build_ffn_experts', 'MoeModule'
] ]

103
colossalai/nn/layer/moe/layers.py

@ -7,9 +7,9 @@ import torch.distributed as dist
from colossalai.core import MOE_CONTEXT from colossalai.core import MOE_CONTEXT
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum
from .experts import MoeExperts from .experts import MoeExperts, Experts
from .utils import ForceFP32Parameter from .utils import ForceFP32Parameter, UniformNoiseGenerator, NormalNoiseGenerator
from typing import Callable, Optional from typing import Callable, Optional, Type
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
@ -315,3 +315,100 @@ class MoeLayer(nn.Module):
ans = ans.reshape(inputs.shape) ans = ans.reshape(inputs.shape)
return ans return ans
class MoeModule(nn.Module):
"""A class for users to create MoE modules in their models.
Args:
dim_model (int): Hidden dimension of training model
num_experts (int): The number experts
top_k (int, optional): The number of experts for dispatchment of each token
capacity_factor_train (float, optional): Capacity factor in routing during training
capacity_factor_eval (float, optional): Capacity factor in routing during evaluation
min_capacity (int, optional): The minimum number of the capacity of each expert
noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'.
'Jitter' can be found in Switch Transformer paper (https://arxiv.org/abs/2101.03961).
'Gaussian' can be found in ViT-MoE paper (https://arxiv.org/abs/2106.05974).
drop_tks (bool, optional): Whether drops tokens in evaluation
use_residual (bool, optional): Makes this MoE layer a Residual MoE.
More information can be found in Microsoft paper (https://arxiv.org/abs/2201.05596).
residual_instance (nn.Module, optional): The instance of residual module in Resiual MoE
expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer
expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given
expert_args (optional): The args of expert when no instance is given
"""
def __init__(self,
dim_model: int,
num_experts: int,
top_k: int = 1,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_policy: Optional[str] = None,
drop_tks: bool = True,
use_residual: bool = False,
residual_instance: Optional[nn.Module] = None,
expert_instance: Optional[MoeExperts] = None,
expert_cls: Optional[Type[nn.Module]] = None,
**expert_args):
super().__init__()
noisy_func = None
if noisy_policy is not None:
if noisy_policy == 'Jitter':
noisy_func = UniformNoiseGenerator()
elif noisy_policy == 'Gaussian':
noisy_func = NormalNoiseGenerator(num_experts)
else:
raise NotImplementedError("Unsupported input noisy policy")
if top_k == 1:
moe_router_cls = Top1Router
elif top_k == 2:
moe_router_cls = Top2Router
else:
raise NotImplementedError("top_k > 2 is not supported yet")
self.moe_router = moe_router_cls(capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks)
self.use_residual = use_residual
if use_residual:
if residual_instance is not None:
self.residual_module = residual_instance
else:
assert expert_cls is not None, \
"Expert class can't be None when residual instance is not given"
self.residual_module = expert_cls(**expert_args)
self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device())
if expert_instance is not None:
self.experts = expert_instance
else:
assert expert_cls is not None, \
"Expert class can't be None when experts instance is not given"
self.experts = Experts(expert_cls, num_experts, **expert_args)
self.moe_layer = MoeLayer(dim_model=dim_model,
num_experts=num_experts,
router=self.moe_router,
experts=self.experts)
def forward(self, inputs: torch.Tensor):
moe_output = self.moe_layer(inputs)
if self.use_residual:
residual_output = self.residual_module(inputs)
combine_coef = self.residual_combine(inputs)
combine_coef = F.softmax(combine_coef, dim=-1)
output = moe_output * combine_coef[..., 0:1] + residual_output * combine_coef[..., 1:]
else:
output = moe_output
return output

44
model_zoo/moe/models.py

@ -4,11 +4,12 @@ import torch.nn as nn
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.nn.layer import VanillaPatchEmbedding, VanillaClassifier, \ from colossalai.nn.layer import VanillaPatchEmbedding, VanillaClassifier, \
WrappedDropout as Dropout, WrappedDropPath as DropPath WrappedDropout as Dropout, WrappedDropPath as DropPath
from colossalai.nn.layer.moe import build_ffn_experts, MoeLayer, Top2Router, NormalNoiseGenerator from colossalai.nn.layer.moe import build_ffn_experts, MoeLayer, Top2Router, NormalNoiseGenerator, MoeModule
from .util import moe_sa_args, moe_mlp_args from .util import moe_sa_args, moe_mlp_args
from ..helper import TransformerLayer from ..helper import TransformerLayer
from colossalai.core import MOE_CONTEXT from colossalai.core import MOE_CONTEXT
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from typing import List
class VanillaSelfAttention(nn.Module): class VanillaSelfAttention(nn.Module):
@ -146,7 +147,8 @@ class Widenet(nn.Module):
class ViTMoE(nn.Module): class ViTMoE(nn.Module):
def __init__(self, def __init__(self,
num_experts: int, num_experts: int or List[int],
use_residual: bool = False,
capacity_factor_train: float = 1.25, capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0, capacity_factor_eval: float = 2.0,
drop_tks: bool = True, drop_tks: bool = True,
@ -164,29 +166,45 @@ class ViTMoE(nn.Module):
drop_path: float = 0.): drop_path: float = 0.):
super().__init__() super().__init__()
assert depth % 2 == 0, "The number of layers should be even right now"
if isinstance(num_experts, list):
assert len(num_experts) == depth // 2, \
"The length of num_experts should equal to the number of MOE layers"
num_experts_list = num_experts
else:
num_experts_list = [num_experts] * (depth // 2)
embedding = VanillaPatchEmbedding(img_size=img_size, embedding = VanillaPatchEmbedding(img_size=img_size,
patch_size=patch_size, patch_size=patch_size,
in_chans=in_chans, in_chans=in_chans,
embed_size=d_model) embed_size=d_model)
embed_dropout = Dropout(p=drop_rate, mode=ParallelMode.TENSOR) embed_dropout = Dropout(p=drop_rate, mode=ParallelMode.TENSOR)
noisy_func = NormalNoiseGenerator(num_experts)
router = Top2Router(capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
noisy_func=noisy_func,
drop_tks=drop_tks)
assert depth % 2 == 0
# stochastic depth decay rule # stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)] dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
blocks = [] blocks = []
for i in range(depth): for i in range(depth):
sa = VanillaSelfAttention(**moe_sa_args( sa = VanillaSelfAttention(**moe_sa_args(
d_model=d_model, n_heads=num_heads, d_kv=d_kv, attention_drop=attention_drop, drop_rate=drop_rate)) d_model=d_model, n_heads=num_heads, d_kv=d_kv, attention_drop=attention_drop, drop_rate=drop_rate))
ffn = VanillaFFN(**moe_mlp_args(
d_model=d_model, d_ff=d_ff, drop_rate=drop_rate)) if i % 2 == 0 else \ if i % 2 == 0:
MoeLayer(dim_model=d_model, num_experts=num_experts, router=router, ffn = VanillaFFN(**moe_mlp_args(d_model=d_model, d_ff=d_ff, drop_rate=drop_rate))
experts=build_ffn_experts(num_experts, d_model, d_ff, drop_rate=drop_rate)) else:
num_experts = num_experts_list[i // 2]
experts = build_ffn_experts(num_experts, d_model, d_ff, drop_rate=drop_rate)
ffn = MoeModule(dim_model=d_model,
num_experts=num_experts,
top_k=1 if use_residual else 2,
capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
noisy_policy='Jitter' if use_residual else 'Gaussian',
drop_tks=drop_tks,
use_residual=use_residual,
expert_instance=experts,
expert_cls=VanillaFFN,
**moe_mlp_args(d_model=d_model, d_ff=d_ff, drop_rate=drop_rate))
layer = TransformerLayer(att=sa, layer = TransformerLayer(att=sa,
ffn=ffn, ffn=ffn,
norm1=nn.LayerNorm(d_model, eps=1e-6), norm1=nn.LayerNorm(d_model, eps=1e-6),

Loading…
Cancel
Save