mirror of https://github.com/InternLM/InternLM
496 lines
17 KiB
Python
496 lines
17 KiB
Python
"""
|
|
The file has been adapted from the following files:
|
|
https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py
|
|
Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555
|
|
We retain the following license from the original files:
|
|
"""
|
|
|
|
# Copyright (c) Microsoft Corporation.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# DeepSpeed Team
|
|
|
|
|
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn.functional as F
|
|
from torch import Tensor
|
|
from torch.nn import Module
|
|
|
|
from internlm.utils.logger import get_logger
|
|
from internlm.utils.megatron_timers import megatron_timer as timer
|
|
|
|
# global llm logger
|
|
logger = get_logger(__file__)
|
|
|
|
if TYPE_CHECKING:
|
|
Base = Module[Tensor]
|
|
else:
|
|
Base = Module
|
|
|
|
uniform_map: Dict[torch.device, Callable] = {}
|
|
gumbel_map: Dict[torch.device, Callable] = {}
|
|
exp_selection_uniform_map: Dict[torch.device, Callable] = {}
|
|
|
|
|
|
def multiplicative_jitter(x, device: torch.device, epsilon=1e-2):
|
|
"""
|
|
Modified from switch transformer paper. mesh transformers
|
|
Multiply values by a random number between 1-epsilon and 1+epsilon.
|
|
Makes models more resilient to rounding errors introduced by bfloat16.
|
|
This seems particularly important for logits.
|
|
Args:
|
|
x: a torch.tensor
|
|
device: torch.device
|
|
epsilon: a floating point value
|
|
Returns:
|
|
a jittered x.
|
|
"""
|
|
if epsilon == 0:
|
|
return x
|
|
uniform = uniform_map.get(device)
|
|
if uniform is None:
|
|
uniform = torch.distributions.uniform.Uniform(
|
|
low=torch.tensor(1.0 - epsilon, device=device), high=torch.tensor(1.0 + epsilon, device=device)
|
|
).rsample # type: ignore
|
|
uniform_map[device] = uniform
|
|
return x * uniform(x.shape)
|
|
|
|
|
|
def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
|
|
gumbel = gumbel_map.get(device)
|
|
if gumbel is None:
|
|
one = torch.tensor(1.0, device=device)
|
|
zero = torch.tensor(0.0, device=device)
|
|
gumbel = torch.distributions.gumbel.Gumbel(zero, one).rsample # type: ignore
|
|
gumbel_map[device] = gumbel
|
|
return gumbel(shape)
|
|
|
|
|
|
# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
|
|
# See https://arxiv.org/pdf/2006.16668.pdf for details.
|
|
|
|
|
|
# Based on https://github.com/pytorch/pytorch/pull/40762
|
|
class _AllToAll(torch.autograd.Function):
|
|
"""
|
|
All to all communication
|
|
"""
|
|
|
|
@staticmethod
|
|
def forward(
|
|
ctx: Any,
|
|
# TODO: replace with DS process group
|
|
group: torch.distributed.ProcessGroup,
|
|
inputs: Tensor,
|
|
) -> Tensor: # type: ignore
|
|
ctx.group = group
|
|
inputs = inputs.contiguous()
|
|
output = torch.empty_like(inputs)
|
|
dist.all_to_all_single(output, inputs, group=group)
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
|
|
return (None, _AllToAll.apply(ctx.group, *grad_output))
|
|
|
|
|
|
# einsum rewrites are on par or more performant
|
|
# switch can be bubbled up in future
|
|
USE_EINSUM = True
|
|
|
|
|
|
# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
|
|
# See https://arxiv.org/pdf/2006.16668.pdf for details.
|
|
def einsum(rule, a, b):
|
|
if USE_EINSUM:
|
|
return torch.einsum(rule, a, b)
|
|
elif rule == "s,se->se":
|
|
# [1, s] * [s, e]
|
|
return a.reshape(a.shape[0], -1) * b
|
|
elif rule == "se,sc->sec":
|
|
# [s,e,1] * [s,1,c]
|
|
return a.unsqueeze(2) * b.unsqueeze(1)
|
|
elif rule == "se,se->s":
|
|
# [s,1,e] * [s,e,1]
|
|
return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1)
|
|
elif rule == "sec,sm->ecm":
|
|
# [e*c, s] * [s, m]
|
|
s = a.shape[0]
|
|
e = a.shape[1]
|
|
c = a.shape[2]
|
|
m = b.shape[1]
|
|
return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m)
|
|
elif rule == "sec,ecm->sm":
|
|
# [s, e*c] * [e*c, m]
|
|
return torch.matmul(a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1]))
|
|
elif rule == "ks,ksm->sm":
|
|
k = b.shape[0]
|
|
s = b.shape[1]
|
|
m = b.shape[2]
|
|
# [k, s] -> [s, k] -> [s, 1, k]
|
|
a = a.t().unsqueeze(1)
|
|
# [k,s,m] -> [k, sm] -> [sm, k] -> [s, m, k]
|
|
b = b.reshape(k, -1).t().reshape(s, m, k)
|
|
# bmm([s, 1, k], [s, m, k]^t) -> [s, m, 1]
|
|
return torch.bmm(a, b.transpose(1, 2)).squeeze(2)
|
|
else:
|
|
return torch.einsum(rule, a, b)
|
|
|
|
|
|
# The following functions are extracted and scripted
|
|
# because otherwise during a torch.jit.trace, the non-Tensor
|
|
# values used in the calculations get recorded as constants.
|
|
# torch.jit.script coerces them into Tensors and preserves
|
|
# their dynamic shapes. This enables ONNX export.
|
|
# We can't script the entire top1gating function because it
|
|
# includes stateful caching logic which is incompatible with ONNX.
|
|
|
|
|
|
@torch.jit.script
|
|
def _capacity(gates: Tensor, capacity_factor: Tensor, min_capacity: Tensor) -> Tensor:
|
|
# gates has shape of SE
|
|
num_tokens = gates.shape[0]
|
|
num_experts = gates.shape[1]
|
|
# to(torch.int64) works around a bug in torch.onnx.export:
|
|
# it should cast k to int64 when converting torch.topk but it doesn't.
|
|
capacity = torch.ceil((num_tokens / num_experts) * capacity_factor).to(torch.int64)
|
|
if capacity < min_capacity:
|
|
capacity = min_capacity.to(torch.int64)
|
|
return capacity
|
|
|
|
|
|
@torch.jit.script
|
|
def _top_idx(source, k):
|
|
return torch.topk(source, k=k, dim=0)[1]
|
|
|
|
|
|
@torch.jit.script
|
|
def _one_hot_to_float(x, num_classes):
|
|
return F.one_hot(x, num_classes=num_classes).float()
|
|
|
|
|
|
def top1gating(
|
|
logits: Tensor,
|
|
capacity_factor: float,
|
|
min_capacity: int,
|
|
used_token: Tensor = None,
|
|
noisy_gate_policy: Optional[str] = None,
|
|
drop_tokens: bool = True,
|
|
use_rts: bool = True,
|
|
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
|
"""Implements Top1Gating on logits."""
|
|
if noisy_gate_policy == "RSample":
|
|
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
|
|
# everything is in fp32 in this function
|
|
gates = F.softmax(logits, dim=1)
|
|
|
|
capacity = _capacity(gates, torch.tensor(capacity_factor), torch.tensor(min_capacity))
|
|
|
|
# Create a mask for 1st's expert per token
|
|
# noisy gating
|
|
indices1_s = torch.argmax(logits_w_noise if noisy_gate_policy == "RSample" else gates, dim=1)
|
|
num_experts = int(gates.shape[1])
|
|
mask1 = F.one_hot(indices1_s, num_classes=num_experts)
|
|
|
|
# mask only used tokens
|
|
if used_token is not None:
|
|
mask1 = einsum("s,se->se", used_token, mask1)
|
|
|
|
# gating decisions
|
|
exp_counts = torch.sum(mask1, dim=0).detach().to("cpu")
|
|
|
|
# if we don't want to drop any tokens
|
|
if not drop_tokens:
|
|
new_capacity = torch.max(exp_counts).to(logits.device)
|
|
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group())
|
|
capacity = new_capacity
|
|
|
|
# Compute l_aux
|
|
me = torch.mean(gates, dim=0)
|
|
ce = torch.mean(mask1.float(), dim=0)
|
|
l_aux = torch.sum(me * ce) * num_experts
|
|
|
|
# Random Token Selection
|
|
if use_rts:
|
|
uniform = exp_selection_uniform_map.get(logits.device)
|
|
if uniform is None:
|
|
uniform = torch.distributions.uniform.Uniform(
|
|
low=torch.tensor(0.0, device=logits.device), high=torch.tensor(1.0, device=logits.device)
|
|
).rsample
|
|
exp_selection_uniform_map[logits.device] = uniform
|
|
|
|
mask1_rand = mask1 * uniform(mask1.shape)
|
|
else:
|
|
mask1_rand = mask1
|
|
|
|
assert logits.shape[0] >= min_capacity, (
|
|
"No. of tokens (batch-size) should be greater than min_capacity."
|
|
"Either set min_capacity to 0 or increase your batch size."
|
|
)
|
|
|
|
top_idx = _top_idx(mask1_rand, capacity) # token index
|
|
|
|
new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1)
|
|
mask1 = new_mask1
|
|
|
|
# Compute locations in capacity buffer
|
|
|
|
locations1 = torch.cumsum(mask1, dim=0) - 1
|
|
|
|
# Store the capacity location for each token
|
|
locations1_s = torch.sum(locations1 * mask1, dim=1)
|
|
|
|
# Normalize gate probabilities
|
|
mask1_float = mask1.float()
|
|
gates = gates * mask1_float
|
|
|
|
locations1_sc = _one_hot_to_float(locations1_s, capacity)
|
|
combine_weights = einsum("se,sc->sec", gates, locations1_sc)
|
|
|
|
dispatch_mask = combine_weights.bool()
|
|
|
|
return l_aux, combine_weights, dispatch_mask, exp_counts
|
|
|
|
|
|
def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
|
"""Implements Top2Gating on logits."""
|
|
# everything is in fp32 in this function
|
|
gates = F.softmax(logits, dim=1)
|
|
|
|
capacity = _capacity(gates, torch.tensor(capacity_factor * 2), torch.tensor(min_capacity))
|
|
|
|
# Create a mask for 1st's expert per token
|
|
indices1_s = torch.argmax(gates, dim=1)
|
|
num_experts = int(gates.shape[1])
|
|
mask1 = F.one_hot(indices1_s, num_classes=num_experts)
|
|
|
|
# Create a mask for 2nd's expert per token using Gumbel-max trick
|
|
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
|
|
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
|
|
# Replace top-expert with min value
|
|
logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf"))
|
|
indices2_s = torch.argmax(logits_except1, dim=1)
|
|
mask2 = F.one_hot(indices2_s, num_classes=num_experts)
|
|
|
|
# Compute locations in capacity buffer
|
|
locations1 = torch.cumsum(mask1, dim=0) - 1
|
|
locations2 = torch.cumsum(mask2, dim=0) - 1
|
|
# Update 2nd's location by accounting for locations of 1st
|
|
locations2 += torch.sum(mask1, dim=0, keepdim=True)
|
|
|
|
# gating decisions
|
|
exp_counts = torch.sum(mask1, dim=0).detach().to("cpu")
|
|
|
|
# Compute l_aux
|
|
me = torch.mean(gates, dim=0)
|
|
ce = torch.mean(mask1.float(), dim=0)
|
|
l_aux = torch.mean(me * ce) * num_experts * num_experts
|
|
|
|
# Remove locations outside capacity from mask
|
|
mask1 *= torch.lt(locations1, capacity)
|
|
mask2 *= torch.lt(locations2, capacity)
|
|
|
|
# Store the capacity location for each token
|
|
locations1_s = torch.sum(locations1 * mask1, dim=1)
|
|
locations2_s = torch.sum(locations2 * mask2, dim=1)
|
|
|
|
# Normalize gate probabilities
|
|
mask1_float = mask1.float()
|
|
mask2_float = mask2.float()
|
|
gates1_s = einsum("se,se->s", gates, mask1_float)
|
|
gates2_s = einsum("se,se->s", gates, mask2_float)
|
|
denom_s = gates1_s + gates2_s
|
|
# Avoid divide-by-zero
|
|
denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
|
|
gates1_s /= denom_s
|
|
gates2_s /= denom_s
|
|
|
|
# Calculate combine_weights and dispatch_mask
|
|
gates1 = einsum("s,se->se", gates1_s, mask1_float)
|
|
gates2 = einsum("s,se->se", gates2_s, mask2_float)
|
|
locations1_sc = _one_hot_to_float(locations1_s, capacity)
|
|
locations2_sc = _one_hot_to_float(locations2_s, capacity)
|
|
combine1_sec = einsum("se,sc->sec", gates1, locations1_sc)
|
|
combine2_sec = einsum("se,sc->sec", gates2, locations2_sc)
|
|
combine_weights = combine1_sec + combine2_sec
|
|
dispatch_mask = combine_weights.bool()
|
|
|
|
return l_aux, combine_weights, dispatch_mask, exp_counts
|
|
|
|
|
|
class TopKGate(Module):
|
|
"""Gate module which implements Top2Gating as described in Gshard_.
|
|
::
|
|
|
|
gate = TopKGate(model_dim, num_experts)
|
|
l_aux, combine_weights, dispatch_mask = gate(input)
|
|
|
|
.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
|
|
|
|
Args:
|
|
model_dim (int):
|
|
size of model embedding dimension
|
|
num_experts (ints):
|
|
number of experts in model
|
|
"""
|
|
|
|
wg: torch.nn.Linear
|
|
|
|
def __init__(
|
|
self,
|
|
model_dim: int,
|
|
num_experts: int,
|
|
k: int = 1,
|
|
capacity_factor: float = 1.0,
|
|
eval_capacity_factor: float = 1.0,
|
|
min_capacity: int = 8,
|
|
noisy_gate_policy: Optional[str] = None,
|
|
drop_tokens: bool = True,
|
|
use_rts: bool = True,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
# Only top-1 and top-2 are supported at the moment.
|
|
if k not in (1, 2):
|
|
raise ValueError("Only top-1 and top-2 gatings are supported.")
|
|
# TODO: can we use tensor parallel here?
|
|
# Deepspeed's mechisms, alway use fp32
|
|
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float()
|
|
self.k = k
|
|
self.capacity_factor = capacity_factor
|
|
self.eval_capacity_factor = eval_capacity_factor
|
|
self.min_capacity = min_capacity
|
|
self.noisy_gate_policy = noisy_gate_policy
|
|
self.wall_clock_breakdown = False
|
|
self.gate_time = 0.0
|
|
self.drop_tokens = drop_tokens
|
|
self.use_rts = use_rts
|
|
|
|
def forward(
|
|
self, inputs: torch.Tensor, used_token: torch.Tensor = None
|
|
) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore
|
|
|
|
if self.wall_clock_breakdown:
|
|
timer("TopKGate").start()
|
|
|
|
if self.wg.weight.dtype != torch.float32:
|
|
self.wg = self.wg.float()
|
|
inputs_fp32 = inputs.float()
|
|
# input jittering
|
|
if self.noisy_gate_policy == "Jitter" and self.training:
|
|
inputs_fp32 = multiplicative_jitter(inputs_fp32, device=inputs.device)
|
|
logits = self.wg(inputs_fp32)
|
|
|
|
if self.k == 1:
|
|
gate_output = top1gating(
|
|
logits,
|
|
self.capacity_factor if self.training else self.eval_capacity_factor,
|
|
self.min_capacity,
|
|
used_token,
|
|
self.noisy_gate_policy if self.training else None,
|
|
self.drop_tokens,
|
|
self.use_rts,
|
|
)
|
|
|
|
else:
|
|
gate_output = top2gating(
|
|
logits, self.capacity_factor if self.training else self.eval_capacity_factor, self.min_capacity
|
|
)
|
|
|
|
if self.wall_clock_breakdown:
|
|
timer("TopKGate").stop()
|
|
self.gate_time = timer("TopKGate").elapsed(reset=False)
|
|
|
|
return gate_output
|
|
|
|
|
|
class MOELayer(Base):
|
|
"""MOELayer module which implements MixtureOfExperts as described in Gshard_.
|
|
::
|
|
|
|
gate = TopKGate(model_dim, num_experts)
|
|
moe = MOELayer(gate, expert)
|
|
output = moe(inputs)
|
|
l_aux = moe.l_aux
|
|
|
|
.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
|
|
|
|
Args:
|
|
gate (torch.nn.Module):
|
|
gate network
|
|
expert (torch.nn.Module):
|
|
expert network
|
|
"""
|
|
|
|
def __init__(self, gate: Module, experts: Module, ep_group, ep_size, num_local_experts: int) -> None:
|
|
super().__init__()
|
|
self.gate = gate
|
|
self.experts = experts
|
|
self.ep_group = ep_group
|
|
self.ep_size = ep_size
|
|
self.num_local_experts = num_local_experts
|
|
self.time_falltoall = 0.0
|
|
self.time_salltoall = 0.0
|
|
self.time_moe = 0.0
|
|
self.wall_clock_breakdown = False
|
|
|
|
def _set_ep_group(self, ep_group):
|
|
self.ep_group = ep_group
|
|
|
|
def forward(self, *inputs: Tensor) -> Tensor:
|
|
|
|
if self.wall_clock_breakdown:
|
|
timer("moe").start()
|
|
|
|
# Implement Algorithm 2 from GShard paper.
|
|
d_model = inputs[0].shape[-1]
|
|
|
|
# Initial implementation -> Reshape into S tokens by dropping sequence dimension.
|
|
# Reshape into G groups so that each group can distribute tokens equally
|
|
# group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1
|
|
reshaped_inputs = inputs[0].reshape(-1, d_model)
|
|
|
|
self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_inputs, inputs[1])
|
|
dispatched_inputs = einsum(
|
|
"sec,sm->ecm", dispatch_mask.type_as(inputs[0]), reshaped_inputs
|
|
) # TODO: heavy memory usage due to long sequence length
|
|
|
|
if self.wall_clock_breakdown:
|
|
timer("falltoall").start()
|
|
|
|
dispatched_inputs = _AllToAll.apply(self.ep_group, dispatched_inputs)
|
|
|
|
if self.wall_clock_breakdown:
|
|
timer("falltoall").stop()
|
|
self.time_falltoall = timer("falltoall").elapsed(reset=False)
|
|
|
|
# Re-shape after all-to-all: ecm -> gecm
|
|
dispatched_inputs = dispatched_inputs.reshape(self.ep_size, self.num_local_experts, -1, d_model)
|
|
|
|
expert_output = self.experts(dispatched_inputs)
|
|
|
|
if self.wall_clock_breakdown:
|
|
timer("salltoall").start()
|
|
|
|
expert_output = _AllToAll.apply(self.ep_group, expert_output)
|
|
|
|
if self.wall_clock_breakdown:
|
|
timer("salltoall").stop()
|
|
self.time_salltoall = timer("salltoall").elapsed(reset=False)
|
|
|
|
# Re-shape back: gecm -> ecm
|
|
expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)
|
|
|
|
combined_output = einsum("sec,ecm->sm", combine_weights.type_as(inputs[0]), expert_output)
|
|
|
|
a = combined_output.reshape(inputs[0].shape)
|
|
|
|
if self.wall_clock_breakdown:
|
|
timer("moe").stop()
|
|
self.time_moe = timer("moe").elapsed(reset=False)
|
|
|
|
return a
|