diff --git a/configs/7B_MoE4_sft.py b/configs/7B_MoE8_sft.py similarity index 98% rename from configs/7B_MoE4_sft.py rename to configs/7B_MoE8_sft.py index 92a93d0..86d2e50 100644 --- a/configs/7B_MoE4_sft.py +++ b/configs/7B_MoE8_sft.py @@ -144,6 +144,8 @@ model = dict( num_experts=8, moe_use_residual=False, moe_gate_k=2, + use_tutel=False, + moe_overlap_degree=1, ) """ zero1 parallel: diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index ad404f2..5d04c20 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -293,6 +293,10 @@ def args_sanity_check(): model._add_item("moe_use_residual", False) if "moe_gate_k" not in model: model._add_item("moe_gate_k", 2) + if "use_tutel" not in model: + model._add_item("use_tutel", False) + if "moe_overlap_degree" not in model: + model._add_item("moe_overlap_degree", 1) # process the parallel config if "sequence_parallel" not in gpc.config.parallel: gpc.config.parallel._add_item("sequence_parallel", False) diff --git a/internlm/model/modeling_moe.py b/internlm/model/modeling_moe.py index df6c7a8..175f7ff 100644 --- a/internlm/model/modeling_moe.py +++ b/internlm/model/modeling_moe.py @@ -2,6 +2,7 @@ # -*- encoding: utf-8 -*- import math +import os from typing import Optional import torch @@ -22,6 +23,7 @@ from internlm.model.linear import ( from internlm.model.moe import MoE from internlm.model.multi_head_attention import MHA from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm +from internlm.moe.sharded_moe import TUTEL_INSTALLED from internlm.solver.pipeline_utils import partition_uniform from internlm.utils.checkpoint import activation_checkpoint from internlm.utils.common import filter_kwargs @@ -567,6 +569,8 @@ def build_model_with_moe_cfg( moe_drop_tokens: bool = True, moe_use_rts: bool = True, moe_use_residual: bool = False, + use_tutel=False, + moe_overlap_degree=1, # pylint: disable=W0613 ): """ Build model with config. @@ -608,6 +612,8 @@ def build_model_with_moe_cfg( moe_use_rts (bool, optional): default=True, whether to use Random Token Selection. moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE (https://arxiv.org/abs/2201.05596) layer. + use_tutel (bool): whether to use tutel to do moe overlap + moe_overlap_degree (int): enable moe computation and communication overlap if >=1. """ cfg = dict( @@ -643,4 +649,16 @@ def build_model_with_moe_cfg( moe_use_residual=moe_use_residual, ) + if use_tutel: + if TUTEL_INSTALLED: + # NOTE: disable 2DH alltoall communication for current version + os.environ["LOCAL_SIZE"] = "1" + elif gpc.is_rank_for_log(): + logger.warning( + "Try import tutel failed! Package import error, if tutel is not installed, please implement: " + "python3 -m pip install --verbose --upgrade git+https://github.com/microsoft/tutel@main, " + "Ref: https://github.com/microsoft/tutel. " + ) + logger.warning("Using default moe overlap strategy. Please note this!!!") + return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg) diff --git a/internlm/moe/communication.py b/internlm/moe/communication.py new file mode 100644 index 0000000..728f590 --- /dev/null +++ b/internlm/moe/communication.py @@ -0,0 +1,154 @@ +from typing import Any, Tuple + +import torch +import torch.distributed as dist +from torch import Tensor + +from internlm.core.context import ParallelMode +from internlm.core.context.parallel_context import global_context as gpc + + +# Based on https://github.com/pytorch/pytorch/pull/40762 +class moe_all_to_all(torch.autograd.Function): + """ + All to all communication + """ + + @staticmethod + def forward( + ctx: Any, + group: torch.distributed.ProcessGroup, + inputs: Tensor, + ) -> Tensor: # type: ignore + ctx.group = group + inputs = inputs.contiguous() + output = torch.empty_like(inputs) + dist.all_to_all_single(output, inputs, group=group) + return output + + @staticmethod + def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]: + return (None, moe_all_to_all.apply(ctx.group, *grad_output)) + + +class moe_stream_acquire(torch.autograd.Function): + """ + switch to stream + """ + + @staticmethod + def forward( + ctx: Any, + stream, + event, + ): + ctx.origin_stream = torch.cuda.current_stream() + ctx.event = event + event.wait(stream) + torch.cuda.set_stream(stream) + + @staticmethod + def backward(ctx: Any): + ctx.event.record(ctx.origin_stream) + torch.cuda.set_stream(ctx.origin_stream) + return None, None + + +class moe_stream_release(torch.autograd.Function): + """ + switch back to stream + """ + + @staticmethod + def forward( + ctx: Any, + stream, + event, + ) -> Tensor: # type: ignore + ctx.origin_stream = stream + ctx.event = event + event.record(stream) + torch.cuda.set_stream(torch.cuda.default_stream()) + + @staticmethod + def backward(ctx: Any): + ctx.event.wait(ctx.origin_stream) + torch.cuda.set_stream(ctx.origin_stream) + return None, None + + +# NOTE: no use due to workload less than 1M +# # Based on https://arxiv.org/pdf/2206.03382.pdf +def _2DHAllToAll(inputs): + output = torch.empty_like(inputs) + length = inputs.shape[0] + slice_size = length // gpc.get_world_size(ParallelMode.EXPERT) + ngpus = 8 # TODO: should set by user + nnodes = gpc.get_world_size(ParallelMode.EXPERT) // ngpus + + # phase 0. per-gpu (ngpus) stride copy + width = nnodes + height = ngpus + for i in range(length): + index = i / slice_size + offset = i % slice_size + j = int((width * (index % height) + (index / height)) * slice_size + offset) + output[j] = inputs[i] + # print("after intra swap from rank ", gpc.get_global_rank(), " : ", output, flush=True) + + # phase 1. intra-node alltoall + reqs = [] + node_rank = int(gpc.get_local_rank(ParallelMode.EXPERT) / ngpus) + for i in range(ngpus): + reqs.append( + dist.P2POp( + dist.isend, output[i * nnodes * slice_size : (i + 1) * nnodes * slice_size], i + node_rank * ngpus + ) + ) + reqs.append( + dist.P2POp( + dist.irecv, inputs[i * nnodes * slice_size : (i + 1) * nnodes * slice_size], i + node_rank * ngpus + ) + ) + + if len(reqs) > 0: + reqs = dist.batch_isend_irecv(reqs) + + for req in reqs: + req.wait() + # print("after intra communication from rank ", gpc.get_global_rank(), " : ", inputs, flush=True) + + # phase 2. per-gpu (nnodes) stride copy + width = ngpus + height = nnodes + for i in range(length): + index = i / slice_size + offset = i % slice_size + j = int((width * (index % height) + (index / height)) * slice_size + offset) + output[j] = inputs[i] + # print("after inter swap from rank ", gpc.get_global_rank(), " : ", output, flush=True) + + # phase 3. inter-node alltoall + reqs = [] + node_rank = int(gpc.get_local_rank(ParallelMode.EXPERT) / ngpus) + g_local_rank = int(gpc.get_local_rank(ParallelMode.EXPERT) % ngpus) + for i in range(nnodes): + reqs.append( + dist.P2POp( + dist.isend, output[i * ngpus * slice_size : (i + 1) * ngpus * slice_size], i * ngpus + g_local_rank + ) + ) + reqs.append( + dist.P2POp( + dist.irecv, inputs[i * ngpus * slice_size : (i + 1) * ngpus * slice_size], i * ngpus + g_local_rank + ) + ) + + if len(reqs) > 0: + reqs = dist.batch_isend_irecv(reqs) + + for req in reqs: + req.wait() + # print("after inter communication from rank ", gpc.get_global_rank(), " : ", inputs, flush=True) + + return inputs diff --git a/internlm/moe/forward_func.py b/internlm/moe/forward_func.py new file mode 100644 index 0000000..13e2ad5 --- /dev/null +++ b/internlm/moe/forward_func.py @@ -0,0 +1,72 @@ +import torch + +from .communication import moe_all_to_all, moe_stream_acquire, moe_stream_release + + +def no_overlap_moe_forward(inputs, expert_fn, ep_group, ep_size, num_local_experts, d_model): + """ + Preform moe forward computation sequentially. + For example: + alltoall(d)---->expert_fn(d)--->alltoall(d) + """ + + inputs = moe_all_to_all.apply(ep_group, inputs) + + # Re-shape after all-to-all: ecm -> gecm + inputs = inputs.reshape(ep_size, num_local_experts, -1, d_model) + expert_output = expert_fn(inputs) + + expert_output = moe_all_to_all.apply(ep_group, expert_output) + + return expert_output + + +def overlap_moe_forward(inputs, expert_fn, a2a_ffn_overlap_degree, ep_group, ep_size, num_local_experts, d_model): + """ + Split the input based on a2a_ffn_overlap_degree and then execute the alltoall and experts function + on different stream to overlap the communication and computation cost. + For example: + communication stream: alltoall(d[0])---->alltoall(d[1])---->alltoall(d[0])---->alltoall(d[1]) + computation stream: expert_fn(d[0]) ----> expert_fn(d[1]) + + """ + + # inputs shape: (e,c,m). split the inputs on 'c' dimension + input_chunks = inputs.chunk(a2a_ffn_overlap_degree, dim=1) + + expert_inputs = [None for _ in range(a2a_ffn_overlap_degree)] + expert_outputs = [None for _ in range(a2a_ffn_overlap_degree)] + + ready_events = [torch.cuda.Event() for _ in range(a2a_ffn_overlap_degree)] + alltoall_stream = [torch.cuda.Stream(torch.cuda.current_device()) for _ in range(a2a_ffn_overlap_degree)] + experts_stream = [torch.cuda.Stream(torch.cuda.current_device()) for _ in range(a2a_ffn_overlap_degree)] + + # NOTE: async alltoall seems unable to improve the performance + # first all2all, execute on alltoall streams + for i, input_split in enumerate(input_chunks): + moe_stream_release.apply(torch.cuda.default_stream(), ready_events[i]) + + moe_stream_acquire.apply(alltoall_stream[i], ready_events[i]) + expert_inputs[i] = moe_all_to_all.apply(ep_group, input_split) + moe_stream_release.apply(alltoall_stream[i], ready_events[i]) + + # expert function, execute on experts stream + for i in range(a2a_ffn_overlap_degree): + moe_stream_acquire.apply(experts_stream[i], ready_events[i]) + # Re-shape after all-to-all: ecm -> gecm + expert_inputs[i] = expert_inputs[i].reshape(ep_size, num_local_experts, -1, d_model) + expert_outputs[i] = expert_fn(expert_inputs[i]) + moe_stream_release.apply(experts_stream[i], ready_events[i]) + + # second all2all, execute on alltoall streams + for i in range(a2a_ffn_overlap_degree): + moe_stream_acquire.apply(alltoall_stream[i], ready_events[i]) + expert_outputs[i] = moe_all_to_all.apply(ep_group, expert_outputs[i]) + moe_stream_release.apply(alltoall_stream[i], ready_events[i]) + + moe_stream_acquire.apply(torch.cuda.default_stream(), ready_events[i]) + + # expert_outputs shape: (g, e,c,m). cat the outputs on 'c' dimension + expert_output_gathered = torch.cat(expert_outputs, dim=2) + + return expert_output_gathered diff --git a/internlm/moe/sharded_moe.py b/internlm/moe/sharded_moe.py index 5d695ac..14fd451 100644 --- a/internlm/moe/sharded_moe.py +++ b/internlm/moe/sharded_moe.py @@ -4,7 +4,7 @@ https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555 We retain the following license from the original files: """ -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple +from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple import torch import torch.distributed as dist @@ -12,8 +12,17 @@ import torch.nn.functional as F from torch import Tensor from torch.nn import Module +from internlm.core.context.parallel_context import global_context as gpc from internlm.utils.logger import get_logger -from internlm.utils.megatron_timers import megatron_timer as timer + +from .forward_func import no_overlap_moe_forward, overlap_moe_forward + +try: + from tutel.impls.overlap import a2a_ffn_overlap_forward as tutel_overlap_moe_forward + + TUTEL_INSTALLED = True +except (ModuleNotFoundError, ImportError): + TUTEL_INSTALLED = False # global llm logger logger = get_logger(__file__) @@ -62,30 +71,6 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor: return gumbel(shape) -# Based on https://github.com/pytorch/pytorch/pull/40762 -class _AllToAll(torch.autograd.Function): - """ - All to all communication - """ - - @staticmethod - def forward( - ctx: Any, - # TODO: replace with DS process group - group: torch.distributed.ProcessGroup, - inputs: Tensor, - ) -> Tensor: # type: ignore - ctx.group = group - inputs = inputs.contiguous() - output = torch.empty_like(inputs) - dist.all_to_all_single(output, inputs, group=group) - return output - - @staticmethod - def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]: - return (None, _AllToAll.apply(ctx.group, *grad_output)) - - # einsum rewrites are on par or more performant # switch can be bubbled up in future USE_EINSUM = True @@ -347,18 +332,12 @@ class TopKGate(Module): self.eval_capacity_factor = eval_capacity_factor self.min_capacity = min_capacity self.noisy_gate_policy = noisy_gate_policy - self.wall_clock_breakdown = False - self.gate_time = 0.0 self.drop_tokens = drop_tokens self.use_rts = use_rts def forward( self, inputs: torch.Tensor, used_token: torch.Tensor = None ) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore - - if self.wall_clock_breakdown: - timer("TopKGate").start() - # input jittering if self.noisy_gate_policy == "Jitter" and self.training: inputs = multiplicative_jitter(inputs, device=inputs.device) @@ -380,10 +359,6 @@ class TopKGate(Module): logits, self.capacity_factor if self.training else self.eval_capacity_factor, self.min_capacity ) - if self.wall_clock_breakdown: - timer("TopKGate").stop() - self.gate_time = timer("TopKGate").elapsed(reset=False) - return gate_output @@ -412,16 +387,13 @@ class MOELayer(Base): self.ep_group = ep_group self.ep_size = ep_size self.num_local_experts = num_local_experts - self.time_falltoall = 0.0 - self.time_salltoall = 0.0 - self.time_moe = 0.0 - self.wall_clock_breakdown = False + self.use_tutel = gpc.config.model.use_tutel and TUTEL_INSTALLED + self.overlap_degree = gpc.config.model.moe_overlap_degree + + # TODO tutel does not reshape inputs for each expert, so its logic will be different with current experts.py + assert (not self.use_tutel) or self.num_local_experts == 1, "only support num_local_experts=1 when enable tutel" def forward(self, *inputs: Tensor) -> Tensor: - - if self.wall_clock_breakdown: - timer("moe").start() - # Implement Algorithm 2 from GShard paper. d_model = inputs[0].shape[-1] @@ -435,38 +407,32 @@ class MOELayer(Base): "sec,sm->ecm", dispatch_mask.type_as(inputs[0]), reshaped_inputs ) # TODO: heavy memory usage due to long sequence length - if self.wall_clock_breakdown: - timer("falltoall").start() - - dispatched_inputs = _AllToAll.apply(self.ep_group, dispatched_inputs) - - if self.wall_clock_breakdown: - timer("falltoall").stop() - self.time_falltoall = timer("falltoall").elapsed(reset=False) - - # Re-shape after all-to-all: ecm -> gecm - dispatched_inputs = dispatched_inputs.reshape(self.ep_size, self.num_local_experts, -1, d_model) - - expert_output = self.experts(dispatched_inputs) - - if self.wall_clock_breakdown: - timer("salltoall").start() - - expert_output = _AllToAll.apply(self.ep_group, expert_output) - - if self.wall_clock_breakdown: - timer("salltoall").stop() - self.time_salltoall = timer("salltoall").elapsed(reset=False) + if self.overlap_degree == 1: + expert_output = no_overlap_moe_forward( + dispatched_inputs, self.experts, self.ep_group, self.ep_size, self.num_local_experts, d_model + ) + elif self.overlap_degree > 1 and not self.use_tutel: + expert_output = overlap_moe_forward( + dispatched_inputs, + self.experts, + self.overlap_degree, + self.ep_group, + self.ep_size, + self.num_local_experts, + d_model, + ) + elif self.overlap_degree > 1 and self.use_tutel: + expert_output = tutel_overlap_moe_forward( + dispatched_inputs, self.experts, self.overlap_degree, True, self.ep_group + ) + else: + assert False, "unsupported moe forward strategy" # Re-shape back: gecm -> ecm expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model) - combined_output = einsum("sec,ecm->sm", combine_weights.type_as(inputs[0]), expert_output) + combined_output = einsum("sec,ecm->sm", combine_weights.type_as(inputs[0]), expert_output.type_as(inputs[0])) out = combined_output.reshape(inputs[0].shape) - if self.wall_clock_breakdown: - timer("moe").stop() - self.time_moe = timer("moe").elapsed(reset=False) - return out