|
|
|
@ -11,6 +11,7 @@ from colossalai.utils import get_current_device
|
|
|
|
|
from ._operation import U_CUDA_MODE, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum |
|
|
|
|
from .experts import MoeExperts |
|
|
|
|
from .utils import autocast_softmax |
|
|
|
|
from typing import Callable |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Top1Router(nn.Module): |
|
|
|
@ -18,21 +19,35 @@ class Top1Router(nn.Module):
|
|
|
|
|
for routing usage. More deailted function can be found in the paper about Switch Transformer |
|
|
|
|
of Google. |
|
|
|
|
|
|
|
|
|
:param capacity_factor: Capacity factor in routing |
|
|
|
|
:param capacity_factor_train: Capacity factor in routing of training |
|
|
|
|
:param capacity_factor_eval: Capacity factor in routing of evaluation |
|
|
|
|
:param min_capacity: The minimum number of the capacity of each expert |
|
|
|
|
:param select_policy: The policy about tokens selection |
|
|
|
|
:param noisy_func: Noisy function used in logits |
|
|
|
|
:param drop_tks: Whether drops tokens in evaluation |
|
|
|
|
|
|
|
|
|
:type capacity_factor: float |
|
|
|
|
:type min_capacity: int |
|
|
|
|
:type capacity_factor_train: float, optional |
|
|
|
|
:type capacity_factor_eval: float, optional |
|
|
|
|
:type min_capacity: int, optional |
|
|
|
|
:type select_policy: str, optional |
|
|
|
|
:type noisy_func: Callable, optional |
|
|
|
|
:type drop_tks: bool, optional |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def __init__(self, capacity_factor: float, min_capacity: int = 0, select_policy: str = "first", noisy_func=None): |
|
|
|
|
def __init__(self, |
|
|
|
|
capacity_factor_train: float = 1.25, |
|
|
|
|
capacity_factor_eval: float = 2.0, |
|
|
|
|
min_capacity: int = 4, |
|
|
|
|
select_policy: str = "first", |
|
|
|
|
noisy_func: Callable = None, |
|
|
|
|
drop_tks: bool = True): |
|
|
|
|
super().__init__() |
|
|
|
|
self.capacity_factor = capacity_factor |
|
|
|
|
self.capacity_factor_train = capacity_factor_train |
|
|
|
|
self.capacity_factor_eval = capacity_factor_eval |
|
|
|
|
self.min_capacity = min_capacity |
|
|
|
|
self.select_policy = select_policy |
|
|
|
|
self.noisy_func = noisy_func |
|
|
|
|
self.drop_tks = drop_tks |
|
|
|
|
|
|
|
|
|
assert select_policy in {"first", "random"} |
|
|
|
|
if select_policy == "random": |
|
|
|
@ -44,7 +59,8 @@ class Top1Router(nn.Module):
|
|
|
|
|
self, |
|
|
|
|
logits_shape, |
|
|
|
|
): |
|
|
|
|
capacity = math.floor(self.capacity_factor * logits_shape[-2] / logits_shape[-1]) |
|
|
|
|
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval |
|
|
|
|
capacity = math.floor(capacity_factor * logits_shape[-2] / logits_shape[-1]) |
|
|
|
|
capacity += capacity % 2 |
|
|
|
|
capacity = max(capacity, self.min_capacity) |
|
|
|
|
assert capacity > 0 |
|
|
|
@ -53,15 +69,13 @@ class Top1Router(nn.Module):
|
|
|
|
|
def forward(self, inputs: torch.Tensor, cuda_mode: bool = False): |
|
|
|
|
|
|
|
|
|
if self.noisy_func is not None and self.training: |
|
|
|
|
inputs_noisy = self.noisy_func(inputs) |
|
|
|
|
else: |
|
|
|
|
inputs_noisy = inputs |
|
|
|
|
inputs = self.noisy_func(inputs) |
|
|
|
|
|
|
|
|
|
logits = autocast_softmax(inputs, dim=-1) |
|
|
|
|
num_experts = logits.size(-1) |
|
|
|
|
capacity = self.get_capacity(logits.shape) |
|
|
|
|
|
|
|
|
|
top1_idx = torch.argmax(inputs_noisy, dim=-1) |
|
|
|
|
top1_idx = torch.argmax(inputs, dim=-1) |
|
|
|
|
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) |
|
|
|
|
|
|
|
|
|
if self.training: |
|
|
|
@ -69,14 +83,14 @@ class Top1Router(nn.Module):
|
|
|
|
|
ce = torch.mean(mask.float(), dim=0) |
|
|
|
|
l_aux = num_experts * torch.sum(me * ce) |
|
|
|
|
moe_env.add_loss(l_aux) |
|
|
|
|
else: |
|
|
|
|
elif not self.drop_tks: |
|
|
|
|
max_num = torch.max(torch.sum(mask, dim=0)) |
|
|
|
|
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MOE_MODEL)) |
|
|
|
|
capacity = max_num.item() |
|
|
|
|
else: |
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
if not self.training: |
|
|
|
|
ranks = moe_cumsum(mask) |
|
|
|
|
elif self.select_policy == "random": |
|
|
|
|
if self.select_policy == "random": |
|
|
|
|
rand_mask = mask * self.uniform(mask.shape) |
|
|
|
|
_, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0) |
|
|
|
|
mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1) |
|
|
|
@ -106,21 +120,40 @@ class Top2Router(nn.Module):
|
|
|
|
|
"""Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] |
|
|
|
|
for routing usage. More deailted function can be found in the paper about ViT-MoE. |
|
|
|
|
|
|
|
|
|
:param capacity_factor: Capacity factor in routing |
|
|
|
|
:param capacity_factor_train: Capacity factor in routing of training |
|
|
|
|
:param capacity_factor_eval: Capacity factor in routing of evaluation |
|
|
|
|
:param min_capacity: The minimum number of the capacity of each expert |
|
|
|
|
:param noisy_func: Noisy function used in logits |
|
|
|
|
:param drop_tks: Whether drops tokens in evaluation |
|
|
|
|
|
|
|
|
|
:type capacity_factor: float |
|
|
|
|
:type capacity_factor_train: float, optional |
|
|
|
|
:type capacity_factor_eval: float, optional |
|
|
|
|
:type min_capacity: int, optional |
|
|
|
|
:type noisy_func: Callable, optional |
|
|
|
|
:type drop_tks: bool, optional |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def __init__(self, capacity_factor: float, noisy_func=None): |
|
|
|
|
def __init__(self, |
|
|
|
|
capacity_factor_train: float = 1.25, |
|
|
|
|
capacity_factor_eval: float = 2.0, |
|
|
|
|
min_capacity: int = 4, |
|
|
|
|
noisy_func: Callable = None, |
|
|
|
|
drop_tks: bool = True): |
|
|
|
|
super().__init__() |
|
|
|
|
self.capacity_factor = capacity_factor |
|
|
|
|
self.capacity_factor_train = capacity_factor_train |
|
|
|
|
self.capacity_factor_eval = capacity_factor_eval |
|
|
|
|
self.min_capacity = min_capacity |
|
|
|
|
self.noisy_func = noisy_func |
|
|
|
|
self.drop_tks = drop_tks |
|
|
|
|
|
|
|
|
|
def get_capacity(self, logits_shape): |
|
|
|
|
capacity = math.floor(2 * self.capacity_factor * logits_shape[-2] / logits_shape[-1]) |
|
|
|
|
def get_capacity( |
|
|
|
|
self, |
|
|
|
|
logits_shape, |
|
|
|
|
): |
|
|
|
|
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval |
|
|
|
|
capacity = math.floor(capacity_factor * logits_shape[-2] / logits_shape[-1]) |
|
|
|
|
capacity += capacity % 2 |
|
|
|
|
capacity = max(capacity, self.min_capacity) |
|
|
|
|
assert capacity > 0 |
|
|
|
|
return capacity |
|
|
|
|
|
|
|
|
@ -143,12 +176,14 @@ class Top2Router(nn.Module):
|
|
|
|
|
if self.training: |
|
|
|
|
me = torch.mean(logits, dim=0) |
|
|
|
|
ce = torch.mean(cmask.float(), dim=0) |
|
|
|
|
l_aux = num_experts * torch.sum(me * ce) / 2.0 |
|
|
|
|
l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1 |
|
|
|
|
moe_env.add_loss(l_aux) |
|
|
|
|
else: |
|
|
|
|
elif not self.drop_tks: |
|
|
|
|
max_num = torch.max(torch.sum(cmask, dim=0)) |
|
|
|
|
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MOE_MODEL)) |
|
|
|
|
capacity = max_num.item() |
|
|
|
|
else: |
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
rank1 = moe_cumsum(mask1) # rank1: [s, e] |
|
|
|
|
rank2 = moe_cumsum(mask2) |
|
|
|
|