From c9023d4078a4e9701a1a23df8214e5534947c959 Mon Sep 17 00:00:00 2001 From: HELSON <72907851+1SAA@users.noreply.github.com> Date: Tue, 22 Mar 2022 16:48:22 +0800 Subject: [PATCH] [MOE] support PR-MOE (#488) --- colossalai/nn/layer/moe/__init__.py | 4 +- colossalai/nn/layer/moe/layers.py | 103 +++++++++++++++++++++++++++- model_zoo/moe/models.py | 44 ++++++++---- 3 files changed, 133 insertions(+), 18 deletions(-) diff --git a/colossalai/nn/layer/moe/__init__.py b/colossalai/nn/layer/moe/__init__.py index 36977ee05..14b3a7ee4 100644 --- a/colossalai/nn/layer/moe/__init__.py +++ b/colossalai/nn/layer/moe/__init__.py @@ -1,8 +1,8 @@ from .experts import Experts, FFNExperts, TPExperts -from .layers import MoeLayer, Top1Router, Top2Router +from .layers import MoeLayer, Top1Router, Top2Router, MoeModule from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts __all__ = [ 'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator', - 'UniformNoiseGenerator', 'build_ffn_experts' + 'UniformNoiseGenerator', 'build_ffn_experts', 'MoeModule' ] diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/nn/layer/moe/layers.py index 7903b6286..ebd8b4f79 100644 --- a/colossalai/nn/layer/moe/layers.py +++ b/colossalai/nn/layer/moe/layers.py @@ -7,9 +7,9 @@ import torch.distributed as dist from colossalai.core import MOE_CONTEXT from colossalai.utils import get_current_device from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum -from .experts import MoeExperts -from .utils import ForceFP32Parameter -from typing import Callable, Optional +from .experts import MoeExperts, Experts +from .utils import ForceFP32Parameter, UniformNoiseGenerator, NormalNoiseGenerator +from typing import Callable, Optional, Type from torch.distributed import ProcessGroup @@ -315,3 +315,100 @@ class MoeLayer(nn.Module): ans = ans.reshape(inputs.shape) return ans + + +class MoeModule(nn.Module): + """A class for users to create MoE modules in their models. + + Args: + dim_model (int): Hidden dimension of training model + num_experts (int): The number experts + top_k (int, optional): The number of experts for dispatchment of each token + capacity_factor_train (float, optional): Capacity factor in routing during training + capacity_factor_eval (float, optional): Capacity factor in routing during evaluation + min_capacity (int, optional): The minimum number of the capacity of each expert + noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'. + 'Jitter' can be found in Switch Transformer paper (https://arxiv.org/abs/2101.03961). + 'Gaussian' can be found in ViT-MoE paper (https://arxiv.org/abs/2106.05974). + drop_tks (bool, optional): Whether drops tokens in evaluation + use_residual (bool, optional): Makes this MoE layer a Residual MoE. + More information can be found in Microsoft paper (https://arxiv.org/abs/2201.05596). + residual_instance (nn.Module, optional): The instance of residual module in Resiual MoE + expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer + expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given + expert_args (optional): The args of expert when no instance is given + """ + + def __init__(self, + dim_model: int, + num_experts: int, + top_k: int = 1, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_policy: Optional[str] = None, + drop_tks: bool = True, + use_residual: bool = False, + residual_instance: Optional[nn.Module] = None, + expert_instance: Optional[MoeExperts] = None, + expert_cls: Optional[Type[nn.Module]] = None, + **expert_args): + super().__init__() + + noisy_func = None + if noisy_policy is not None: + if noisy_policy == 'Jitter': + noisy_func = UniformNoiseGenerator() + elif noisy_policy == 'Gaussian': + noisy_func = NormalNoiseGenerator(num_experts) + else: + raise NotImplementedError("Unsupported input noisy policy") + + if top_k == 1: + moe_router_cls = Top1Router + elif top_k == 2: + moe_router_cls = Top2Router + else: + raise NotImplementedError("top_k > 2 is not supported yet") + + self.moe_router = moe_router_cls(capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks) + + self.use_residual = use_residual + if use_residual: + if residual_instance is not None: + self.residual_module = residual_instance + else: + assert expert_cls is not None, \ + "Expert class can't be None when residual instance is not given" + self.residual_module = expert_cls(**expert_args) + + self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device()) + + if expert_instance is not None: + self.experts = expert_instance + else: + assert expert_cls is not None, \ + "Expert class can't be None when experts instance is not given" + self.experts = Experts(expert_cls, num_experts, **expert_args) + + self.moe_layer = MoeLayer(dim_model=dim_model, + num_experts=num_experts, + router=self.moe_router, + experts=self.experts) + + def forward(self, inputs: torch.Tensor): + moe_output = self.moe_layer(inputs) + + if self.use_residual: + residual_output = self.residual_module(inputs) + combine_coef = self.residual_combine(inputs) + combine_coef = F.softmax(combine_coef, dim=-1) + output = moe_output * combine_coef[..., 0:1] + residual_output * combine_coef[..., 1:] + else: + output = moe_output + + return output diff --git a/model_zoo/moe/models.py b/model_zoo/moe/models.py index 627f2a4d0..e9659a347 100644 --- a/model_zoo/moe/models.py +++ b/model_zoo/moe/models.py @@ -4,11 +4,12 @@ import torch.nn as nn from colossalai.context import ParallelMode from colossalai.nn.layer import VanillaPatchEmbedding, VanillaClassifier, \ WrappedDropout as Dropout, WrappedDropPath as DropPath -from colossalai.nn.layer.moe import build_ffn_experts, MoeLayer, Top2Router, NormalNoiseGenerator +from colossalai.nn.layer.moe import build_ffn_experts, MoeLayer, Top2Router, NormalNoiseGenerator, MoeModule from .util import moe_sa_args, moe_mlp_args from ..helper import TransformerLayer from colossalai.core import MOE_CONTEXT from colossalai.utils import get_current_device +from typing import List class VanillaSelfAttention(nn.Module): @@ -146,7 +147,8 @@ class Widenet(nn.Module): class ViTMoE(nn.Module): def __init__(self, - num_experts: int, + num_experts: int or List[int], + use_residual: bool = False, capacity_factor_train: float = 1.25, capacity_factor_eval: float = 2.0, drop_tks: bool = True, @@ -164,29 +166,45 @@ class ViTMoE(nn.Module): drop_path: float = 0.): super().__init__() + assert depth % 2 == 0, "The number of layers should be even right now" + + if isinstance(num_experts, list): + assert len(num_experts) == depth // 2, \ + "The length of num_experts should equal to the number of MOE layers" + num_experts_list = num_experts + else: + num_experts_list = [num_experts] * (depth // 2) + embedding = VanillaPatchEmbedding(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_size=d_model) embed_dropout = Dropout(p=drop_rate, mode=ParallelMode.TENSOR) - noisy_func = NormalNoiseGenerator(num_experts) - router = Top2Router(capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - noisy_func=noisy_func, - drop_tks=drop_tks) - assert depth % 2 == 0 - # stochastic depth decay rule dpr = [x.item() for x in torch.linspace(0, drop_path, depth)] blocks = [] for i in range(depth): sa = VanillaSelfAttention(**moe_sa_args( d_model=d_model, n_heads=num_heads, d_kv=d_kv, attention_drop=attention_drop, drop_rate=drop_rate)) - ffn = VanillaFFN(**moe_mlp_args( - d_model=d_model, d_ff=d_ff, drop_rate=drop_rate)) if i % 2 == 0 else \ - MoeLayer(dim_model=d_model, num_experts=num_experts, router=router, - experts=build_ffn_experts(num_experts, d_model, d_ff, drop_rate=drop_rate)) + + if i % 2 == 0: + ffn = VanillaFFN(**moe_mlp_args(d_model=d_model, d_ff=d_ff, drop_rate=drop_rate)) + else: + num_experts = num_experts_list[i // 2] + experts = build_ffn_experts(num_experts, d_model, d_ff, drop_rate=drop_rate) + ffn = MoeModule(dim_model=d_model, + num_experts=num_experts, + top_k=1 if use_residual else 2, + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + noisy_policy='Jitter' if use_residual else 'Gaussian', + drop_tks=drop_tks, + use_residual=use_residual, + expert_instance=experts, + expert_cls=VanillaFFN, + **moe_mlp_args(d_model=d_model, d_ff=d_ff, drop_rate=drop_rate)) + layer = TransformerLayer(att=sa, ffn=ffn, norm1=nn.LayerNorm(d_model, eps=1e-6),