mirror of https://github.com/hpcaitech/ColossalAI
[MOE] support PR-MOE (#488)
parent
a9ecb4b244
commit
c9023d4078
|
@ -1,8 +1,8 @@
|
|||
from .experts import Experts, FFNExperts, TPExperts
|
||||
from .layers import MoeLayer, Top1Router, Top2Router
|
||||
from .layers import MoeLayer, Top1Router, Top2Router, MoeModule
|
||||
from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts
|
||||
|
||||
__all__ = [
|
||||
'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator',
|
||||
'UniformNoiseGenerator', 'build_ffn_experts'
|
||||
'UniformNoiseGenerator', 'build_ffn_experts', 'MoeModule'
|
||||
]
|
||||
|
|
|
@ -7,9 +7,9 @@ import torch.distributed as dist
|
|||
from colossalai.core import MOE_CONTEXT
|
||||
from colossalai.utils import get_current_device
|
||||
from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum
|
||||
from .experts import MoeExperts
|
||||
from .utils import ForceFP32Parameter
|
||||
from typing import Callable, Optional
|
||||
from .experts import MoeExperts, Experts
|
||||
from .utils import ForceFP32Parameter, UniformNoiseGenerator, NormalNoiseGenerator
|
||||
from typing import Callable, Optional, Type
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
||||
|
@ -315,3 +315,100 @@ class MoeLayer(nn.Module):
|
|||
|
||||
ans = ans.reshape(inputs.shape)
|
||||
return ans
|
||||
|
||||
|
||||
class MoeModule(nn.Module):
|
||||
"""A class for users to create MoE modules in their models.
|
||||
|
||||
Args:
|
||||
dim_model (int): Hidden dimension of training model
|
||||
num_experts (int): The number experts
|
||||
top_k (int, optional): The number of experts for dispatchment of each token
|
||||
capacity_factor_train (float, optional): Capacity factor in routing during training
|
||||
capacity_factor_eval (float, optional): Capacity factor in routing during evaluation
|
||||
min_capacity (int, optional): The minimum number of the capacity of each expert
|
||||
noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'.
|
||||
'Jitter' can be found in Switch Transformer paper (https://arxiv.org/abs/2101.03961).
|
||||
'Gaussian' can be found in ViT-MoE paper (https://arxiv.org/abs/2106.05974).
|
||||
drop_tks (bool, optional): Whether drops tokens in evaluation
|
||||
use_residual (bool, optional): Makes this MoE layer a Residual MoE.
|
||||
More information can be found in Microsoft paper (https://arxiv.org/abs/2201.05596).
|
||||
residual_instance (nn.Module, optional): The instance of residual module in Resiual MoE
|
||||
expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer
|
||||
expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given
|
||||
expert_args (optional): The args of expert when no instance is given
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dim_model: int,
|
||||
num_experts: int,
|
||||
top_k: int = 1,
|
||||
capacity_factor_train: float = 1.25,
|
||||
capacity_factor_eval: float = 2.0,
|
||||
min_capacity: int = 4,
|
||||
noisy_policy: Optional[str] = None,
|
||||
drop_tks: bool = True,
|
||||
use_residual: bool = False,
|
||||
residual_instance: Optional[nn.Module] = None,
|
||||
expert_instance: Optional[MoeExperts] = None,
|
||||
expert_cls: Optional[Type[nn.Module]] = None,
|
||||
**expert_args):
|
||||
super().__init__()
|
||||
|
||||
noisy_func = None
|
||||
if noisy_policy is not None:
|
||||
if noisy_policy == 'Jitter':
|
||||
noisy_func = UniformNoiseGenerator()
|
||||
elif noisy_policy == 'Gaussian':
|
||||
noisy_func = NormalNoiseGenerator(num_experts)
|
||||
else:
|
||||
raise NotImplementedError("Unsupported input noisy policy")
|
||||
|
||||
if top_k == 1:
|
||||
moe_router_cls = Top1Router
|
||||
elif top_k == 2:
|
||||
moe_router_cls = Top2Router
|
||||
else:
|
||||
raise NotImplementedError("top_k > 2 is not supported yet")
|
||||
|
||||
self.moe_router = moe_router_cls(capacity_factor_train=capacity_factor_train,
|
||||
capacity_factor_eval=capacity_factor_eval,
|
||||
min_capacity=min_capacity,
|
||||
noisy_func=noisy_func,
|
||||
drop_tks=drop_tks)
|
||||
|
||||
self.use_residual = use_residual
|
||||
if use_residual:
|
||||
if residual_instance is not None:
|
||||
self.residual_module = residual_instance
|
||||
else:
|
||||
assert expert_cls is not None, \
|
||||
"Expert class can't be None when residual instance is not given"
|
||||
self.residual_module = expert_cls(**expert_args)
|
||||
|
||||
self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device())
|
||||
|
||||
if expert_instance is not None:
|
||||
self.experts = expert_instance
|
||||
else:
|
||||
assert expert_cls is not None, \
|
||||
"Expert class can't be None when experts instance is not given"
|
||||
self.experts = Experts(expert_cls, num_experts, **expert_args)
|
||||
|
||||
self.moe_layer = MoeLayer(dim_model=dim_model,
|
||||
num_experts=num_experts,
|
||||
router=self.moe_router,
|
||||
experts=self.experts)
|
||||
|
||||
def forward(self, inputs: torch.Tensor):
|
||||
moe_output = self.moe_layer(inputs)
|
||||
|
||||
if self.use_residual:
|
||||
residual_output = self.residual_module(inputs)
|
||||
combine_coef = self.residual_combine(inputs)
|
||||
combine_coef = F.softmax(combine_coef, dim=-1)
|
||||
output = moe_output * combine_coef[..., 0:1] + residual_output * combine_coef[..., 1:]
|
||||
else:
|
||||
output = moe_output
|
||||
|
||||
return output
|
||||
|
|
|
@ -4,11 +4,12 @@ import torch.nn as nn
|
|||
from colossalai.context import ParallelMode
|
||||
from colossalai.nn.layer import VanillaPatchEmbedding, VanillaClassifier, \
|
||||
WrappedDropout as Dropout, WrappedDropPath as DropPath
|
||||
from colossalai.nn.layer.moe import build_ffn_experts, MoeLayer, Top2Router, NormalNoiseGenerator
|
||||
from colossalai.nn.layer.moe import build_ffn_experts, MoeLayer, Top2Router, NormalNoiseGenerator, MoeModule
|
||||
from .util import moe_sa_args, moe_mlp_args
|
||||
from ..helper import TransformerLayer
|
||||
from colossalai.core import MOE_CONTEXT
|
||||
from colossalai.utils import get_current_device
|
||||
from typing import List
|
||||
|
||||
|
||||
class VanillaSelfAttention(nn.Module):
|
||||
|
@ -146,7 +147,8 @@ class Widenet(nn.Module):
|
|||
class ViTMoE(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
num_experts: int,
|
||||
num_experts: int or List[int],
|
||||
use_residual: bool = False,
|
||||
capacity_factor_train: float = 1.25,
|
||||
capacity_factor_eval: float = 2.0,
|
||||
drop_tks: bool = True,
|
||||
|
@ -164,29 +166,45 @@ class ViTMoE(nn.Module):
|
|||
drop_path: float = 0.):
|
||||
super().__init__()
|
||||
|
||||
assert depth % 2 == 0, "The number of layers should be even right now"
|
||||
|
||||
if isinstance(num_experts, list):
|
||||
assert len(num_experts) == depth // 2, \
|
||||
"The length of num_experts should equal to the number of MOE layers"
|
||||
num_experts_list = num_experts
|
||||
else:
|
||||
num_experts_list = [num_experts] * (depth // 2)
|
||||
|
||||
embedding = VanillaPatchEmbedding(img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_size=d_model)
|
||||
embed_dropout = Dropout(p=drop_rate, mode=ParallelMode.TENSOR)
|
||||
|
||||
noisy_func = NormalNoiseGenerator(num_experts)
|
||||
router = Top2Router(capacity_factor_train=capacity_factor_train,
|
||||
capacity_factor_eval=capacity_factor_eval,
|
||||
noisy_func=noisy_func,
|
||||
drop_tks=drop_tks)
|
||||
assert depth % 2 == 0
|
||||
|
||||
# stochastic depth decay rule
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
|
||||
blocks = []
|
||||
for i in range(depth):
|
||||
sa = VanillaSelfAttention(**moe_sa_args(
|
||||
d_model=d_model, n_heads=num_heads, d_kv=d_kv, attention_drop=attention_drop, drop_rate=drop_rate))
|
||||
ffn = VanillaFFN(**moe_mlp_args(
|
||||
d_model=d_model, d_ff=d_ff, drop_rate=drop_rate)) if i % 2 == 0 else \
|
||||
MoeLayer(dim_model=d_model, num_experts=num_experts, router=router,
|
||||
experts=build_ffn_experts(num_experts, d_model, d_ff, drop_rate=drop_rate))
|
||||
|
||||
if i % 2 == 0:
|
||||
ffn = VanillaFFN(**moe_mlp_args(d_model=d_model, d_ff=d_ff, drop_rate=drop_rate))
|
||||
else:
|
||||
num_experts = num_experts_list[i // 2]
|
||||
experts = build_ffn_experts(num_experts, d_model, d_ff, drop_rate=drop_rate)
|
||||
ffn = MoeModule(dim_model=d_model,
|
||||
num_experts=num_experts,
|
||||
top_k=1 if use_residual else 2,
|
||||
capacity_factor_train=capacity_factor_train,
|
||||
capacity_factor_eval=capacity_factor_eval,
|
||||
noisy_policy='Jitter' if use_residual else 'Gaussian',
|
||||
drop_tks=drop_tks,
|
||||
use_residual=use_residual,
|
||||
expert_instance=experts,
|
||||
expert_cls=VanillaFFN,
|
||||
**moe_mlp_args(d_model=d_model, d_ff=d_ff, drop_rate=drop_rate))
|
||||
|
||||
layer = TransformerLayer(att=sa,
|
||||
ffn=ffn,
|
||||
norm1=nn.LayerNorm(d_model, eps=1e-6),
|
||||
|
|
Loading…
Reference in New Issue