|
|
|
@ -7,9 +7,9 @@ import torch.distributed as dist
|
|
|
|
|
from colossalai.core import MOE_CONTEXT |
|
|
|
|
from colossalai.utils import get_current_device |
|
|
|
|
from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum |
|
|
|
|
from .experts import MoeExperts |
|
|
|
|
from .utils import ForceFP32Parameter |
|
|
|
|
from typing import Callable, Optional |
|
|
|
|
from .experts import MoeExperts, Experts |
|
|
|
|
from .utils import ForceFP32Parameter, UniformNoiseGenerator, NormalNoiseGenerator |
|
|
|
|
from typing import Callable, Optional, Type |
|
|
|
|
from torch.distributed import ProcessGroup |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -315,3 +315,100 @@ class MoeLayer(nn.Module):
|
|
|
|
|
|
|
|
|
|
ans = ans.reshape(inputs.shape) |
|
|
|
|
return ans |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MoeModule(nn.Module): |
|
|
|
|
"""A class for users to create MoE modules in their models. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
dim_model (int): Hidden dimension of training model |
|
|
|
|
num_experts (int): The number experts |
|
|
|
|
top_k (int, optional): The number of experts for dispatchment of each token |
|
|
|
|
capacity_factor_train (float, optional): Capacity factor in routing during training |
|
|
|
|
capacity_factor_eval (float, optional): Capacity factor in routing during evaluation |
|
|
|
|
min_capacity (int, optional): The minimum number of the capacity of each expert |
|
|
|
|
noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'. |
|
|
|
|
'Jitter' can be found in Switch Transformer paper (https://arxiv.org/abs/2101.03961). |
|
|
|
|
'Gaussian' can be found in ViT-MoE paper (https://arxiv.org/abs/2106.05974). |
|
|
|
|
drop_tks (bool, optional): Whether drops tokens in evaluation |
|
|
|
|
use_residual (bool, optional): Makes this MoE layer a Residual MoE. |
|
|
|
|
More information can be found in Microsoft paper (https://arxiv.org/abs/2201.05596). |
|
|
|
|
residual_instance (nn.Module, optional): The instance of residual module in Resiual MoE |
|
|
|
|
expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer |
|
|
|
|
expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given |
|
|
|
|
expert_args (optional): The args of expert when no instance is given |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
|
dim_model: int, |
|
|
|
|
num_experts: int, |
|
|
|
|
top_k: int = 1, |
|
|
|
|
capacity_factor_train: float = 1.25, |
|
|
|
|
capacity_factor_eval: float = 2.0, |
|
|
|
|
min_capacity: int = 4, |
|
|
|
|
noisy_policy: Optional[str] = None, |
|
|
|
|
drop_tks: bool = True, |
|
|
|
|
use_residual: bool = False, |
|
|
|
|
residual_instance: Optional[nn.Module] = None, |
|
|
|
|
expert_instance: Optional[MoeExperts] = None, |
|
|
|
|
expert_cls: Optional[Type[nn.Module]] = None, |
|
|
|
|
**expert_args): |
|
|
|
|
super().__init__() |
|
|
|
|
|
|
|
|
|
noisy_func = None |
|
|
|
|
if noisy_policy is not None: |
|
|
|
|
if noisy_policy == 'Jitter': |
|
|
|
|
noisy_func = UniformNoiseGenerator() |
|
|
|
|
elif noisy_policy == 'Gaussian': |
|
|
|
|
noisy_func = NormalNoiseGenerator(num_experts) |
|
|
|
|
else: |
|
|
|
|
raise NotImplementedError("Unsupported input noisy policy") |
|
|
|
|
|
|
|
|
|
if top_k == 1: |
|
|
|
|
moe_router_cls = Top1Router |
|
|
|
|
elif top_k == 2: |
|
|
|
|
moe_router_cls = Top2Router |
|
|
|
|
else: |
|
|
|
|
raise NotImplementedError("top_k > 2 is not supported yet") |
|
|
|
|
|
|
|
|
|
self.moe_router = moe_router_cls(capacity_factor_train=capacity_factor_train, |
|
|
|
|
capacity_factor_eval=capacity_factor_eval, |
|
|
|
|
min_capacity=min_capacity, |
|
|
|
|
noisy_func=noisy_func, |
|
|
|
|
drop_tks=drop_tks) |
|
|
|
|
|
|
|
|
|
self.use_residual = use_residual |
|
|
|
|
if use_residual: |
|
|
|
|
if residual_instance is not None: |
|
|
|
|
self.residual_module = residual_instance |
|
|
|
|
else: |
|
|
|
|
assert expert_cls is not None, \ |
|
|
|
|
"Expert class can't be None when residual instance is not given" |
|
|
|
|
self.residual_module = expert_cls(**expert_args) |
|
|
|
|
|
|
|
|
|
self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device()) |
|
|
|
|
|
|
|
|
|
if expert_instance is not None: |
|
|
|
|
self.experts = expert_instance |
|
|
|
|
else: |
|
|
|
|
assert expert_cls is not None, \ |
|
|
|
|
"Expert class can't be None when experts instance is not given" |
|
|
|
|
self.experts = Experts(expert_cls, num_experts, **expert_args) |
|
|
|
|
|
|
|
|
|
self.moe_layer = MoeLayer(dim_model=dim_model, |
|
|
|
|
num_experts=num_experts, |
|
|
|
|
router=self.moe_router, |
|
|
|
|
experts=self.experts) |
|
|
|
|
|
|
|
|
|
def forward(self, inputs: torch.Tensor): |
|
|
|
|
moe_output = self.moe_layer(inputs) |
|
|
|
|
|
|
|
|
|
if self.use_residual: |
|
|
|
|
residual_output = self.residual_module(inputs) |
|
|
|
|
combine_coef = self.residual_combine(inputs) |
|
|
|
|
combine_coef = F.softmax(combine_coef, dim=-1) |
|
|
|
|
output = moe_output * combine_coef[..., 0:1] + residual_output * combine_coef[..., 1:] |
|
|
|
|
else: |
|
|
|
|
output = moe_output |
|
|
|
|
|
|
|
|
|
return output |
|
|
|
|