2022-01-07 07:08:36 +00:00
|
|
|
import math
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
2022-02-18 12:42:31 +00:00
|
|
|
import torch.distributed as dist
|
2022-03-19 07:36:25 +00:00
|
|
|
from colossalai.core import MOE_CONTEXT
|
2022-01-07 07:08:36 +00:00
|
|
|
from colossalai.utils import get_current_device
|
2022-03-19 07:36:25 +00:00
|
|
|
from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum
|
2022-02-27 14:28:39 +00:00
|
|
|
from .experts import MoeExperts
|
2022-02-18 12:42:31 +00:00
|
|
|
from .utils import autocast_softmax
|
2022-03-19 07:36:25 +00:00
|
|
|
from typing import Callable, Optional
|
|
|
|
from torch.distributed import ProcessGroup
|
2022-01-07 07:08:36 +00:00
|
|
|
|
|
|
|
|
|
|
|
class Top1Router(nn.Module):
|
|
|
|
"""Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
|
|
|
|
for routing usage. More deailted function can be found in the paper about Switch Transformer
|
|
|
|
of Google.
|
2022-01-21 02:44:30 +00:00
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
:param capacity_factor_train: Capacity factor in routing during training
|
|
|
|
:param capacity_factor_eval: Capacity factor in routing during evaluation
|
2022-01-21 02:44:30 +00:00
|
|
|
:param min_capacity: The minimum number of the capacity of each expert
|
2022-03-16 08:47:44 +00:00
|
|
|
:param select_policy: The policy about tokens selection
|
2022-01-21 02:44:30 +00:00
|
|
|
:param noisy_func: Noisy function used in logits
|
2022-03-16 08:47:44 +00:00
|
|
|
:param drop_tks: Whether drops tokens in evaluation
|
2022-01-21 02:44:30 +00:00
|
|
|
|
2022-03-16 08:47:44 +00:00
|
|
|
:type capacity_factor_train: float, optional
|
|
|
|
:type capacity_factor_eval: float, optional
|
|
|
|
:type min_capacity: int, optional
|
|
|
|
:type select_policy: str, optional
|
2022-01-21 02:44:30 +00:00
|
|
|
:type noisy_func: Callable, optional
|
2022-03-16 08:47:44 +00:00
|
|
|
:type drop_tks: bool, optional
|
2022-01-07 07:08:36 +00:00
|
|
|
"""
|
|
|
|
|
2022-03-16 08:47:44 +00:00
|
|
|
def __init__(self,
|
|
|
|
capacity_factor_train: float = 1.25,
|
|
|
|
capacity_factor_eval: float = 2.0,
|
|
|
|
min_capacity: int = 4,
|
|
|
|
select_policy: str = "first",
|
|
|
|
noisy_func: Callable = None,
|
|
|
|
drop_tks: bool = True):
|
2022-01-07 07:08:36 +00:00
|
|
|
super().__init__()
|
2022-03-16 08:47:44 +00:00
|
|
|
self.capacity_factor_train = capacity_factor_train
|
|
|
|
self.capacity_factor_eval = capacity_factor_eval
|
2022-01-07 07:08:36 +00:00
|
|
|
self.min_capacity = min_capacity
|
2022-02-18 12:42:31 +00:00
|
|
|
self.select_policy = select_policy
|
2022-01-07 07:08:36 +00:00
|
|
|
self.noisy_func = noisy_func
|
2022-03-16 08:47:44 +00:00
|
|
|
self.drop_tks = drop_tks
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-02-18 12:42:31 +00:00
|
|
|
assert select_policy in {"first", "random"}
|
|
|
|
if select_policy == "random":
|
|
|
|
self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()),
|
|
|
|
high=torch.tensor(1.0,
|
|
|
|
device=get_current_device())).rsample
|
|
|
|
|
|
|
|
def get_capacity(
|
|
|
|
self,
|
|
|
|
logits_shape,
|
|
|
|
):
|
2022-03-16 08:47:44 +00:00
|
|
|
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
|
|
|
|
capacity = math.floor(capacity_factor * logits_shape[-2] / logits_shape[-1])
|
2022-02-18 12:42:31 +00:00
|
|
|
capacity += capacity % 2
|
|
|
|
capacity = max(capacity, self.min_capacity)
|
|
|
|
assert capacity > 0
|
2022-01-07 07:08:36 +00:00
|
|
|
return capacity
|
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None):
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-03-15 04:06:09 +00:00
|
|
|
if self.noisy_func is not None and self.training:
|
2022-03-16 08:47:44 +00:00
|
|
|
inputs = self.noisy_func(inputs)
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-02-18 12:42:31 +00:00
|
|
|
logits = autocast_softmax(inputs, dim=-1)
|
|
|
|
num_experts = logits.size(-1)
|
2022-01-07 07:08:36 +00:00
|
|
|
capacity = self.get_capacity(logits.shape)
|
|
|
|
|
2022-03-16 08:47:44 +00:00
|
|
|
top1_idx = torch.argmax(inputs, dim=-1)
|
2022-02-18 12:42:31 +00:00
|
|
|
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-02-18 12:42:31 +00:00
|
|
|
if self.training:
|
|
|
|
me = torch.mean(logits, dim=0)
|
|
|
|
ce = torch.mean(mask.float(), dim=0)
|
|
|
|
l_aux = num_experts * torch.sum(me * ce)
|
2022-03-19 07:36:25 +00:00
|
|
|
MOE_CONTEXT.add_loss(l_aux)
|
2022-03-16 08:47:44 +00:00
|
|
|
elif not self.drop_tks:
|
2022-02-18 12:42:31 +00:00
|
|
|
max_num = torch.max(torch.sum(mask, dim=0))
|
2022-03-19 07:36:25 +00:00
|
|
|
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
|
2022-02-18 12:42:31 +00:00
|
|
|
capacity = max_num.item()
|
2022-03-16 08:47:44 +00:00
|
|
|
else:
|
|
|
|
pass
|
2022-02-18 12:42:31 +00:00
|
|
|
|
2022-03-16 08:47:44 +00:00
|
|
|
if self.select_policy == "random":
|
2022-02-18 12:42:31 +00:00
|
|
|
rand_mask = mask * self.uniform(mask.shape)
|
|
|
|
_, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0)
|
|
|
|
mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1)
|
|
|
|
ranks = moe_cumsum(mask)
|
|
|
|
elif self.select_policy == "first":
|
|
|
|
ranks = moe_cumsum(mask)
|
|
|
|
mask = mask * torch.lt(ranks, capacity)
|
|
|
|
else:
|
|
|
|
raise NotImplementedError("Not support such select policy yet.")
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-02-18 12:42:31 +00:00
|
|
|
ranks = torch.sum(mask * ranks, dim=-1)
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
if use_kernel:
|
2022-02-18 12:42:31 +00:00
|
|
|
mask = torch.sum(mask, dim=-1)
|
|
|
|
mask = torch.stack([mask], dim=0).to(torch.int32)
|
|
|
|
dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32)
|
|
|
|
return logits, mask, dest_idx, num_experts * capacity
|
|
|
|
else:
|
|
|
|
ranks = F.one_hot(ranks, num_classes=capacity)
|
|
|
|
weight = mask * logits.type_as(inputs)
|
|
|
|
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
|
|
|
|
sec_mask = combine_weights.bool()
|
|
|
|
return combine_weights, sec_mask
|
2022-01-07 07:08:36 +00:00
|
|
|
|
|
|
|
|
|
|
|
class Top2Router(nn.Module):
|
|
|
|
"""Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
|
|
|
|
for routing usage. More deailted function can be found in the paper about ViT-MoE.
|
2022-01-21 02:44:30 +00:00
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
:param capacity_factor_train: Capacity factor in routing during training
|
|
|
|
:param capacity_factor_eval: Capacity factor in routing during evaluation
|
2022-03-16 08:47:44 +00:00
|
|
|
:param min_capacity: The minimum number of the capacity of each expert
|
2022-01-21 02:44:30 +00:00
|
|
|
:param noisy_func: Noisy function used in logits
|
2022-03-16 08:47:44 +00:00
|
|
|
:param drop_tks: Whether drops tokens in evaluation
|
2022-01-21 02:44:30 +00:00
|
|
|
|
2022-03-16 08:47:44 +00:00
|
|
|
:type capacity_factor_train: float, optional
|
|
|
|
:type capacity_factor_eval: float, optional
|
|
|
|
:type min_capacity: int, optional
|
2022-01-21 02:44:30 +00:00
|
|
|
:type noisy_func: Callable, optional
|
2022-03-16 08:47:44 +00:00
|
|
|
:type drop_tks: bool, optional
|
2022-01-07 07:08:36 +00:00
|
|
|
"""
|
|
|
|
|
2022-03-16 08:47:44 +00:00
|
|
|
def __init__(self,
|
|
|
|
capacity_factor_train: float = 1.25,
|
|
|
|
capacity_factor_eval: float = 2.0,
|
|
|
|
min_capacity: int = 4,
|
|
|
|
noisy_func: Callable = None,
|
|
|
|
drop_tks: bool = True):
|
2022-01-07 07:08:36 +00:00
|
|
|
super().__init__()
|
2022-03-16 08:47:44 +00:00
|
|
|
self.capacity_factor_train = capacity_factor_train
|
|
|
|
self.capacity_factor_eval = capacity_factor_eval
|
|
|
|
self.min_capacity = min_capacity
|
2022-01-07 07:08:36 +00:00
|
|
|
self.noisy_func = noisy_func
|
2022-03-16 08:47:44 +00:00
|
|
|
self.drop_tks = drop_tks
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-03-16 08:47:44 +00:00
|
|
|
def get_capacity(
|
|
|
|
self,
|
|
|
|
logits_shape,
|
|
|
|
):
|
|
|
|
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
|
|
|
|
capacity = math.floor(capacity_factor * logits_shape[-2] / logits_shape[-1])
|
2022-02-18 12:42:31 +00:00
|
|
|
capacity += capacity % 2
|
2022-03-16 08:47:44 +00:00
|
|
|
capacity = max(capacity, self.min_capacity)
|
2022-02-18 12:42:31 +00:00
|
|
|
assert capacity > 0
|
2022-01-07 07:08:36 +00:00
|
|
|
return capacity
|
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None):
|
2022-02-18 12:42:31 +00:00
|
|
|
# inputs: [s, h]
|
2022-03-15 04:06:09 +00:00
|
|
|
if self.noisy_func is not None and self.training:
|
2022-01-07 07:08:36 +00:00
|
|
|
inputs = self.noisy_func(inputs)
|
|
|
|
|
2022-02-18 12:42:31 +00:00
|
|
|
logits = autocast_softmax(inputs, dim=-1) # logits: [s, e]
|
2022-01-07 07:08:36 +00:00
|
|
|
num_experts = logits.size(-1)
|
|
|
|
capacity = self.get_capacity(logits.shape)
|
|
|
|
|
2022-02-18 12:42:31 +00:00
|
|
|
top1_idx = torch.argmax(logits, dim=-1)
|
|
|
|
mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
|
|
|
|
logits_except1 = logits.masked_fill(mask1.bool(), float("-inf"))
|
|
|
|
top2_idx = torch.argmax(logits_except1, dim=-1)
|
|
|
|
mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
|
|
|
|
|
|
|
|
cmask = (mask1 + mask2) # loss: [s, e]
|
|
|
|
if self.training:
|
|
|
|
me = torch.mean(logits, dim=0)
|
|
|
|
ce = torch.mean(cmask.float(), dim=0)
|
2022-03-16 08:47:44 +00:00
|
|
|
l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1
|
2022-03-19 07:36:25 +00:00
|
|
|
MOE_CONTEXT.add_loss(l_aux)
|
2022-03-16 08:47:44 +00:00
|
|
|
elif not self.drop_tks:
|
2022-02-18 12:42:31 +00:00
|
|
|
max_num = torch.max(torch.sum(cmask, dim=0))
|
2022-03-19 07:36:25 +00:00
|
|
|
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
|
2022-02-18 12:42:31 +00:00
|
|
|
capacity = max_num.item()
|
2022-03-16 08:47:44 +00:00
|
|
|
else:
|
|
|
|
pass
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-02-18 12:42:31 +00:00
|
|
|
rank1 = moe_cumsum(mask1) # rank1: [s, e]
|
|
|
|
rank2 = moe_cumsum(mask2)
|
|
|
|
rank2 += torch.sum(mask1, dim=-2, keepdim=True)
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-02-18 12:42:31 +00:00
|
|
|
mask1 *= torch.lt(rank1, capacity)
|
|
|
|
mask2 *= torch.lt(rank2, capacity)
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-02-18 12:42:31 +00:00
|
|
|
rank1 = torch.sum(mask1 * rank1, dim=-1)
|
|
|
|
rank2 = torch.sum(mask2 * rank2, dim=-1)
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
if use_kernel:
|
2022-02-18 12:42:31 +00:00
|
|
|
mask1 = torch.sum(mask1, dim=-1)
|
|
|
|
mask2 = torch.sum(mask2, dim=-1)
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-02-18 12:42:31 +00:00
|
|
|
mask = torch.stack([mask1, mask2], dim=0).to(torch.int32)
|
|
|
|
dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32)
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-02-18 12:42:31 +00:00
|
|
|
return logits, mask, dest_idx, num_experts * capacity
|
|
|
|
else:
|
|
|
|
weight1 = mask1 * logits.type_as(inputs)
|
|
|
|
weight2 = mask2 * logits.type_as(inputs)
|
|
|
|
rank1_sc = F.one_hot(rank1, num_classes=capacity)
|
|
|
|
rank2_sc = F.one_hot(rank2, num_classes=capacity)
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-02-18 12:42:31 +00:00
|
|
|
cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1)
|
|
|
|
cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1)
|
|
|
|
cb_weight = cb_weight1 + cb_weight2
|
|
|
|
sec_mask = cb_weight.bool()
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-02-18 12:42:31 +00:00
|
|
|
return cb_weight, sec_mask
|
2022-01-07 07:08:36 +00:00
|
|
|
|
|
|
|
|
|
|
|
class MoeLayer(nn.Module):
|
|
|
|
"""A MoE layer, that puts its input tensor to its gate and uses the output logits
|
|
|
|
to router all tokens, is mainly used to exchange all tokens for every expert across
|
|
|
|
the moe tensor group by all to all comunication. Then it will get the output of all
|
|
|
|
experts and exchange the output. At last returns the output of the moe system.
|
2022-01-21 02:44:30 +00:00
|
|
|
|
|
|
|
:param dim_model: Dimension of model
|
|
|
|
:param num_experts: The number of experts
|
|
|
|
:param router: Instance of router used in routing
|
|
|
|
:param experts: Instance of experts generated by Expert
|
|
|
|
|
|
|
|
:type dim_model: int
|
|
|
|
:type num_experts: int
|
|
|
|
:type router: nn.Module
|
|
|
|
:type experts: nn.Module
|
2022-01-07 07:08:36 +00:00
|
|
|
"""
|
|
|
|
|
2022-02-27 14:28:39 +00:00
|
|
|
def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: MoeExperts):
|
2022-01-07 07:08:36 +00:00
|
|
|
super().__init__()
|
|
|
|
self.d_model = dim_model
|
|
|
|
self.num_experts = num_experts
|
2022-02-18 12:42:31 +00:00
|
|
|
self.gate = nn.Linear(dim_model, num_experts, bias=False, device=get_current_device())
|
2022-01-07 07:08:36 +00:00
|
|
|
self.router = router
|
|
|
|
self.experts = experts
|
2022-03-19 07:36:25 +00:00
|
|
|
self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False
|
|
|
|
self.ep_group = experts.dist_info.ep_group
|
|
|
|
self.ep_size = experts.dist_info.ep_size
|
|
|
|
self.num_local_experts = experts.num_local_experts
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-02-27 14:28:39 +00:00
|
|
|
def a2a_process(self, dispatch_data: torch.Tensor):
|
2022-03-19 07:36:25 +00:00
|
|
|
expert_input = AllToAll.apply(dispatch_data, self.ep_group)
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-02-18 12:42:31 +00:00
|
|
|
input_shape = expert_input.shape
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.d_model)
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-02-18 12:42:31 +00:00
|
|
|
expert_output = self.experts(expert_input)
|
|
|
|
expert_output = expert_output.reshape(input_shape)
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
expert_output = AllToAll.apply(expert_output, self.ep_group)
|
2022-02-18 12:42:31 +00:00
|
|
|
return expert_output
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-02-27 14:28:39 +00:00
|
|
|
def tp_process(self, dispatch_data: torch.Tensor):
|
2022-03-19 07:36:25 +00:00
|
|
|
expert_in = AllGather.apply(dispatch_data, self.ep_group)
|
2022-02-27 14:28:39 +00:00
|
|
|
expert_out = self.experts(expert_in)
|
2022-03-19 07:36:25 +00:00
|
|
|
expert_out = ReduceScatter.apply(expert_out, self.ep_group)
|
2022-02-27 14:28:39 +00:00
|
|
|
return expert_out
|
|
|
|
|
2022-02-18 12:42:31 +00:00
|
|
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
|
|
tokens = inputs.reshape(-1, self.d_model)
|
|
|
|
gate_output = self.gate(tokens)
|
2022-03-19 07:36:25 +00:00
|
|
|
router_res = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group)
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
if self.use_kernel:
|
2022-02-27 14:28:39 +00:00
|
|
|
dispatch_data = MoeDispatch.apply(tokens, *router_res[1:])
|
|
|
|
dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model)
|
|
|
|
else:
|
|
|
|
sec_mask_f = router_res[1].type_as(inputs)
|
|
|
|
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
|
|
|
|
|
|
|
|
# dispatch_data [e, c, h]
|
2022-03-19 07:36:25 +00:00
|
|
|
if self.experts.comm_name == "all_to_all":
|
2022-02-27 14:28:39 +00:00
|
|
|
expert_output = self.a2a_process(dispatch_data)
|
2022-03-19 07:36:25 +00:00
|
|
|
elif self.experts.comm_name == "all_gather":
|
2022-02-27 14:28:39 +00:00
|
|
|
expert_output = self.tp_process(dispatch_data)
|
|
|
|
else:
|
|
|
|
raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts "
|
|
|
|
"build function.")
|
|
|
|
# expert_output [e, c, h]
|
|
|
|
|
2022-03-19 07:36:25 +00:00
|
|
|
if self.use_kernel:
|
2022-02-27 14:28:39 +00:00
|
|
|
expert_output = expert_output.reshape(-1, self.d_model)
|
|
|
|
ans = MoeCombine.apply(expert_output, *router_res)
|
2022-02-18 12:42:31 +00:00
|
|
|
else:
|
2022-02-27 14:28:39 +00:00
|
|
|
combine_weights = router_res[0]
|
2022-02-18 12:42:31 +00:00
|
|
|
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
|
|
|
|
expert_output = expert_output.view(-1, expert_output.shape[-1])
|
2022-02-27 14:28:39 +00:00
|
|
|
ans = torch.matmul(combine_weights, expert_output)
|
2022-01-07 07:08:36 +00:00
|
|
|
|
2022-02-27 14:28:39 +00:00
|
|
|
ans = ans.reshape(inputs.shape)
|
|
|
|
return ans
|