ColossalAI/colossalai/moe/routers.py

442 lines
19 KiB
Python

import math
from abc import ABC
from typing import Callable, Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed import ProcessGroup
from colossalai.accelerator import get_accelerator
from colossalai.moe._operation import moe_cumsum
from colossalai.moe.manager import MOE_MANAGER
class MoeRouter(nn.Module, ABC):
"""Base class for all MoE routers.
Args:
k_value (int): The value of top_k.
capacity_factor_train (float): Capacity factor in routing of training.
capacity_factor_eval (float): Capacity factor in routing of evaluation.
min_capacity (int): The minimum number of the capacity of each expert.
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
drop_tks (bool, optional): Whether drops tokens in evaluation
"""
def __init__(
self,
k_value: int,
capacity_factor_train: float,
capacity_factor_eval: float,
min_capacity: int,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True,
use_kernel: bool = False,
):
super().__init__()
self.k_value = k_value
self.capacity_factor_train = capacity_factor_train
self.capacity_factor_eval = capacity_factor_eval
self.min_capacity = min_capacity
self.noisy_func = noisy_func
self.drop_tks = drop_tks
self._aux_loss = None
self._z_loss = None
self.use_kernel = use_kernel
def get_capacity(self, logits_shape):
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1])
capacity += capacity % 2
capacity = max(capacity, self.min_capacity)
assert capacity > 0
return int(capacity)
def set_aux_loss(self, router_probs: torch.Tensor, expert_indices: torch.Tensor, num_experts: int) -> None:
"""Computes auxiliary load balancing loss as in Switch Transformer.
See Switch Transformer (https://arxiv.org/abs/2101.03961). This function
implements the loss function presented in equations (4) - (6). It aims to
penalize those cases where the routing between experts is unbalanced.
Args:
router_probs: Probability assigned to each expert per token. Shape:
<float32>[num_groups, tokens_per_group, num_experts].
expert_indices: <int>[num_groups, tokens_per_group, num_selected_experts]
indices identifying the top num_selected_experts for a given token.
"""
assert self._aux_loss is None
if router_probs.dim() == expert_indices.dim() == 2:
router_probs = router_probs.unsqueeze(0)
expert_indices = expert_indices.unsqueeze(0)
assert (
router_probs.dim() == expert_indices.dim() == 3
), "router_probs must be 3D tensor and expert_indices must be 4D tensor"
# Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
expert_mask = F.one_hot(expert_indices, num_experts)
# For a given token, determine if it was routed to a given expert.
# Shape: [num_groups, tokens_per_group, num_experts]
expert_mask = expert_mask.max(dim=-2)[0]
tokens_per_group_and_expert = torch.mean(expert_mask.float(), dim=-2)
router_prob_per_group_and_expert = torch.mean(router_probs.float(), dim=-2)
aux_loss = num_experts**2 * torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert)
self._aux_loss = aux_loss
def set_z_loss(self, router_logits: torch.Tensor):
"""Compute router z-loss.
The router z-loss was introduced in Designing Effective Sparse Expert Models
(https://arxiv.org/abs/2202.08906). It encourages router logits to remain
small in an effort to improve stability.
Args:
router_logits: <float>[num_groups, tokens_per_group, num_experts] router logits.
"""
assert self._z_loss is None
if router_logits.dim() == 2:
router_logits = router_logits.unsqueeze(0)
assert router_logits.dim() == 3, "router_logits must be 3D tensor"
num_groups, tokens_per_group, _ = router_logits.shape
log_z = torch.logsumexp(router_logits, dim=-1)
z_loss = torch.sum(log_z**2, dtype=torch.float32) / (num_groups * tokens_per_group)
self._z_loss = z_loss
def pop_router_loss(self) -> torch.Tensor:
assert self._aux_loss is not None
MOE_MANAGER.add_loss(self._aux_loss, self._z_loss)
self._aux_loss = None
self._z_loss = None
class Top1Router(MoeRouter):
"""Top1 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity)
and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed
function can be found in the paper about Switch Transformer of Google.
Args:
capacity_factor_train (float, optional): Capacity factor in routing of training.
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
min_capacity (int, optional): The minimum number of the capacity of each expert.
select_policy (str, optional): The policy about tokens selection.
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
drop_tks (bool, optional): Whether drops tokens in evaluation
"""
def __init__(
self,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
select_policy: str = "first",
noisy_func: Optional[Callable] = None,
drop_tks: bool = True,
):
super().__init__(
k_value=1,
capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks,
)
self.select_policy = select_policy
assert select_policy in {"first", "random"}
if select_policy == "random":
self.uniform = torch.distributions.uniform.Uniform(
low=torch.tensor(0.0, device=get_accelerator().get_current_device()),
high=torch.tensor(1.0, device=get_accelerator().get_current_device()),
).rsample
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
"""
Args:
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
Returns:
1. use_kernel is False:
The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity).
The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity).
2. use_kernel is True:
...
"""
if self.noisy_func is not None and self.training:
inputs = self.noisy_func(inputs)
assert inputs.dtype == torch.float
probs = F.softmax(inputs, dim=-1)
num_experts = probs.size(-1)
capacity = self.get_capacity(inputs.shape)
top1_idx = torch.argmax(inputs, dim=-1)
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
# calculate router loss
self.set_aux_loss(probs, top1_idx.unsqueeze(-1), num_experts)
self.set_z_loss(inputs)
self.pop_router_loss()
if not self.training and not self.drop_tks and ep_group is not None:
max_num = torch.max(torch.sum(mask, dim=0))
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
capacity = max_num.item()
if self.select_policy == "random":
rand_mask = mask * self.uniform(mask.shape)
_, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0)
mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1)
ranks = moe_cumsum(mask, use_kernel=self.use_kernel)
elif self.select_policy == "first":
ranks = moe_cumsum(mask, use_kernel=self.use_kernel)
mask = mask * torch.lt(ranks, capacity)
else:
raise NotImplementedError("Not support such select policy yet.")
ranks = torch.sum(mask * ranks, dim=-1)
used_capacity = mask.sum(dim=0)
if use_kernel:
mask = torch.sum(mask, dim=-1)
mask = torch.stack([mask], dim=0).to(torch.int32)
dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32)
return used_capacity, probs, mask, dest_idx, num_experts * capacity
else:
ranks = F.one_hot(ranks, num_classes=capacity)
weight = mask * probs.type_as(inputs)
combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
sec_mask = combine_weights.bool()
return used_capacity, combine_weights, sec_mask
class Top2Router(MoeRouter):
"""Top2 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity)
and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed
function can be found in the paper about ViT-MoE.
Args:
capacity_factor_train (float, optional): Capacity factor in routing of training.
capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
min_capacity (int, optional): The minimum number of the capacity of each expert
noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
drop_tks (bool, optional): Whether drops tokens in evaluation.
"""
def __init__(
self,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True,
):
super().__init__(
k_value=2,
capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks,
)
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
"""
Args:
inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
Returns:
1. use_kernel is False:
The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity).
The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity).
2. use_kernel is True:
...
"""
if self.noisy_func is not None and self.training:
inputs = self.noisy_func(inputs)
assert inputs.dtype == torch.float
probs = F.softmax(inputs, dim=-1)
num_experts = probs.size(-1)
capacity = self.get_capacity(inputs.shape)
top1_idx = torch.argmax(probs, dim=-1)
mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
logits_except1 = probs.masked_fill(mask1.bool(), float("-inf"))
top2_idx = torch.argmax(logits_except1, dim=-1)
mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
cmask = mask1 + mask2 # loss: [s, e]
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1
# calculate loss
expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
self.set_aux_loss(probs, expert_indices, num_experts)
self.set_z_loss(inputs)
self.pop_router_loss()
if not self.training and not self.drop_tks and ep_group is not None:
max_num = torch.max(torch.sum(cmask, dim=0))
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
capacity = max_num.item()
rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e]
rank2 = moe_cumsum(mask2, use_kernel=self.use_kernel)
rank2 += torch.sum(mask1, dim=-2, keepdim=True)
mask1 *= torch.lt(rank1, capacity)
mask2 *= torch.lt(rank2, capacity)
used_capacity = mask1.sum(dim=0) + mask2.sum(dim=0)
rank1 = torch.sum(mask1 * rank1, dim=-1)
rank2 = torch.sum(mask2 * rank2, dim=-1)
if use_kernel:
mask1 = torch.sum(mask1, dim=-1)
mask2 = torch.sum(mask2, dim=-1)
mask = torch.stack([mask1, mask2], dim=0).to(torch.int32)
dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32)
return used_capacity, probs, mask, dest_idx, num_experts * capacity
else:
"""
The following code is equivalent to:
```
weight1 = mask1 * probs.type_as(inputs)
weight2 = mask2 * probs.type_as(inputs)
rank1_sc = F.one_hot(rank1, num_classes=capacity)
rank2_sc = F.one_hot(rank2, num_classes=capacity)
cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1)
cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1)
cb_weight = cb_weight1 + cb_weight2
sec_mask = cb_weight.bool()
```
"""
weight1 = mask1 * probs.type_as(inputs)
weight2 = mask2 * probs.type_as(inputs)
cb_weight = torch.zeros(inputs.shape + (capacity,), device=inputs.device)
sec_mask = torch.zeros_like(cb_weight, dtype=torch.bool)
indices = torch.arange(0, inputs.shape[0], device=inputs.device)
cb_weight[indices, top1_idx[indices], rank1[indices]] += weight1[indices, top1_idx[indices]]
cb_weight[indices, top2_idx[indices], rank2[indices]] += weight2[indices, top2_idx[indices]]
sec_mask[indices, top1_idx[indices], rank1[indices]] |= mask1.bool()[indices, top1_idx[indices]]
sec_mask[indices, top2_idx[indices], rank2[indices]] |= mask2.bool()[indices, top2_idx[indices]]
return used_capacity, cb_weight, sec_mask
class TopKRouter(MoeRouter):
"""Masked matmul router using tokens choose top-k experts assignment.
NOTE: this is modified from flaxformer.
This router uses the same mechanism as in Switch Transformer
(https://arxiv.org/abs/2101.03961) and V-MoE
(https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are
sorted by router_probs and then routed to their choice of expert until the
expert's expert_capacity is reached. There is no guarantee that each token is
processed by an expert, or that each expert receives at least one token.
Attributes:
num_selected_experts: Maximum number of experts to which each token is
routed. Tokens may be routed to fewer experts if particular experts are
oversubscribed / reach capacity.
"""
def __init__(
self,
num_selected_experts: int,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True,
):
super().__init__(
num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, drop_tks
)
def forward(
self,
router_probs: torch.Tensor,
expert_capacity: int,
) -> Tuple:
"""Computes masks for the top-k experts per token.
Args:
router_probs: <float32>[num_groups, tokens_per_group, num_experts]
probabilities used to determine the routing of tokens to the experts.
Returns:
Dispatch and combine arrays for routing with masked matmuls.
"""
# TODO: FIXME: add parallel group
num_groups, _, num_experts = router_probs.shape
# Top-k router probability and corresponding expert indices for each token.
# Shape: [num_groups, tokens_per_group, num_selected_experts].
expert_gate, expert_index = torch.topk(router_probs, self.k_value)
self.set_aux_loss(router_probs, expert_index, num_experts)
self.pop_router_loss()
# Make num_selected_experts the leading axis to ensure that top-1 choices
# have priority over top-2 choices, which have priority over top-3 choices,
# etc.
expert_index = torch.transpose(expert_index, 1, 2)
# Shape: [num_groups, num_selected_experts * tokens_per_group]
expert_index = expert_index.reshape(num_groups, -1)
# Create mask out of indices.
# Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts].
expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32)
# Experts have a fixed capacity that we cannot exceed. A token's priority
# within the expert's buffer is given by the masked, cumulative capacity of
# its target expert.
# Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts].
token_priority = torch.cumsum(expert_mask, dim=1) * expert_mask - 1
# Shape: [num_groups, num_selected_experts, tokens_per_group, num_experts].
token_priority = token_priority.reshape((num_groups, self.k_value, -1, num_experts))
# Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
token_priority = torch.transpose(token_priority, 1, 2)
# For each token, across all selected experts, select the only non-negative
# (unmasked) priority. Now, for group G routing to expert E, token T has
# non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E
# is its targeted expert.
# Shape: [num_groups, tokens_per_group, num_experts].
token_priority = torch.max(token_priority, dim=2)[0]
# Token T can only be routed to expert E if its priority is positive and
# less than the expert capacity. One-hot matrix will ignore indices outside
# the range [0, expert_capacity).
# Shape: [num_groups, tokens_per_group, num_experts, expert_capacity].
valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity)
token_priority = torch.masked_fill(token_priority, ~valid_mask, 0)
dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool)
valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, -1, expert_capacity)
dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0)
# The combine array will be used for combining expert outputs, scaled by the
# router probabilities. Shape: [num_groups, tokens_per_group, num_experts,
# expert_capacity].
combine_array = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask)
return combine_array, dispatch_mask
def get_router_cls(top_k: int, grouped: bool = False) -> MoeRouter:
if not grouped:
if top_k == 1:
return Top1Router
elif top_k == 2:
return Top2Router
else:
raise NotImplementedError("top_k > 2 is not supported yet")
else:
return TopKRouter