import dataclasses import math from typing import Any, Optional, Tuple import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F from colossalai.moe._operation import AllGather, AllToAll, HierarchicalAllToAll, MoeCombine, MoeDispatch, ReduceScatter from colossalai.moe.experts import MLPExperts from colossalai.moe.load_balance import LoadBalancer from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.routers import MoeRouter, get_router_cls from colossalai.moe.utils import create_ep_hierarchical_group, get_noise_generator from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_group_ranks, get_ep_size class SparseMLP(nn.Module): """A class for users to create MoE modules in their models. Args: dim_model (int): Hidden dimension of training model num_experts (int): The number experts top_k (int, optional): The number of experts for dispatchment of each token capacity_factor_train (float, optional): Capacity factor in routing during training capacity_factor_eval (float, optional): Capacity factor in routing during evaluation min_capacity (int, optional): The minimum number of the capacity of each expert noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'. 'Jitter' can be found in `Switch Transformer paper`_. 'Gaussian' can be found in `ViT-MoE paper`_. drop_tks (bool, optional): Whether drops tokens in evaluation use_residual (bool, optional): Makes this MoE layer a Residual MoE. More information can be found in `Microsoft paper`_. residual_instance (nn.Module, optional): The instance of residual module in Residual MoE expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given expert_args (optional): The args of expert when no instance is given .. _Switch Transformer paper: https://arxiv.org/abs/2101.03961 .. _ViT-MoE paper: https://arxiv.org/abs/2106.05974 .. _Microsoft paper: https://arxiv.org/abs/2201.05596 """ def __init__( self, num_experts: int, hidden_size: int, intermediate_size: int, router_top_k: int = 1, router_loss: bool = True, router_norm: bool = False, router_capacity_factor_train: float = 1.25, router_capacity_factor_eval: float = 2.0, router_min_capacity: int = 4, router_noisy_policy: Optional[str] = None, router_drop_tks: bool = True, mlp_activation: Optional[str] = None, mlp_gated: bool = False, enable_load_balance: bool = False, load_balance_tolerance: float = 0.1, load_balance_beam_width: int = 8, load_balance_group_swap_factor: float = 0.4, enable_kernel: bool = False, enable_comm_overlap: bool = False, enable_hierarchical_comm: bool = False, return_gate_logits: bool = False, ): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_experts = num_experts self.gated = mlp_gated self.return_gate_logits = return_gate_logits self.enable_kernel = enable_kernel self.enable_comm_overlap = enable_comm_overlap self.expert_parallel = MOE_MANAGER.get_parallel() self.router_loss = router_loss self.router_norm = router_norm # moe router noisy_func = get_noise_generator(router_noisy_policy, num_experts) router_cls = get_router_cls(router_top_k) self.topk = router_top_k self.router: MoeRouter = router_cls( capacity_factor_train=router_capacity_factor_train, capacity_factor_eval=router_capacity_factor_eval, min_capacity=router_min_capacity, noisy_func=noisy_func, drop_tks=router_drop_tks, ) # gate self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, self.hidden_size)) # moe experts self.experts = MLPExperts( num_experts=self.num_experts, expert_parallel=self.expert_parallel, hidden_size=self.hidden_size, intermediate_size=self.intermediate_size, activation=mlp_activation, gated=mlp_gated, use_kernel=self.enable_kernel, ) # get parallel settings if self.expert_parallel is not None: self.ep_group = get_ep_group(self.experts) self.ep_size = get_ep_size(self.experts) self.ep_hierarchical_group = None if enable_hierarchical_comm: self.ep_intra_src_rank, *self.ep_hierarchical_group = create_ep_hierarchical_group( get_ep_group_ranks(self.experts) ) self.dp_group = get_dp_group(self.experts) else: self.ep_group = None self.dp_group = None self.num_local_experts = self.experts.num_local_experts # load balance self.enable_load_balance = enable_load_balance if self.enable_load_balance == True: self.load_balancer = LoadBalancer( experts=self.experts, gate=self.gate_weight, local_expert_num=self.num_local_experts, expert_num=self.num_experts, ep_group=self.ep_group, dp_group=self.dp_group, tolerance=load_balance_tolerance, beam_width=load_balance_beam_width, group_swap_factor=load_balance_group_swap_factor, ) # init param self.reset_parameters() @torch.no_grad() def reset_parameters(self): torch.nn.init.normal_(self.gate_weight, std=math.sqrt(0.1 / self.hidden_size)) def forward(self, inputs: torch.Tensor) -> torch.Tensor: """ Args: inputs (torch.Tensor): The input tensor of shape (batch_size, seq_len, hidden_size) Returns: torch.Tensor: The output tensor of shape (batch_size, seq_len, hidden_size) """ # reshape the input tokens tokens = inputs.reshape(-1, self.hidden_size) # the data type of the inputs in the gating should be fp32 gate_logits = F.linear(tokens, self.gate_weight) gate_output = gate_logits.to(torch.float) # update expert load if self.enable_load_balance == True: with torch.no_grad(): # TODO: optimize computation expert_load = torch.topk(gate_output, k=self.topk, dim=-1)[1] # TODO: bincount introduces synchronize, fix it expert_load = torch.bincount(expert_load.view(-1)) self.load_balancer.update_load(expert_load) # the result from the router used_capacity, *route_result_list = self.router( inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group, use_loss=self.router_loss, use_norm=self.router_norm, ) # dispatch_data: (num_experts, capacity, hidden_size) if self.enable_kernel: dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:]) dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.hidden_size) else: sec_mask_f = route_result_list[1].type_as(inputs) dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens) # expert_output: (num_groups, num_experts, capacity, hidden_size) if self.expert_parallel == "EP": expert_output = self._ep_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap) elif self.expert_parallel == "TP": expert_output = self._tp_process(dispatch_data, used_capacity, overlap=self.enable_comm_overlap) elif self.expert_parallel is None: expert_output = self._local_process(dispatch_data) else: raise NotImplementedError( "This kind of communication has not been implemented yet.\n" "Please use Experts build function." ) if self.enable_kernel: expert_output = expert_output.reshape(-1, self.hidden_size) ans = MoeCombine.apply(expert_output, *route_result_list) else: combine_weights = route_result_list[0].type_as(inputs) combine_weights = combine_weights.view(combine_weights.shape[0], -1) expert_output = expert_output.view(-1, expert_output.shape[-1]) ans = torch.matmul(combine_weights, expert_output) ans = ans.reshape(inputs.shape) if self.return_gate_logits: return ans, gate_logits else: return ans def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor: expert_in = expert_in.unsqueeze(0) expert_out = self.experts(expert_in) return expert_out def _ep_process( self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False ) -> torch.Tensor: """ Expert Parallel Args: dispatch_data (torch.Tensor): (num_experts, capacity, hidden_size) Returns: torch.Tensor: (num_experts, capacity, hidden_size) """ if not overlap or dist.get_world_size(self.ep_group) == 1: if self.ep_hierarchical_group is not None: expert_input = HierarchicalAllToAll.apply( dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank ) expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) expert_output = self.experts(expert_input) expert_output = HierarchicalAllToAll.apply( expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank ) return expert_output else: expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0] expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) expert_output = self.experts(expert_input) expert_output = AllToAll.apply(expert_output, self.ep_group, False)[0] return expert_output else: @dataclasses.dataclass class Capsule: data: torch.Tensor handle: Any = None NUM_CHUNK = 4 NUM_STAGES = 4 assert dispatch_data.shape[1] % NUM_CHUNK == 0, "arbitrary chunk num is not supported yet" chunk_size = dispatch_data.shape[1] // NUM_CHUNK input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size) dispatch_data = dispatch_data.reshape(*input_shape) chunk_data = torch.split(dispatch_data, chunk_size, dim=2) output = torch.empty_like(dispatch_data) offset = 0 _expert_in, expert_in, _expert_out, expert_out = None, None, None, None for i in range(NUM_CHUNK + NUM_STAGES - 1): if expert_out is not None: expert_out.handle.wait() output[:, :, offset : offset + chunk_size, :] = expert_out.data offset += chunk_size expert_out = None # all2all last output if _expert_out is not None: expert_out = Capsule( *AllToAll.apply(_expert_out.data, self.ep_group, True), ) _expert_out = None # all2all next input if 0 <= i < NUM_CHUNK: _expert_in = Capsule(*AllToAll.apply(chunk_data[i].contiguous(), self.ep_group, True)) # compute if expert_in is not None: expert_in.handle.wait() _expert_out = Capsule(data=self.experts(expert_in.data), handle=None) expert_in = None if _expert_in is not None: expert_in = _expert_in _expert_in = None return output def _tp_process( self, dispatch_data: torch.Tensor, used_capacity: torch.Tensor, overlap: bool = False ) -> torch.Tensor: """ without overlap: | C | | A | | R | with overlap: | C1 || C2 || C3 || C4 | | A1 || A2 | | R1 | A3 || R2 | A4 || R3 | | R4 | where C is computation, A is all gather, R is reduce scatter. Args: dispatch_data (torch.Tensor): (num_experts, capacity, hidden_size) Returns: torch.Tensor: (num_experts, capacity, hidden_size) """ if not overlap or dist.get_world_size(self.ep_group) == 1: expert_in = AllGather.apply(dispatch_data, self.ep_group, False)[0] expert_out = self.experts(expert_in) expert_out = ReduceScatter.apply(expert_out, self.ep_group, False)[0] return expert_out else: @dataclasses.dataclass class Capsule: data: torch.Tensor handle: Any indices: Tuple NUM_CHUNK = 4 NUM_STAGES = 4 assert ( dispatch_data.shape[0] % NUM_CHUNK == 0 ), "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts" chunk_size = dispatch_data.shape[0] // NUM_CHUNK chunk_data = torch.split(dispatch_data, chunk_size, dim=0) output = torch.empty_like(dispatch_data) def get_chunk_slice(idx: int, chunk_size: int) -> Tuple[slice]: return (slice(idx * chunk_size, (idx + 1) * chunk_size),) _expert_in, expert_in, _expert_out, expert_out = None, None, None, None for i in range(NUM_CHUNK + NUM_STAGES - 1): if expert_out is not None: expert_out.handle.wait() output[expert_out.indices] = expert_out.data expert_out = None # reduce scatter last output if _expert_out is not None: expert_out = Capsule( *ReduceScatter.apply(_expert_out.data, self.ep_group, True), indices=_expert_out.indices, ) _expert_out = None # all gather next input if 0 <= i < NUM_CHUNK: _expert_in = Capsule( *AllGather.apply(chunk_data[i].contiguous(), self.ep_group, True), indices=get_chunk_slice(i, chunk_size), ) # compute if expert_in is not None: expert_in.handle.wait() _expert_out = Capsule( self.experts(expert_in.data, expert_in.indices), handle=None, indices=expert_in.indices, ) expert_in = None if _expert_in is not None: expert_in = _expert_in _expert_in = None return output def apply_load_balance(model: nn.Module, optim: Any) -> None: """ apply load balance to every experts in the model """ def _apply_recursive(module: nn.Module): for _, sub_module in module.named_children(): if isinstance(sub_module, SparseMLP): if sub_module.enable_load_balance == True: sub_module.load_balancer.balance_load(optim) _apply_recursive(sub_module) torch.cuda.empty_cache() _apply_recursive(model) torch.cuda.empty_cache()