import math from abc import ABC from typing import Callable, Optional, Tuple import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F from torch.distributed import ProcessGroup from colossalai.accelerator import get_accelerator from colossalai.moe._operation import moe_cumsum from colossalai.moe.manager import MOE_MANAGER class MoeRouter(nn.Module, ABC): """Base class for all MoE routers. Args: k_value (int): The value of top_k. capacity_factor_train (float): Capacity factor in routing of training. capacity_factor_eval (float): Capacity factor in routing of evaluation. min_capacity (int): The minimum number of the capacity of each expert. noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. drop_tks (bool, optional): Whether drops tokens in evaluation """ def __init__( self, k_value: int, capacity_factor_train: float, capacity_factor_eval: float, min_capacity: int, noisy_func: Optional[Callable] = None, drop_tks: bool = True, use_kernel: bool = False, ): super().__init__() self.k_value = k_value self.capacity_factor_train = capacity_factor_train self.capacity_factor_eval = capacity_factor_eval self.min_capacity = min_capacity self.noisy_func = noisy_func self.drop_tks = drop_tks self._aux_loss = None self._z_loss = None self.use_kernel = use_kernel def get_capacity(self, num_tokens, num_experts, ep_group=None): if ep_group is not None: num_tokens_tensor = torch.tensor(num_tokens, device=get_accelerator().get_current_device()) dist.all_reduce(num_tokens_tensor, group=ep_group) num_tokens = num_tokens_tensor.item() // dist.get_world_size(ep_group) capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval capacity = math.floor(self.k_value * capacity_factor * num_tokens / num_experts) capacity += capacity % 2 capacity = max(capacity, self.min_capacity) assert capacity > 0 return int(capacity) def set_aux_loss(self, router_probs: torch.Tensor, expert_indices: torch.Tensor, num_experts: int) -> None: """Computes auxiliary load balancing loss as in Switch Transformer. See Switch Transformer (https://arxiv.org/abs/2101.03961). This function implements the loss function presented in equations (4) - (6). It aims to penalize those cases where the routing between experts is unbalanced. Args: router_probs: Probability assigned to each expert per token. Shape: [num_groups, tokens_per_group, num_experts]. expert_indices: [num_groups, tokens_per_group, num_selected_experts] indices identifying the top num_selected_experts for a given token. """ assert self._aux_loss is None if router_probs.dim() == expert_indices.dim() == 2: router_probs = router_probs.unsqueeze(0) expert_indices = expert_indices.unsqueeze(0) assert ( router_probs.dim() == expert_indices.dim() == 3 ), "router_probs must be 3D tensor and expert_indices must be 4D tensor" # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. expert_mask = F.one_hot(expert_indices, num_experts) # For a given token, determine if it was routed to a given expert. # Shape: [num_groups, tokens_per_group, num_experts] expert_mask = expert_mask.max(dim=-2)[0] tokens_per_group_and_expert = torch.mean(expert_mask.float(), dim=-2) router_prob_per_group_and_expert = torch.mean(router_probs.float(), dim=-2) aux_loss = num_experts**2 * torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) self._aux_loss = aux_loss def set_z_loss(self, router_logits: torch.Tensor): """Compute router z-loss. The router z-loss was introduced in Designing Effective Sparse Expert Models (https://arxiv.org/abs/2202.08906). It encourages router logits to remain small in an effort to improve stability. Args: router_logits: [num_groups, tokens_per_group, num_experts] router logits. """ assert self._z_loss is None if router_logits.dim() == 2: router_logits = router_logits.unsqueeze(0) assert router_logits.dim() == 3, "router_logits must be 3D tensor" num_groups, tokens_per_group, _ = router_logits.shape log_z = torch.logsumexp(router_logits, dim=-1) z_loss = torch.sum(log_z**2, dtype=torch.float32) / (num_groups * tokens_per_group) self._z_loss = z_loss def pop_router_loss(self) -> torch.Tensor: assert self._aux_loss is not None MOE_MANAGER.add_loss(self._aux_loss, self._z_loss) self._aux_loss = None self._z_loss = None class Top1Router(MoeRouter): """Top1 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity) and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed function can be found in the paper about Switch Transformer of Google. Args: capacity_factor_train (float, optional): Capacity factor in routing of training. capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. min_capacity (int, optional): The minimum number of the capacity of each expert. select_policy (str, optional): The policy about tokens selection. noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. drop_tks (bool, optional): Whether drops tokens in evaluation """ def __init__( self, capacity_factor_train: float = 1.25, capacity_factor_eval: float = 2.0, min_capacity: int = 4, select_policy: str = "first", noisy_func: Optional[Callable] = None, drop_tks: bool = True, ): super().__init__( k_value=1, capacity_factor_train=capacity_factor_train, capacity_factor_eval=capacity_factor_eval, min_capacity=min_capacity, noisy_func=noisy_func, drop_tks=drop_tks, ) self.select_policy = select_policy assert select_policy in {"first", "random"} if select_policy == "random": self.uniform = torch.distributions.uniform.Uniform( low=torch.tensor(0.0, device=get_accelerator().get_current_device()), high=torch.tensor(1.0, device=get_accelerator().get_current_device()), ).rsample def forward( self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None, use_loss: bool = False, use_norm: bool = False, ) -> Tuple: """ Args: inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). Returns: 1. use_kernel is False: The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity). The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity). 2. use_kernel is True: ... """ if self.noisy_func is not None and self.training: inputs = self.noisy_func(inputs) assert inputs.dtype == torch.float probs = F.softmax(inputs, dim=-1) num_experts = probs.size(-1) num_tokens = inputs.size(0) capacity = self.get_capacity(num_tokens, num_experts, ep_group) top1_idx = torch.argmax(inputs, dim=-1) mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) # calculate router loss self.set_aux_loss(probs, top1_idx.unsqueeze(-1), num_experts) self.set_z_loss(inputs) self.pop_router_loss() if not self.training and not self.drop_tks and ep_group is not None: max_num = torch.max(torch.sum(mask, dim=0)) dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) capacity = max_num.item() if self.select_policy == "random": rand_mask = mask * self.uniform(mask.shape) _, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0) mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1) ranks = moe_cumsum(mask, use_kernel=self.use_kernel) elif self.select_policy == "first": ranks = moe_cumsum(mask, use_kernel=self.use_kernel) mask = mask * torch.lt(ranks, capacity) else: raise NotImplementedError("Not support such select policy yet.") ranks = torch.sum(mask * ranks, dim=-1) used_capacity = mask.sum(dim=0) if use_kernel: mask = torch.sum(mask, dim=-1) mask = torch.stack([mask], dim=0).to(torch.int32) dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32) return used_capacity, probs, mask, dest_idx, num_experts * capacity else: ranks = F.one_hot(ranks, num_classes=capacity) weight = mask * probs.type_as(inputs) combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1) sec_mask = combine_weights.bool() return used_capacity, combine_weights, sec_mask, probs class Top2Router(MoeRouter): """Top2 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity) and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed function can be found in the paper about ViT-MoE. Args: capacity_factor_train (float, optional): Capacity factor in routing of training. capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. min_capacity (int, optional): The minimum number of the capacity of each expert noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. drop_tks (bool, optional): Whether drops tokens in evaluation. """ def __init__( self, capacity_factor_train: float = 1.25, capacity_factor_eval: float = 2.0, min_capacity: int = 4, noisy_func: Optional[Callable] = None, drop_tks: bool = True, ): super().__init__( k_value=2, capacity_factor_train=capacity_factor_train, capacity_factor_eval=capacity_factor_eval, min_capacity=min_capacity, noisy_func=noisy_func, drop_tks=drop_tks, ) def forward( self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None, use_norm: bool = False, use_loss: bool = True, ) -> Tuple: """ Args: inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts). Returns: 1. use_kernel is False: The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity). The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity). 2. use_kernel is True: ... """ if self.noisy_func is not None and self.training: inputs = self.noisy_func(inputs) assert inputs.dtype == torch.float probs = F.softmax(inputs, dim=-1) if use_norm: routing_weights, _ = torch.topk(probs, 2, dim=-1) probs = probs / routing_weights.sum(dim=-1, keepdim=True) num_experts = probs.size(-1) num_tokens = inputs.size(0) capacity = self.get_capacity(num_tokens, num_experts, ep_group) top1_idx = torch.argmax(probs, dim=-1) mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) logits_except1 = probs.masked_fill(mask1.bool(), float("-inf")) top2_idx = torch.argmax(logits_except1, dim=-1) mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32) cmask = mask1 + mask2 # loss: [s, e] cmask = cmask.float() / 2.0 # div 2 to normalize it to 1 # calculate loss if use_loss: expert_indices = torch.stack([top1_idx, top2_idx], dim=-1) self.set_aux_loss(probs, expert_indices, num_experts) self.set_z_loss(inputs) self.pop_router_loss() if not self.training and not self.drop_tks and ep_group is not None: max_num = torch.max(torch.sum(cmask, dim=0)) dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) capacity = max_num.item() rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e] rank2 = moe_cumsum(mask2, use_kernel=self.use_kernel) rank2 += torch.sum(mask1, dim=-2, keepdim=True) mask1 *= torch.lt(rank1, capacity) mask2 *= torch.lt(rank2, capacity) used_capacity = mask1.sum(dim=0) + mask2.sum(dim=0) rank1 = torch.sum(mask1 * rank1, dim=-1) rank2 = torch.sum(mask2 * rank2, dim=-1) if use_kernel: mask1 = torch.sum(mask1, dim=-1) mask2 = torch.sum(mask2, dim=-1) mask = torch.stack([mask1, mask2], dim=0).to(torch.int32) dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32) return used_capacity, probs, mask, dest_idx, num_experts * capacity else: """ The following code is equivalent to: ``` weight1 = mask1 * probs.type_as(inputs) weight2 = mask2 * probs.type_as(inputs) rank1_sc = F.one_hot(rank1, num_classes=capacity) rank2_sc = F.one_hot(rank2, num_classes=capacity) cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1) cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1) cb_weight = cb_weight1 + cb_weight2 sec_mask = cb_weight.bool() ``` """ weight1 = mask1 * probs.type_as(inputs) weight2 = mask2 * probs.type_as(inputs) cb_weight = torch.zeros(inputs.shape + (capacity,), device=inputs.device) sec_mask = torch.zeros_like(cb_weight, dtype=torch.bool) indices = torch.arange(0, inputs.shape[0], device=inputs.device) cb_weight[indices, top1_idx[indices], rank1[indices]] += weight1[indices, top1_idx[indices]] cb_weight[indices, top2_idx[indices], rank2[indices]] += weight2[indices, top2_idx[indices]] sec_mask[indices, top1_idx[indices], rank1[indices]] |= mask1.bool()[indices, top1_idx[indices]] sec_mask[indices, top2_idx[indices], rank2[indices]] |= mask2.bool()[indices, top2_idx[indices]] return used_capacity, cb_weight, sec_mask class TopKRouter(MoeRouter): """Masked matmul router using tokens choose top-k experts assignment. NOTE: this is modified from flaxformer. This router uses the same mechanism as in Switch Transformer (https://arxiv.org/abs/2101.03961) and V-MoE (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are sorted by router_probs and then routed to their choice of expert until the expert's expert_capacity is reached. There is no guarantee that each token is processed by an expert, or that each expert receives at least one token. Attributes: num_selected_experts: Maximum number of experts to which each token is routed. Tokens may be routed to fewer experts if particular experts are oversubscribed / reach capacity. """ def __init__( self, num_selected_experts: int, capacity_factor_train: float = 1.25, capacity_factor_eval: float = 2.0, min_capacity: int = 4, noisy_func: Optional[Callable] = None, drop_tks: bool = True, ): super().__init__( num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, drop_tks ) def forward( self, router_probs: torch.Tensor, expert_capacity: int, ) -> Tuple: """Computes masks for the top-k experts per token. Args: router_probs: [num_groups, tokens_per_group, num_experts] probabilities used to determine the routing of tokens to the experts. Returns: Dispatch and combine arrays for routing with masked matmuls. """ # TODO: FIXME: add parallel group num_groups, _, num_experts = router_probs.shape # Top-k router probability and corresponding expert indices for each token. # Shape: [num_groups, tokens_per_group, num_selected_experts]. expert_gate, expert_index = torch.topk(router_probs, self.k_value) self.set_aux_loss(router_probs, expert_index, num_experts) self.pop_router_loss() # Make num_selected_experts the leading axis to ensure that top-1 choices # have priority over top-2 choices, which have priority over top-3 choices, # etc. expert_index = torch.transpose(expert_index, 1, 2) # Shape: [num_groups, num_selected_experts * tokens_per_group] expert_index = expert_index.reshape(num_groups, -1) # Create mask out of indices. # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32) # Experts have a fixed capacity that we cannot exceed. A token's priority # within the expert's buffer is given by the masked, cumulative capacity of # its target expert. # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts]. token_priority = torch.cumsum(expert_mask, dim=1) * expert_mask - 1 # Shape: [num_groups, num_selected_experts, tokens_per_group, num_experts]. token_priority = token_priority.reshape((num_groups, self.k_value, -1, num_experts)) # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. token_priority = torch.transpose(token_priority, 1, 2) # For each token, across all selected experts, select the only non-negative # (unmasked) priority. Now, for group G routing to expert E, token T has # non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E # is its targeted expert. # Shape: [num_groups, tokens_per_group, num_experts]. token_priority = torch.max(token_priority, dim=2)[0] # Token T can only be routed to expert E if its priority is positive and # less than the expert capacity. One-hot matrix will ignore indices outside # the range [0, expert_capacity). # Shape: [num_groups, tokens_per_group, num_experts, expert_capacity]. valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity) token_priority = torch.masked_fill(token_priority, ~valid_mask, 0) dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool) valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, -1, expert_capacity) dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0) # The combine array will be used for combining expert outputs, scaled by the # router probabilities. Shape: [num_groups, tokens_per_group, num_experts, # expert_capacity]. combine_array = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask) return combine_array, dispatch_mask def get_router_cls(top_k: int, grouped: bool = False) -> MoeRouter: if not grouped: if top_k == 1: return Top1Router elif top_k == 2: return Top2Router else: raise NotImplementedError("top_k > 2 is not supported yet") else: return TopKRouter