implement overlap moe forward

pull/506/head
Wenwen Qu 2023-11-16 19:43:47 +08:00 committed by Qu Wenwen
parent 21624f6f81
commit d20aa41d86
6 changed files with 287 additions and 71 deletions

View File

@ -144,6 +144,8 @@ model = dict(
num_experts=8,
moe_use_residual=False,
moe_gate_k=2,
use_tutel=False,
moe_overlap_degree=1,
)
"""
zero1 parallel:

View File

@ -293,6 +293,10 @@ def args_sanity_check():
model._add_item("moe_use_residual", False)
if "moe_gate_k" not in model:
model._add_item("moe_gate_k", 2)
if "use_tutel" not in model:
model._add_item("use_tutel", False)
if "moe_overlap_degree" not in model:
model._add_item("moe_overlap_degree", 1)
# process the parallel config
if "sequence_parallel" not in gpc.config.parallel:
gpc.config.parallel._add_item("sequence_parallel", False)

View File

@ -2,6 +2,7 @@
# -*- encoding: utf-8 -*-
import math
import os
from typing import Optional
import torch
@ -22,6 +23,7 @@ from internlm.model.linear import (
from internlm.model.moe import MoE
from internlm.model.multi_head_attention import MHA
from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm
from internlm.moe.sharded_moe import TUTEL_INSTALLED
from internlm.solver.pipeline_utils import partition_uniform
from internlm.utils.checkpoint import activation_checkpoint
from internlm.utils.common import filter_kwargs
@ -567,6 +569,8 @@ def build_model_with_moe_cfg(
moe_drop_tokens: bool = True,
moe_use_rts: bool = True,
moe_use_residual: bool = False,
use_tutel=False,
moe_overlap_degree=1, # pylint: disable=W0613
):
"""
Build model with config.
@ -608,6 +612,8 @@ def build_model_with_moe_cfg(
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
(https://arxiv.org/abs/2201.05596) layer.
use_tutel (bool): whether to use tutel to do moe overlap
moe_overlap_degree (int): enable moe computation and communication overlap if >=1.
"""
cfg = dict(
@ -643,4 +649,16 @@ def build_model_with_moe_cfg(
moe_use_residual=moe_use_residual,
)
if use_tutel:
if TUTEL_INSTALLED:
# NOTE: disable 2DH alltoall communication for current version
os.environ["LOCAL_SIZE"] = "1"
elif gpc.is_rank_for_log():
logger.warning(
"Try import tutel failed! Package import error, if tutel is not installed, please implement: "
"python3 -m pip install --verbose --upgrade git+https://github.com/microsoft/tutel@main, "
"Ref: https://github.com/microsoft/tutel. "
)
logger.warning("Using default moe overlap strategy. Please note this!!!")
return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)

View File

@ -0,0 +1,154 @@
from typing import Any, Tuple
import torch
import torch.distributed as dist
from torch import Tensor
from internlm.core.context import ParallelMode
from internlm.core.context.parallel_context import global_context as gpc
# Based on https://github.com/pytorch/pytorch/pull/40762
class moe_all_to_all(torch.autograd.Function):
"""
All to all communication
"""
@staticmethod
def forward(
ctx: Any,
group: torch.distributed.ProcessGroup,
inputs: Tensor,
) -> Tensor: # type: ignore
ctx.group = group
inputs = inputs.contiguous()
output = torch.empty_like(inputs)
dist.all_to_all_single(output, inputs, group=group)
return output
@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
return (None, moe_all_to_all.apply(ctx.group, *grad_output))
class moe_stream_acquire(torch.autograd.Function):
"""
switch to stream
"""
@staticmethod
def forward(
ctx: Any,
stream,
event,
):
ctx.origin_stream = torch.cuda.current_stream()
ctx.event = event
event.wait(stream)
torch.cuda.set_stream(stream)
@staticmethod
def backward(ctx: Any):
ctx.event.record(ctx.origin_stream)
torch.cuda.set_stream(ctx.origin_stream)
return None, None
class moe_stream_release(torch.autograd.Function):
"""
switch back to stream
"""
@staticmethod
def forward(
ctx: Any,
stream,
event,
) -> Tensor: # type: ignore
ctx.origin_stream = stream
ctx.event = event
event.record(stream)
torch.cuda.set_stream(torch.cuda.default_stream())
@staticmethod
def backward(ctx: Any):
ctx.event.wait(ctx.origin_stream)
torch.cuda.set_stream(ctx.origin_stream)
return None, None
# NOTE: no use due to workload less than 1M
# # Based on https://arxiv.org/pdf/2206.03382.pdf
def _2DHAllToAll(inputs):
output = torch.empty_like(inputs)
length = inputs.shape[0]
slice_size = length // gpc.get_world_size(ParallelMode.EXPERT)
ngpus = 8 # TODO: should set by user
nnodes = gpc.get_world_size(ParallelMode.EXPERT) // ngpus
# phase 0. per-gpu (ngpus) stride copy
width = nnodes
height = ngpus
for i in range(length):
index = i / slice_size
offset = i % slice_size
j = int((width * (index % height) + (index / height)) * slice_size + offset)
output[j] = inputs[i]
# print("after intra swap from rank ", gpc.get_global_rank(), " : ", output, flush=True)
# phase 1. intra-node alltoall
reqs = []
node_rank = int(gpc.get_local_rank(ParallelMode.EXPERT) / ngpus)
for i in range(ngpus):
reqs.append(
dist.P2POp(
dist.isend, output[i * nnodes * slice_size : (i + 1) * nnodes * slice_size], i + node_rank * ngpus
)
)
reqs.append(
dist.P2POp(
dist.irecv, inputs[i * nnodes * slice_size : (i + 1) * nnodes * slice_size], i + node_rank * ngpus
)
)
if len(reqs) > 0:
reqs = dist.batch_isend_irecv(reqs)
for req in reqs:
req.wait()
# print("after intra communication from rank ", gpc.get_global_rank(), " : ", inputs, flush=True)
# phase 2. per-gpu (nnodes) stride copy
width = ngpus
height = nnodes
for i in range(length):
index = i / slice_size
offset = i % slice_size
j = int((width * (index % height) + (index / height)) * slice_size + offset)
output[j] = inputs[i]
# print("after inter swap from rank ", gpc.get_global_rank(), " : ", output, flush=True)
# phase 3. inter-node alltoall
reqs = []
node_rank = int(gpc.get_local_rank(ParallelMode.EXPERT) / ngpus)
g_local_rank = int(gpc.get_local_rank(ParallelMode.EXPERT) % ngpus)
for i in range(nnodes):
reqs.append(
dist.P2POp(
dist.isend, output[i * ngpus * slice_size : (i + 1) * ngpus * slice_size], i * ngpus + g_local_rank
)
)
reqs.append(
dist.P2POp(
dist.irecv, inputs[i * ngpus * slice_size : (i + 1) * ngpus * slice_size], i * ngpus + g_local_rank
)
)
if len(reqs) > 0:
reqs = dist.batch_isend_irecv(reqs)
for req in reqs:
req.wait()
# print("after inter communication from rank ", gpc.get_global_rank(), " : ", inputs, flush=True)
return inputs

View File

@ -0,0 +1,72 @@
import torch
from .communication import moe_all_to_all, moe_stream_acquire, moe_stream_release
def no_overlap_moe_forward(inputs, expert_fn, ep_group, ep_size, num_local_experts, d_model):
"""
Preform moe forward computation sequentially.
For example:
alltoall(d)---->expert_fn(d)--->alltoall(d)
"""
inputs = moe_all_to_all.apply(ep_group, inputs)
# Re-shape after all-to-all: ecm -> gecm
inputs = inputs.reshape(ep_size, num_local_experts, -1, d_model)
expert_output = expert_fn(inputs)
expert_output = moe_all_to_all.apply(ep_group, expert_output)
return expert_output
def overlap_moe_forward(inputs, expert_fn, a2a_ffn_overlap_degree, ep_group, ep_size, num_local_experts, d_model):
"""
Split the input based on a2a_ffn_overlap_degree and then execute the alltoall and experts function
on different stream to overlap the communication and computation cost.
For example:
communication stream: alltoall(d[0])---->alltoall(d[1])---->alltoall(d[0])---->alltoall(d[1])
computation stream: expert_fn(d[0]) ----> expert_fn(d[1])
"""
# inputs shape: (e,c,m). split the inputs on 'c' dimension
input_chunks = inputs.chunk(a2a_ffn_overlap_degree, dim=1)
expert_inputs = [None for _ in range(a2a_ffn_overlap_degree)]
expert_outputs = [None for _ in range(a2a_ffn_overlap_degree)]
ready_events = [torch.cuda.Event() for _ in range(a2a_ffn_overlap_degree)]
alltoall_stream = [torch.cuda.Stream(torch.cuda.current_device()) for _ in range(a2a_ffn_overlap_degree)]
experts_stream = [torch.cuda.Stream(torch.cuda.current_device()) for _ in range(a2a_ffn_overlap_degree)]
# NOTE: async alltoall seems unable to improve the performance
# first all2all, execute on alltoall streams
for i, input_split in enumerate(input_chunks):
moe_stream_release.apply(torch.cuda.default_stream(), ready_events[i])
moe_stream_acquire.apply(alltoall_stream[i], ready_events[i])
expert_inputs[i] = moe_all_to_all.apply(ep_group, input_split)
moe_stream_release.apply(alltoall_stream[i], ready_events[i])
# expert function, execute on experts stream
for i in range(a2a_ffn_overlap_degree):
moe_stream_acquire.apply(experts_stream[i], ready_events[i])
# Re-shape after all-to-all: ecm -> gecm
expert_inputs[i] = expert_inputs[i].reshape(ep_size, num_local_experts, -1, d_model)
expert_outputs[i] = expert_fn(expert_inputs[i])
moe_stream_release.apply(experts_stream[i], ready_events[i])
# second all2all, execute on alltoall streams
for i in range(a2a_ffn_overlap_degree):
moe_stream_acquire.apply(alltoall_stream[i], ready_events[i])
expert_outputs[i] = moe_all_to_all.apply(ep_group, expert_outputs[i])
moe_stream_release.apply(alltoall_stream[i], ready_events[i])
moe_stream_acquire.apply(torch.cuda.default_stream(), ready_events[i])
# expert_outputs shape: (g, e,c,m). cat the outputs on 'c' dimension
expert_output_gathered = torch.cat(expert_outputs, dim=2)
return expert_output_gathered

View File

@ -4,7 +4,7 @@ https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py
Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555
We retain the following license from the original files:
"""
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple
from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple
import torch
import torch.distributed as dist
@ -12,8 +12,17 @@ import torch.nn.functional as F
from torch import Tensor
from torch.nn import Module
from internlm.core.context.parallel_context import global_context as gpc
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from .forward_func import no_overlap_moe_forward, overlap_moe_forward
try:
from tutel.impls.overlap import a2a_ffn_overlap_forward as tutel_overlap_moe_forward
TUTEL_INSTALLED = True
except (ModuleNotFoundError, ImportError):
TUTEL_INSTALLED = False
# global llm logger
logger = get_logger(__file__)
@ -62,30 +71,6 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
return gumbel(shape)
# Based on https://github.com/pytorch/pytorch/pull/40762
class _AllToAll(torch.autograd.Function):
"""
All to all communication
"""
@staticmethod
def forward(
ctx: Any,
# TODO: replace with DS process group
group: torch.distributed.ProcessGroup,
inputs: Tensor,
) -> Tensor: # type: ignore
ctx.group = group
inputs = inputs.contiguous()
output = torch.empty_like(inputs)
dist.all_to_all_single(output, inputs, group=group)
return output
@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
return (None, _AllToAll.apply(ctx.group, *grad_output))
# einsum rewrites are on par or more performant
# switch can be bubbled up in future
USE_EINSUM = True
@ -347,18 +332,12 @@ class TopKGate(Module):
self.eval_capacity_factor = eval_capacity_factor
self.min_capacity = min_capacity
self.noisy_gate_policy = noisy_gate_policy
self.wall_clock_breakdown = False
self.gate_time = 0.0
self.drop_tokens = drop_tokens
self.use_rts = use_rts
def forward(
self, inputs: torch.Tensor, used_token: torch.Tensor = None
) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore
if self.wall_clock_breakdown:
timer("TopKGate").start()
# input jittering
if self.noisy_gate_policy == "Jitter" and self.training:
inputs = multiplicative_jitter(inputs, device=inputs.device)
@ -380,10 +359,6 @@ class TopKGate(Module):
logits, self.capacity_factor if self.training else self.eval_capacity_factor, self.min_capacity
)
if self.wall_clock_breakdown:
timer("TopKGate").stop()
self.gate_time = timer("TopKGate").elapsed(reset=False)
return gate_output
@ -412,16 +387,13 @@ class MOELayer(Base):
self.ep_group = ep_group
self.ep_size = ep_size
self.num_local_experts = num_local_experts
self.time_falltoall = 0.0
self.time_salltoall = 0.0
self.time_moe = 0.0
self.wall_clock_breakdown = False
self.use_tutel = gpc.config.model.use_tutel and TUTEL_INSTALLED
self.overlap_degree = gpc.config.model.moe_overlap_degree
# TODO tutel does not reshape inputs for each expert, so its logic will be different with current experts.py
assert (not self.use_tutel) or self.num_local_experts == 1, "only support num_local_experts=1 when enable tutel"
def forward(self, *inputs: Tensor) -> Tensor:
if self.wall_clock_breakdown:
timer("moe").start()
# Implement Algorithm 2 from GShard paper.
d_model = inputs[0].shape[-1]
@ -435,38 +407,32 @@ class MOELayer(Base):
"sec,sm->ecm", dispatch_mask.type_as(inputs[0]), reshaped_inputs
) # TODO: heavy memory usage due to long sequence length
if self.wall_clock_breakdown:
timer("falltoall").start()
dispatched_inputs = _AllToAll.apply(self.ep_group, dispatched_inputs)
if self.wall_clock_breakdown:
timer("falltoall").stop()
self.time_falltoall = timer("falltoall").elapsed(reset=False)
# Re-shape after all-to-all: ecm -> gecm
dispatched_inputs = dispatched_inputs.reshape(self.ep_size, self.num_local_experts, -1, d_model)
expert_output = self.experts(dispatched_inputs)
if self.wall_clock_breakdown:
timer("salltoall").start()
expert_output = _AllToAll.apply(self.ep_group, expert_output)
if self.wall_clock_breakdown:
timer("salltoall").stop()
self.time_salltoall = timer("salltoall").elapsed(reset=False)
if self.overlap_degree == 1:
expert_output = no_overlap_moe_forward(
dispatched_inputs, self.experts, self.ep_group, self.ep_size, self.num_local_experts, d_model
)
elif self.overlap_degree > 1 and not self.use_tutel:
expert_output = overlap_moe_forward(
dispatched_inputs,
self.experts,
self.overlap_degree,
self.ep_group,
self.ep_size,
self.num_local_experts,
d_model,
)
elif self.overlap_degree > 1 and self.use_tutel:
expert_output = tutel_overlap_moe_forward(
dispatched_inputs, self.experts, self.overlap_degree, True, self.ep_group
)
else:
assert False, "unsupported moe forward strategy"
# Re-shape back: gecm -> ecm
expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)
combined_output = einsum("sec,ecm->sm", combine_weights.type_as(inputs[0]), expert_output)
combined_output = einsum("sec,ecm->sm", combine_weights.type_as(inputs[0]), expert_output.type_as(inputs[0]))
out = combined_output.reshape(inputs[0].shape)
if self.wall_clock_breakdown:
timer("moe").stop()
self.time_moe = timer("moe").elapsed(reset=False)
return out