2022-03-29 09:57:59 +00:00
|
|
|
import functools
|
2022-01-07 07:08:36 +00:00
|
|
|
import math
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
2022-02-18 12:42:31 +00:00
|
|
|
import torch.distributed as dist
|
2022-03-23 10:03:39 +00:00
|
|
|
from colossalai.context.moe_context import MOE_CONTEXT
|
2022-01-07 07:08:36 +00:00
|
|
|
from colossalai.utils import get_current_device
|
2022-03-19 07:36:25 +00:00
|
|
|
from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum
|
2022-03-22 08:48:22 +00:00
|
|
|
from .experts import MoeExperts, Experts
|
2022-03-31 10:34:11 +00:00
|
|
|
from .utils import ForceFP32Parameter, UniformNoiseGenerator, NormalNoiseGenerator, autocast_softmax
|
2022-03-29 09:57:59 +00:00
|
|
|
from colossalai.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator
|
2022-03-22 08:48:22 +00:00
|
|
|
from typing import Callable, Optional, Type
|
2022-03-19 07:36:25 +00:00
|
|
|
from torch.distributed import ProcessGroup
|
2022-01-07 07:08:36 +00:00
|
|
|
|
|
|
|
|
|
|
|
class Top1Router(nn.Module):
|
|
|
|
"""Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
|
|
|
|
for routing usage. More deailted function can be found in the paper about Switch Transformer
|
|
|
|
of Google.
|
2022-01-21 02:44:30 +00:00
|
|
|
|
2022-03-25 05:02:39 +00:00
|
|
|
Args:
|
|
|
|
capacity_factor_train (float, optional): Capacity factor in routing of training.
|
|
|
|
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
|
|
|
|
min_capacity (int, optional): The minimum number of the capacity of each expert.
|
|
|
|
select_policy (str, optional): The policy about tokens selection.
|
|
|
|
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
|
|
|
|
drop_tks (bool, optional): Whether drops tokens in evaluation
|
2022-01-07 07:08:36 +00:00
|
|
|
"""
|
|
|
|
|
2022-03-16 08:47:44 +00:00
|
|
|
def __init__(self,
|
|
|
|
capacity_factor_train: float = 1.25,
|
|
|
|
capacity_factor_eval: float = 2.0,
|
|
|
|
min_capacity: int = 4,
|
|
|
|
select_policy: str = "first",
|
|
|
|
noisy_func: Callable = None,
|
|
|
|
drop_tks: bool = True):
|
2022-01-07 07:08:36 +00:00
|
|
|
super().__init__()
|
2022-03-16 08:47:44 +00:00
|
|
|
self.capacity_factor_train = capacity_factor_train
|
|
|
|
self.capacity_factor_eval = capacity_factor_eval
|
2022-01-07 07:08:36 +00:00
|
|
|
self.min_capacity = min_capacity
|
2022-02-18 12:42:31 +00:00
|
|
|
self.select_policy = select_policy
|
2022-01-07 07:08:36 +00:00
|
|
|
self.noisy_func = noisy_func
|
2022-03-16 08:47:44 +00:00
|
|
|
self.drop_tks = drop_tks
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-02-18 12:42:31 +00:00
|
|
|
assert select_policy in {"first", "random"}
|
|
|
|
if select_policy == "random":
|
|
|
|
self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()),
|
|
|
|
high=torch.tensor(1.0,
|
|
|
|
device=get_current_device())).rsample
|
|
|
|
|
|
|
|
def get_capacity(
|
|
|
|
self,
|
|
|
|
logits_shape,
|
|
|
|
):
|
2022-03-16 08:47:44 +00:00
|
|
|
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
|
|
|
|
capacity = math.floor(capacity_factor * logits_shape[-2] / logits_shape[-1])
|
2022-02-18 12:42:31 +00:00
|
|
|
capacity += capacity % 2
|
|
|
|
capacity = max(capacity, self.min_capacity)
|
|
|
|
assert capacity > 0
|
2022-01-07 07:08:36 +00:00
|
|
|
return capacity
|
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None):
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-03-15 04:06:09 +00:00
|
|
|
if self.noisy_func is not None and self.training:
|
2022-03-16 08:47:44 +00:00
|
|
|
inputs = self.noisy_func(inputs)
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-03-31 10:34:11 +00:00
|
|
|
logits = autocast_softmax(inputs, dim=-1)
|
2022-02-18 12:42:31 +00:00
|
|
|
num_experts = logits.size(-1)
|
2022-01-07 07:08:36 +00:00
|
|
|
capacity = self.get_capacity(logits.shape)
|
|
|
|
|
2022-03-16 08:47:44 +00:00
|
|
|
top1_idx = torch.argmax(inputs, dim=-1)
|
2022-02-18 12:42:31 +00:00
|
|
|
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-02-18 12:42:31 +00:00
|
|
|
if self.training:
|
|
|
|
me = torch.mean(logits, dim=0)
|
|
|
|
ce = torch.mean(mask.float(), dim=0)
|
|
|
|
l_aux = num_experts * torch.sum(me * ce)
|
2022-03-19 07:36:25 +00:00
|
|
|
MOE_CONTEXT.add_loss(l_aux)
|
2022-03-16 08:47:44 +00:00
|
|
|
elif not self.drop_tks:
|
2022-02-18 12:42:31 +00:00
|
|
|
max_num = torch.max(torch.sum(mask, dim=0))
|
2022-03-19 07:36:25 +00:00
|
|
|
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
|
2022-02-18 12:42:31 +00:00
|
|
|
capacity = max_num.item()
|
2022-03-16 08:47:44 +00:00
|
|
|
else:
|
|
|
|
pass
|
2022-02-18 12:42:31 +00:00
|
|
|
|
2022-03-16 08:47:44 +00:00
|
|
|
if self.select_policy == "random":
|
2022-02-18 12:42:31 +00:00
|
|
|
rand_mask = mask * self.uniform(mask.shape)
|
|
|
|
_, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0)
|
|
|
|
mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1)
|
|
|
|
ranks = moe_cumsum(mask)
|
|
|
|
elif self.select_policy == "first":
|
|
|
|
ranks = moe_cumsum(mask)
|
|
|
|
mask = mask * torch.lt(ranks, capacity)
|
|
|
|
else:
|
|
|
|
raise NotImplementedError("Not support such select policy yet.")
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-02-18 12:42:31 +00:00
|
|
|
ranks = torch.sum(mask * ranks, dim=-1)
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
if use_kernel:
|
2022-02-18 12:42:31 +00:00
|
|
|
mask = torch.sum(mask, dim=-1)
|
|
|
|
mask = torch.stack([mask], dim=0).to(torch.int32)
|
|
|
|
dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32)
|
|
|
|
return logits, mask, dest_idx, num_experts * capacity
|
|
|
|
else:
|
|
|
|
ranks = F.one_hot(ranks, num_classes=capacity)
|
|
|
|
weight = mask * logits.type_as(inputs)
|
|
|
|
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
|
|
|
|
sec_mask = combine_weights.bool()
|
|
|
|
return combine_weights, sec_mask
|
2022-01-07 07:08:36 +00:00
|
|
|
|
|
|
|
|
|
|
|
class Top2Router(nn.Module):
|
|
|
|
"""Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
|
|
|
|
for routing usage. More deailted function can be found in the paper about ViT-MoE.
|
2022-01-21 02:44:30 +00:00
|
|
|
|
2022-03-25 05:02:39 +00:00
|
|
|
Args:
|
|
|
|
capacity_factor_train (float, optional): Capacity factor in routing of training.
|
|
|
|
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
|
|
|
|
min_capacity (int, optional): The minimum number of the capacity of each expert
|
|
|
|
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
|
|
|
|
drop_tks (bool, optional): Whether drops tokens in evaluation.
|
2022-01-07 07:08:36 +00:00
|
|
|
"""
|
|
|
|
|
2022-03-16 08:47:44 +00:00
|
|
|
def __init__(self,
|
|
|
|
capacity_factor_train: float = 1.25,
|
|
|
|
capacity_factor_eval: float = 2.0,
|
|
|
|
min_capacity: int = 4,
|
|
|
|
noisy_func: Callable = None,
|
|
|
|
drop_tks: bool = True):
|
2022-01-07 07:08:36 +00:00
|
|
|
super().__init__()
|
2022-03-16 08:47:44 +00:00
|
|
|
self.capacity_factor_train = capacity_factor_train
|
|
|
|
self.capacity_factor_eval = capacity_factor_eval
|
|
|
|
self.min_capacity = min_capacity
|
2022-01-07 07:08:36 +00:00
|
|
|
self.noisy_func = noisy_func
|
2022-03-16 08:47:44 +00:00
|
|
|
self.drop_tks = drop_tks
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-03-16 08:47:44 +00:00
|
|
|
def get_capacity(
|
|
|
|
self,
|
|
|
|
logits_shape,
|
|
|
|
):
|
|
|
|
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
|
|
|
|
capacity = math.floor(capacity_factor * logits_shape[-2] / logits_shape[-1])
|
2022-02-18 12:42:31 +00:00
|
|
|
capacity += capacity % 2
|
2022-03-16 08:47:44 +00:00
|
|
|
capacity = max(capacity, self.min_capacity)
|
2022-02-18 12:42:31 +00:00
|
|
|
assert capacity > 0
|
2022-01-07 07:08:36 +00:00
|
|
|
return capacity
|
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None):
|
2022-02-18 12:42:31 +00:00
|
|
|
# inputs: [s, h]
|
2022-03-15 04:06:09 +00:00
|
|
|
if self.noisy_func is not None and self.training:
|
2022-01-07 07:08:36 +00:00
|
|
|
inputs = self.noisy_func(inputs)
|
|
|
|
|
2022-03-31 10:34:11 +00:00
|
|
|
logits = autocast_softmax(inputs, dim=-1) # logits: [s, e]
|
2022-01-07 07:08:36 +00:00
|
|
|
num_experts = logits.size(-1)
|
|
|
|
capacity = self.get_capacity(logits.shape)
|
|
|
|
|
2022-02-18 12:42:31 +00:00
|
|
|
top1_idx = torch.argmax(logits, dim=-1)
|
|
|
|
mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
|
|
|
|
logits_except1 = logits.masked_fill(mask1.bool(), float("-inf"))
|
|
|
|
top2_idx = torch.argmax(logits_except1, dim=-1)
|
|
|
|
mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
|
|
|
|
|
|
|
|
cmask = (mask1 + mask2) # loss: [s, e]
|
|
|
|
if self.training:
|
|
|
|
me = torch.mean(logits, dim=0)
|
|
|
|
ce = torch.mean(cmask.float(), dim=0)
|
2022-03-16 08:47:44 +00:00
|
|
|
l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1
|
2022-03-19 07:36:25 +00:00
|
|
|
MOE_CONTEXT.add_loss(l_aux)
|
2022-03-16 08:47:44 +00:00
|
|
|
elif not self.drop_tks:
|
2022-02-18 12:42:31 +00:00
|
|
|
max_num = torch.max(torch.sum(cmask, dim=0))
|
2022-03-19 07:36:25 +00:00
|
|
|
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
|
2022-02-18 12:42:31 +00:00
|
|
|
capacity = max_num.item()
|
2022-03-16 08:47:44 +00:00
|
|
|
else:
|
|
|
|
pass
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-02-18 12:42:31 +00:00
|
|
|
rank1 = moe_cumsum(mask1) # rank1: [s, e]
|
|
|
|
rank2 = moe_cumsum(mask2)
|
|
|
|
rank2 += torch.sum(mask1, dim=-2, keepdim=True)
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-02-18 12:42:31 +00:00
|
|
|
mask1 *= torch.lt(rank1, capacity)
|
|
|
|
mask2 *= torch.lt(rank2, capacity)
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-02-18 12:42:31 +00:00
|
|
|
rank1 = torch.sum(mask1 * rank1, dim=-1)
|
|
|
|
rank2 = torch.sum(mask2 * rank2, dim=-1)
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
if use_kernel:
|
2022-02-18 12:42:31 +00:00
|
|
|
mask1 = torch.sum(mask1, dim=-1)
|
|
|
|
mask2 = torch.sum(mask2, dim=-1)
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-02-18 12:42:31 +00:00
|
|
|
mask = torch.stack([mask1, mask2], dim=0).to(torch.int32)
|
|
|
|
dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32)
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-02-18 12:42:31 +00:00
|
|
|
return logits, mask, dest_idx, num_experts * capacity
|
|
|
|
else:
|
|
|
|
weight1 = mask1 * logits.type_as(inputs)
|
|
|
|
weight2 = mask2 * logits.type_as(inputs)
|
|
|
|
rank1_sc = F.one_hot(rank1, num_classes=capacity)
|
|
|
|
rank2_sc = F.one_hot(rank2, num_classes=capacity)
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-02-18 12:42:31 +00:00
|
|
|
cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1)
|
|
|
|
cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1)
|
|
|
|
cb_weight = cb_weight1 + cb_weight2
|
|
|
|
sec_mask = cb_weight.bool()
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-02-18 12:42:31 +00:00
|
|
|
return cb_weight, sec_mask
|
2022-01-07 07:08:36 +00:00
|
|
|
|
|
|
|
|
2022-03-29 09:57:59 +00:00
|
|
|
class FP32LinearGate(nn.Module):
|
2022-03-22 02:50:20 +00:00
|
|
|
"""Gate module used in MOE layer. Just a linear function without bias.
|
|
|
|
But it should be kept as fp32 forever.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
d_model (int): Hidden dimension of training model
|
|
|
|
num_experts (int): The number experts
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
weight (ForceFP32Parameter): The weight of linear gate
|
|
|
|
"""
|
|
|
|
|
2022-03-29 09:57:59 +00:00
|
|
|
def __init__(self, d_model: int, num_experts: int, scale: float = 0.1):
|
|
|
|
super().__init__()
|
|
|
|
self.weight = ForceFP32Parameter(torch.empty(num_experts, d_model, device=get_current_device()))
|
|
|
|
nn.init.trunc_normal_(self.weight, std=math.sqrt(scale / d_model))
|
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
|
|
return F.linear(x, self.weight)
|
2022-03-22 02:50:20 +00:00
|
|
|
|
|
|
|
|
2022-01-07 07:08:36 +00:00
|
|
|
class MoeLayer(nn.Module):
|
|
|
|
"""A MoE layer, that puts its input tensor to its gate and uses the output logits
|
|
|
|
to router all tokens, is mainly used to exchange all tokens for every expert across
|
|
|
|
the moe tensor group by all to all comunication. Then it will get the output of all
|
|
|
|
experts and exchange the output. At last returns the output of the moe system.
|
2022-01-21 02:44:30 +00:00
|
|
|
|
2022-03-25 05:02:39 +00:00
|
|
|
Args:
|
|
|
|
dim_model (int): Dimension of model.
|
|
|
|
num_experts (int): The number of experts.
|
|
|
|
router (:class:`torch.nn.Module`): Instance of router used in routing.
|
|
|
|
experts (:class:`torch.nn.Module`): Instance of experts generated by Expert.
|
2022-01-07 07:08:36 +00:00
|
|
|
"""
|
|
|
|
|
2022-03-31 10:34:11 +00:00
|
|
|
@no_shard_zero_decrator(is_replicated=True)
|
2022-02-27 14:28:39 +00:00
|
|
|
def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: MoeExperts):
|
2022-01-07 07:08:36 +00:00
|
|
|
super().__init__()
|
|
|
|
self.d_model = dim_model
|
|
|
|
self.num_experts = num_experts
|
2022-03-22 02:50:20 +00:00
|
|
|
self.gate = FP32LinearGate(dim_model, num_experts)
|
2022-01-07 07:08:36 +00:00
|
|
|
self.router = router
|
|
|
|
self.experts = experts
|
2022-03-19 07:36:25 +00:00
|
|
|
self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False
|
|
|
|
self.ep_group = experts.dist_info.ep_group
|
|
|
|
self.ep_size = experts.dist_info.ep_size
|
|
|
|
self.num_local_experts = experts.num_local_experts
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-02-27 14:28:39 +00:00
|
|
|
def a2a_process(self, dispatch_data: torch.Tensor):
|
2022-03-19 07:36:25 +00:00
|
|
|
expert_input = AllToAll.apply(dispatch_data, self.ep_group)
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-02-18 12:42:31 +00:00
|
|
|
input_shape = expert_input.shape
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.d_model)
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-02-18 12:42:31 +00:00
|
|
|
expert_output = self.experts(expert_input)
|
|
|
|
expert_output = expert_output.reshape(input_shape)
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
expert_output = AllToAll.apply(expert_output, self.ep_group)
|
2022-02-18 12:42:31 +00:00
|
|
|
return expert_output
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-02-27 14:28:39 +00:00
|
|
|
def tp_process(self, dispatch_data: torch.Tensor):
|
2022-03-19 07:36:25 +00:00
|
|
|
expert_in = AllGather.apply(dispatch_data, self.ep_group)
|
2022-02-27 14:28:39 +00:00
|
|
|
expert_out = self.experts(expert_in)
|
2022-03-19 07:36:25 +00:00
|
|
|
expert_out = ReduceScatter.apply(expert_out, self.ep_group)
|
2022-02-27 14:28:39 +00:00
|
|
|
return expert_out
|
|
|
|
|
2022-02-18 12:42:31 +00:00
|
|
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
|
|
tokens = inputs.reshape(-1, self.d_model)
|
2022-03-22 02:50:20 +00:00
|
|
|
fp32_input = tokens.to(torch.float32) if inputs.dtype != torch.float32 else tokens
|
|
|
|
gate_output = self.gate(fp32_input)
|
2022-03-19 07:36:25 +00:00
|
|
|
router_res = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group)
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
if self.use_kernel:
|
2022-02-27 14:28:39 +00:00
|
|
|
dispatch_data = MoeDispatch.apply(tokens, *router_res[1:])
|
|
|
|
dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model)
|
|
|
|
else:
|
|
|
|
sec_mask_f = router_res[1].type_as(inputs)
|
|
|
|
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
|
|
|
|
|
|
|
|
# dispatch_data [e, c, h]
|
2022-03-19 07:36:25 +00:00
|
|
|
if self.experts.comm_name == "all_to_all":
|
2022-02-27 14:28:39 +00:00
|
|
|
expert_output = self.a2a_process(dispatch_data)
|
2022-03-19 07:36:25 +00:00
|
|
|
elif self.experts.comm_name == "all_gather":
|
2022-02-27 14:28:39 +00:00
|
|
|
expert_output = self.tp_process(dispatch_data)
|
|
|
|
else:
|
|
|
|
raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts "
|
|
|
|
"build function.")
|
|
|
|
# expert_output [e, c, h]
|
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
if self.use_kernel:
|
2022-02-27 14:28:39 +00:00
|
|
|
expert_output = expert_output.reshape(-1, self.d_model)
|
|
|
|
ans = MoeCombine.apply(expert_output, *router_res)
|
2022-02-18 12:42:31 +00:00
|
|
|
else:
|
2022-03-22 02:50:20 +00:00
|
|
|
combine_weights = router_res[0].type_as(inputs)
|
2022-02-18 12:42:31 +00:00
|
|
|
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
|
|
|
|
expert_output = expert_output.view(-1, expert_output.shape[-1])
|
2022-02-27 14:28:39 +00:00
|
|
|
ans = torch.matmul(combine_weights, expert_output)
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-02-27 14:28:39 +00:00
|
|
|
ans = ans.reshape(inputs.shape)
|
|
|
|
return ans
|
2022-03-22 08:48:22 +00:00
|
|
|
|
|
|
|
|
|
|
|
class MoeModule(nn.Module):
|
|
|
|
"""A class for users to create MoE modules in their models.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
dim_model (int): Hidden dimension of training model
|
|
|
|
num_experts (int): The number experts
|
|
|
|
top_k (int, optional): The number of experts for dispatchment of each token
|
|
|
|
capacity_factor_train (float, optional): Capacity factor in routing during training
|
|
|
|
capacity_factor_eval (float, optional): Capacity factor in routing during evaluation
|
|
|
|
min_capacity (int, optional): The minimum number of the capacity of each expert
|
|
|
|
noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'.
|
|
|
|
'Jitter' can be found in Switch Transformer paper (https://arxiv.org/abs/2101.03961).
|
|
|
|
'Gaussian' can be found in ViT-MoE paper (https://arxiv.org/abs/2106.05974).
|
|
|
|
drop_tks (bool, optional): Whether drops tokens in evaluation
|
|
|
|
use_residual (bool, optional): Makes this MoE layer a Residual MoE.
|
|
|
|
More information can be found in Microsoft paper (https://arxiv.org/abs/2201.05596).
|
|
|
|
residual_instance (nn.Module, optional): The instance of residual module in Resiual MoE
|
|
|
|
expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer
|
|
|
|
expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given
|
|
|
|
expert_args (optional): The args of expert when no instance is given
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
dim_model: int,
|
|
|
|
num_experts: int,
|
|
|
|
top_k: int = 1,
|
|
|
|
capacity_factor_train: float = 1.25,
|
|
|
|
capacity_factor_eval: float = 2.0,
|
|
|
|
min_capacity: int = 4,
|
|
|
|
noisy_policy: Optional[str] = None,
|
|
|
|
drop_tks: bool = True,
|
|
|
|
use_residual: bool = False,
|
|
|
|
residual_instance: Optional[nn.Module] = None,
|
|
|
|
expert_instance: Optional[MoeExperts] = None,
|
|
|
|
expert_cls: Optional[Type[nn.Module]] = None,
|
|
|
|
**expert_args):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
noisy_func = None
|
|
|
|
if noisy_policy is not None:
|
|
|
|
if noisy_policy == 'Jitter':
|
|
|
|
noisy_func = UniformNoiseGenerator()
|
|
|
|
elif noisy_policy == 'Gaussian':
|
|
|
|
noisy_func = NormalNoiseGenerator(num_experts)
|
|
|
|
else:
|
|
|
|
raise NotImplementedError("Unsupported input noisy policy")
|
|
|
|
|
|
|
|
if top_k == 1:
|
|
|
|
moe_router_cls = Top1Router
|
|
|
|
elif top_k == 2:
|
|
|
|
moe_router_cls = Top2Router
|
|
|
|
else:
|
|
|
|
raise NotImplementedError("top_k > 2 is not supported yet")
|
|
|
|
|
|
|
|
self.moe_router = moe_router_cls(capacity_factor_train=capacity_factor_train,
|
|
|
|
capacity_factor_eval=capacity_factor_eval,
|
|
|
|
min_capacity=min_capacity,
|
|
|
|
noisy_func=noisy_func,
|
|
|
|
drop_tks=drop_tks)
|
|
|
|
self.use_residual = use_residual
|
|
|
|
if use_residual:
|
|
|
|
if residual_instance is not None:
|
|
|
|
self.residual_module = residual_instance
|
|
|
|
else:
|
|
|
|
assert expert_cls is not None, \
|
|
|
|
"Expert class can't be None when residual instance is not given"
|
|
|
|
self.residual_module = expert_cls(**expert_args)
|
|
|
|
|
2022-03-29 09:57:59 +00:00
|
|
|
with no_shard_zero_context():
|
|
|
|
self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device())
|
2022-03-22 08:48:22 +00:00
|
|
|
|
|
|
|
if expert_instance is not None:
|
|
|
|
self.experts = expert_instance
|
|
|
|
else:
|
|
|
|
assert expert_cls is not None, \
|
|
|
|
"Expert class can't be None when experts instance is not given"
|
|
|
|
self.experts = Experts(expert_cls, num_experts, **expert_args)
|
|
|
|
|
|
|
|
self.moe_layer = MoeLayer(dim_model=dim_model,
|
|
|
|
num_experts=num_experts,
|
|
|
|
router=self.moe_router,
|
|
|
|
experts=self.experts)
|
|
|
|
|
|
|
|
def forward(self, inputs: torch.Tensor):
|
|
|
|
moe_output = self.moe_layer(inputs)
|
|
|
|
|
|
|
|
if self.use_residual:
|
|
|
|
residual_output = self.residual_module(inputs)
|
|
|
|
combine_coef = self.residual_combine(inputs)
|
|
|
|
combine_coef = F.softmax(combine_coef, dim=-1)
|
|
|
|
output = moe_output * combine_coef[..., 0:1] + residual_output * combine_coef[..., 1:]
|
|
|
|
else:
|
|
|
|
output = moe_output
|
|
|
|
|
|
|
|
return output
|