mirror of https://github.com/InternLM/InternLM
				
				
				
			
		
			
				
	
	
		
			494 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			494 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
"""
 | 
						|
The file has been adapted from the following files:
 | 
						|
https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py
 | 
						|
 Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555
 | 
						|
 We retain the following license from the original files:
 | 
						|
"""
 | 
						|
 | 
						|
# Copyright (c) Microsoft Corporation.
 | 
						|
# SPDX-License-Identifier: Apache-2.0
 | 
						|
 | 
						|
# DeepSpeed Team
 | 
						|
 | 
						|
 | 
						|
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple
 | 
						|
 | 
						|
import torch
 | 
						|
import torch.distributed as dist
 | 
						|
import torch.nn.functional as F
 | 
						|
from torch import Tensor
 | 
						|
from torch.nn import Module
 | 
						|
 | 
						|
from internlm.utils.logger import get_logger
 | 
						|
from internlm.utils.megatron_timers import megatron_timer as timer
 | 
						|
 | 
						|
# global llm logger
 | 
						|
logger = get_logger(__file__)
 | 
						|
 | 
						|
if TYPE_CHECKING:
 | 
						|
    Base = Module[Tensor]
 | 
						|
else:
 | 
						|
    Base = Module
 | 
						|
 | 
						|
uniform_map: Dict[torch.device, Callable] = {}
 | 
						|
gumbel_map: Dict[torch.device, Callable] = {}
 | 
						|
exp_selection_uniform_map: Dict[torch.device, Callable] = {}
 | 
						|
 | 
						|
 | 
						|
def multiplicative_jitter(x, device: torch.device, epsilon=1e-2):
 | 
						|
    """
 | 
						|
    Modified from switch transformer paper. mesh transformers
 | 
						|
    Multiply values by a random number between 1-epsilon and 1+epsilon.
 | 
						|
    Makes models more resilient to rounding errors introduced by bfloat16.
 | 
						|
    This seems particularly important for logits.
 | 
						|
    Args:
 | 
						|
        x: a torch.tensor
 | 
						|
        device: torch.device
 | 
						|
        epsilon: a floating point value
 | 
						|
    Returns:
 | 
						|
        a jittered x.
 | 
						|
    """
 | 
						|
    if epsilon == 0:
 | 
						|
        return x
 | 
						|
    uniform = uniform_map.get(device)
 | 
						|
    if uniform is None:
 | 
						|
        uniform = torch.distributions.uniform.Uniform(
 | 
						|
            low=torch.tensor(1.0 - epsilon, device=device), high=torch.tensor(1.0 + epsilon, device=device)
 | 
						|
        ).rsample  # type: ignore
 | 
						|
        uniform_map[device] = uniform
 | 
						|
    return x * uniform(x.shape)
 | 
						|
 | 
						|
 | 
						|
def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
 | 
						|
    gumbel = gumbel_map.get(device)
 | 
						|
    if gumbel is None:
 | 
						|
        one = torch.tensor(1.0, device=device)
 | 
						|
        zero = torch.tensor(0.0, device=device)
 | 
						|
        gumbel = torch.distributions.gumbel.Gumbel(zero, one).rsample  # type: ignore
 | 
						|
        gumbel_map[device] = gumbel
 | 
						|
    return gumbel(shape)
 | 
						|
 | 
						|
 | 
						|
# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
 | 
						|
# See https://arxiv.org/pdf/2006.16668.pdf for details.
 | 
						|
 | 
						|
 | 
						|
# Based on https://github.com/pytorch/pytorch/pull/40762
 | 
						|
class _AllToAll(torch.autograd.Function):
 | 
						|
    """
 | 
						|
    All to all communication
 | 
						|
    """
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def forward(
 | 
						|
        ctx: Any,
 | 
						|
        # TODO: replace with DS process group
 | 
						|
        group: torch.distributed.ProcessGroup,
 | 
						|
        inputs: Tensor,
 | 
						|
    ) -> Tensor:  # type: ignore
 | 
						|
        ctx.group = group
 | 
						|
        inputs = inputs.contiguous()
 | 
						|
        output = torch.empty_like(inputs)
 | 
						|
        dist.all_to_all_single(output, inputs, group=group)
 | 
						|
        return output
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
 | 
						|
        return (None, _AllToAll.apply(ctx.group, *grad_output))
 | 
						|
 | 
						|
 | 
						|
# einsum rewrites are on par or more performant
 | 
						|
# switch can be bubbled up in future
 | 
						|
USE_EINSUM = True
 | 
						|
 | 
						|
 | 
						|
# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
 | 
						|
# See https://arxiv.org/pdf/2006.16668.pdf for details.
 | 
						|
def einsum(rule, a, b):
 | 
						|
    if USE_EINSUM:
 | 
						|
        return torch.einsum(rule, a, b)
 | 
						|
    elif rule == "s,se->se":
 | 
						|
        # [1, s] * [s, e]
 | 
						|
        return a.reshape(a.shape[0], -1) * b
 | 
						|
    elif rule == "se,sc->sec":
 | 
						|
        # [s,e,1] * [s,1,c]
 | 
						|
        return a.unsqueeze(2) * b.unsqueeze(1)
 | 
						|
    elif rule == "se,se->s":
 | 
						|
        # [s,1,e] * [s,e,1]
 | 
						|
        return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1)
 | 
						|
    elif rule == "sec,sm->ecm":
 | 
						|
        # [e*c, s] * [s, m]
 | 
						|
        s = a.shape[0]
 | 
						|
        e = a.shape[1]
 | 
						|
        c = a.shape[2]
 | 
						|
        m = b.shape[1]
 | 
						|
        return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m)
 | 
						|
    elif rule == "sec,ecm->sm":
 | 
						|
        # [s, e*c] * [e*c, m]
 | 
						|
        return torch.matmul(a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1]))
 | 
						|
    elif rule == "ks,ksm->sm":
 | 
						|
        k = b.shape[0]
 | 
						|
        s = b.shape[1]
 | 
						|
        m = b.shape[2]
 | 
						|
        # [k, s] -> [s, k] -> [s, 1, k]
 | 
						|
        a = a.t().unsqueeze(1)
 | 
						|
        # [k,s,m] -> [k, sm] -> [sm, k] -> [s, m, k]
 | 
						|
        b = b.reshape(k, -1).t().reshape(s, m, k)
 | 
						|
        # bmm([s, 1, k], [s, m, k]^t) -> [s, m, 1]
 | 
						|
        return torch.bmm(a, b.transpose(1, 2)).squeeze(2)
 | 
						|
    else:
 | 
						|
        return torch.einsum(rule, a, b)
 | 
						|
 | 
						|
 | 
						|
# The following functions are extracted and scripted
 | 
						|
# because otherwise during a torch.jit.trace, the non-Tensor
 | 
						|
# values used in the calculations get recorded as constants.
 | 
						|
# torch.jit.script coerces them into Tensors and preserves
 | 
						|
# their dynamic shapes. This enables ONNX export.
 | 
						|
# We can't script the entire top1gating function because it
 | 
						|
# includes stateful caching logic which is incompatible with ONNX.
 | 
						|
 | 
						|
 | 
						|
@torch.jit.script
 | 
						|
def _capacity(gates: Tensor, capacity_factor: Tensor, min_capacity: Tensor) -> Tensor:
 | 
						|
    # gates has shape of SE
 | 
						|
    num_tokens = gates.shape[0]
 | 
						|
    num_experts = gates.shape[1]
 | 
						|
    # to(torch.int64) works around a bug in torch.onnx.export:
 | 
						|
    # it should cast k to int64 when converting torch.topk but it doesn't.
 | 
						|
    capacity = torch.ceil((num_tokens / num_experts) * capacity_factor).to(torch.int64)
 | 
						|
    if capacity < min_capacity:
 | 
						|
        capacity = min_capacity.to(torch.int64)
 | 
						|
    return capacity
 | 
						|
 | 
						|
 | 
						|
@torch.jit.script
 | 
						|
def _top_idx(source, k):
 | 
						|
    return torch.topk(source, k=k, dim=0)[1]
 | 
						|
 | 
						|
 | 
						|
@torch.jit.script
 | 
						|
def _one_hot_to_float(x, num_classes):
 | 
						|
    return F.one_hot(x, num_classes=num_classes).float()
 | 
						|
 | 
						|
 | 
						|
def top1gating(
 | 
						|
    logits: Tensor,
 | 
						|
    capacity_factor: float,
 | 
						|
    min_capacity: int,
 | 
						|
    used_token: Tensor = None,
 | 
						|
    noisy_gate_policy: Optional[str] = None,
 | 
						|
    drop_tokens: bool = True,
 | 
						|
    use_rts: bool = True,
 | 
						|
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
 | 
						|
    """Implements Top1Gating on logits."""
 | 
						|
    if noisy_gate_policy == "RSample":
 | 
						|
        logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
 | 
						|
    # everything is in fp32 in this function
 | 
						|
    gates = F.softmax(logits, dim=1)
 | 
						|
 | 
						|
    capacity = _capacity(gates, torch.tensor(capacity_factor), torch.tensor(min_capacity))
 | 
						|
 | 
						|
    # Create a mask for 1st's expert per token
 | 
						|
    # noisy gating
 | 
						|
    indices1_s = torch.argmax(logits_w_noise if noisy_gate_policy == "RSample" else gates, dim=1)
 | 
						|
    num_experts = int(gates.shape[1])
 | 
						|
    mask1 = F.one_hot(indices1_s, num_classes=num_experts)
 | 
						|
 | 
						|
    # mask only used tokens
 | 
						|
    if used_token is not None:
 | 
						|
        mask1 = einsum("s,se->se", used_token, mask1)
 | 
						|
 | 
						|
    # gating decisions
 | 
						|
    exp_counts = torch.sum(mask1, dim=0).detach().to("cpu")
 | 
						|
 | 
						|
    # if we don't want to drop any tokens
 | 
						|
    if not drop_tokens:
 | 
						|
        new_capacity = torch.max(exp_counts).to(logits.device)
 | 
						|
        dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group())
 | 
						|
        capacity = new_capacity
 | 
						|
 | 
						|
    # Compute l_aux
 | 
						|
    me = torch.mean(gates, dim=0)
 | 
						|
    ce = torch.mean(mask1.float(), dim=0)
 | 
						|
    l_aux = torch.sum(me * ce) * num_experts
 | 
						|
 | 
						|
    # Random Token Selection
 | 
						|
    if use_rts:
 | 
						|
        uniform = exp_selection_uniform_map.get(logits.device)
 | 
						|
        if uniform is None:
 | 
						|
            uniform = torch.distributions.uniform.Uniform(
 | 
						|
                low=torch.tensor(0.0, device=logits.device), high=torch.tensor(1.0, device=logits.device)
 | 
						|
            ).rsample
 | 
						|
            exp_selection_uniform_map[logits.device] = uniform
 | 
						|
 | 
						|
        mask1_rand = mask1 * uniform(mask1.shape)
 | 
						|
    else:
 | 
						|
        mask1_rand = mask1
 | 
						|
 | 
						|
    assert logits.shape[0] >= min_capacity, (
 | 
						|
        "No. of tokens (batch-size) should be greater than min_capacity."
 | 
						|
        "Either set min_capacity to 0 or increase your batch size."
 | 
						|
    )
 | 
						|
 | 
						|
    top_idx = _top_idx(mask1_rand, capacity)  # token index
 | 
						|
 | 
						|
    new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1)
 | 
						|
    mask1 = new_mask1
 | 
						|
 | 
						|
    # Compute locations in capacity buffer
 | 
						|
 | 
						|
    locations1 = torch.cumsum(mask1, dim=0) - 1
 | 
						|
 | 
						|
    # Store the capacity location for each token
 | 
						|
    locations1_s = torch.sum(locations1 * mask1, dim=1)
 | 
						|
 | 
						|
    # Normalize gate probabilities
 | 
						|
    mask1_float = mask1.float()
 | 
						|
    gates = gates * mask1_float
 | 
						|
 | 
						|
    locations1_sc = _one_hot_to_float(locations1_s, capacity)
 | 
						|
    combine_weights = einsum("se,sc->sec", gates, locations1_sc)
 | 
						|
 | 
						|
    dispatch_mask = combine_weights.bool()
 | 
						|
 | 
						|
    return l_aux, combine_weights, dispatch_mask, exp_counts
 | 
						|
 | 
						|
 | 
						|
def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
 | 
						|
    """Implements Top2Gating on logits."""
 | 
						|
    # everything is in fp32 in this function
 | 
						|
    gates = F.softmax(logits, dim=1)
 | 
						|
 | 
						|
    capacity = _capacity(gates, torch.tensor(capacity_factor * 2), torch.tensor(min_capacity))
 | 
						|
 | 
						|
    # Create a mask for 1st's expert per token
 | 
						|
    indices1_s = torch.argmax(gates, dim=1)
 | 
						|
    num_experts = int(gates.shape[1])
 | 
						|
    mask1 = F.one_hot(indices1_s, num_classes=num_experts)
 | 
						|
 | 
						|
    # Create a mask for 2nd's expert per token using Gumbel-max trick
 | 
						|
    # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
 | 
						|
    logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
 | 
						|
    # Replace top-expert with min value
 | 
						|
    logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf"))
 | 
						|
    indices2_s = torch.argmax(logits_except1, dim=1)
 | 
						|
    mask2 = F.one_hot(indices2_s, num_classes=num_experts)
 | 
						|
 | 
						|
    # Compute locations in capacity buffer
 | 
						|
    locations1 = torch.cumsum(mask1, dim=0) - 1
 | 
						|
    locations2 = torch.cumsum(mask2, dim=0) - 1
 | 
						|
    # Update 2nd's location by accounting for locations of 1st
 | 
						|
    locations2 += torch.sum(mask1, dim=0, keepdim=True)
 | 
						|
 | 
						|
    # gating decisions
 | 
						|
    exp_counts = torch.sum(mask1, dim=0).detach().to("cpu")
 | 
						|
 | 
						|
    # Compute l_aux
 | 
						|
    me = torch.mean(gates, dim=0)
 | 
						|
    ce = torch.mean(mask1.float(), dim=0)
 | 
						|
    l_aux = torch.mean(me * ce) * num_experts * num_experts
 | 
						|
 | 
						|
    # Remove locations outside capacity from mask
 | 
						|
    mask1 *= torch.lt(locations1, capacity)
 | 
						|
    mask2 *= torch.lt(locations2, capacity)
 | 
						|
 | 
						|
    # Store the capacity location for each token
 | 
						|
    locations1_s = torch.sum(locations1 * mask1, dim=1)
 | 
						|
    locations2_s = torch.sum(locations2 * mask2, dim=1)
 | 
						|
 | 
						|
    # Normalize gate probabilities
 | 
						|
    mask1_float = mask1.float()
 | 
						|
    mask2_float = mask2.float()
 | 
						|
    gates1_s = einsum("se,se->s", gates, mask1_float)
 | 
						|
    gates2_s = einsum("se,se->s", gates, mask2_float)
 | 
						|
    denom_s = gates1_s + gates2_s
 | 
						|
    # Avoid divide-by-zero
 | 
						|
    denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
 | 
						|
    gates1_s /= denom_s
 | 
						|
    gates2_s /= denom_s
 | 
						|
 | 
						|
    # Calculate combine_weights and dispatch_mask
 | 
						|
    gates1 = einsum("s,se->se", gates1_s, mask1_float)
 | 
						|
    gates2 = einsum("s,se->se", gates2_s, mask2_float)
 | 
						|
    locations1_sc = _one_hot_to_float(locations1_s, capacity)
 | 
						|
    locations2_sc = _one_hot_to_float(locations2_s, capacity)
 | 
						|
    combine1_sec = einsum("se,sc->sec", gates1, locations1_sc)
 | 
						|
    combine2_sec = einsum("se,sc->sec", gates2, locations2_sc)
 | 
						|
    combine_weights = combine1_sec + combine2_sec
 | 
						|
    dispatch_mask = combine_weights.bool()
 | 
						|
 | 
						|
    return l_aux, combine_weights, dispatch_mask, exp_counts
 | 
						|
 | 
						|
 | 
						|
class TopKGate(Module):
 | 
						|
    """Gate module which implements Top2Gating as described in Gshard_.
 | 
						|
    ::
 | 
						|
 | 
						|
        gate = TopKGate(model_dim, num_experts)
 | 
						|
        l_aux, combine_weights, dispatch_mask = gate(input)
 | 
						|
 | 
						|
    .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
 | 
						|
 | 
						|
    Args:
 | 
						|
        model_dim (int):
 | 
						|
            size of model embedding dimension
 | 
						|
        num_experts (ints):
 | 
						|
            number of experts in model
 | 
						|
    """
 | 
						|
 | 
						|
    wg: torch.nn.Linear
 | 
						|
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        model_dim: int,
 | 
						|
        num_experts: int,
 | 
						|
        k: int = 1,
 | 
						|
        capacity_factor: float = 1.0,
 | 
						|
        eval_capacity_factor: float = 1.0,
 | 
						|
        min_capacity: int = 8,
 | 
						|
        noisy_gate_policy: Optional[str] = None,
 | 
						|
        drop_tokens: bool = True,
 | 
						|
        use_rts: bool = True,
 | 
						|
    ) -> None:
 | 
						|
        super().__init__()
 | 
						|
 | 
						|
        # Only top-1 and top-2 are supported at the moment.
 | 
						|
        if k not in (1, 2):
 | 
						|
            raise ValueError("Only top-1 and top-2 gatings are supported.")
 | 
						|
        self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float()
 | 
						|
        self.k = k
 | 
						|
        self.capacity_factor = capacity_factor
 | 
						|
        self.eval_capacity_factor = eval_capacity_factor
 | 
						|
        self.min_capacity = min_capacity
 | 
						|
        self.noisy_gate_policy = noisy_gate_policy
 | 
						|
        self.wall_clock_breakdown = False
 | 
						|
        self.gate_time = 0.0
 | 
						|
        self.drop_tokens = drop_tokens
 | 
						|
        self.use_rts = use_rts
 | 
						|
 | 
						|
    def forward(
 | 
						|
        self, inputs: torch.Tensor, used_token: torch.Tensor = None
 | 
						|
    ) -> Tuple[Tensor, Tensor, Tensor]:  # type: ignore
 | 
						|
 | 
						|
        if self.wall_clock_breakdown:
 | 
						|
            timer("TopKGate").start()
 | 
						|
 | 
						|
        if self.wg.weight.dtype != torch.float32:
 | 
						|
            self.wg = self.wg.float()
 | 
						|
        inputs_fp32 = inputs.float()
 | 
						|
        # input jittering
 | 
						|
        if self.noisy_gate_policy == "Jitter" and self.training:
 | 
						|
            inputs_fp32 = multiplicative_jitter(inputs_fp32, device=inputs.device)
 | 
						|
        logits = self.wg(inputs_fp32)
 | 
						|
 | 
						|
        if self.k == 1:
 | 
						|
            gate_output = top1gating(
 | 
						|
                logits,
 | 
						|
                self.capacity_factor if self.training else self.eval_capacity_factor,
 | 
						|
                self.min_capacity,
 | 
						|
                used_token,
 | 
						|
                self.noisy_gate_policy if self.training else None,
 | 
						|
                self.drop_tokens,
 | 
						|
                self.use_rts,
 | 
						|
            )
 | 
						|
 | 
						|
        else:
 | 
						|
            gate_output = top2gating(
 | 
						|
                logits, self.capacity_factor if self.training else self.eval_capacity_factor, self.min_capacity
 | 
						|
            )
 | 
						|
 | 
						|
        if self.wall_clock_breakdown:
 | 
						|
            timer("TopKGate").stop()
 | 
						|
            self.gate_time = timer("TopKGate").elapsed(reset=False)
 | 
						|
 | 
						|
        return gate_output
 | 
						|
 | 
						|
 | 
						|
class MOELayer(Base):
 | 
						|
    """MOELayer module which implements MixtureOfExperts as described in Gshard_.
 | 
						|
    ::
 | 
						|
 | 
						|
        gate = TopKGate(model_dim, num_experts)
 | 
						|
        moe = MOELayer(gate, expert)
 | 
						|
        output = moe(inputs)
 | 
						|
        l_aux = moe.l_aux
 | 
						|
 | 
						|
    .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
 | 
						|
 | 
						|
    Args:
 | 
						|
        gate (torch.nn.Module):
 | 
						|
            gate network
 | 
						|
        expert (torch.nn.Module):
 | 
						|
            expert network
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, gate: Module, experts: Module, ep_group, ep_size, num_local_experts: int) -> None:
 | 
						|
        super().__init__()
 | 
						|
        self.gate = gate
 | 
						|
        self.experts = experts
 | 
						|
        self.ep_group = ep_group
 | 
						|
        self.ep_size = ep_size
 | 
						|
        self.num_local_experts = num_local_experts
 | 
						|
        self.time_falltoall = 0.0
 | 
						|
        self.time_salltoall = 0.0
 | 
						|
        self.time_moe = 0.0
 | 
						|
        self.wall_clock_breakdown = False
 | 
						|
 | 
						|
    def _set_ep_group(self, ep_group):
 | 
						|
        self.ep_group = ep_group
 | 
						|
 | 
						|
    def forward(self, *inputs: Tensor) -> Tensor:
 | 
						|
 | 
						|
        if self.wall_clock_breakdown:
 | 
						|
            timer("moe").start()
 | 
						|
 | 
						|
        # Implement Algorithm 2 from GShard paper.
 | 
						|
        d_model = inputs[0].shape[-1]
 | 
						|
 | 
						|
        # Initial implementation -> Reshape into S tokens by dropping sequence dimension.
 | 
						|
        # Reshape into G groups so that each group can distribute tokens equally
 | 
						|
        # group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1
 | 
						|
        reshaped_inputs = inputs[0].reshape(-1, d_model)
 | 
						|
 | 
						|
        self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_inputs, inputs[1])
 | 
						|
        dispatched_inputs = einsum(
 | 
						|
            "sec,sm->ecm", dispatch_mask.type_as(inputs[0]), reshaped_inputs
 | 
						|
        )  # TODO: heavy memory usage due to long sequence length
 | 
						|
 | 
						|
        if self.wall_clock_breakdown:
 | 
						|
            timer("falltoall").start()
 | 
						|
 | 
						|
        dispatched_inputs = _AllToAll.apply(self.ep_group, dispatched_inputs)
 | 
						|
 | 
						|
        if self.wall_clock_breakdown:
 | 
						|
            timer("falltoall").stop()
 | 
						|
            self.time_falltoall = timer("falltoall").elapsed(reset=False)
 | 
						|
 | 
						|
        # Re-shape after all-to-all: ecm -> gecm
 | 
						|
        dispatched_inputs = dispatched_inputs.reshape(self.ep_size, self.num_local_experts, -1, d_model)
 | 
						|
 | 
						|
        expert_output = self.experts(dispatched_inputs)
 | 
						|
 | 
						|
        if self.wall_clock_breakdown:
 | 
						|
            timer("salltoall").start()
 | 
						|
 | 
						|
        expert_output = _AllToAll.apply(self.ep_group, expert_output)
 | 
						|
 | 
						|
        if self.wall_clock_breakdown:
 | 
						|
            timer("salltoall").stop()
 | 
						|
            self.time_salltoall = timer("salltoall").elapsed(reset=False)
 | 
						|
 | 
						|
        # Re-shape back: gecm -> ecm
 | 
						|
        expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)
 | 
						|
 | 
						|
        combined_output = einsum("sec,ecm->sm", combine_weights.type_as(inputs[0]), expert_output)
 | 
						|
 | 
						|
        a = combined_output.reshape(inputs[0].shape)
 | 
						|
 | 
						|
        if self.wall_clock_breakdown:
 | 
						|
            timer("moe").stop()
 | 
						|
            self.time_moe = timer("moe").elapsed(reset=False)
 | 
						|
 | 
						|
        return a
 |