mirror of https://github.com/InternLM/InternLM
feat(XXX): add moe
parent
c219065348
commit
c357288a8b
|
|
@ -77,6 +77,7 @@ hybrid_zero_optimizer = dict(
|
|||
|
||||
loss = dict(
|
||||
label_smoothing=0,
|
||||
moe_loss_coeff=1.0,
|
||||
)
|
||||
|
||||
adam = dict(
|
||||
|
|
@ -119,6 +120,7 @@ model = dict(
|
|||
use_flash_attn=True,
|
||||
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
|
||||
sequence_parallel=False,
|
||||
num_experts=8,
|
||||
)
|
||||
"""
|
||||
zero1 parallel:
|
||||
|
|
|
|||
|
|
@ -143,6 +143,7 @@ class ParallelContext(metaclass=SingletonMeta):
|
|||
self.pipeline_parallel_size = 1
|
||||
self.tensor_parallel_size = 1
|
||||
self.zero1_parallel_size = -1
|
||||
self.expert_parallel_size = -1
|
||||
self.num_processes_on_current_node = -1
|
||||
self.virtual_pipeline_parallel_size = None
|
||||
self.virtual_pipeline_parallel_rank = None
|
||||
|
|
@ -442,6 +443,9 @@ class ParallelContext(metaclass=SingletonMeta):
|
|||
# instead, it should be calculated based on other parallel config
|
||||
self.data_parallel_size = self.world_size // (self.pipeline_parallel_size * self.tensor_parallel_size)
|
||||
|
||||
# TODO : data parallel size can be different with expert parallel size
|
||||
self.expert_parallel_size = self.data_parallel_size
|
||||
|
||||
if self.zero1_parallel_size <= 0:
|
||||
self.zero1_parallel_size = self.data_parallel_size
|
||||
|
||||
|
|
@ -454,6 +458,7 @@ class ParallelContext(metaclass=SingletonMeta):
|
|||
self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size,
|
||||
self.zero1_parallel_size,
|
||||
self.expert_parallel_size,
|
||||
]
|
||||
|
||||
# run initialization of different process groups
|
||||
|
|
@ -464,6 +469,8 @@ class ParallelContext(metaclass=SingletonMeta):
|
|||
initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
|
||||
if self.pipeline_parallel_size > 1:
|
||||
initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args))
|
||||
if self.config.model.num_experts > 1:
|
||||
initializers.append(pgroup_initializer.Initializer_Expert(*initializer_args))
|
||||
for initializer in initializers:
|
||||
parallel_setting = initializer.init_dist_group()
|
||||
if isinstance(parallel_setting, list):
|
||||
|
|
|
|||
|
|
@ -31,6 +31,9 @@ class ParallelMode(Enum):
|
|||
# zero1 parallel
|
||||
ZERO1 = "zero1"
|
||||
|
||||
# expert parallel
|
||||
EXPERT = "expert"
|
||||
|
||||
|
||||
class ProcessGroupInitializer(ABC):
|
||||
"""An object, knowing the parallelism configuration, that initializes parallel groups.
|
||||
|
|
@ -42,6 +45,7 @@ class ProcessGroupInitializer(ABC):
|
|||
pipeline_parallel_size (int): Size of pipeline parallel.
|
||||
tensor_parallel_size (int): Size of tensor parallel.
|
||||
zero1_parallel_size (int): Size of zero1 parallel.
|
||||
expert_parallel_size (int): Size of expert parallel.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -52,6 +56,7 @@ class ProcessGroupInitializer(ABC):
|
|||
pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int,
|
||||
zero1_parallel_size: int,
|
||||
expert_parallel_size: int,
|
||||
):
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
|
|
@ -59,6 +64,7 @@ class ProcessGroupInitializer(ABC):
|
|||
self.pipeline_parallel_size = pipeline_parallel_size
|
||||
self.tensor_parallel_size = tensor_parallel_size
|
||||
self.zero1_parallel_size = zero1_parallel_size
|
||||
self.expert_parallel_size = expert_parallel_size
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
|
|
@ -76,6 +82,7 @@ class Initializer_Data(ProcessGroupInitializer):
|
|||
pipeline_parallel_size (int): Size of pipeline parallel.
|
||||
tensor_parallel_size (int): Size of tensor parallel.
|
||||
zero1_parallel_size (int): Size of zero1 parallel.
|
||||
expert_parallel_size (int): Size of expert parallel.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
|
@ -127,6 +134,7 @@ class Initializer_Model(ProcessGroupInitializer):
|
|||
pipeline_parallel_size (int): Size of pipeline parallel.
|
||||
tensor_parallel_size (int): Size of tensor parallel.
|
||||
zero1_parallel_size (int): Size of zero1 parallel.
|
||||
expert_parallel_size (int): Size of expert parallel.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
|
@ -178,6 +186,7 @@ class Initializer_Pipeline(ProcessGroupInitializer):
|
|||
pipeline_parallel_size (int): Size of pipeline parallel
|
||||
tensor_parallel_size (int): Size of tensor parallel
|
||||
zero1_parallel_size (int): Size of zero1 parallel.
|
||||
expert_parallel_size (int): Size of expert parallel.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
|
@ -238,6 +247,7 @@ class Initializer_Tensor(ProcessGroupInitializer):
|
|||
pipeline_parallel_size (int): Size of pipeline parallel.
|
||||
tensor_parallel_size (int): Size of tensor parallel.
|
||||
zero1_parallel_size (int): Size of zero1 parallel.
|
||||
expert_parallel_size (int): Size of expert parallel.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
|
@ -288,6 +298,7 @@ class Initializer_Zero1(ProcessGroupInitializer):
|
|||
pipeline_parallel_size (int): Size of pipeline parallel.
|
||||
tensor_parallel_size (int): Size of tensor parallel.
|
||||
zero1_parallel_size (int): Size of zero-1 parallel.
|
||||
expert_parallel_size (int): Size of expert parallel.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
|
@ -332,3 +343,59 @@ class Initializer_Zero1(ProcessGroupInitializer):
|
|||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
||||
|
||||
class Initializer_Expert(ProcessGroupInitializer):
|
||||
"""A ProcessGroupInitializer for zero-1 parallelism.
|
||||
|
||||
Args:
|
||||
rank (int): The rank of current process.
|
||||
world_size (int): Size of whole communication world.
|
||||
data_parallel_size (int): Size of data parallel.
|
||||
pipeline_parallel_size (int): Size of pipeline parallel.
|
||||
tensor_parallel_size (int): Size of tensor parallel.
|
||||
zero1_parallel_size (int): Size of zero-1 parallel.
|
||||
expert_parallel_size (int): Size of expert parallel.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.num_expert_parallel_group = self.world_size // self.expert_parallel_size
|
||||
|
||||
assert self.world_size % self.num_expert_parallel_group == 0
|
||||
|
||||
# TODO: to match expert parallel with differnt data parallel size
|
||||
assert self.data_parallel_size == self.expert_parallel_size
|
||||
|
||||
def init_dist_group(self, use_cpu: bool = False):
|
||||
"""Initialize expert parallel groups, and assign local_ranks and groups to each gpu.
|
||||
|
||||
Example: world_size = 8, model_parallel_size = 2, expert_parallel_size = 4
|
||||
model_parallel_group = [0,1], [2,3], [4,5], [6,7]
|
||||
expert_parallel_group = [0,2,4,6], [1,3,5,7]
|
||||
|
||||
Returns:
|
||||
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
|
||||
A expert parallelism's information tuple.
|
||||
"""
|
||||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
cpu_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.EXPERT
|
||||
|
||||
for i in range(self.num_expert_parallel_group):
|
||||
ranks = list(range(i, self.world_size, self.num_expert_parallel_group))
|
||||
group = dist.new_group(ranks)
|
||||
if use_cpu:
|
||||
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
|
||||
else:
|
||||
group_cpu = None
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
cpu_group = group_cpu
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
||||
|
|
|
|||
|
|
@ -88,6 +88,7 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
forward_only: bool = False,
|
||||
return_loss: bool = True,
|
||||
scale_loss: int = 1,
|
||||
moe_loss_coeff: float = 1.0,
|
||||
):
|
||||
"""Trains one batch of data.
|
||||
|
||||
|
|
@ -104,7 +105,7 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
# forward
|
||||
with conditional_context(torch.no_grad(), enable=forward_only):
|
||||
self._call_hooks("before_forward", data)
|
||||
output = self._call_engine(engine, data)
|
||||
output, moe_losses = self._call_engine(engine, data)
|
||||
self._call_hooks("after_forward", output)
|
||||
|
||||
self._call_hooks("post_helper_func", output, label)
|
||||
|
|
@ -113,7 +114,9 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
self._call_hooks("before_criterion", output, label)
|
||||
loss = self._call_engine_criterion(engine, output, label)
|
||||
self._call_hooks("after_criterion", loss)
|
||||
loss /= scale_loss
|
||||
moe_loss = sum(moe_losses) * moe_loss_coeff
|
||||
loss += moe_loss
|
||||
loss /= scale_loss ## TODO: check whether mos_loss should be scaled
|
||||
|
||||
# backward
|
||||
if not forward_only:
|
||||
|
|
@ -133,6 +136,7 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
forward_only: bool = False,
|
||||
return_loss: bool = True,
|
||||
return_output_label: bool = True,
|
||||
moe_loss_coeff: float = 1.0,
|
||||
):
|
||||
"""The process function that loads a batch of dataset and feeds it to the model.
|
||||
The returned labels and loss will None if :attr:`return_loss` is False.
|
||||
|
|
@ -177,7 +181,7 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
_data, _label = self._load_accum_batch(data, label)
|
||||
|
||||
_output, _loss = self._train_one_batch(
|
||||
_data, _label, engine, forward_only, return_loss, self._grad_accum_size
|
||||
_data, _label, engine, forward_only, return_loss, self._grad_accum_size, moe_loss_coeff
|
||||
)
|
||||
|
||||
if return_loss:
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from internlm.model.linear import (
|
|||
RewardModelLinear,
|
||||
ScaleColumnParallelLinear,
|
||||
)
|
||||
from internlm.model.moe import MoE
|
||||
from internlm.model.multi_head_attention import MHA
|
||||
from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm
|
||||
from internlm.solver.pipeline_utils import partition_uniform
|
||||
|
|
@ -49,6 +50,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
device (Optional[Union[str, torch.device]]): The device will be used.
|
||||
norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
|
||||
use_flash_attn (bool): Whether use flash-attn. True by default.
|
||||
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -69,6 +71,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
use_scaled_init: bool = True,
|
||||
use_swiglu: bool = True,
|
||||
use_flash_attn: bool = True,
|
||||
num_experts: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
|
|
@ -101,37 +104,57 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
|
||||
if use_swiglu:
|
||||
self.mlp = FeedForward(
|
||||
## TODO: replace num_experts and epsize with function parameter
|
||||
self.num_experts = num_experts
|
||||
ep_size = gpc.get_world_size(ParallelMode.EXPERT)
|
||||
if num_experts <= 1: # dense, not MoE
|
||||
if use_swiglu:
|
||||
self.mlp = FeedForward(
|
||||
hidden_size,
|
||||
int(hidden_size * mlp_ratio),
|
||||
out_features=hidden_size,
|
||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
else:
|
||||
self.mlp = ParallelFusedMLP(
|
||||
hidden_size,
|
||||
int(hidden_size * mlp_ratio),
|
||||
out_features=hidden_size,
|
||||
activation="gelu_approx",
|
||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||
bias1=False,
|
||||
bias2=False,
|
||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||
checkpoint_lvl=0,
|
||||
heuristic="auto",
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
else:
|
||||
expert = torch.nn.ModuleList([FeedForward(
|
||||
hidden_size,
|
||||
int(hidden_size * mlp_ratio),
|
||||
int(hidden_size * gpc.config.model.mlp_ratio),
|
||||
out_features=hidden_size,
|
||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
else:
|
||||
self.mlp = ParallelFusedMLP(
|
||||
hidden_size,
|
||||
int(hidden_size * mlp_ratio),
|
||||
out_features=hidden_size,
|
||||
activation="gelu_approx",
|
||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||
bias1=False,
|
||||
bias2=False,
|
||||
sequence_parallel=gpc.config.model.sequence_parallel,
|
||||
checkpoint_lvl=0,
|
||||
heuristic="auto",
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
device=torch.device("cuda"),
|
||||
dtype=torch.float,
|
||||
) for i in range(num_experts // ep_size)])
|
||||
# TODO: test moe for now, need more parameter such as: capacity_factor, eval_capacity_factor, min_capacity, drop_tokens
|
||||
self.mlp = MoE(hidden_size=hidden_size,
|
||||
expert=expert,
|
||||
ep_size=ep_size,
|
||||
num_experts=num_experts,
|
||||
k=1)
|
||||
self.dropout2 = nn.Dropout(drop_rate)
|
||||
self.use_swiglu = use_swiglu
|
||||
self.use_scaled_init = use_scaled_init
|
||||
self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm
|
||||
self.return_residual = False
|
||||
self.reset_parameters()
|
||||
self.reset_parameters() ## TODO: check this should be changed when moe is added
|
||||
|
||||
def reset_parameters(self):
|
||||
with torch.no_grad():
|
||||
|
|
@ -163,7 +186,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
if self.checkpoint and self.training:
|
||||
return activation_checkpoint(
|
||||
self._forward, False, hidden_states, cu_seqlens, indexes, inference_params, max_seqlen
|
||||
)
|
||||
) ##TODO: check whether this will be affected by moe
|
||||
else:
|
||||
return self._forward(hidden_states, cu_seqlens, indexes, inference_params, max_seqlen)
|
||||
|
||||
|
|
@ -213,9 +236,14 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
if self.residual_in_fp32:
|
||||
residual = residual.to(torch.float32)
|
||||
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
# MLP.
|
||||
moe_loss = torch.tensor(0.0, device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
if self.num_experts <= 1:
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
else:
|
||||
hidden_states, moe_loss, _ = self.mlp(hidden_states)
|
||||
|
||||
return hidden_states + residual
|
||||
return hidden_states + residual, moe_loss
|
||||
|
||||
|
||||
class PackedFlashInternLm1D(nn.Module):
|
||||
|
|
@ -246,6 +274,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
|
||||
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
|
||||
use_flash_attn (bool): Whether to use flash-attn. True by default.
|
||||
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
|
||||
|
||||
"""
|
||||
|
||||
|
|
@ -276,6 +305,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
use_scaled_init: bool = True,
|
||||
use_swiglu: bool = True,
|
||||
use_flash_attn: bool = True,
|
||||
num_experts: bool = 1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
|
@ -327,6 +357,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
use_scaled_init=use_scaled_init,
|
||||
use_swiglu=use_swiglu,
|
||||
use_flash_attn=use_flash_attn,
|
||||
num_experts=num_experts,
|
||||
)
|
||||
for lid in range(num_layers)
|
||||
]
|
||||
|
|
@ -374,14 +405,16 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
indexes = indexes[0]
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None
|
||||
|
||||
moe_losses = []
|
||||
for _, block in enumerate(self.blocks):
|
||||
hidden_states = block(
|
||||
hidden_states, mos_loss = block(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
indexes=indexes,
|
||||
inference_params=inference_params,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
moe_losses.append(mos_loss)
|
||||
|
||||
if hasattr(self, "norm"):
|
||||
hidden_states = self.norm(hidden_states.float())
|
||||
|
|
@ -390,7 +423,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
|
||||
if not self.parallel_output:
|
||||
hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
|
||||
return hidden_states
|
||||
return hidden_states, moe_losses
|
||||
|
||||
|
||||
def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs):
|
||||
|
|
@ -462,6 +495,7 @@ def build_model_with_cfg(
|
|||
use_swiglu: bool = True,
|
||||
use_flash_attn: bool = True,
|
||||
sequence_parallel: bool = False,
|
||||
num_experts: int = 1,
|
||||
):
|
||||
"""
|
||||
Builde model with config
|
||||
|
|
@ -492,6 +526,7 @@ def build_model_with_cfg(
|
|||
use_scaled_init (bool): Whether to use scaled init. True by default.
|
||||
use_swiglu (bool): Whether to use swiglu. True by default.
|
||||
use_flash_attn (bool): Whether to use flash-attn. True by default.
|
||||
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
|
||||
|
||||
"""
|
||||
|
||||
|
|
@ -515,6 +550,7 @@ def build_model_with_cfg(
|
|||
use_scaled_init=use_scaled_init,
|
||||
use_swiglu=use_swiglu,
|
||||
use_flash_attn=use_flash_attn,
|
||||
num_experts=num_experts,
|
||||
)
|
||||
|
||||
return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,111 @@
|
|||
|
||||
from internlm.moe.sharded_moe import MOELayer, TopKGate
|
||||
from internlm.moe.experts import Experts
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
|
||||
from internlm.utils.logger import get_logger
|
||||
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
import typing
|
||||
|
||||
# global llm logger
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
def has_moe_layers(m):
|
||||
has_moe = False
|
||||
num_experts = 0
|
||||
|
||||
for _, module in m.named_modules():
|
||||
if isinstance(module, MoE):
|
||||
has_moe = True
|
||||
num_experts = module.num_experts
|
||||
break
|
||||
return has_moe, num_experts
|
||||
|
||||
|
||||
def is_moe_param(param: torch.Tensor) -> bool:
|
||||
if hasattr(param, "allreduce") and not param.allreduce:
|
||||
return True
|
||||
return False
|
||||
|
||||
class MoE(torch.nn.Module):
|
||||
"""Initialize an MoE layer.
|
||||
|
||||
Arguments:
|
||||
hidden_size (int): the hidden dimension of the model, importantly this is also the input and output dimension.
|
||||
expert (torch.nn.Module): the torch module that defines the expert (e.g., MLP, torch.linear).
|
||||
num_experts (int, optional): default=1, the total number of experts per layer.
|
||||
ep_size (int, optional): default=1, number of ranks in the expert parallel world or group.
|
||||
k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
|
||||
capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
|
||||
eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
|
||||
min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
|
||||
noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample' or 'None'.
|
||||
drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to infinite capacity).
|
||||
use_rts (bool, optional): default=True, whether to use Random Token Selection.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size,
|
||||
expert,
|
||||
num_experts=1,
|
||||
ep_size=1,
|
||||
k=1,
|
||||
capacity_factor=1.,
|
||||
eval_capacity_factor=1.,
|
||||
min_capacity=4,
|
||||
noisy_gate_policy: typing.Optional[str] = None,
|
||||
drop_tokens: bool = True,
|
||||
use_rts: bool = True,
|
||||
using_default_moe: bool = True):
|
||||
|
||||
super(MoE, self).__init__()
|
||||
|
||||
assert num_experts % ep_size == 0, f"Number of experts ({num_experts}) should be divisible by expert parallel size ({ep_size})"
|
||||
self.ep_size = ep_size
|
||||
self.num_experts = num_experts
|
||||
self.num_local_experts = num_experts // self.ep_size
|
||||
|
||||
logger.info(
|
||||
f'Creating MoE layer with num_experts: {num_experts} | num_local_experts: {self.num_local_experts} | expert_parallel_size: {self.ep_size}')
|
||||
|
||||
assert noisy_gate_policy is None or noisy_gate_policy in ['None', 'Jitter', 'RSample'], \
|
||||
'Unsupported noisy_gate_policy: ' + noisy_gate_policy
|
||||
|
||||
experts = Experts(expert, self.num_local_experts)
|
||||
|
||||
if using_default_moe:
|
||||
self.moe_layer = MOELayer(TopKGate(hidden_size, num_experts, k, capacity_factor, eval_capacity_factor,
|
||||
min_capacity, noisy_gate_policy, drop_tokens, use_rts),
|
||||
experts,
|
||||
gpc.get_group(ParallelMode.EXPERT),
|
||||
self.ep_size,
|
||||
self.num_local_experts)
|
||||
|
||||
|
||||
def forward(self, hidden_states, used_token=None):
|
||||
""" MoE forward
|
||||
|
||||
Arguments:
|
||||
hidden_states (Tensor): input to the layer
|
||||
used_token (Tensor, optional): default: None, mask only used tokens
|
||||
|
||||
Returns:
|
||||
A tuple including output, gate loss, and expert count.
|
||||
|
||||
* output (Tensor): output of the model
|
||||
|
||||
* l_aux (Tensor): gate loss value
|
||||
|
||||
* exp_counts (int): expert count
|
||||
"""
|
||||
output = self.moe_layer(hidden_states, used_token)
|
||||
|
||||
return output, self.moe_layer.l_aux, self.moe_layer.exp_counts
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
"""
|
||||
The file has been adapted from the following files:
|
||||
https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py
|
||||
Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555
|
||||
We retain the following license from the original files:
|
||||
"""
|
||||
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import torch
|
||||
import copy
|
||||
from torch.nn import Module, ModuleList
|
||||
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast
|
||||
|
||||
class Experts(torch.nn.Module):
|
||||
|
||||
def __init__(self, experts: Union[Module, ModuleList], num_local_experts=1):
|
||||
super(Experts, self).__init__()
|
||||
|
||||
# TODO: We can not deepcopy FeedForward since it contains a process_group in submodules
|
||||
# self.experts = torch.nn.ModuleList([copy.deepcopy(expert) for i in range(num_local_experts)])
|
||||
|
||||
|
||||
if type(experts) == ModuleList:
|
||||
self.experts = cast(ModuleList, experts)
|
||||
else:
|
||||
self.experts = ModuleList([experts])
|
||||
self.num_local_experts = num_local_experts
|
||||
|
||||
# TODO: revisit allreduce for moe.gate...
|
||||
for expert in self.experts:
|
||||
# TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group)
|
||||
for name, param in expert.named_parameters():
|
||||
param.all_reduce = False
|
||||
|
||||
def forward(self, inputs):
|
||||
chunks = inputs.chunk(self.num_local_experts, dim=1)
|
||||
expert_outputs = []
|
||||
for chunk, expert in zip(chunks, self.experts):
|
||||
out = expert(chunk)
|
||||
if type(out) is tuple:
|
||||
out = out[0] # Ignore the bias term for now
|
||||
expert_outputs += [out]
|
||||
|
||||
expert_output = torch.cat(expert_outputs, dim=1)
|
||||
return expert_output
|
||||
|
|
@ -0,0 +1,503 @@
|
|||
import torch.distributed as dist
|
||||
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.context import ParallelMode
|
||||
|
||||
|
||||
# global llm logger
|
||||
logger = get_logger(__file__)
|
||||
|
||||
"""
|
||||
The file has been adapted from the following files:
|
||||
https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py
|
||||
Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555
|
||||
We retain the following license from the original files:
|
||||
"""
|
||||
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
|
||||
|
||||
from typing import Callable, Dict, TYPE_CHECKING, Any, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
import torch.nn.functional as F
|
||||
|
||||
if TYPE_CHECKING:
|
||||
Base = Module[Tensor]
|
||||
else:
|
||||
Base = Module
|
||||
|
||||
uniform_map: Dict[torch.device, Callable] = {}
|
||||
gumbel_map: Dict[torch.device, Callable] = {}
|
||||
exp_selection_uniform_map: Dict[torch.device, Callable] = {}
|
||||
|
||||
|
||||
def multiplicative_jitter(x, device: torch.device, epsilon=1e-2):
|
||||
"""
|
||||
Modified from switch transformer paper. mesh transformers
|
||||
Multiply values by a random number between 1-epsilon and 1+epsilon.
|
||||
Makes models more resilient to rounding errors introduced by bfloat16.
|
||||
This seems particularly important for logits.
|
||||
Args:
|
||||
x: a torch.tensor
|
||||
device: torch.device
|
||||
epsilon: a floating point value
|
||||
Returns:
|
||||
a jittered x.
|
||||
"""
|
||||
if epsilon == 0:
|
||||
return x
|
||||
uniform = uniform_map.get(device)
|
||||
if uniform is None:
|
||||
uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - epsilon, device=device),
|
||||
high=torch.tensor(1.0 + epsilon,
|
||||
device=device)).rsample # type: ignore
|
||||
uniform_map[device] = uniform
|
||||
return x * uniform(x.shape)
|
||||
|
||||
|
||||
def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
|
||||
gumbel = gumbel_map.get(device)
|
||||
if gumbel is None:
|
||||
one = torch.tensor(1.0, device=device)
|
||||
zero = torch.tensor(0.0, device=device)
|
||||
gumbel = torch.distributions.gumbel.Gumbel(zero, one).rsample # type: ignore
|
||||
gumbel_map[device] = gumbel
|
||||
return gumbel(shape)
|
||||
|
||||
# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
|
||||
# See https://arxiv.org/pdf/2006.16668.pdf for details.
|
||||
|
||||
|
||||
# Based on https://github.com/pytorch/pytorch/pull/40762
|
||||
class _AllToAll(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: Any,
|
||||
# TODO: replace with DS process group
|
||||
group: torch.distributed.ProcessGroup,
|
||||
input: Tensor) -> Tensor: # type: ignore
|
||||
ctx.group = group
|
||||
input = input.contiguous()
|
||||
output = torch.empty_like(input)
|
||||
dist.all_to_all_single(output, input, group=group)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
|
||||
return (None, _AllToAll.apply(ctx.group, *grad_output))
|
||||
|
||||
|
||||
# einsum rewrites are on par or more performant
|
||||
# switch can be bubbled up in future
|
||||
USE_EINSUM = True
|
||||
|
||||
|
||||
# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
|
||||
# See https://arxiv.org/pdf/2006.16668.pdf for details.
|
||||
def einsum(rule, a, b):
|
||||
if USE_EINSUM:
|
||||
return torch.einsum(rule, a, b)
|
||||
elif rule == 's,se->se':
|
||||
## [1, s] * [s, e]
|
||||
return a.reshape(a.shape[0], -1) * b
|
||||
elif rule == 'se,sc->sec':
|
||||
## [s,e,1] * [s,1,c]
|
||||
return a.unsqueeze(2) * b.unsqueeze(1)
|
||||
elif rule == 'se,se->s':
|
||||
## [s,1,e] * [s,e,1]
|
||||
return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1)
|
||||
elif rule == 'sec,sm->ecm':
|
||||
## [e*c, s] * [s, m]
|
||||
s = a.shape[0]
|
||||
e = a.shape[1]
|
||||
c = a.shape[2]
|
||||
m = b.shape[1]
|
||||
return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m)
|
||||
elif rule == 'sec,ecm->sm':
|
||||
## [s, e*c] * [e*c, m]
|
||||
return torch.matmul(a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1]))
|
||||
elif rule == 'ks,ksm->sm':
|
||||
k = b.shape[0]
|
||||
s = b.shape[1]
|
||||
m = b.shape[2]
|
||||
# [k, s] -> [s, k] -> [s, 1, k]
|
||||
a = a.t().unsqueeze(1)
|
||||
# [k,s,m] -> [k, sm] -> [sm, k] -> [s, m, k]
|
||||
b = b.reshape(k, -1).t().reshape(s, m, k)
|
||||
# bmm([s, 1, k], [s, m, k]^t) -> [s, m, 1]
|
||||
return torch.bmm(a, b.transpose(1, 2)).squeeze(2)
|
||||
else:
|
||||
return torch.einsum(rule, a, b)
|
||||
|
||||
|
||||
# The following functions are extracted and scripted
|
||||
# because otherwise during a torch.jit.trace, the non-Tensor
|
||||
# values used in the calculations get recorded as constants.
|
||||
# torch.jit.script coerces them into Tensors and preserves
|
||||
# their dynamic shapes. This enables ONNX export.
|
||||
# We can't script the entire top1gating function because it
|
||||
# includes stateful caching logic which is incompatible with ONNX.
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def _capacity(gates: Tensor, capacity_factor: Tensor, min_capacity: Tensor) -> Tensor:
|
||||
# gates has shape of SE
|
||||
num_tokens = gates.shape[0]
|
||||
num_experts = gates.shape[1]
|
||||
# to(torch.int64) works around a bug in torch.onnx.export:
|
||||
# it should cast k to int64 when converting torch.topk but it doesn't.
|
||||
capacity = torch.ceil((num_tokens / num_experts) * capacity_factor).to(torch.int64)
|
||||
if capacity < min_capacity:
|
||||
capacity = min_capacity.to(torch.int64)
|
||||
return capacity
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def _top_idx(source, k):
|
||||
return torch.topk(source, k=k, dim=0)[1]
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def _one_hot_to_float(x, num_classes):
|
||||
return F.one_hot(x, num_classes=num_classes).float()
|
||||
|
||||
|
||||
def top1gating(logits: Tensor,
|
||||
capacity_factor: float,
|
||||
min_capacity: int,
|
||||
used_token: Tensor = None,
|
||||
noisy_gate_policy: Optional[str] = None,
|
||||
drop_tokens: bool = True,
|
||||
use_rts: bool = True,
|
||||
use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
"""Implements Top1Gating on logits."""
|
||||
if noisy_gate_policy == 'RSample':
|
||||
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
|
||||
# everything is in fp32 in this function
|
||||
gates = F.softmax(logits, dim=1)
|
||||
|
||||
capacity = _capacity(gates, torch.tensor(capacity_factor), torch.tensor(min_capacity))
|
||||
|
||||
# Create a mask for 1st's expert per token
|
||||
# noisy gating
|
||||
indices1_s = torch.argmax(logits_w_noise if noisy_gate_policy == 'RSample' else gates, dim=1)
|
||||
num_experts = int(gates.shape[1])
|
||||
mask1 = F.one_hot(indices1_s, num_classes=num_experts)
|
||||
|
||||
# mask only used tokens
|
||||
if used_token is not None:
|
||||
mask1 = einsum("s,se->se", used_token, mask1)
|
||||
|
||||
# gating decisions
|
||||
exp_counts = torch.sum(mask1, dim=0).detach().to('cpu')
|
||||
|
||||
# if we don't want to drop any tokens
|
||||
if not drop_tokens:
|
||||
new_capacity = torch.max(exp_counts).to(logits.device)
|
||||
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group())
|
||||
capacity = new_capacity
|
||||
|
||||
# Compute l_aux
|
||||
me = torch.mean(gates, dim=0)
|
||||
ce = torch.mean(mask1.float(), dim=0)
|
||||
l_aux = torch.sum(me * ce) * num_experts
|
||||
|
||||
# Random Token Selection
|
||||
if use_rts:
|
||||
uniform = exp_selection_uniform_map.get(logits.device)
|
||||
if uniform is None:
|
||||
uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=logits.device),
|
||||
high=torch.tensor(1.0, device=logits.device)).rsample
|
||||
exp_selection_uniform_map[logits.device] = uniform
|
||||
|
||||
mask1_rand = mask1 * uniform(mask1.shape)
|
||||
else:
|
||||
mask1_rand = mask1
|
||||
|
||||
assert logits.shape[
|
||||
0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size."
|
||||
|
||||
top_idx = _top_idx(mask1_rand, capacity) #@wenwen: token index
|
||||
|
||||
new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1)
|
||||
mask1 = new_mask1
|
||||
|
||||
if use_tutel:
|
||||
# Tutel doesn't support index values masked with zero
|
||||
# so we need to replace masked indices with -1
|
||||
indices_mask = mask1.sum(dim=1) * num_experts - 1
|
||||
indices1_s = torch.min(indices1_s, indices_mask)
|
||||
|
||||
# Compute locations in capacity buffer
|
||||
|
||||
locations1 = torch.cumsum(mask1, dim=0) - 1
|
||||
|
||||
if use_tutel:
|
||||
gates1_s = (gates * mask1).sum(dim=1)
|
||||
locations1_s = torch.sum(locations1 * mask1, dim=1)
|
||||
return l_aux, capacity, num_experts, [
|
||||
indices1_s,
|
||||
], [
|
||||
locations1_s,
|
||||
], [
|
||||
gates1_s,
|
||||
], exp_counts
|
||||
|
||||
# Store the capacity location for each token
|
||||
locations1_s = torch.sum(locations1 * mask1, dim=1)
|
||||
|
||||
# Normalize gate probabilities
|
||||
mask1_float = mask1.float()
|
||||
gates = gates * mask1_float
|
||||
|
||||
locations1_sc = _one_hot_to_float(locations1_s, capacity)
|
||||
combine_weights = einsum("se,sc->sec", gates, locations1_sc)
|
||||
|
||||
dispatch_mask = combine_weights.bool()
|
||||
|
||||
return l_aux, combine_weights, dispatch_mask, exp_counts
|
||||
|
||||
|
||||
def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
"""Implements Top2Gating on logits."""
|
||||
# everything is in fp32 in this function
|
||||
gates = F.softmax(logits, dim=1)
|
||||
|
||||
capacity = _capacity(gates, torch.tensor(capacity_factor * 2), torch.tensor(min_capacity))
|
||||
|
||||
# Create a mask for 1st's expert per token
|
||||
indices1_s = torch.argmax(gates, dim=1)
|
||||
num_experts = int(gates.shape[1])
|
||||
mask1 = F.one_hot(indices1_s, num_classes=num_experts)
|
||||
|
||||
# Create a mask for 2nd's expert per token using Gumbel-max trick
|
||||
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
|
||||
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
|
||||
# Replace top-expert with min value
|
||||
logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf"))
|
||||
indices2_s = torch.argmax(logits_except1, dim=1)
|
||||
mask2 = F.one_hot(indices2_s, num_classes=num_experts)
|
||||
|
||||
# Compute locations in capacity buffer
|
||||
locations1 = torch.cumsum(mask1, dim=0) - 1
|
||||
locations2 = torch.cumsum(mask2, dim=0) - 1
|
||||
# Update 2nd's location by accounting for locations of 1st
|
||||
locations2 += torch.sum(mask1, dim=0, keepdim=True)
|
||||
|
||||
# gating decisions
|
||||
exp_counts = torch.sum(mask1, dim=0).detach().to('cpu')
|
||||
|
||||
# Compute l_aux
|
||||
me = torch.mean(gates, dim=0)
|
||||
ce = torch.mean(mask1.float(), dim=0)
|
||||
l_aux = torch.mean(me * ce) * num_experts * num_experts
|
||||
|
||||
# Remove locations outside capacity from mask
|
||||
mask1 *= torch.lt(locations1, capacity)
|
||||
mask2 *= torch.lt(locations2, capacity)
|
||||
|
||||
# Store the capacity location for each token
|
||||
locations1_s = torch.sum(locations1 * mask1, dim=1)
|
||||
locations2_s = torch.sum(locations2 * mask2, dim=1)
|
||||
|
||||
# Normalize gate probabilities
|
||||
mask1_float = mask1.float()
|
||||
mask2_float = mask2.float()
|
||||
gates1_s = einsum("se,se->s", gates, mask1_float)
|
||||
gates2_s = einsum("se,se->s", gates, mask2_float)
|
||||
denom_s = gates1_s + gates2_s
|
||||
# Avoid divide-by-zero
|
||||
denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
|
||||
gates1_s /= denom_s
|
||||
gates2_s /= denom_s
|
||||
|
||||
# Calculate combine_weights and dispatch_mask
|
||||
gates1 = einsum("s,se->se", gates1_s, mask1_float)
|
||||
gates2 = einsum("s,se->se", gates2_s, mask2_float)
|
||||
locations1_sc = _one_hot_to_float(locations1_s, capacity)
|
||||
locations2_sc = _one_hot_to_float(locations2_s, capacity)
|
||||
combine1_sec = einsum("se,sc->sec", gates1, locations1_sc)
|
||||
combine2_sec = einsum("se,sc->sec", gates2, locations2_sc)
|
||||
combine_weights = combine1_sec + combine2_sec
|
||||
dispatch_mask = combine_weights.bool()
|
||||
|
||||
return l_aux, combine_weights, dispatch_mask, exp_counts
|
||||
|
||||
|
||||
class TopKGate(Module):
|
||||
"""Gate module which implements Top2Gating as described in Gshard_.
|
||||
::
|
||||
|
||||
gate = TopKGate(model_dim, num_experts)
|
||||
l_aux, combine_weights, dispatch_mask = gate(input)
|
||||
|
||||
.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
|
||||
|
||||
Args:
|
||||
model_dim (int):
|
||||
size of model embedding dimension
|
||||
num_experts (ints):
|
||||
number of experts in model
|
||||
"""
|
||||
|
||||
wg: torch.nn.Linear
|
||||
|
||||
def __init__(self,
|
||||
model_dim: int,
|
||||
num_experts: int,
|
||||
k: int = 1,
|
||||
capacity_factor: float = 1.0,
|
||||
eval_capacity_factor: float = 1.0,
|
||||
min_capacity: int = 8,
|
||||
noisy_gate_policy: Optional[str] = None,
|
||||
drop_tokens: bool = True,
|
||||
use_rts: bool = True) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Only top-1 and top-2 are supported at the moment.
|
||||
if k != 1 and k != 2:
|
||||
raise ValueError('Only top-1 and top-2 gatings are supported.')
|
||||
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float()
|
||||
self.k = k
|
||||
self.capacity_factor = capacity_factor
|
||||
self.eval_capacity_factor = eval_capacity_factor
|
||||
self.min_capacity = min_capacity
|
||||
self.noisy_gate_policy = noisy_gate_policy
|
||||
self.wall_clock_breakdown = False
|
||||
self.gate_time = 0.0
|
||||
self.drop_tokens = drop_tokens
|
||||
self.use_rts = use_rts
|
||||
|
||||
def forward(self,
|
||||
input: torch.Tensor,
|
||||
used_token: torch.Tensor = None,
|
||||
use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore
|
||||
|
||||
if self.wall_clock_breakdown:
|
||||
timer('TopKGate').start()
|
||||
|
||||
if self.wg.weight.dtype != torch.float32:
|
||||
self.wg = self.wg.float()
|
||||
input_fp32 = input.float()
|
||||
# input jittering
|
||||
if self.noisy_gate_policy == 'Jitter' and self.training:
|
||||
input_fp32 = multiplicative_jitter(input_fp32, device=input.device)
|
||||
logits = self.wg(input_fp32)
|
||||
|
||||
if self.k == 1:
|
||||
gate_output = top1gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
|
||||
self.min_capacity, used_token, self.noisy_gate_policy if self.training else None,
|
||||
self.drop_tokens, self.use_rts, use_tutel)
|
||||
|
||||
else:
|
||||
gate_output = top2gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
|
||||
self.min_capacity)
|
||||
|
||||
if self.wall_clock_breakdown:
|
||||
timer('TopKGate').stop()
|
||||
self.gate_time = timer('TopKGate').elapsed(reset=False)
|
||||
|
||||
return gate_output
|
||||
|
||||
|
||||
class MOELayer(Base):
|
||||
"""MOELayer module which implements MixtureOfExperts as described in Gshard_.
|
||||
::
|
||||
|
||||
gate = TopKGate(model_dim, num_experts)
|
||||
moe = MOELayer(gate, expert)
|
||||
output = moe(input)
|
||||
l_aux = moe.l_aux
|
||||
|
||||
.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
|
||||
|
||||
Args:
|
||||
gate (torch.nn.Module):
|
||||
gate network
|
||||
expert (torch.nn.Module):
|
||||
expert network
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
gate: Module,
|
||||
experts: Module,
|
||||
ep_group,
|
||||
ep_size,
|
||||
num_local_experts: int) -> None:
|
||||
super().__init__()
|
||||
self.gate = gate
|
||||
self.experts = experts
|
||||
self.ep_group = ep_group
|
||||
self.ep_size = ep_size
|
||||
self.num_local_experts = num_local_experts
|
||||
self.time_falltoall = 0.0
|
||||
self.time_salltoall = 0.0
|
||||
self.time_moe = 0.0
|
||||
self.wall_clock_breakdown = False
|
||||
|
||||
|
||||
def _set_ep_group(self, ep_group):
|
||||
self.ep_group = ep_group
|
||||
|
||||
def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:
|
||||
|
||||
if self.wall_clock_breakdown:
|
||||
timer('moe').start()
|
||||
|
||||
# Implement Algorithm 2 from GShard paper.
|
||||
d_model = input[0].shape[-1]
|
||||
|
||||
# Initial implementation -> Reshape into S tokens by dropping sequence dimension.
|
||||
# Reshape into G groups so that each group can distribute tokens equally
|
||||
# group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1
|
||||
reshaped_input = input[0].reshape(-1, d_model)
|
||||
|
||||
self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input, input[1])
|
||||
dispatched_input = einsum("sec,sm->ecm", dispatch_mask.type_as(input[0]), reshaped_input) ## TODO: heavy memory usage due to long sequence length
|
||||
|
||||
if self.wall_clock_breakdown:
|
||||
timer('falltoall').start()
|
||||
|
||||
|
||||
dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input)
|
||||
|
||||
if self.wall_clock_breakdown:
|
||||
timer('falltoall').stop()
|
||||
self.time_falltoall = timer('falltoall').elapsed(reset=False)
|
||||
|
||||
# Re-shape after all-to-all: ecm -> gecm
|
||||
dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model)
|
||||
|
||||
expert_output = self.experts(dispatched_input)
|
||||
|
||||
if self.wall_clock_breakdown:
|
||||
timer('salltoall').start()
|
||||
|
||||
expert_output = _AllToAll.apply(self.ep_group, expert_output)
|
||||
|
||||
if self.wall_clock_breakdown:
|
||||
timer('salltoall').stop()
|
||||
self.time_salltoall = timer('salltoall').elapsed(reset=False)
|
||||
|
||||
# Re-shape back: gecm -> ecm
|
||||
expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)
|
||||
|
||||
combined_output = einsum("sec,ecm->sm", combine_weights.type_as(input[0]), expert_output)
|
||||
|
||||
a = combined_output.reshape(input[0].shape)
|
||||
|
||||
if self.wall_clock_breakdown:
|
||||
timer('moe').stop()
|
||||
self.time_moe = timer('moe').elapsed(reset=False)
|
||||
|
||||
return a
|
||||
|
|
@ -25,6 +25,7 @@ from internlm.solver.optimizer.utils import (
|
|||
split_half_float_double,
|
||||
sync_param,
|
||||
)
|
||||
from internlm.model.moe import is_moe_param
|
||||
from internlm.utils.common import get_current_device
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
|
|
@ -285,7 +286,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
for group_id in range(self.num_param_groups):
|
||||
param_group = self._fp16_param_groups[group_id]
|
||||
for param in param_group:
|
||||
if param.requires_grad:
|
||||
if param.requires_grad and not is_moe_param(param):
|
||||
reduce_rank = None
|
||||
|
||||
def _define_and_attach(param, reduce_rank=None):
|
||||
|
|
|
|||
2
train.py
2
train.py
|
|
@ -580,7 +580,7 @@ def main(args):
|
|||
|
||||
# do forward and backward
|
||||
timer("fwd-bwd").start()
|
||||
_, _, loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False)
|
||||
_, _, loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False, moe_loss_coeff = gpc.config.loss.moe_loss_coeff)
|
||||
timer("fwd-bwd").stop()
|
||||
|
||||
# update parameters, and returns (success_update, grad_norm)
|
||||
|
|
|
|||
Loading…
Reference in New Issue