InternLM/internlm/moe/sharded_moe.py

483 lines
16 KiB
Python

"""
The file has been adapted from the following files:
https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py
Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555
We retain the following license from the original files:
"""
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Module
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
# global llm logger
logger = get_logger(__file__)
if TYPE_CHECKING:
Base = Module[Tensor]
else:
Base = Module
uniform_map: Dict[torch.device, Callable] = {}
gumbel_map: Dict[torch.device, Callable] = {}
exp_selection_uniform_map: Dict[torch.device, Callable] = {}
def multiplicative_jitter(x, device: torch.device, epsilon=1e-2):
"""
Modified from switch transformer paper. mesh transformers
Multiply values by a random number between 1-epsilon and 1+epsilon.
Makes models more resilient to rounding errors introduced by bfloat16.
This seems particularly important for logits.
Args:
x: a torch.tensor
device: torch.device
epsilon: a floating point value
Returns:
a jittered x.
"""
if epsilon == 0:
return x
uniform = uniform_map.get(device)
if uniform is None:
uniform = torch.distributions.uniform.Uniform(
low=torch.tensor(1.0 - epsilon, device=device), high=torch.tensor(1.0 + epsilon, device=device)
).rsample # type: ignore
uniform_map[device] = uniform
return x * uniform(x.shape)
def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
gumbel = gumbel_map.get(device)
if gumbel is None:
one = torch.tensor(1.0, device=device)
zero = torch.tensor(0.0, device=device)
gumbel = torch.distributions.gumbel.Gumbel(zero, one).rsample # type: ignore
gumbel_map[device] = gumbel
return gumbel(shape)
# Based on https://github.com/pytorch/pytorch/pull/40762
class _AllToAll(torch.autograd.Function):
"""
All to all communication
"""
@staticmethod
def forward(
ctx: Any,
# TODO: replace with DS process group
group: torch.distributed.ProcessGroup,
inputs: Tensor,
) -> Tensor: # type: ignore
ctx.group = group
inputs = inputs.contiguous()
output = torch.empty_like(inputs)
dist.all_to_all_single(output, inputs, group=group)
return output
@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
return (None, _AllToAll.apply(ctx.group, *grad_output))
# einsum rewrites are on par or more performant
# switch can be bubbled up in future
USE_EINSUM = True
# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
# See https://arxiv.org/pdf/2006.16668.pdf for details.
def einsum(rule, a, b):
if USE_EINSUM:
return torch.einsum(rule, a, b)
elif rule == "s,se->se":
# [s, 1] * [s, e]
return a.reshape(a.shape[0], -1) * b
elif rule == "ks,kse->kse":
# [k, s, 1] * [s, e]
return a.reshape(a.shape[0], a.shape[1], -1) * b
elif rule == "se,sc->sec":
# [s,e,1] * [s,1,c]
return a.unsqueeze(2) * b.unsqueeze(1)
elif rule == "kse,ksc->ksec":
# [k,s,e,1] * [k,s,1,c]
return a.unsqueeze(3) * b.unsqueeze(2)
elif rule == "se,se->s":
# [s,1,e] * [s,e,1]
return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1)
elif rule == "se,kse->ks":
# [s,1,e] * [k,s,e,1]
return torch.matmul(a.unsqueeze(1), b.unsqueeze(3)).reshape(b.shape[0], -1)
elif rule == "sec,sm->ecm":
# [e*c, s] @ [s, m]
s = a.shape[0]
e = a.shape[1]
c = a.shape[2]
m = b.shape[1]
return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m)
elif rule == "sec,ecm->sm":
# [s, e*c] @ [e*c, m]
return torch.matmul(a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1]))
elif rule == "ks,ksm->sm":
k = b.shape[0]
s = b.shape[1]
m = b.shape[2]
# [k, s] -> [s, k] -> [s, 1, k]
a = a.t().unsqueeze(1)
# [k,s,m] -> [k, sm] -> [sm, k] -> [s, m, k]
b = b.reshape(k, -1).t().reshape(s, m, k)
# bmm([s, 1, k], [s, m, k]^t) -> [s, m, 1]
return torch.bmm(a, b.transpose(1, 2)).squeeze(2)
else:
return torch.einsum(rule, a, b)
# The following functions are extracted and scripted
# because otherwise during a torch.jit.trace, the non-Tensor
# values used in the calculations get recorded as constants.
# torch.jit.script coerces them into Tensors and preserves
# their dynamic shapes. This enables ONNX export.
# We can't script the entire top1gating function because it
# includes stateful caching logic which is incompatible with ONNX.
@torch.jit.script
def _capacity(gates: Tensor, capacity_factor: Tensor, min_capacity: Tensor) -> Tensor:
# gates has shape of SE
num_tokens = gates.shape[0]
num_experts = gates.shape[1]
# to(torch.int64) works around a bug in torch.onnx.export:
# it should cast k to int64 when converting torch.topk but it doesn't.
capacity = torch.ceil((num_tokens / num_experts) * capacity_factor).to(torch.int64)
if capacity < min_capacity:
capacity = min_capacity.to(torch.int64)
return capacity
@torch.jit.script
def _top_idx(source, k):
return torch.topk(source, k=k, dim=0)[1]
def top1gating(
logits: Tensor,
capacity_factor: float,
min_capacity: int,
used_token: Tensor = None,
noisy_gate_policy: Optional[str] = None,
drop_tokens: bool = True,
use_rts: bool = True,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Implements Top1Gating on logits."""
if noisy_gate_policy == "RSample":
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
# everything is in fp32 in this function
gates = F.softmax(logits, dim=1)
capacity = _capacity(gates, torch.tensor(capacity_factor), torch.tensor(min_capacity))
# Create a mask for 1st's expert per token
# noisy gating
indices1_s = torch.argmax(logits_w_noise if noisy_gate_policy == "RSample" else gates, dim=1)
num_experts = int(gates.shape[1])
mask1 = F.one_hot(indices1_s, num_classes=num_experts)
# mask only used tokens
if used_token is not None:
mask1 = einsum("s,se->se", used_token, mask1)
# gating decisions
exp_counts = torch.sum(mask1, dim=0).detach().to("cpu")
# if we don't want to drop any tokens
if not drop_tokens:
new_capacity = torch.max(exp_counts).to(logits.device)
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group())
capacity = new_capacity
# Compute l_aux
me = torch.mean(gates, dim=0)
ce = torch.mean(mask1.type_as(logits), dim=0)
l_aux = torch.sum(me * ce) * num_experts
# Random Token Selection
if use_rts:
uniform = exp_selection_uniform_map.get(logits.device)
if uniform is None:
uniform = torch.distributions.uniform.Uniform(
low=torch.tensor(0.0, device=logits.device), high=torch.tensor(1.0, device=logits.device)
).rsample
exp_selection_uniform_map[logits.device] = uniform
mask1_rand = mask1 * uniform(mask1.shape)
else:
mask1_rand = mask1
assert logits.shape[0] >= min_capacity, (
"No. of tokens (batch-size) should be greater than min_capacity."
"Either set min_capacity to 0 or increase your batch size."
)
top_idx = _top_idx(mask1_rand, capacity) # token index
new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1)
mask1 = new_mask1
# Compute locations in capacity buffer
locations1 = torch.cumsum(mask1, dim=0) - 1
# Store the capacity location for each token
locations1_s = torch.sum(locations1 * mask1, dim=1)
# Normalize gate probabilities
mask1_float = mask1.type_as(logits)
gates = gates * mask1_float
locations1_sc = F.one_hot(locations1_s, num_classes=capacity).type_as(logits)
combine_weights = einsum("se,sc->sec", gates, locations1_sc)
dispatch_mask = combine_weights.bool()
return l_aux, combine_weights, dispatch_mask, exp_counts
def top2gating(
logits: Tensor,
capacity_factor: float,
min_capacity: int,
noisy_gate_policy: Optional[str] = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Implements Top2Gating on logits."""
# everything is in fp32 in this function
gates = F.softmax(logits, dim=1)
num_experts = int(gates.shape[1])
capacity = _capacity(gates, torch.tensor(capacity_factor * 2), torch.tensor(min_capacity))
# NOTE: here we just add noise on 2nd expert, following
# https://github.com/facebookresearch/fairscale/blob/main/fairscale/nn/moe/top2gate.py
if noisy_gate_policy == "RSample":
# Create a mask for 1st's expert per token
indices1_s = torch.argmax(gates, dim=1)
mask1 = F.one_hot(indices1_s, num_classes=num_experts)
# Create a mask for 2nd's expert per token using Gumbel-max trick
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
# Replace top-expert with min value
logits_except1 = logits_w_noise.masked_fill(mask1.bool(), torch.finfo(logits.dtype).min)
indices2_s = torch.argmax(logits_except1, dim=1)
mask2 = F.one_hot(indices2_s, num_classes=num_experts)
# merge operands in topk gating to save launch overhead
masks = torch.cat((mask1, mask2), dim=0)
else:
# Create a mask by top-2 experts
indices_s = torch.topk(gates, 2, dim=1).indices
indices_s = indices_s.permute(1, 0).reshape(-1)
masks = F.one_hot(indices_s, num_classes=num_experts)
# Compute locations in capacity buffer
locations = torch.cumsum(masks, dim=0) - 1
# reshape (s,e) to (k,s,e)
masks = masks.reshape(-1, gates.shape[0], num_experts)
locations = locations.reshape(-1, gates.shape[0], num_experts)
# gating decisions
exp_counts = torch.sum(masks[0], dim=0).detach().to("cpu")
# Compute l_aux
me = torch.mean(gates, dim=0)
ce = torch.mean(masks[0].type_as(logits), dim=0)
l_aux = torch.mean(me * ce) * num_experts * num_experts
# Remove locations outside capacity from mask
masks *= torch.lt(locations, capacity)
# Store the capacity location for each token
locations_s = torch.sum(locations * masks, dim=2)
# Normalize gate probabilities
mask_float = masks.type_as(logits)
gate_s = einsum("se,kse->ks", gates, mask_float)
denom_s = torch.sum(gate_s, dim=0)
# Avoid divide-by-zero
denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
gate_s /= denom_s
# Calculate combine_weights and dispatch_mask
gate_all = einsum("ks,kse->kse", gate_s, mask_float)
locations_sc = F.one_hot(locations_s, num_classes=capacity).type_as(logits)
combine_sec = einsum("kse,ksc->ksec", gate_all, locations_sc)
combine_weights = torch.sum(combine_sec, dim=0)
dispatch_mask = combine_weights.bool()
return l_aux, combine_weights, dispatch_mask, exp_counts
class TopKGate(Module):
"""Gate module which implements Top2Gating as described in Gshard_.
::
gate = TopKGate(model_dim, num_experts)
l_aux, combine_weights, dispatch_mask = gate(input)
.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
Args:
model_dim (int):
size of model embedding dimension
num_experts (ints):
number of experts in model
"""
wg: torch.nn.Linear
def __init__(
self,
model_dim: int,
num_experts: int,
k: int = 1,
capacity_factor: float = 1.0,
eval_capacity_factor: float = 1.0,
min_capacity: int = 8,
noisy_gate_policy: Optional[str] = None,
drop_tokens: bool = True,
use_rts: bool = True,
) -> None:
super().__init__()
# Only top-1 and top-2 are supported at the moment.
if k not in (1, 2):
raise ValueError("Only top-1 and top-2 gatings are supported.")
# Deepspeed's mechisms, alway use fp32
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)
self.k = k
self.capacity_factor = capacity_factor
self.eval_capacity_factor = eval_capacity_factor
self.min_capacity = min_capacity
self.noisy_gate_policy = noisy_gate_policy
self.drop_tokens = drop_tokens
self.use_rts = use_rts
def forward(
self, inputs: torch.Tensor, used_token: torch.Tensor = None
) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore
# input jittering
if self.noisy_gate_policy == "Jitter" and self.training:
inputs = multiplicative_jitter(inputs, device=inputs.device)
logits = self.wg(inputs)
if self.k == 1:
gate_output = top1gating(
logits,
self.capacity_factor if self.training else self.eval_capacity_factor,
self.min_capacity,
used_token,
self.noisy_gate_policy if self.training else None,
self.drop_tokens,
self.use_rts,
)
else:
gate_output = top2gating(
logits,
self.capacity_factor if self.training else self.eval_capacity_factor,
self.min_capacity,
self.noisy_gate_policy,
)
return gate_output
class MOELayer(Base):
"""MOELayer module which implements MixtureOfExperts as described in Gshard_.
::
gate = TopKGate(model_dim, num_experts)
moe = MOELayer(gate, expert)
output = moe(inputs)
l_aux = moe.l_aux
.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
Args:
gate (torch.nn.Module):
gate network
expert (torch.nn.Module):
expert network
"""
def __init__(self, gate: Module, experts: Module, ep_group, ep_size, num_local_experts: int) -> None:
super().__init__()
self.gate = gate
self.experts = experts
self.ep_group = ep_group
self.ep_size = ep_size
self.num_local_experts = num_local_experts
self.time_falltoall = 0.0
self.time_salltoall = 0.0
self.time_moe = 0.0
self.wall_clock_breakdown = False
def forward(self, *inputs: Tensor) -> Tensor:
if self.wall_clock_breakdown:
timer("moe").start()
# Implement Algorithm 2 from GShard paper.
d_model = inputs[0].shape[-1]
# Initial implementation -> Reshape into S tokens by dropping sequence dimension.
# Reshape into G groups so that each group can distribute tokens equally
# group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1
reshaped_inputs = inputs[0].reshape(-1, d_model)
self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_inputs, inputs[1])
dispatched_inputs = einsum(
"sec,sm->ecm", dispatch_mask.type_as(inputs[0]), reshaped_inputs
) # TODO: heavy memory usage due to long sequence length
if self.wall_clock_breakdown:
timer("falltoall").start()
dispatched_inputs = _AllToAll.apply(self.ep_group, dispatched_inputs)
if self.wall_clock_breakdown:
timer("falltoall").stop()
self.time_falltoall = timer("falltoall").elapsed(reset=False)
# Re-shape after all-to-all: ecm -> gecm
dispatched_inputs = dispatched_inputs.reshape(self.ep_size, self.num_local_experts, -1, d_model)
expert_output = self.experts(dispatched_inputs)
if self.wall_clock_breakdown:
timer("salltoall").start()
expert_output = _AllToAll.apply(self.ep_group, expert_output)
if self.wall_clock_breakdown:
timer("salltoall").stop()
self.time_salltoall = timer("salltoall").elapsed(reset=False)
# Re-shape back: gecm -> ecm
expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)
combined_output = einsum("sec,ecm->sm", combine_weights.type_as(inputs[0]), expert_output)
out = combined_output.reshape(inputs[0].shape)
if self.wall_clock_breakdown:
timer("moe").stop()
self.time_moe = timer("moe").elapsed(reset=False)
return out