diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 30655de..1e190d1 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -77,6 +77,7 @@ hybrid_zero_optimizer = dict( loss = dict( label_smoothing=0, + moe_loss_coeff=1.0, ) adam = dict( @@ -119,6 +120,7 @@ model = dict( use_flash_attn=True, num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. sequence_parallel=False, + num_experts=8, ) """ zero1 parallel: diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index 87d3114..ceea60c 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -143,6 +143,7 @@ class ParallelContext(metaclass=SingletonMeta): self.pipeline_parallel_size = 1 self.tensor_parallel_size = 1 self.zero1_parallel_size = -1 + self.expert_parallel_size = -1 self.num_processes_on_current_node = -1 self.virtual_pipeline_parallel_size = None self.virtual_pipeline_parallel_rank = None @@ -442,6 +443,9 @@ class ParallelContext(metaclass=SingletonMeta): # instead, it should be calculated based on other parallel config self.data_parallel_size = self.world_size // (self.pipeline_parallel_size * self.tensor_parallel_size) + # TODO : data parallel size can be different with expert parallel size + self.expert_parallel_size = self.data_parallel_size + if self.zero1_parallel_size <= 0: self.zero1_parallel_size = self.data_parallel_size @@ -454,6 +458,7 @@ class ParallelContext(metaclass=SingletonMeta): self.pipeline_parallel_size, self.tensor_parallel_size, self.zero1_parallel_size, + self.expert_parallel_size, ] # run initialization of different process groups @@ -464,6 +469,8 @@ class ParallelContext(metaclass=SingletonMeta): initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args)) if self.pipeline_parallel_size > 1: initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args)) + if self.config.model.num_experts > 1: + initializers.append(pgroup_initializer.Initializer_Expert(*initializer_args)) for initializer in initializers: parallel_setting = initializer.init_dist_group() if isinstance(parallel_setting, list): diff --git a/internlm/core/context/process_group_initializer.py b/internlm/core/context/process_group_initializer.py index 56cf16d..9cc7a7d 100644 --- a/internlm/core/context/process_group_initializer.py +++ b/internlm/core/context/process_group_initializer.py @@ -31,6 +31,9 @@ class ParallelMode(Enum): # zero1 parallel ZERO1 = "zero1" + # expert parallel + EXPERT = "expert" + class ProcessGroupInitializer(ABC): """An object, knowing the parallelism configuration, that initializes parallel groups. @@ -42,6 +45,7 @@ class ProcessGroupInitializer(ABC): pipeline_parallel_size (int): Size of pipeline parallel. tensor_parallel_size (int): Size of tensor parallel. zero1_parallel_size (int): Size of zero1 parallel. + expert_parallel_size (int): Size of expert parallel. """ def __init__( @@ -52,6 +56,7 @@ class ProcessGroupInitializer(ABC): pipeline_parallel_size: int, tensor_parallel_size: int, zero1_parallel_size: int, + expert_parallel_size: int, ): self.rank = rank self.world_size = world_size @@ -59,6 +64,7 @@ class ProcessGroupInitializer(ABC): self.pipeline_parallel_size = pipeline_parallel_size self.tensor_parallel_size = tensor_parallel_size self.zero1_parallel_size = zero1_parallel_size + self.expert_parallel_size = expert_parallel_size super().__init__() @abstractmethod @@ -76,6 +82,7 @@ class Initializer_Data(ProcessGroupInitializer): pipeline_parallel_size (int): Size of pipeline parallel. tensor_parallel_size (int): Size of tensor parallel. zero1_parallel_size (int): Size of zero1 parallel. + expert_parallel_size (int): Size of expert parallel. """ def __init__(self, *args, **kwargs): @@ -127,6 +134,7 @@ class Initializer_Model(ProcessGroupInitializer): pipeline_parallel_size (int): Size of pipeline parallel. tensor_parallel_size (int): Size of tensor parallel. zero1_parallel_size (int): Size of zero1 parallel. + expert_parallel_size (int): Size of expert parallel. """ def __init__(self, *args, **kwargs): @@ -178,6 +186,7 @@ class Initializer_Pipeline(ProcessGroupInitializer): pipeline_parallel_size (int): Size of pipeline parallel tensor_parallel_size (int): Size of tensor parallel zero1_parallel_size (int): Size of zero1 parallel. + expert_parallel_size (int): Size of expert parallel. """ def __init__(self, *args, **kwargs): @@ -238,6 +247,7 @@ class Initializer_Tensor(ProcessGroupInitializer): pipeline_parallel_size (int): Size of pipeline parallel. tensor_parallel_size (int): Size of tensor parallel. zero1_parallel_size (int): Size of zero1 parallel. + expert_parallel_size (int): Size of expert parallel. """ def __init__(self, *args, **kwargs): @@ -288,6 +298,7 @@ class Initializer_Zero1(ProcessGroupInitializer): pipeline_parallel_size (int): Size of pipeline parallel. tensor_parallel_size (int): Size of tensor parallel. zero1_parallel_size (int): Size of zero-1 parallel. + expert_parallel_size (int): Size of expert parallel. """ def __init__(self, *args, **kwargs): @@ -332,3 +343,59 @@ class Initializer_Zero1(ProcessGroupInitializer): ranks_in_group = ranks return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode + +class Initializer_Expert(ProcessGroupInitializer): + """A ProcessGroupInitializer for zero-1 parallelism. + + Args: + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + zero1_parallel_size (int): Size of zero-1 parallel. + expert_parallel_size (int): Size of expert parallel. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.num_expert_parallel_group = self.world_size // self.expert_parallel_size + + assert self.world_size % self.num_expert_parallel_group == 0 + + # TODO: to match expert parallel with differnt data parallel size + assert self.data_parallel_size == self.expert_parallel_size + + def init_dist_group(self, use_cpu: bool = False): + """Initialize expert parallel groups, and assign local_ranks and groups to each gpu. + + Example: world_size = 8, model_parallel_size = 2, expert_parallel_size = 4 + model_parallel_group = [0,1], [2,3], [4,5], [6,7] + expert_parallel_group = [0,2,4,6], [1,3,5,7] + + Returns: + Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode): + A expert parallelism's information tuple. + """ + local_rank = None + ranks_in_group = None + process_group = None + cpu_group = None + group_world_size = None + mode = ParallelMode.EXPERT + + for i in range(self.num_expert_parallel_group): + ranks = list(range(i, self.world_size, self.num_expert_parallel_group)) + group = dist.new_group(ranks) + if use_cpu: + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group + else: + group_cpu = None + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index 2633a9c..e2084d5 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -88,6 +88,7 @@ class NonPipelineScheduler(BaseScheduler): forward_only: bool = False, return_loss: bool = True, scale_loss: int = 1, + moe_loss_coeff: float = 1.0, ): """Trains one batch of data. @@ -104,7 +105,7 @@ class NonPipelineScheduler(BaseScheduler): # forward with conditional_context(torch.no_grad(), enable=forward_only): self._call_hooks("before_forward", data) - output = self._call_engine(engine, data) + output, moe_losses = self._call_engine(engine, data) self._call_hooks("after_forward", output) self._call_hooks("post_helper_func", output, label) @@ -113,7 +114,9 @@ class NonPipelineScheduler(BaseScheduler): self._call_hooks("before_criterion", output, label) loss = self._call_engine_criterion(engine, output, label) self._call_hooks("after_criterion", loss) - loss /= scale_loss + moe_loss = sum(moe_losses) * moe_loss_coeff + loss += moe_loss + loss /= scale_loss ## TODO: check whether mos_loss should be scaled # backward if not forward_only: @@ -133,6 +136,7 @@ class NonPipelineScheduler(BaseScheduler): forward_only: bool = False, return_loss: bool = True, return_output_label: bool = True, + moe_loss_coeff: float = 1.0, ): """The process function that loads a batch of dataset and feeds it to the model. The returned labels and loss will None if :attr:`return_loss` is False. @@ -177,7 +181,7 @@ class NonPipelineScheduler(BaseScheduler): _data, _label = self._load_accum_batch(data, label) _output, _loss = self._train_one_batch( - _data, _label, engine, forward_only, return_loss, self._grad_accum_size + _data, _label, engine, forward_only, return_loss, self._grad_accum_size, moe_loss_coeff ) if return_loss: diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 31138fa..b57c8f0 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -18,6 +18,7 @@ from internlm.model.linear import ( RewardModelLinear, ScaleColumnParallelLinear, ) +from internlm.model.moe import MoE from internlm.model.multi_head_attention import MHA from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm from internlm.solver.pipeline_utils import partition_uniform @@ -49,6 +50,7 @@ class PackedFlashBaseLayer1D(nn.Module): device (Optional[Union[str, torch.device]]): The device will be used. norm_type (str): Use RMS norm or layernorm."rmsnorm" by default. use_flash_attn (bool): Whether use flash-attn. True by default. + num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default. """ def __init__( @@ -69,6 +71,7 @@ class PackedFlashBaseLayer1D(nn.Module): use_scaled_init: bool = True, use_swiglu: bool = True, use_flash_attn: bool = True, + num_experts: int = 1, ): super().__init__() self.checkpoint = checkpoint @@ -101,37 +104,57 @@ class PackedFlashBaseLayer1D(nn.Module): self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) - if use_swiglu: - self.mlp = FeedForward( + ## TODO: replace num_experts and epsize with function parameter + self.num_experts = num_experts + ep_size = gpc.get_world_size(ParallelMode.EXPERT) + if num_experts <= 1: # dense, not MoE + if use_swiglu: + self.mlp = FeedForward( + hidden_size, + int(hidden_size * mlp_ratio), + out_features=hidden_size, + process_group=gpc.get_group(ParallelMode.TENSOR), + bias=False, + device=device, + dtype=dtype, + ) + else: + self.mlp = ParallelFusedMLP( + hidden_size, + int(hidden_size * mlp_ratio), + out_features=hidden_size, + activation="gelu_approx", + process_group=gpc.get_group(ParallelMode.TENSOR), + bias1=False, + bias2=False, + sequence_parallel=gpc.config.model.sequence_parallel, + checkpoint_lvl=0, + heuristic="auto", + device=device, + dtype=dtype, + ) + else: + expert = torch.nn.ModuleList([FeedForward( hidden_size, - int(hidden_size * mlp_ratio), + int(hidden_size * gpc.config.model.mlp_ratio), out_features=hidden_size, process_group=gpc.get_group(ParallelMode.TENSOR), bias=False, - device=device, - dtype=dtype, - ) - else: - self.mlp = ParallelFusedMLP( - hidden_size, - int(hidden_size * mlp_ratio), - out_features=hidden_size, - activation="gelu_approx", - process_group=gpc.get_group(ParallelMode.TENSOR), - bias1=False, - bias2=False, - sequence_parallel=gpc.config.model.sequence_parallel, - checkpoint_lvl=0, - heuristic="auto", - device=device, - dtype=dtype, - ) + device=torch.device("cuda"), + dtype=torch.float, + ) for i in range(num_experts // ep_size)]) + # TODO: test moe for now, need more parameter such as: capacity_factor, eval_capacity_factor, min_capacity, drop_tokens + self.mlp = MoE(hidden_size=hidden_size, + expert=expert, + ep_size=ep_size, + num_experts=num_experts, + k=1) self.dropout2 = nn.Dropout(drop_rate) self.use_swiglu = use_swiglu self.use_scaled_init = use_scaled_init self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm self.return_residual = False - self.reset_parameters() + self.reset_parameters() ## TODO: check this should be changed when moe is added def reset_parameters(self): with torch.no_grad(): @@ -163,7 +186,7 @@ class PackedFlashBaseLayer1D(nn.Module): if self.checkpoint and self.training: return activation_checkpoint( self._forward, False, hidden_states, cu_seqlens, indexes, inference_params, max_seqlen - ) + ) ##TODO: check whether this will be affected by moe else: return self._forward(hidden_states, cu_seqlens, indexes, inference_params, max_seqlen) @@ -213,9 +236,14 @@ class PackedFlashBaseLayer1D(nn.Module): if self.residual_in_fp32: residual = residual.to(torch.float32) - hidden_states = self.mlp(hidden_states) + # MLP. + moe_loss = torch.tensor(0.0, device=hidden_states.device, dtype=hidden_states.dtype) + if self.num_experts <= 1: + hidden_states = self.mlp(hidden_states) + else: + hidden_states, moe_loss, _ = self.mlp(hidden_states) - return hidden_states + residual + return hidden_states + residual, moe_loss class PackedFlashInternLm1D(nn.Module): @@ -246,6 +274,7 @@ class PackedFlashInternLm1D(nn.Module): residual_in_fp32 (bool): Whether to use residual in fp32. False by default. norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. use_flash_attn (bool): Whether to use flash-attn. True by default. + num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default. """ @@ -276,6 +305,7 @@ class PackedFlashInternLm1D(nn.Module): use_scaled_init: bool = True, use_swiglu: bool = True, use_flash_attn: bool = True, + num_experts: bool = 1, ): super().__init__() @@ -327,6 +357,7 @@ class PackedFlashInternLm1D(nn.Module): use_scaled_init=use_scaled_init, use_swiglu=use_swiglu, use_flash_attn=use_flash_attn, + num_experts=num_experts, ) for lid in range(num_layers) ] @@ -374,14 +405,16 @@ class PackedFlashInternLm1D(nn.Module): indexes = indexes[0] max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None + moe_losses = [] for _, block in enumerate(self.blocks): - hidden_states = block( + hidden_states, mos_loss = block( hidden_states, cu_seqlens=cu_seqlens, indexes=indexes, inference_params=inference_params, max_seqlen=max_seqlen, ) + moe_losses.append(mos_loss) if hasattr(self, "norm"): hidden_states = self.norm(hidden_states.float()) @@ -390,7 +423,7 @@ class PackedFlashInternLm1D(nn.Module): if not self.parallel_output: hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1) - return hidden_states + return hidden_states, moe_losses def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs): @@ -462,6 +495,7 @@ def build_model_with_cfg( use_swiglu: bool = True, use_flash_attn: bool = True, sequence_parallel: bool = False, + num_experts: int = 1, ): """ Builde model with config @@ -492,6 +526,7 @@ def build_model_with_cfg( use_scaled_init (bool): Whether to use scaled init. True by default. use_swiglu (bool): Whether to use swiglu. True by default. use_flash_attn (bool): Whether to use flash-attn. True by default. + num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default. """ @@ -515,6 +550,7 @@ def build_model_with_cfg( use_scaled_init=use_scaled_init, use_swiglu=use_swiglu, use_flash_attn=use_flash_attn, + num_experts=num_experts, ) return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg) diff --git a/internlm/model/moe.py b/internlm/model/moe.py new file mode 100644 index 0000000..3bb35bf --- /dev/null +++ b/internlm/model/moe.py @@ -0,0 +1,111 @@ + +from internlm.moe.sharded_moe import MOELayer, TopKGate +from internlm.moe.experts import Experts +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc + +from internlm.utils.logger import get_logger + +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import typing + +# global llm logger +logger = get_logger(__file__) + + +def has_moe_layers(m): + has_moe = False + num_experts = 0 + + for _, module in m.named_modules(): + if isinstance(module, MoE): + has_moe = True + num_experts = module.num_experts + break + return has_moe, num_experts + + +def is_moe_param(param: torch.Tensor) -> bool: + if hasattr(param, "allreduce") and not param.allreduce: + return True + return False + +class MoE(torch.nn.Module): + """Initialize an MoE layer. + + Arguments: + hidden_size (int): the hidden dimension of the model, importantly this is also the input and output dimension. + expert (torch.nn.Module): the torch module that defines the expert (e.g., MLP, torch.linear). + num_experts (int, optional): default=1, the total number of experts per layer. + ep_size (int, optional): default=1, number of ranks in the expert parallel world or group. + k (int, optional): default=1, top-k gating value, only supports k=1 or k=2. + capacity_factor (float, optional): default=1.0, the capacity of the expert at training time. + eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time. + min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor. + noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample' or 'None'. + drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to infinite capacity). + use_rts (bool, optional): default=True, whether to use Random Token Selection. + """ + + def __init__(self, + hidden_size, + expert, + num_experts=1, + ep_size=1, + k=1, + capacity_factor=1., + eval_capacity_factor=1., + min_capacity=4, + noisy_gate_policy: typing.Optional[str] = None, + drop_tokens: bool = True, + use_rts: bool = True, + using_default_moe: bool = True): + + super(MoE, self).__init__() + + assert num_experts % ep_size == 0, f"Number of experts ({num_experts}) should be divisible by expert parallel size ({ep_size})" + self.ep_size = ep_size + self.num_experts = num_experts + self.num_local_experts = num_experts // self.ep_size + + logger.info( + f'Creating MoE layer with num_experts: {num_experts} | num_local_experts: {self.num_local_experts} | expert_parallel_size: {self.ep_size}') + + assert noisy_gate_policy is None or noisy_gate_policy in ['None', 'Jitter', 'RSample'], \ + 'Unsupported noisy_gate_policy: ' + noisy_gate_policy + + experts = Experts(expert, self.num_local_experts) + + if using_default_moe: + self.moe_layer = MOELayer(TopKGate(hidden_size, num_experts, k, capacity_factor, eval_capacity_factor, + min_capacity, noisy_gate_policy, drop_tokens, use_rts), + experts, + gpc.get_group(ParallelMode.EXPERT), + self.ep_size, + self.num_local_experts) + + + def forward(self, hidden_states, used_token=None): + """ MoE forward + + Arguments: + hidden_states (Tensor): input to the layer + used_token (Tensor, optional): default: None, mask only used tokens + + Returns: + A tuple including output, gate loss, and expert count. + + * output (Tensor): output of the model + + * l_aux (Tensor): gate loss value + + * exp_counts (int): expert count + """ + output = self.moe_layer(hidden_states, used_token) + + return output, self.moe_layer.l_aux, self.moe_layer.exp_counts diff --git a/internlm/moe/experts.py b/internlm/moe/experts.py new file mode 100644 index 0000000..3b7af2c --- /dev/null +++ b/internlm/moe/experts.py @@ -0,0 +1,49 @@ +""" +The file has been adapted from the following files: +https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py + Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555 + We retain the following license from the original files: +""" + +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import copy +from torch.nn import Module, ModuleList +from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast + +class Experts(torch.nn.Module): + + def __init__(self, experts: Union[Module, ModuleList], num_local_experts=1): + super(Experts, self).__init__() + + # TODO: We can not deepcopy FeedForward since it contains a process_group in submodules + # self.experts = torch.nn.ModuleList([copy.deepcopy(expert) for i in range(num_local_experts)]) + + + if type(experts) == ModuleList: + self.experts = cast(ModuleList, experts) + else: + self.experts = ModuleList([experts]) + self.num_local_experts = num_local_experts + + # TODO: revisit allreduce for moe.gate... + for expert in self.experts: + # TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group) + for name, param in expert.named_parameters(): + param.all_reduce = False + + def forward(self, inputs): + chunks = inputs.chunk(self.num_local_experts, dim=1) + expert_outputs = [] + for chunk, expert in zip(chunks, self.experts): + out = expert(chunk) + if type(out) is tuple: + out = out[0] # Ignore the bias term for now + expert_outputs += [out] + + expert_output = torch.cat(expert_outputs, dim=1) + return expert_output diff --git a/internlm/moe/sharded_moe.py b/internlm/moe/sharded_moe.py new file mode 100644 index 0000000..01daecc --- /dev/null +++ b/internlm/moe/sharded_moe.py @@ -0,0 +1,503 @@ +import torch.distributed as dist + +from internlm.utils.logger import get_logger +from internlm.utils.megatron_timers import megatron_timer as timer +from internlm.core.context import global_context as gpc +from internlm.core.context import ParallelMode + + +# global llm logger +logger = get_logger(__file__) + +""" +The file has been adapted from the following files: +https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py + Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555 + We retain the following license from the original files: +""" + +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + + + +from typing import Callable, Dict, TYPE_CHECKING, Any, Optional, Tuple + +import torch +from torch import Tensor +from torch.nn import Module +import torch.nn.functional as F + +if TYPE_CHECKING: + Base = Module[Tensor] +else: + Base = Module + +uniform_map: Dict[torch.device, Callable] = {} +gumbel_map: Dict[torch.device, Callable] = {} +exp_selection_uniform_map: Dict[torch.device, Callable] = {} + + +def multiplicative_jitter(x, device: torch.device, epsilon=1e-2): + """ + Modified from switch transformer paper. mesh transformers + Multiply values by a random number between 1-epsilon and 1+epsilon. + Makes models more resilient to rounding errors introduced by bfloat16. + This seems particularly important for logits. + Args: + x: a torch.tensor + device: torch.device + epsilon: a floating point value + Returns: + a jittered x. + """ + if epsilon == 0: + return x + uniform = uniform_map.get(device) + if uniform is None: + uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - epsilon, device=device), + high=torch.tensor(1.0 + epsilon, + device=device)).rsample # type: ignore + uniform_map[device] = uniform + return x * uniform(x.shape) + + +def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor: + gumbel = gumbel_map.get(device) + if gumbel is None: + one = torch.tensor(1.0, device=device) + zero = torch.tensor(0.0, device=device) + gumbel = torch.distributions.gumbel.Gumbel(zero, one).rsample # type: ignore + gumbel_map[device] = gumbel + return gumbel(shape) + +# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity +# See https://arxiv.org/pdf/2006.16668.pdf for details. + + +# Based on https://github.com/pytorch/pytorch/pull/40762 +class _AllToAll(torch.autograd.Function): + + @staticmethod + def forward( + ctx: Any, + # TODO: replace with DS process group + group: torch.distributed.ProcessGroup, + input: Tensor) -> Tensor: # type: ignore + ctx.group = group + input = input.contiguous() + output = torch.empty_like(input) + dist.all_to_all_single(output, input, group=group) + return output + + @staticmethod + def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]: + return (None, _AllToAll.apply(ctx.group, *grad_output)) + + +# einsum rewrites are on par or more performant +# switch can be bubbled up in future +USE_EINSUM = True + + +# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity +# See https://arxiv.org/pdf/2006.16668.pdf for details. +def einsum(rule, a, b): + if USE_EINSUM: + return torch.einsum(rule, a, b) + elif rule == 's,se->se': + ## [1, s] * [s, e] + return a.reshape(a.shape[0], -1) * b + elif rule == 'se,sc->sec': + ## [s,e,1] * [s,1,c] + return a.unsqueeze(2) * b.unsqueeze(1) + elif rule == 'se,se->s': + ## [s,1,e] * [s,e,1] + return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1) + elif rule == 'sec,sm->ecm': + ## [e*c, s] * [s, m] + s = a.shape[0] + e = a.shape[1] + c = a.shape[2] + m = b.shape[1] + return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m) + elif rule == 'sec,ecm->sm': + ## [s, e*c] * [e*c, m] + return torch.matmul(a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1])) + elif rule == 'ks,ksm->sm': + k = b.shape[0] + s = b.shape[1] + m = b.shape[2] + # [k, s] -> [s, k] -> [s, 1, k] + a = a.t().unsqueeze(1) + # [k,s,m] -> [k, sm] -> [sm, k] -> [s, m, k] + b = b.reshape(k, -1).t().reshape(s, m, k) + # bmm([s, 1, k], [s, m, k]^t) -> [s, m, 1] + return torch.bmm(a, b.transpose(1, 2)).squeeze(2) + else: + return torch.einsum(rule, a, b) + + +# The following functions are extracted and scripted +# because otherwise during a torch.jit.trace, the non-Tensor +# values used in the calculations get recorded as constants. +# torch.jit.script coerces them into Tensors and preserves +# their dynamic shapes. This enables ONNX export. +# We can't script the entire top1gating function because it +# includes stateful caching logic which is incompatible with ONNX. + + +@torch.jit.script +def _capacity(gates: Tensor, capacity_factor: Tensor, min_capacity: Tensor) -> Tensor: + # gates has shape of SE + num_tokens = gates.shape[0] + num_experts = gates.shape[1] + # to(torch.int64) works around a bug in torch.onnx.export: + # it should cast k to int64 when converting torch.topk but it doesn't. + capacity = torch.ceil((num_tokens / num_experts) * capacity_factor).to(torch.int64) + if capacity < min_capacity: + capacity = min_capacity.to(torch.int64) + return capacity + + +@torch.jit.script +def _top_idx(source, k): + return torch.topk(source, k=k, dim=0)[1] + + +@torch.jit.script +def _one_hot_to_float(x, num_classes): + return F.one_hot(x, num_classes=num_classes).float() + + +def top1gating(logits: Tensor, + capacity_factor: float, + min_capacity: int, + used_token: Tensor = None, + noisy_gate_policy: Optional[str] = None, + drop_tokens: bool = True, + use_rts: bool = True, + use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Implements Top1Gating on logits.""" + if noisy_gate_policy == 'RSample': + logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device) + # everything is in fp32 in this function + gates = F.softmax(logits, dim=1) + + capacity = _capacity(gates, torch.tensor(capacity_factor), torch.tensor(min_capacity)) + + # Create a mask for 1st's expert per token + # noisy gating + indices1_s = torch.argmax(logits_w_noise if noisy_gate_policy == 'RSample' else gates, dim=1) + num_experts = int(gates.shape[1]) + mask1 = F.one_hot(indices1_s, num_classes=num_experts) + + # mask only used tokens + if used_token is not None: + mask1 = einsum("s,se->se", used_token, mask1) + + # gating decisions + exp_counts = torch.sum(mask1, dim=0).detach().to('cpu') + + # if we don't want to drop any tokens + if not drop_tokens: + new_capacity = torch.max(exp_counts).to(logits.device) + dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group()) + capacity = new_capacity + + # Compute l_aux + me = torch.mean(gates, dim=0) + ce = torch.mean(mask1.float(), dim=0) + l_aux = torch.sum(me * ce) * num_experts + + # Random Token Selection + if use_rts: + uniform = exp_selection_uniform_map.get(logits.device) + if uniform is None: + uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=logits.device), + high=torch.tensor(1.0, device=logits.device)).rsample + exp_selection_uniform_map[logits.device] = uniform + + mask1_rand = mask1 * uniform(mask1.shape) + else: + mask1_rand = mask1 + + assert logits.shape[ + 0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size." + + top_idx = _top_idx(mask1_rand, capacity) #@wenwen: token index + + new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1) + mask1 = new_mask1 + + if use_tutel: + # Tutel doesn't support index values masked with zero + # so we need to replace masked indices with -1 + indices_mask = mask1.sum(dim=1) * num_experts - 1 + indices1_s = torch.min(indices1_s, indices_mask) + + # Compute locations in capacity buffer + + locations1 = torch.cumsum(mask1, dim=0) - 1 + + if use_tutel: + gates1_s = (gates * mask1).sum(dim=1) + locations1_s = torch.sum(locations1 * mask1, dim=1) + return l_aux, capacity, num_experts, [ + indices1_s, + ], [ + locations1_s, + ], [ + gates1_s, + ], exp_counts + + # Store the capacity location for each token + locations1_s = torch.sum(locations1 * mask1, dim=1) + + # Normalize gate probabilities + mask1_float = mask1.float() + gates = gates * mask1_float + + locations1_sc = _one_hot_to_float(locations1_s, capacity) + combine_weights = einsum("se,sc->sec", gates, locations1_sc) + + dispatch_mask = combine_weights.bool() + + return l_aux, combine_weights, dispatch_mask, exp_counts + + +def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Implements Top2Gating on logits.""" + # everything is in fp32 in this function + gates = F.softmax(logits, dim=1) + + capacity = _capacity(gates, torch.tensor(capacity_factor * 2), torch.tensor(min_capacity)) + + # Create a mask for 1st's expert per token + indices1_s = torch.argmax(gates, dim=1) + num_experts = int(gates.shape[1]) + mask1 = F.one_hot(indices1_s, num_classes=num_experts) + + # Create a mask for 2nd's expert per token using Gumbel-max trick + # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/ + logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device) + # Replace top-expert with min value + logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf")) + indices2_s = torch.argmax(logits_except1, dim=1) + mask2 = F.one_hot(indices2_s, num_classes=num_experts) + + # Compute locations in capacity buffer + locations1 = torch.cumsum(mask1, dim=0) - 1 + locations2 = torch.cumsum(mask2, dim=0) - 1 + # Update 2nd's location by accounting for locations of 1st + locations2 += torch.sum(mask1, dim=0, keepdim=True) + + # gating decisions + exp_counts = torch.sum(mask1, dim=0).detach().to('cpu') + + # Compute l_aux + me = torch.mean(gates, dim=0) + ce = torch.mean(mask1.float(), dim=0) + l_aux = torch.mean(me * ce) * num_experts * num_experts + + # Remove locations outside capacity from mask + mask1 *= torch.lt(locations1, capacity) + mask2 *= torch.lt(locations2, capacity) + + # Store the capacity location for each token + locations1_s = torch.sum(locations1 * mask1, dim=1) + locations2_s = torch.sum(locations2 * mask2, dim=1) + + # Normalize gate probabilities + mask1_float = mask1.float() + mask2_float = mask2.float() + gates1_s = einsum("se,se->s", gates, mask1_float) + gates2_s = einsum("se,se->s", gates, mask2_float) + denom_s = gates1_s + gates2_s + # Avoid divide-by-zero + denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps) + gates1_s /= denom_s + gates2_s /= denom_s + + # Calculate combine_weights and dispatch_mask + gates1 = einsum("s,se->se", gates1_s, mask1_float) + gates2 = einsum("s,se->se", gates2_s, mask2_float) + locations1_sc = _one_hot_to_float(locations1_s, capacity) + locations2_sc = _one_hot_to_float(locations2_s, capacity) + combine1_sec = einsum("se,sc->sec", gates1, locations1_sc) + combine2_sec = einsum("se,sc->sec", gates2, locations2_sc) + combine_weights = combine1_sec + combine2_sec + dispatch_mask = combine_weights.bool() + + return l_aux, combine_weights, dispatch_mask, exp_counts + + +class TopKGate(Module): + """Gate module which implements Top2Gating as described in Gshard_. + :: + + gate = TopKGate(model_dim, num_experts) + l_aux, combine_weights, dispatch_mask = gate(input) + + .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf + + Args: + model_dim (int): + size of model embedding dimension + num_experts (ints): + number of experts in model + """ + + wg: torch.nn.Linear + + def __init__(self, + model_dim: int, + num_experts: int, + k: int = 1, + capacity_factor: float = 1.0, + eval_capacity_factor: float = 1.0, + min_capacity: int = 8, + noisy_gate_policy: Optional[str] = None, + drop_tokens: bool = True, + use_rts: bool = True) -> None: + super().__init__() + + # Only top-1 and top-2 are supported at the moment. + if k != 1 and k != 2: + raise ValueError('Only top-1 and top-2 gatings are supported.') + self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float() + self.k = k + self.capacity_factor = capacity_factor + self.eval_capacity_factor = eval_capacity_factor + self.min_capacity = min_capacity + self.noisy_gate_policy = noisy_gate_policy + self.wall_clock_breakdown = False + self.gate_time = 0.0 + self.drop_tokens = drop_tokens + self.use_rts = use_rts + + def forward(self, + input: torch.Tensor, + used_token: torch.Tensor = None, + use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore + + if self.wall_clock_breakdown: + timer('TopKGate').start() + + if self.wg.weight.dtype != torch.float32: + self.wg = self.wg.float() + input_fp32 = input.float() + # input jittering + if self.noisy_gate_policy == 'Jitter' and self.training: + input_fp32 = multiplicative_jitter(input_fp32, device=input.device) + logits = self.wg(input_fp32) + + if self.k == 1: + gate_output = top1gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor, + self.min_capacity, used_token, self.noisy_gate_policy if self.training else None, + self.drop_tokens, self.use_rts, use_tutel) + + else: + gate_output = top2gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor, + self.min_capacity) + + if self.wall_clock_breakdown: + timer('TopKGate').stop() + self.gate_time = timer('TopKGate').elapsed(reset=False) + + return gate_output + + +class MOELayer(Base): + """MOELayer module which implements MixtureOfExperts as described in Gshard_. + :: + + gate = TopKGate(model_dim, num_experts) + moe = MOELayer(gate, expert) + output = moe(input) + l_aux = moe.l_aux + + .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf + + Args: + gate (torch.nn.Module): + gate network + expert (torch.nn.Module): + expert network + """ + + def __init__(self, + gate: Module, + experts: Module, + ep_group, + ep_size, + num_local_experts: int) -> None: + super().__init__() + self.gate = gate + self.experts = experts + self.ep_group = ep_group + self.ep_size = ep_size + self.num_local_experts = num_local_experts + self.time_falltoall = 0.0 + self.time_salltoall = 0.0 + self.time_moe = 0.0 + self.wall_clock_breakdown = False + + + def _set_ep_group(self, ep_group): + self.ep_group = ep_group + + def forward(self, *input: Tensor, **kwargs: Any) -> Tensor: + + if self.wall_clock_breakdown: + timer('moe').start() + + # Implement Algorithm 2 from GShard paper. + d_model = input[0].shape[-1] + + # Initial implementation -> Reshape into S tokens by dropping sequence dimension. + # Reshape into G groups so that each group can distribute tokens equally + # group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1 + reshaped_input = input[0].reshape(-1, d_model) + + self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input, input[1]) + dispatched_input = einsum("sec,sm->ecm", dispatch_mask.type_as(input[0]), reshaped_input) ## TODO: heavy memory usage due to long sequence length + + if self.wall_clock_breakdown: + timer('falltoall').start() + + + dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input) + + if self.wall_clock_breakdown: + timer('falltoall').stop() + self.time_falltoall = timer('falltoall').elapsed(reset=False) + + # Re-shape after all-to-all: ecm -> gecm + dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model) + + expert_output = self.experts(dispatched_input) + + if self.wall_clock_breakdown: + timer('salltoall').start() + + expert_output = _AllToAll.apply(self.ep_group, expert_output) + + if self.wall_clock_breakdown: + timer('salltoall').stop() + self.time_salltoall = timer('salltoall').elapsed(reset=False) + + # Re-shape back: gecm -> ecm + expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model) + + combined_output = einsum("sec,ecm->sm", combine_weights.type_as(input[0]), expert_output) + + a = combined_output.reshape(input[0].shape) + + if self.wall_clock_breakdown: + timer('moe').stop() + self.time_moe = timer('moe').elapsed(reset=False) + + return a diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 116ffc2..41ab97c 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -25,6 +25,7 @@ from internlm.solver.optimizer.utils import ( split_half_float_double, sync_param, ) +from internlm.model.moe import is_moe_param from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer @@ -285,7 +286,7 @@ class HybridZeroOptimizer(BaseOptimizer): for group_id in range(self.num_param_groups): param_group = self._fp16_param_groups[group_id] for param in param_group: - if param.requires_grad: + if param.requires_grad and not is_moe_param(param): reduce_rank = None def _define_and_attach(param, reduce_rank=None): diff --git a/train.py b/train.py index 59729e7..675cc77 100644 --- a/train.py +++ b/train.py @@ -580,7 +580,7 @@ def main(args): # do forward and backward timer("fwd-bwd").start() - _, _, loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False) + _, _, loss = trainer.execute_schedule(batch, forward_only=False, return_loss=True, return_output_label=False, moe_loss_coeff = gpc.config.loss.moe_loss_coeff) timer("fwd-bwd").stop() # update parameters, and returns (success_update, grad_norm)