""" The file has been adapted from the following files: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555 We retain the following license from the original files: """ # Copyright (c) Microsoft Corporation. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple import torch import torch.distributed as dist import torch.nn.functional as F from torch import Tensor from torch.nn import Module from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer # global llm logger logger = get_logger(__file__) if TYPE_CHECKING: Base = Module[Tensor] else: Base = Module uniform_map: Dict[torch.device, Callable] = {} gumbel_map: Dict[torch.device, Callable] = {} exp_selection_uniform_map: Dict[torch.device, Callable] = {} def multiplicative_jitter(x, device: torch.device, epsilon=1e-2): """ Modified from switch transformer paper. mesh transformers Multiply values by a random number between 1-epsilon and 1+epsilon. Makes models more resilient to rounding errors introduced by bfloat16. This seems particularly important for logits. Args: x: a torch.tensor device: torch.device epsilon: a floating point value Returns: a jittered x. """ if epsilon == 0: return x uniform = uniform_map.get(device) if uniform is None: uniform = torch.distributions.uniform.Uniform( low=torch.tensor(1.0 - epsilon, device=device), high=torch.tensor(1.0 + epsilon, device=device) ).rsample # type: ignore uniform_map[device] = uniform return x * uniform(x.shape) def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor: gumbel = gumbel_map.get(device) if gumbel is None: one = torch.tensor(1.0, device=device) zero = torch.tensor(0.0, device=device) gumbel = torch.distributions.gumbel.Gumbel(zero, one).rsample # type: ignore gumbel_map[device] = gumbel return gumbel(shape) # einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity # See https://arxiv.org/pdf/2006.16668.pdf for details. # Based on https://github.com/pytorch/pytorch/pull/40762 class _AllToAll(torch.autograd.Function): """ All to all communication """ @staticmethod def forward( ctx: Any, # TODO: replace with DS process group group: torch.distributed.ProcessGroup, inputs: Tensor, ) -> Tensor: # type: ignore ctx.group = group inputs = inputs.contiguous() output = torch.empty_like(inputs) dist.all_to_all_single(output, inputs, group=group) return output @staticmethod def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]: return (None, _AllToAll.apply(ctx.group, *grad_output)) # einsum rewrites are on par or more performant # switch can be bubbled up in future USE_EINSUM = True # einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity # See https://arxiv.org/pdf/2006.16668.pdf for details. def einsum(rule, a, b): if USE_EINSUM: return torch.einsum(rule, a, b) elif rule == "s,se->se": # [1, s] * [s, e] return a.reshape(a.shape[0], -1) * b elif rule == "se,sc->sec": # [s,e,1] * [s,1,c] return a.unsqueeze(2) * b.unsqueeze(1) elif rule == "se,se->s": # [s,1,e] * [s,e,1] return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1) elif rule == "sec,sm->ecm": # [e*c, s] * [s, m] s = a.shape[0] e = a.shape[1] c = a.shape[2] m = b.shape[1] return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m) elif rule == "sec,ecm->sm": # [s, e*c] * [e*c, m] return torch.matmul(a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1])) elif rule == "ks,ksm->sm": k = b.shape[0] s = b.shape[1] m = b.shape[2] # [k, s] -> [s, k] -> [s, 1, k] a = a.t().unsqueeze(1) # [k,s,m] -> [k, sm] -> [sm, k] -> [s, m, k] b = b.reshape(k, -1).t().reshape(s, m, k) # bmm([s, 1, k], [s, m, k]^t) -> [s, m, 1] return torch.bmm(a, b.transpose(1, 2)).squeeze(2) else: return torch.einsum(rule, a, b) # The following functions are extracted and scripted # because otherwise during a torch.jit.trace, the non-Tensor # values used in the calculations get recorded as constants. # torch.jit.script coerces them into Tensors and preserves # their dynamic shapes. This enables ONNX export. # We can't script the entire top1gating function because it # includes stateful caching logic which is incompatible with ONNX. @torch.jit.script def _capacity(gates: Tensor, capacity_factor: Tensor, min_capacity: Tensor) -> Tensor: # gates has shape of SE num_tokens = gates.shape[0] num_experts = gates.shape[1] # to(torch.int64) works around a bug in torch.onnx.export: # it should cast k to int64 when converting torch.topk but it doesn't. capacity = torch.ceil((num_tokens / num_experts) * capacity_factor).to(torch.int64) if capacity < min_capacity: capacity = min_capacity.to(torch.int64) return capacity @torch.jit.script def _top_idx(source, k): return torch.topk(source, k=k, dim=0)[1] @torch.jit.script def _one_hot_to_float(x, num_classes): return F.one_hot(x, num_classes=num_classes).float() def top1gating( logits: Tensor, capacity_factor: float, min_capacity: int, used_token: Tensor = None, noisy_gate_policy: Optional[str] = None, drop_tokens: bool = True, use_rts: bool = True, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Implements Top1Gating on logits.""" if noisy_gate_policy == "RSample": logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device) # everything is in fp32 in this function gates = F.softmax(logits, dim=1) capacity = _capacity(gates, torch.tensor(capacity_factor), torch.tensor(min_capacity)) # Create a mask for 1st's expert per token # noisy gating indices1_s = torch.argmax(logits_w_noise if noisy_gate_policy == "RSample" else gates, dim=1) num_experts = int(gates.shape[1]) mask1 = F.one_hot(indices1_s, num_classes=num_experts) # mask only used tokens if used_token is not None: mask1 = einsum("s,se->se", used_token, mask1) # gating decisions exp_counts = torch.sum(mask1, dim=0).detach().to("cpu") # if we don't want to drop any tokens if not drop_tokens: new_capacity = torch.max(exp_counts).to(logits.device) dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group()) capacity = new_capacity # Compute l_aux me = torch.mean(gates, dim=0) ce = torch.mean(mask1.float(), dim=0) l_aux = torch.sum(me * ce) * num_experts # Random Token Selection if use_rts: uniform = exp_selection_uniform_map.get(logits.device) if uniform is None: uniform = torch.distributions.uniform.Uniform( low=torch.tensor(0.0, device=logits.device), high=torch.tensor(1.0, device=logits.device) ).rsample exp_selection_uniform_map[logits.device] = uniform mask1_rand = mask1 * uniform(mask1.shape) else: mask1_rand = mask1 assert logits.shape[0] >= min_capacity, ( "No. of tokens (batch-size) should be greater than min_capacity." "Either set min_capacity to 0 or increase your batch size." ) top_idx = _top_idx(mask1_rand, capacity) # token index new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1) mask1 = new_mask1 # Compute locations in capacity buffer locations1 = torch.cumsum(mask1, dim=0) - 1 # Store the capacity location for each token locations1_s = torch.sum(locations1 * mask1, dim=1) # Normalize gate probabilities mask1_float = mask1.float() gates = gates * mask1_float locations1_sc = _one_hot_to_float(locations1_s, capacity) combine_weights = einsum("se,sc->sec", gates, locations1_sc) dispatch_mask = combine_weights.bool() return l_aux, combine_weights, dispatch_mask, exp_counts def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Implements Top2Gating on logits.""" # everything is in fp32 in this function gates = F.softmax(logits, dim=1) capacity = _capacity(gates, torch.tensor(capacity_factor * 2), torch.tensor(min_capacity)) # Create a mask for 1st's expert per token indices1_s = torch.argmax(gates, dim=1) num_experts = int(gates.shape[1]) mask1 = F.one_hot(indices1_s, num_classes=num_experts) # Create a mask for 2nd's expert per token using Gumbel-max trick # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/ logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device) # Replace top-expert with min value logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf")) indices2_s = torch.argmax(logits_except1, dim=1) mask2 = F.one_hot(indices2_s, num_classes=num_experts) # Compute locations in capacity buffer locations1 = torch.cumsum(mask1, dim=0) - 1 locations2 = torch.cumsum(mask2, dim=0) - 1 # Update 2nd's location by accounting for locations of 1st locations2 += torch.sum(mask1, dim=0, keepdim=True) # gating decisions exp_counts = torch.sum(mask1, dim=0).detach().to("cpu") # Compute l_aux me = torch.mean(gates, dim=0) ce = torch.mean(mask1.float(), dim=0) l_aux = torch.mean(me * ce) * num_experts * num_experts # Remove locations outside capacity from mask mask1 *= torch.lt(locations1, capacity) mask2 *= torch.lt(locations2, capacity) # Store the capacity location for each token locations1_s = torch.sum(locations1 * mask1, dim=1) locations2_s = torch.sum(locations2 * mask2, dim=1) # Normalize gate probabilities mask1_float = mask1.float() mask2_float = mask2.float() gates1_s = einsum("se,se->s", gates, mask1_float) gates2_s = einsum("se,se->s", gates, mask2_float) denom_s = gates1_s + gates2_s # Avoid divide-by-zero denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps) gates1_s /= denom_s gates2_s /= denom_s # Calculate combine_weights and dispatch_mask gates1 = einsum("s,se->se", gates1_s, mask1_float) gates2 = einsum("s,se->se", gates2_s, mask2_float) locations1_sc = _one_hot_to_float(locations1_s, capacity) locations2_sc = _one_hot_to_float(locations2_s, capacity) combine1_sec = einsum("se,sc->sec", gates1, locations1_sc) combine2_sec = einsum("se,sc->sec", gates2, locations2_sc) combine_weights = combine1_sec + combine2_sec dispatch_mask = combine_weights.bool() return l_aux, combine_weights, dispatch_mask, exp_counts class TopKGate(Module): """Gate module which implements Top2Gating as described in Gshard_. :: gate = TopKGate(model_dim, num_experts) l_aux, combine_weights, dispatch_mask = gate(input) .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf Args: model_dim (int): size of model embedding dimension num_experts (ints): number of experts in model """ wg: torch.nn.Linear def __init__( self, model_dim: int, num_experts: int, k: int = 1, capacity_factor: float = 1.0, eval_capacity_factor: float = 1.0, min_capacity: int = 8, noisy_gate_policy: Optional[str] = None, drop_tokens: bool = True, use_rts: bool = True, ) -> None: super().__init__() # Only top-1 and top-2 are supported at the moment. if k not in (1, 2): raise ValueError("Only top-1 and top-2 gatings are supported.") # TODO: can we use tensor parallel here? # Deepspeed's mechisms, alway use fp32 self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float() self.k = k self.capacity_factor = capacity_factor self.eval_capacity_factor = eval_capacity_factor self.min_capacity = min_capacity self.noisy_gate_policy = noisy_gate_policy self.wall_clock_breakdown = False self.gate_time = 0.0 self.drop_tokens = drop_tokens self.use_rts = use_rts def forward( self, inputs: torch.Tensor, used_token: torch.Tensor = None ) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore if self.wall_clock_breakdown: timer("TopKGate").start() if self.wg.weight.dtype != torch.float32: self.wg = self.wg.float() inputs_fp32 = inputs.float() # input jittering if self.noisy_gate_policy == "Jitter" and self.training: inputs_fp32 = multiplicative_jitter(inputs_fp32, device=inputs.device) logits = self.wg(inputs_fp32) if self.k == 1: gate_output = top1gating( logits, self.capacity_factor if self.training else self.eval_capacity_factor, self.min_capacity, used_token, self.noisy_gate_policy if self.training else None, self.drop_tokens, self.use_rts, ) else: gate_output = top2gating( logits, self.capacity_factor if self.training else self.eval_capacity_factor, self.min_capacity ) if self.wall_clock_breakdown: timer("TopKGate").stop() self.gate_time = timer("TopKGate").elapsed(reset=False) return gate_output class MOELayer(Base): """MOELayer module which implements MixtureOfExperts as described in Gshard_. :: gate = TopKGate(model_dim, num_experts) moe = MOELayer(gate, expert) output = moe(inputs) l_aux = moe.l_aux .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf Args: gate (torch.nn.Module): gate network expert (torch.nn.Module): expert network """ def __init__(self, gate: Module, experts: Module, ep_group, ep_size, num_local_experts: int) -> None: super().__init__() self.gate = gate self.experts = experts self.ep_group = ep_group self.ep_size = ep_size self.num_local_experts = num_local_experts self.time_falltoall = 0.0 self.time_salltoall = 0.0 self.time_moe = 0.0 self.wall_clock_breakdown = False def _set_ep_group(self, ep_group): self.ep_group = ep_group def forward(self, *inputs: Tensor) -> Tensor: if self.wall_clock_breakdown: timer("moe").start() # Implement Algorithm 2 from GShard paper. d_model = inputs[0].shape[-1] # Initial implementation -> Reshape into S tokens by dropping sequence dimension. # Reshape into G groups so that each group can distribute tokens equally # group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1 reshaped_inputs = inputs[0].reshape(-1, d_model) self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_inputs, inputs[1]) dispatched_inputs = einsum( "sec,sm->ecm", dispatch_mask.type_as(inputs[0]), reshaped_inputs ) # TODO: heavy memory usage due to long sequence length if self.wall_clock_breakdown: timer("falltoall").start() dispatched_inputs = _AllToAll.apply(self.ep_group, dispatched_inputs) if self.wall_clock_breakdown: timer("falltoall").stop() self.time_falltoall = timer("falltoall").elapsed(reset=False) # Re-shape after all-to-all: ecm -> gecm dispatched_inputs = dispatched_inputs.reshape(self.ep_size, self.num_local_experts, -1, d_model) expert_output = self.experts(dispatched_inputs) if self.wall_clock_breakdown: timer("salltoall").start() expert_output = _AllToAll.apply(self.ep_group, expert_output) if self.wall_clock_breakdown: timer("salltoall").stop() self.time_salltoall = timer("salltoall").elapsed(reset=False) # Re-shape back: gecm -> ecm expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model) combined_output = einsum("sec,ecm->sm", combine_weights.type_as(inputs[0]), expert_output) a = combined_output.reshape(inputs[0].shape) if self.wall_clock_breakdown: timer("moe").stop() self.time_moe = timer("moe").elapsed(reset=False) return a