feat(XXX): add moe

pull/375/head
Wenwen Qu 2023-08-07 20:17:49 +08:00
parent c219065348
commit c357288a8b
10 changed files with 812 additions and 32 deletions

View File

@ -77,6 +77,7 @@ hybrid_zero_optimizer = dict(
loss = dict(
label_smoothing=0,
moe_loss_coeff=1.0,
)
adam = dict(
@ -119,6 +120,7 @@ model = dict(
use_flash_attn=True,
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
sequence_parallel=False,
num_experts=8,
)
"""
zero1 parallel:

View File

@ -143,6 +143,7 @@ class ParallelContext(metaclass=SingletonMeta):
self.pipeline_parallel_size = 1
self.tensor_parallel_size = 1
self.zero1_parallel_size = -1
self.expert_parallel_size = -1
self.num_processes_on_current_node = -1
self.virtual_pipeline_parallel_size = None
self.virtual_pipeline_parallel_rank = None
@ -442,6 +443,9 @@ class ParallelContext(metaclass=SingletonMeta):
# instead, it should be calculated based on other parallel config
self.data_parallel_size = self.world_size // (self.pipeline_parallel_size * self.tensor_parallel_size)
# TODO : data parallel size can be different with expert parallel size
self.expert_parallel_size = self.data_parallel_size
if self.zero1_parallel_size <= 0:
self.zero1_parallel_size = self.data_parallel_size
@ -454,6 +458,7 @@ class ParallelContext(metaclass=SingletonMeta):
self.pipeline_parallel_size,
self.tensor_parallel_size,
self.zero1_parallel_size,
self.expert_parallel_size,
]
# run initialization of different process groups
@ -464,6 +469,8 @@ class ParallelContext(metaclass=SingletonMeta):
initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
if self.pipeline_parallel_size > 1:
initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args))
if self.config.model.num_experts > 1:
initializers.append(pgroup_initializer.Initializer_Expert(*initializer_args))
for initializer in initializers:
parallel_setting = initializer.init_dist_group()
if isinstance(parallel_setting, list):

View File

@ -31,6 +31,9 @@ class ParallelMode(Enum):
# zero1 parallel
ZERO1 = "zero1"
# expert parallel
EXPERT = "expert"
class ProcessGroupInitializer(ABC):
"""An object, knowing the parallelism configuration, that initializes parallel groups.
@ -42,6 +45,7 @@ class ProcessGroupInitializer(ABC):
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
zero1_parallel_size (int): Size of zero1 parallel.
expert_parallel_size (int): Size of expert parallel.
"""
def __init__(
@ -52,6 +56,7 @@ class ProcessGroupInitializer(ABC):
pipeline_parallel_size: int,
tensor_parallel_size: int,
zero1_parallel_size: int,
expert_parallel_size: int,
):
self.rank = rank
self.world_size = world_size
@ -59,6 +64,7 @@ class ProcessGroupInitializer(ABC):
self.pipeline_parallel_size = pipeline_parallel_size
self.tensor_parallel_size = tensor_parallel_size
self.zero1_parallel_size = zero1_parallel_size
self.expert_parallel_size = expert_parallel_size
super().__init__()
@abstractmethod
@ -76,6 +82,7 @@ class Initializer_Data(ProcessGroupInitializer):
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
zero1_parallel_size (int): Size of zero1 parallel.
expert_parallel_size (int): Size of expert parallel.
"""
def __init__(self, *args, **kwargs):
@ -127,6 +134,7 @@ class Initializer_Model(ProcessGroupInitializer):
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
zero1_parallel_size (int): Size of zero1 parallel.
expert_parallel_size (int): Size of expert parallel.
"""
def __init__(self, *args, **kwargs):
@ -178,6 +186,7 @@ class Initializer_Pipeline(ProcessGroupInitializer):
pipeline_parallel_size (int): Size of pipeline parallel
tensor_parallel_size (int): Size of tensor parallel
zero1_parallel_size (int): Size of zero1 parallel.
expert_parallel_size (int): Size of expert parallel.
"""
def __init__(self, *args, **kwargs):
@ -238,6 +247,7 @@ class Initializer_Tensor(ProcessGroupInitializer):
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
zero1_parallel_size (int): Size of zero1 parallel.
expert_parallel_size (int): Size of expert parallel.
"""
def __init__(self, *args, **kwargs):
@ -288,6 +298,7 @@ class Initializer_Zero1(ProcessGroupInitializer):
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
zero1_parallel_size (int): Size of zero-1 parallel.
expert_parallel_size (int): Size of expert parallel.
"""
def __init__(self, *args, **kwargs):
@ -332,3 +343,59 @@ class Initializer_Zero1(ProcessGroupInitializer):
ranks_in_group = ranks
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
class Initializer_Expert(ProcessGroupInitializer):
"""A ProcessGroupInitializer for zero-1 parallelism.
Args:
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
zero1_parallel_size (int): Size of zero-1 parallel.
expert_parallel_size (int): Size of expert parallel.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.num_expert_parallel_group = self.world_size // self.expert_parallel_size
assert self.world_size % self.num_expert_parallel_group == 0
# TODO: to match expert parallel with differnt data parallel size
assert self.data_parallel_size == self.expert_parallel_size
def init_dist_group(self, use_cpu: bool = False):
"""Initialize expert parallel groups, and assign local_ranks and groups to each gpu.
Example: world_size = 8, model_parallel_size = 2, expert_parallel_size = 4
model_parallel_group = [0,1], [2,3], [4,5], [6,7]
expert_parallel_group = [0,2,4,6], [1,3,5,7]
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
A expert parallelism's information tuple.
"""
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.EXPERT
for i in range(self.num_expert_parallel_group):
ranks = list(range(i, self.world_size, self.num_expert_parallel_group))
group = dist.new_group(ranks)
if use_cpu:
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
else:
group_cpu = None
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode

View File

@ -88,6 +88,7 @@ class NonPipelineScheduler(BaseScheduler):
forward_only: bool = False,
return_loss: bool = True,
scale_loss: int = 1,
moe_loss_coeff: float = 1.0,
):
"""Trains one batch of data.
@ -104,7 +105,7 @@ class NonPipelineScheduler(BaseScheduler):
# forward
with conditional_context(torch.no_grad(), enable=forward_only):
self._call_hooks("before_forward", data)
output = self._call_engine(engine, data)
output, moe_losses = self._call_engine(engine, data)
self._call_hooks("after_forward", output)
self._call_hooks("post_helper_func", output, label)
@ -113,7 +114,9 @@ class NonPipelineScheduler(BaseScheduler):
self._call_hooks("before_criterion", output, label)
loss = self._call_engine_criterion(engine, output, label)
self._call_hooks("after_criterion", loss)
loss /= scale_loss
moe_loss = sum(moe_losses) * moe_loss_coeff
loss += moe_loss
loss /= scale_loss ## TODO: check whether mos_loss should be scaled
# backward
if not forward_only:
@ -133,6 +136,7 @@ class NonPipelineScheduler(BaseScheduler):
forward_only: bool = False,
return_loss: bool = True,
return_output_label: bool = True,
moe_loss_coeff: float = 1.0,
):
"""The process function that loads a batch of dataset and feeds it to the model.
The returned labels and loss will None if :attr:`return_loss` is False.
@ -177,7 +181,7 @@ class NonPipelineScheduler(BaseScheduler):
_data, _label = self._load_accum_batch(data, label)
_output, _loss = self._train_one_batch(
_data, _label, engine, forward_only, return_loss, self._grad_accum_size
_data, _label, engine, forward_only, return_loss, self._grad_accum_size, moe_loss_coeff
)
if return_loss:

View File

@ -18,6 +18,7 @@ from internlm.model.linear import (
RewardModelLinear,
ScaleColumnParallelLinear,
)
from internlm.model.moe import MoE
from internlm.model.multi_head_attention import MHA
from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm
from internlm.solver.pipeline_utils import partition_uniform
@ -49,6 +50,7 @@ class PackedFlashBaseLayer1D(nn.Module):
device (Optional[Union[str, torch.device]]): The device will be used.
norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
use_flash_attn (bool): Whether use flash-attn. True by default.
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
"""
def __init__(
@ -69,6 +71,7 @@ class PackedFlashBaseLayer1D(nn.Module):
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
num_experts: int = 1,
):
super().__init__()
self.checkpoint = checkpoint
@ -101,37 +104,57 @@ class PackedFlashBaseLayer1D(nn.Module):
self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
if use_swiglu:
self.mlp = FeedForward(
## TODO: replace num_experts and epsize with function parameter
self.num_experts = num_experts
ep_size = gpc.get_world_size(ParallelMode.EXPERT)
if num_experts <= 1: # dense, not MoE
if use_swiglu:
self.mlp = FeedForward(
hidden_size,
int(hidden_size * mlp_ratio),
out_features=hidden_size,
process_group=gpc.get_group(ParallelMode.TENSOR),
bias=False,
device=device,
dtype=dtype,
)
else:
self.mlp = ParallelFusedMLP(
hidden_size,
int(hidden_size * mlp_ratio),
out_features=hidden_size,
activation="gelu_approx",
process_group=gpc.get_group(ParallelMode.TENSOR),
bias1=False,
bias2=False,
sequence_parallel=gpc.config.model.sequence_parallel,
checkpoint_lvl=0,
heuristic="auto",
device=device,
dtype=dtype,
)
else:
expert = torch.nn.ModuleList([FeedForward(
hidden_size,
int(hidden_size * mlp_ratio),
int(hidden_size * gpc.config.model.mlp_ratio),
out_features=hidden_size,
process_group=gpc.get_group(ParallelMode.TENSOR),
bias=False,
device=device,
dtype=dtype,
)
else:
self.mlp = ParallelFusedMLP(
hidden_size,
int(hidden_size * mlp_ratio),
out_features=hidden_size,
activation="gelu_approx",
process_group=gpc.get_group(ParallelMode.TENSOR),
bias1=False,
bias2=False,
sequence_parallel=gpc.config.model.sequence_parallel,
checkpoint_lvl=0,
heuristic="auto",
device=device,
dtype=dtype,
)
device=torch.device("cuda"),
dtype=torch.float,
) for i in range(num_experts // ep_size)])
# TODO: test moe for now, need more parameter such as: capacity_factor, eval_capacity_factor, min_capacity, drop_tokens
self.mlp = MoE(hidden_size=hidden_size,
expert=expert,
ep_size=ep_size,
num_experts=num_experts,
k=1)
self.dropout2 = nn.Dropout(drop_rate)
self.use_swiglu = use_swiglu
self.use_scaled_init = use_scaled_init
self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm
self.return_residual = False
self.reset_parameters()
self.reset_parameters() ## TODO: check this should be changed when moe is added
def reset_parameters(self):
with torch.no_grad():
@ -163,7 +186,7 @@ class PackedFlashBaseLayer1D(nn.Module):
if self.checkpoint and self.training:
return activation_checkpoint(
self._forward, False, hidden_states, cu_seqlens, indexes, inference_params, max_seqlen
)
) ##TODO: check whether this will be affected by moe
else:
return self._forward(hidden_states, cu_seqlens, indexes, inference_params, max_seqlen)
@ -213,9 +236,14 @@ class PackedFlashBaseLayer1D(nn.Module):
if self.residual_in_fp32:
residual = residual.to(torch.float32)
hidden_states = self.mlp(hidden_states)
# MLP.
moe_loss = torch.tensor(0.0, device=hidden_states.device, dtype=hidden_states.dtype)
if self.num_experts <= 1:
hidden_states = self.mlp(hidden_states)
else:
hidden_states, moe_loss, _ = self.mlp(hidden_states)
return hidden_states + residual
return hidden_states + residual, moe_loss
class PackedFlashInternLm1D(nn.Module):
@ -246,6 +274,7 @@ class PackedFlashInternLm1D(nn.Module):
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
use_flash_attn (bool): Whether to use flash-attn. True by default.
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
"""
@ -276,6 +305,7 @@ class PackedFlashInternLm1D(nn.Module):
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
num_experts: bool = 1,
):
super().__init__()
@ -327,6 +357,7 @@ class PackedFlashInternLm1D(nn.Module):
use_scaled_init=use_scaled_init,
use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn,
num_experts=num_experts,
)
for lid in range(num_layers)
]
@ -374,14 +405,16 @@ class PackedFlashInternLm1D(nn.Module):
indexes = indexes[0]
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None
moe_losses = []
for _, block in enumerate(self.blocks):
hidden_states = block(
hidden_states, mos_loss = block(
hidden_states,
cu_seqlens=cu_seqlens,
indexes=indexes,
inference_params=inference_params,
max_seqlen=max_seqlen,
)
moe_losses.append(mos_loss)
if hasattr(self, "norm"):
hidden_states = self.norm(hidden_states.float())
@ -390,7 +423,7 @@ class PackedFlashInternLm1D(nn.Module):
if not self.parallel_output:
hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
return hidden_states
return hidden_states, moe_losses
def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs):
@ -462,6 +495,7 @@ def build_model_with_cfg(
use_swiglu: bool = True,
use_flash_attn: bool = True,
sequence_parallel: bool = False,
num_experts: int = 1,
):
"""
Builde model with config
@ -492,6 +526,7 @@ def build_model_with_cfg(
use_scaled_init (bool): Whether to use scaled init. True by default.
use_swiglu (bool): Whether to use swiglu. True by default.
use_flash_attn (bool): Whether to use flash-attn. True by default.
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
"""
@ -515,6 +550,7 @@ def build_model_with_cfg(
use_scaled_init=use_scaled_init,
use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn,
num_experts=num_experts,
)
return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)

111
internlm/model/moe.py Normal file
View File

@ -0,0 +1,111 @@
from internlm.moe.sharded_moe import MOELayer, TopKGate
from internlm.moe.experts import Experts
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.utils.logger import get_logger
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import typing
# global llm logger
logger = get_logger(__file__)
def has_moe_layers(m):
has_moe = False
num_experts = 0
for _, module in m.named_modules():
if isinstance(module, MoE):
has_moe = True
num_experts = module.num_experts
break
return has_moe, num_experts
def is_moe_param(param: torch.Tensor) -> bool:
if hasattr(param, "allreduce") and not param.allreduce:
return True
return False
class MoE(torch.nn.Module):
"""Initialize an MoE layer.
Arguments:
hidden_size (int): the hidden dimension of the model, importantly this is also the input and output dimension.
expert (torch.nn.Module): the torch module that defines the expert (e.g., MLP, torch.linear).
num_experts (int, optional): default=1, the total number of experts per layer.
ep_size (int, optional): default=1, number of ranks in the expert parallel world or group.
k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample' or 'None'.
drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to infinite capacity).
use_rts (bool, optional): default=True, whether to use Random Token Selection.
"""
def __init__(self,
hidden_size,
expert,
num_experts=1,
ep_size=1,
k=1,
capacity_factor=1.,
eval_capacity_factor=1.,
min_capacity=4,
noisy_gate_policy: typing.Optional[str] = None,
drop_tokens: bool = True,
use_rts: bool = True,
using_default_moe: bool = True):
super(MoE, self).__init__()
assert num_experts % ep_size == 0, f"Number of experts ({num_experts}) should be divisible by expert parallel size ({ep_size})"
self.ep_size = ep_size
self.num_experts = num_experts
self.num_local_experts = num_experts // self.ep_size
logger.info(
f'Creating MoE layer with num_experts: {num_experts} | num_local_experts: {self.num_local_experts} | expert_parallel_size: {self.ep_size}')
assert noisy_gate_policy is None or noisy_gate_policy in ['None', 'Jitter', 'RSample'], \
'Unsupported noisy_gate_policy: ' + noisy_gate_policy
experts = Experts(expert, self.num_local_experts)
if using_default_moe:
self.moe_layer = MOELayer(TopKGate(hidden_size, num_experts, k, capacity_factor, eval_capacity_factor,
min_capacity, noisy_gate_policy, drop_tokens, use_rts),
experts,
gpc.get_group(ParallelMode.EXPERT),
self.ep_size,
self.num_local_experts)
def forward(self, hidden_states, used_token=None):
""" MoE forward
Arguments:
hidden_states (Tensor): input to the layer
used_token (Tensor, optional): default: None, mask only used tokens
Returns:
A tuple including output, gate loss, and expert count.
* output (Tensor): output of the model
* l_aux (Tensor): gate loss value
* exp_counts (int): expert count
"""
output = self.moe_layer(hidden_states, used_token)
return output, self.moe_layer.l_aux, self.moe_layer.exp_counts

49
internlm/moe/experts.py Normal file
View File

@ -0,0 +1,49 @@
"""
The file has been adapted from the following files:
https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py
Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555
We retain the following license from the original files:
"""
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import copy
from torch.nn import Module, ModuleList
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast
class Experts(torch.nn.Module):
def __init__(self, experts: Union[Module, ModuleList], num_local_experts=1):
super(Experts, self).__init__()
# TODO: We can not deepcopy FeedForward since it contains a process_group in submodules
# self.experts = torch.nn.ModuleList([copy.deepcopy(expert) for i in range(num_local_experts)])
if type(experts) == ModuleList:
self.experts = cast(ModuleList, experts)
else:
self.experts = ModuleList([experts])
self.num_local_experts = num_local_experts
# TODO: revisit allreduce for moe.gate...
for expert in self.experts:
# TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group)
for name, param in expert.named_parameters():
param.all_reduce = False
def forward(self, inputs):
chunks = inputs.chunk(self.num_local_experts, dim=1)
expert_outputs = []
for chunk, expert in zip(chunks, self.experts):
out = expert(chunk)
if type(out) is tuple:
out = out[0] # Ignore the bias term for now
expert_outputs += [out]
expert_output = torch.cat(expert_outputs, dim=1)
return expert_output

503
internlm/moe/sharded_moe.py Normal file
View File

@ -0,0 +1,503 @@
import torch.distributed as dist
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.core.context import global_context as gpc
from internlm.core.context import ParallelMode
# global llm logger
logger = get_logger(__file__)
"""
The file has been adapted from the following files:
https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py
Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555
We retain the following license from the original files:
"""
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import Callable, Dict, TYPE_CHECKING, Any, Optional, Tuple
import torch
from torch import Tensor
from torch.nn import Module
import torch.nn.functional as F
if TYPE_CHECKING:
Base = Module[Tensor]
else:
Base = Module
uniform_map: Dict[torch.device, Callable] = {}
gumbel_map: Dict[torch.device, Callable] = {}
exp_selection_uniform_map: Dict[torch.device, Callable] = {}
def multiplicative_jitter(x, device: torch.device, epsilon=1e-2):
"""
Modified from switch transformer paper. mesh transformers
Multiply values by a random number between 1-epsilon and 1+epsilon.
Makes models more resilient to rounding errors introduced by bfloat16.
This seems particularly important for logits.
Args:
x: a torch.tensor
device: torch.device
epsilon: a floating point value
Returns:
a jittered x.
"""
if epsilon == 0:
return x
uniform = uniform_map.get(device)
if uniform is None:
uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - epsilon, device=device),
high=torch.tensor(1.0 + epsilon,
device=device)).rsample # type: ignore
uniform_map[device] = uniform
return x * uniform(x.shape)
def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
gumbel = gumbel_map.get(device)
if gumbel is None:
one = torch.tensor(1.0, device=device)
zero = torch.tensor(0.0, device=device)
gumbel = torch.distributions.gumbel.Gumbel(zero, one).rsample # type: ignore
gumbel_map[device] = gumbel
return gumbel(shape)
# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
# See https://arxiv.org/pdf/2006.16668.pdf for details.
# Based on https://github.com/pytorch/pytorch/pull/40762
class _AllToAll(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
# TODO: replace with DS process group
group: torch.distributed.ProcessGroup,
input: Tensor) -> Tensor: # type: ignore
ctx.group = group
input = input.contiguous()
output = torch.empty_like(input)
dist.all_to_all_single(output, input, group=group)
return output
@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
return (None, _AllToAll.apply(ctx.group, *grad_output))
# einsum rewrites are on par or more performant
# switch can be bubbled up in future
USE_EINSUM = True
# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
# See https://arxiv.org/pdf/2006.16668.pdf for details.
def einsum(rule, a, b):
if USE_EINSUM:
return torch.einsum(rule, a, b)
elif rule == 's,se->se':
## [1, s] * [s, e]
return a.reshape(a.shape[0], -1) * b
elif rule == 'se,sc->sec':
## [s,e,1] * [s,1,c]
return a.unsqueeze(2) * b.unsqueeze(1)
elif rule == 'se,se->s':
## [s,1,e] * [s,e,1]
return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1)
elif rule == 'sec,sm->ecm':
## [e*c, s] * [s, m]
s = a.shape[0]
e = a.shape[1]
c = a.shape[2]
m = b.shape[1]
return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m)
elif rule == 'sec,ecm->sm':
## [s, e*c] * [e*c, m]
return torch.matmul(a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1]))
elif rule == 'ks,ksm->sm':
k = b.shape[0]
s = b.shape[1]
m = b.shape[2]
# [k, s] -> [s, k] -> [s, 1, k]
a = a.t().unsqueeze(1)
# [k,s,m] -> [k, sm] -> [sm, k] -> [s, m, k]
b = b.reshape(k, -1).t().reshape(s, m, k)
# bmm([s, 1, k], [s, m, k]^t) -> [s, m, 1]
return torch.bmm(a, b.transpose(1, 2)).squeeze(2)
else:
return torch.einsum(rule, a, b)
# The following functions are extracted and scripted
# because otherwise during a torch.jit.trace, the non-Tensor
# values used in the calculations get recorded as constants.
# torch.jit.script coerces them into Tensors and preserves
# their dynamic shapes. This enables ONNX export.
# We can't script the entire top1gating function because it
# includes stateful caching logic which is incompatible with ONNX.
@torch.jit.script
def _capacity(gates: Tensor, capacity_factor: Tensor, min_capacity: Tensor) -> Tensor:
# gates has shape of SE
num_tokens = gates.shape[0]
num_experts = gates.shape[1]
# to(torch.int64) works around a bug in torch.onnx.export:
# it should cast k to int64 when converting torch.topk but it doesn't.
capacity = torch.ceil((num_tokens / num_experts) * capacity_factor).to(torch.int64)
if capacity < min_capacity:
capacity = min_capacity.to(torch.int64)
return capacity
@torch.jit.script
def _top_idx(source, k):
return torch.topk(source, k=k, dim=0)[1]
@torch.jit.script
def _one_hot_to_float(x, num_classes):
return F.one_hot(x, num_classes=num_classes).float()
def top1gating(logits: Tensor,
capacity_factor: float,
min_capacity: int,
used_token: Tensor = None,
noisy_gate_policy: Optional[str] = None,
drop_tokens: bool = True,
use_rts: bool = True,
use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Implements Top1Gating on logits."""
if noisy_gate_policy == 'RSample':
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
# everything is in fp32 in this function
gates = F.softmax(logits, dim=1)
capacity = _capacity(gates, torch.tensor(capacity_factor), torch.tensor(min_capacity))
# Create a mask for 1st's expert per token
# noisy gating
indices1_s = torch.argmax(logits_w_noise if noisy_gate_policy == 'RSample' else gates, dim=1)
num_experts = int(gates.shape[1])
mask1 = F.one_hot(indices1_s, num_classes=num_experts)
# mask only used tokens
if used_token is not None:
mask1 = einsum("s,se->se", used_token, mask1)
# gating decisions
exp_counts = torch.sum(mask1, dim=0).detach().to('cpu')
# if we don't want to drop any tokens
if not drop_tokens:
new_capacity = torch.max(exp_counts).to(logits.device)
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group())
capacity = new_capacity
# Compute l_aux
me = torch.mean(gates, dim=0)
ce = torch.mean(mask1.float(), dim=0)
l_aux = torch.sum(me * ce) * num_experts
# Random Token Selection
if use_rts:
uniform = exp_selection_uniform_map.get(logits.device)
if uniform is None:
uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=logits.device),
high=torch.tensor(1.0, device=logits.device)).rsample
exp_selection_uniform_map[logits.device] = uniform
mask1_rand = mask1 * uniform(mask1.shape)
else:
mask1_rand = mask1
assert logits.shape[
0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size."
top_idx = _top_idx(mask1_rand, capacity) #@wenwen: token index
new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1)
mask1 = new_mask1
if use_tutel:
# Tutel doesn't support index values masked with zero
# so we need to replace masked indices with -1
indices_mask = mask1.sum(dim=1) * num_experts - 1
indices1_s = torch.min(indices1_s, indices_mask)
# Compute locations in capacity buffer
locations1 = torch.cumsum(mask1, dim=0) - 1
if use_tutel:
gates1_s = (gates * mask1).sum(dim=1)
locations1_s = torch.sum(locations1 * mask1, dim=1)
return l_aux, capacity, num_experts, [
indices1_s,
], [
locations1_s,
], [
gates1_s,
], exp_counts
# Store the capacity location for each token
locations1_s = torch.sum(locations1 * mask1, dim=1)
# Normalize gate probabilities
mask1_float = mask1.float()
gates = gates * mask1_float
locations1_sc = _one_hot_to_float(locations1_s, capacity)
combine_weights = einsum("se,sc->sec", gates, locations1_sc)
dispatch_mask = combine_weights.bool()
return l_aux, combine_weights, dispatch_mask, exp_counts
def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Implements Top2Gating on logits."""
# everything is in fp32 in this function
gates = F.softmax(logits, dim=1)
capacity = _capacity(gates, torch.tensor(capacity_factor * 2), torch.tensor(min_capacity))
# Create a mask for 1st's expert per token
indices1_s = torch.argmax(gates, dim=1)
num_experts = int(gates.shape[1])
mask1 = F.one_hot(indices1_s, num_classes=num_experts)
# Create a mask for 2nd's expert per token using Gumbel-max trick
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
# Replace top-expert with min value
logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf"))
indices2_s = torch.argmax(logits_except1, dim=1)
mask2 = F.one_hot(indices2_s, num_classes=num_experts)
# Compute locations in capacity buffer
locations1 = torch.cumsum(mask1, dim=0) - 1
locations2 = torch.cumsum(mask2, dim=0) - 1
# Update 2nd's location by accounting for locations of 1st
locations2 += torch.sum(mask1, dim=0, keepdim=True)
# gating decisions
exp_counts = torch.sum(mask1, dim=0).detach().to('cpu')
# Compute l_aux
me = torch.mean(gates, dim=0)
ce = torch.mean(mask1.float(), dim=0)
l_aux = torch.mean(me * ce) * num_experts * num_experts
# Remove locations outside capacity from mask
mask1 *= torch.lt(locations1, capacity)
mask2 *= torch.lt(locations2, capacity)
# Store the capacity location for each token
locations1_s = torch.sum(locations1 * mask1, dim=1)
locations2_s = torch.sum(locations2 * mask2, dim=1)
# Normalize gate probabilities
mask1_float = mask1.float()
mask2_float = mask2.float()
gates1_s = einsum("se,se->s", gates, mask1_float)
gates2_s = einsum("se,se->s", gates, mask2_float)
denom_s = gates1_s + gates2_s
# Avoid divide-by-zero
denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
gates1_s /= denom_s
gates2_s /= denom_s
# Calculate combine_weights and dispatch_mask
gates1 = einsum("s,se->se", gates1_s, mask1_float)
gates2 = einsum("s,se->se", gates2_s, mask2_float)
locations1_sc = _one_hot_to_float(locations1_s, capacity)
locations2_sc = _one_hot_to_float(locations2_s, capacity)
combine1_sec = einsum("se,sc->sec", gates1, locations1_sc)
combine2_sec = einsum("se,sc->sec", gates2, locations2_sc)
combine_weights = combine1_sec + combine2_sec
dispatch_mask = combine_weights.bool()
return l_aux, combine_weights, dispatch_mask, exp_counts
class TopKGate(Module):
"""Gate module which implements Top2Gating as described in Gshard_.
::
gate = TopKGate(model_dim, num_experts)
l_aux, combine_weights, dispatch_mask = gate(input)
.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
Args:
model_dim (int):
size of model embedding dimension
num_experts (ints):
number of experts in model
"""
wg: torch.nn.Linear
def __init__(self,
model_dim: int,
num_experts: int,
k: int = 1,
capacity_factor: float = 1.0,
eval_capacity_factor: float = 1.0,
min_capacity: int = 8,
noisy_gate_policy: Optional[str] = None,
drop_tokens: bool = True,
use_rts: bool = True) -> None:
super().__init__()
# Only top-1 and top-2 are supported at the moment.
if k != 1 and k != 2:
raise ValueError('Only top-1 and top-2 gatings are supported.')
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float()
self.k = k
self.capacity_factor = capacity_factor
self.eval_capacity_factor = eval_capacity_factor
self.min_capacity = min_capacity
self.noisy_gate_policy = noisy_gate_policy
self.wall_clock_breakdown = False
self.gate_time = 0.0
self.drop_tokens = drop_tokens
self.use_rts = use_rts
def forward(self,
input: torch.Tensor,
used_token: torch.Tensor = None,
use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore
if self.wall_clock_breakdown:
timer('TopKGate').start()
if self.wg.weight.dtype != torch.float32:
self.wg = self.wg.float()
input_fp32 = input.float()
# input jittering
if self.noisy_gate_policy == 'Jitter' and self.training:
input_fp32 = multiplicative_jitter(input_fp32, device=input.device)
logits = self.wg(input_fp32)
if self.k == 1:
gate_output = top1gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
self.min_capacity, used_token, self.noisy_gate_policy if self.training else None,
self.drop_tokens, self.use_rts, use_tutel)
else:
gate_output = top2gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
self.min_capacity)
if self.wall_clock_breakdown:
timer('TopKGate').stop()
self.gate_time = timer('TopKGate').elapsed(reset=False)
return gate_output
class MOELayer(Base):
"""MOELayer module which implements MixtureOfExperts as described in Gshard_.
::
gate = TopKGate(model_dim, num_experts)
moe = MOELayer(gate, expert)
output = moe(input)
l_aux = moe.l_aux
.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
Args:
gate (torch.nn.Module):
gate network
expert (torch.nn.Module):
expert network
"""
def __init__(self,
gate: Module,
experts: Module,
ep_group,
ep_size,
num_local_experts: int) -> None:
super().__init__()
self.gate = gate
self.experts = experts
self.ep_group = ep_group
self.ep_size = ep_size
self.num_local_experts = num_local_experts
self.time_falltoall = 0.0
self.time_salltoall = 0.0
self.time_moe = 0.0
self.wall_clock_breakdown = False
def _set_ep_group(self, ep_group):
self.ep_group = ep_group
def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
if self.wall_clock_breakdown:
timer('moe').start()
# Implement Algorithm 2 from GShard paper.
d_model = input[0].shape[-1]
# Initial implementation -> Reshape into S tokens by dropping sequence dimension.
# Reshape into G groups so that each group can distribute tokens equally
# group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1
reshaped_input = input[0].reshape(-1, d_model)
self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input, input[1])
dispatched_input = einsum("sec,sm->ecm", dispatch_mask.type_as(input[0]), reshaped_input) ## TODO: heavy memory usage due to long sequence length
if self.wall_clock_breakdown:
timer('falltoall').start()
dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input)
if self.wall_clock_breakdown:
timer('falltoall').stop()
self.time_falltoall = timer('falltoall').elapsed(reset=False)
# Re-shape after all-to-all: ecm -> gecm
dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model)
expert_output = self.experts(dispatched_input)
if self.wall_clock_breakdown:
timer('salltoall').start()
expert_output = _AllToAll.apply(self.ep_group, expert_output)
if self.wall_clock_breakdown:
timer('salltoall').stop()
self.time_salltoall = timer('salltoall').elapsed(reset=False)
# Re-shape back: gecm -> ecm
expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)
combined_output = einsum("sec,ecm->sm", combine_weights.type_as(input[0]), expert_output)
a = combined_output.reshape(input[0].shape)
if self.wall_clock_breakdown:
timer('moe').stop()
self.time_moe = timer('moe').elapsed(reset=False)
return a

View File

@ -25,6 +25,7 @@ from internlm.solver.optimizer.utils import (
split_half_float_double,
sync_param,
)
from internlm.model.moe import is_moe_param
from internlm.utils.common import get_current_device
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
@ -285,7 +286,7 @@ class HybridZeroOptimizer(BaseOptimizer):
for group_id in range(self.num_param_groups):
param_group = self._fp16_param_groups[group_id]
for param in param_group:
if param.requires_grad:
if param.requires_grad and not is_moe_param(param):
reduce_rank = None
def _define_and_attach(param, reduce_rank=None):

View File

@ -580,7 +580,7 @@ def main(args):
# do forward and backward
timer("fwd-bwd").start()
_, _, loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False)
_, _, loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False, moe_loss_coeff = gpc.config.loss.moe_loss_coeff)
timer("fwd-bwd").stop()
# update parameters, and returns (success_update, grad_norm)