mirror of https://github.com/InternLM/InternLM
implement overlap moe forward
parent
21624f6f81
commit
d20aa41d86
|
@ -144,6 +144,8 @@ model = dict(
|
|||
num_experts=8,
|
||||
moe_use_residual=False,
|
||||
moe_gate_k=2,
|
||||
use_tutel=False,
|
||||
moe_overlap_degree=1,
|
||||
)
|
||||
"""
|
||||
zero1 parallel:
|
|
@ -293,6 +293,10 @@ def args_sanity_check():
|
|||
model._add_item("moe_use_residual", False)
|
||||
if "moe_gate_k" not in model:
|
||||
model._add_item("moe_gate_k", 2)
|
||||
if "use_tutel" not in model:
|
||||
model._add_item("use_tutel", False)
|
||||
if "moe_overlap_degree" not in model:
|
||||
model._add_item("moe_overlap_degree", 1)
|
||||
# process the parallel config
|
||||
if "sequence_parallel" not in gpc.config.parallel:
|
||||
gpc.config.parallel._add_item("sequence_parallel", False)
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import math
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
@ -22,6 +23,7 @@ from internlm.model.linear import (
|
|||
from internlm.model.moe import MoE
|
||||
from internlm.model.multi_head_attention import MHA
|
||||
from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm
|
||||
from internlm.moe.sharded_moe import TUTEL_INSTALLED
|
||||
from internlm.solver.pipeline_utils import partition_uniform
|
||||
from internlm.utils.checkpoint import activation_checkpoint
|
||||
from internlm.utils.common import filter_kwargs
|
||||
|
@ -567,6 +569,8 @@ def build_model_with_moe_cfg(
|
|||
moe_drop_tokens: bool = True,
|
||||
moe_use_rts: bool = True,
|
||||
moe_use_residual: bool = False,
|
||||
use_tutel=False,
|
||||
moe_overlap_degree=1, # pylint: disable=W0613
|
||||
):
|
||||
"""
|
||||
Build model with config.
|
||||
|
@ -608,6 +612,8 @@ def build_model_with_moe_cfg(
|
|||
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
|
||||
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
|
||||
(https://arxiv.org/abs/2201.05596) layer.
|
||||
use_tutel (bool): whether to use tutel to do moe overlap
|
||||
moe_overlap_degree (int): enable moe computation and communication overlap if >=1.
|
||||
"""
|
||||
|
||||
cfg = dict(
|
||||
|
@ -643,4 +649,16 @@ def build_model_with_moe_cfg(
|
|||
moe_use_residual=moe_use_residual,
|
||||
)
|
||||
|
||||
if use_tutel:
|
||||
if TUTEL_INSTALLED:
|
||||
# NOTE: disable 2DH alltoall communication for current version
|
||||
os.environ["LOCAL_SIZE"] = "1"
|
||||
elif gpc.is_rank_for_log():
|
||||
logger.warning(
|
||||
"Try import tutel failed! Package import error, if tutel is not installed, please implement: "
|
||||
"python3 -m pip install --verbose --upgrade git+https://github.com/microsoft/tutel@main, "
|
||||
"Ref: https://github.com/microsoft/tutel. "
|
||||
)
|
||||
logger.warning("Using default moe overlap strategy. Please note this!!!")
|
||||
|
||||
return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)
|
||||
|
|
|
@ -0,0 +1,154 @@
|
|||
from typing import Any, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context.parallel_context import global_context as gpc
|
||||
|
||||
|
||||
# Based on https://github.com/pytorch/pytorch/pull/40762
|
||||
class moe_all_to_all(torch.autograd.Function):
|
||||
"""
|
||||
All to all communication
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: Any,
|
||||
group: torch.distributed.ProcessGroup,
|
||||
inputs: Tensor,
|
||||
) -> Tensor: # type: ignore
|
||||
ctx.group = group
|
||||
inputs = inputs.contiguous()
|
||||
output = torch.empty_like(inputs)
|
||||
dist.all_to_all_single(output, inputs, group=group)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
|
||||
return (None, moe_all_to_all.apply(ctx.group, *grad_output))
|
||||
|
||||
|
||||
class moe_stream_acquire(torch.autograd.Function):
|
||||
"""
|
||||
switch to stream
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: Any,
|
||||
stream,
|
||||
event,
|
||||
):
|
||||
ctx.origin_stream = torch.cuda.current_stream()
|
||||
ctx.event = event
|
||||
event.wait(stream)
|
||||
torch.cuda.set_stream(stream)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any):
|
||||
ctx.event.record(ctx.origin_stream)
|
||||
torch.cuda.set_stream(ctx.origin_stream)
|
||||
return None, None
|
||||
|
||||
|
||||
class moe_stream_release(torch.autograd.Function):
|
||||
"""
|
||||
switch back to stream
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: Any,
|
||||
stream,
|
||||
event,
|
||||
) -> Tensor: # type: ignore
|
||||
ctx.origin_stream = stream
|
||||
ctx.event = event
|
||||
event.record(stream)
|
||||
torch.cuda.set_stream(torch.cuda.default_stream())
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any):
|
||||
ctx.event.wait(ctx.origin_stream)
|
||||
torch.cuda.set_stream(ctx.origin_stream)
|
||||
return None, None
|
||||
|
||||
|
||||
# NOTE: no use due to workload less than 1M
|
||||
# # Based on https://arxiv.org/pdf/2206.03382.pdf
|
||||
def _2DHAllToAll(inputs):
|
||||
output = torch.empty_like(inputs)
|
||||
length = inputs.shape[0]
|
||||
slice_size = length // gpc.get_world_size(ParallelMode.EXPERT)
|
||||
ngpus = 8 # TODO: should set by user
|
||||
nnodes = gpc.get_world_size(ParallelMode.EXPERT) // ngpus
|
||||
|
||||
# phase 0. per-gpu (ngpus) stride copy
|
||||
width = nnodes
|
||||
height = ngpus
|
||||
for i in range(length):
|
||||
index = i / slice_size
|
||||
offset = i % slice_size
|
||||
j = int((width * (index % height) + (index / height)) * slice_size + offset)
|
||||
output[j] = inputs[i]
|
||||
# print("after intra swap from rank ", gpc.get_global_rank(), " : ", output, flush=True)
|
||||
|
||||
# phase 1. intra-node alltoall
|
||||
reqs = []
|
||||
node_rank = int(gpc.get_local_rank(ParallelMode.EXPERT) / ngpus)
|
||||
for i in range(ngpus):
|
||||
reqs.append(
|
||||
dist.P2POp(
|
||||
dist.isend, output[i * nnodes * slice_size : (i + 1) * nnodes * slice_size], i + node_rank * ngpus
|
||||
)
|
||||
)
|
||||
reqs.append(
|
||||
dist.P2POp(
|
||||
dist.irecv, inputs[i * nnodes * slice_size : (i + 1) * nnodes * slice_size], i + node_rank * ngpus
|
||||
)
|
||||
)
|
||||
|
||||
if len(reqs) > 0:
|
||||
reqs = dist.batch_isend_irecv(reqs)
|
||||
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
# print("after intra communication from rank ", gpc.get_global_rank(), " : ", inputs, flush=True)
|
||||
|
||||
# phase 2. per-gpu (nnodes) stride copy
|
||||
width = ngpus
|
||||
height = nnodes
|
||||
for i in range(length):
|
||||
index = i / slice_size
|
||||
offset = i % slice_size
|
||||
j = int((width * (index % height) + (index / height)) * slice_size + offset)
|
||||
output[j] = inputs[i]
|
||||
# print("after inter swap from rank ", gpc.get_global_rank(), " : ", output, flush=True)
|
||||
|
||||
# phase 3. inter-node alltoall
|
||||
reqs = []
|
||||
node_rank = int(gpc.get_local_rank(ParallelMode.EXPERT) / ngpus)
|
||||
g_local_rank = int(gpc.get_local_rank(ParallelMode.EXPERT) % ngpus)
|
||||
for i in range(nnodes):
|
||||
reqs.append(
|
||||
dist.P2POp(
|
||||
dist.isend, output[i * ngpus * slice_size : (i + 1) * ngpus * slice_size], i * ngpus + g_local_rank
|
||||
)
|
||||
)
|
||||
reqs.append(
|
||||
dist.P2POp(
|
||||
dist.irecv, inputs[i * ngpus * slice_size : (i + 1) * ngpus * slice_size], i * ngpus + g_local_rank
|
||||
)
|
||||
)
|
||||
|
||||
if len(reqs) > 0:
|
||||
reqs = dist.batch_isend_irecv(reqs)
|
||||
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
# print("after inter communication from rank ", gpc.get_global_rank(), " : ", inputs, flush=True)
|
||||
|
||||
return inputs
|
|
@ -0,0 +1,72 @@
|
|||
import torch
|
||||
|
||||
from .communication import moe_all_to_all, moe_stream_acquire, moe_stream_release
|
||||
|
||||
|
||||
def no_overlap_moe_forward(inputs, expert_fn, ep_group, ep_size, num_local_experts, d_model):
|
||||
"""
|
||||
Preform moe forward computation sequentially.
|
||||
For example:
|
||||
alltoall(d)---->expert_fn(d)--->alltoall(d)
|
||||
"""
|
||||
|
||||
inputs = moe_all_to_all.apply(ep_group, inputs)
|
||||
|
||||
# Re-shape after all-to-all: ecm -> gecm
|
||||
inputs = inputs.reshape(ep_size, num_local_experts, -1, d_model)
|
||||
expert_output = expert_fn(inputs)
|
||||
|
||||
expert_output = moe_all_to_all.apply(ep_group, expert_output)
|
||||
|
||||
return expert_output
|
||||
|
||||
|
||||
def overlap_moe_forward(inputs, expert_fn, a2a_ffn_overlap_degree, ep_group, ep_size, num_local_experts, d_model):
|
||||
"""
|
||||
Split the input based on a2a_ffn_overlap_degree and then execute the alltoall and experts function
|
||||
on different stream to overlap the communication and computation cost.
|
||||
For example:
|
||||
communication stream: alltoall(d[0])---->alltoall(d[1])---->alltoall(d[0])---->alltoall(d[1])
|
||||
computation stream: expert_fn(d[0]) ----> expert_fn(d[1])
|
||||
|
||||
"""
|
||||
|
||||
# inputs shape: (e,c,m). split the inputs on 'c' dimension
|
||||
input_chunks = inputs.chunk(a2a_ffn_overlap_degree, dim=1)
|
||||
|
||||
expert_inputs = [None for _ in range(a2a_ffn_overlap_degree)]
|
||||
expert_outputs = [None for _ in range(a2a_ffn_overlap_degree)]
|
||||
|
||||
ready_events = [torch.cuda.Event() for _ in range(a2a_ffn_overlap_degree)]
|
||||
alltoall_stream = [torch.cuda.Stream(torch.cuda.current_device()) for _ in range(a2a_ffn_overlap_degree)]
|
||||
experts_stream = [torch.cuda.Stream(torch.cuda.current_device()) for _ in range(a2a_ffn_overlap_degree)]
|
||||
|
||||
# NOTE: async alltoall seems unable to improve the performance
|
||||
# first all2all, execute on alltoall streams
|
||||
for i, input_split in enumerate(input_chunks):
|
||||
moe_stream_release.apply(torch.cuda.default_stream(), ready_events[i])
|
||||
|
||||
moe_stream_acquire.apply(alltoall_stream[i], ready_events[i])
|
||||
expert_inputs[i] = moe_all_to_all.apply(ep_group, input_split)
|
||||
moe_stream_release.apply(alltoall_stream[i], ready_events[i])
|
||||
|
||||
# expert function, execute on experts stream
|
||||
for i in range(a2a_ffn_overlap_degree):
|
||||
moe_stream_acquire.apply(experts_stream[i], ready_events[i])
|
||||
# Re-shape after all-to-all: ecm -> gecm
|
||||
expert_inputs[i] = expert_inputs[i].reshape(ep_size, num_local_experts, -1, d_model)
|
||||
expert_outputs[i] = expert_fn(expert_inputs[i])
|
||||
moe_stream_release.apply(experts_stream[i], ready_events[i])
|
||||
|
||||
# second all2all, execute on alltoall streams
|
||||
for i in range(a2a_ffn_overlap_degree):
|
||||
moe_stream_acquire.apply(alltoall_stream[i], ready_events[i])
|
||||
expert_outputs[i] = moe_all_to_all.apply(ep_group, expert_outputs[i])
|
||||
moe_stream_release.apply(alltoall_stream[i], ready_events[i])
|
||||
|
||||
moe_stream_acquire.apply(torch.cuda.default_stream(), ready_events[i])
|
||||
|
||||
# expert_outputs shape: (g, e,c,m). cat the outputs on 'c' dimension
|
||||
expert_output_gathered = torch.cat(expert_outputs, dim=2)
|
||||
|
||||
return expert_output_gathered
|
|
@ -4,7 +4,7 @@ https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/moe/experts.py
|
|||
Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555
|
||||
We retain the following license from the original files:
|
||||
"""
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -12,8 +12,17 @@ import torch.nn.functional as F
|
|||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
||||
from internlm.core.context.parallel_context import global_context as gpc
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
|
||||
from .forward_func import no_overlap_moe_forward, overlap_moe_forward
|
||||
|
||||
try:
|
||||
from tutel.impls.overlap import a2a_ffn_overlap_forward as tutel_overlap_moe_forward
|
||||
|
||||
TUTEL_INSTALLED = True
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
TUTEL_INSTALLED = False
|
||||
|
||||
# global llm logger
|
||||
logger = get_logger(__file__)
|
||||
|
@ -62,30 +71,6 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
|
|||
return gumbel(shape)
|
||||
|
||||
|
||||
# Based on https://github.com/pytorch/pytorch/pull/40762
|
||||
class _AllToAll(torch.autograd.Function):
|
||||
"""
|
||||
All to all communication
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: Any,
|
||||
# TODO: replace with DS process group
|
||||
group: torch.distributed.ProcessGroup,
|
||||
inputs: Tensor,
|
||||
) -> Tensor: # type: ignore
|
||||
ctx.group = group
|
||||
inputs = inputs.contiguous()
|
||||
output = torch.empty_like(inputs)
|
||||
dist.all_to_all_single(output, inputs, group=group)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
|
||||
return (None, _AllToAll.apply(ctx.group, *grad_output))
|
||||
|
||||
|
||||
# einsum rewrites are on par or more performant
|
||||
# switch can be bubbled up in future
|
||||
USE_EINSUM = True
|
||||
|
@ -347,18 +332,12 @@ class TopKGate(Module):
|
|||
self.eval_capacity_factor = eval_capacity_factor
|
||||
self.min_capacity = min_capacity
|
||||
self.noisy_gate_policy = noisy_gate_policy
|
||||
self.wall_clock_breakdown = False
|
||||
self.gate_time = 0.0
|
||||
self.drop_tokens = drop_tokens
|
||||
self.use_rts = use_rts
|
||||
|
||||
def forward(
|
||||
self, inputs: torch.Tensor, used_token: torch.Tensor = None
|
||||
) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore
|
||||
|
||||
if self.wall_clock_breakdown:
|
||||
timer("TopKGate").start()
|
||||
|
||||
# input jittering
|
||||
if self.noisy_gate_policy == "Jitter" and self.training:
|
||||
inputs = multiplicative_jitter(inputs, device=inputs.device)
|
||||
|
@ -380,10 +359,6 @@ class TopKGate(Module):
|
|||
logits, self.capacity_factor if self.training else self.eval_capacity_factor, self.min_capacity
|
||||
)
|
||||
|
||||
if self.wall_clock_breakdown:
|
||||
timer("TopKGate").stop()
|
||||
self.gate_time = timer("TopKGate").elapsed(reset=False)
|
||||
|
||||
return gate_output
|
||||
|
||||
|
||||
|
@ -412,16 +387,13 @@ class MOELayer(Base):
|
|||
self.ep_group = ep_group
|
||||
self.ep_size = ep_size
|
||||
self.num_local_experts = num_local_experts
|
||||
self.time_falltoall = 0.0
|
||||
self.time_salltoall = 0.0
|
||||
self.time_moe = 0.0
|
||||
self.wall_clock_breakdown = False
|
||||
self.use_tutel = gpc.config.model.use_tutel and TUTEL_INSTALLED
|
||||
self.overlap_degree = gpc.config.model.moe_overlap_degree
|
||||
|
||||
# TODO tutel does not reshape inputs for each expert, so its logic will be different with current experts.py
|
||||
assert (not self.use_tutel) or self.num_local_experts == 1, "only support num_local_experts=1 when enable tutel"
|
||||
|
||||
def forward(self, *inputs: Tensor) -> Tensor:
|
||||
|
||||
if self.wall_clock_breakdown:
|
||||
timer("moe").start()
|
||||
|
||||
# Implement Algorithm 2 from GShard paper.
|
||||
d_model = inputs[0].shape[-1]
|
||||
|
||||
|
@ -435,38 +407,32 @@ class MOELayer(Base):
|
|||
"sec,sm->ecm", dispatch_mask.type_as(inputs[0]), reshaped_inputs
|
||||
) # TODO: heavy memory usage due to long sequence length
|
||||
|
||||
if self.wall_clock_breakdown:
|
||||
timer("falltoall").start()
|
||||
|
||||
dispatched_inputs = _AllToAll.apply(self.ep_group, dispatched_inputs)
|
||||
|
||||
if self.wall_clock_breakdown:
|
||||
timer("falltoall").stop()
|
||||
self.time_falltoall = timer("falltoall").elapsed(reset=False)
|
||||
|
||||
# Re-shape after all-to-all: ecm -> gecm
|
||||
dispatched_inputs = dispatched_inputs.reshape(self.ep_size, self.num_local_experts, -1, d_model)
|
||||
|
||||
expert_output = self.experts(dispatched_inputs)
|
||||
|
||||
if self.wall_clock_breakdown:
|
||||
timer("salltoall").start()
|
||||
|
||||
expert_output = _AllToAll.apply(self.ep_group, expert_output)
|
||||
|
||||
if self.wall_clock_breakdown:
|
||||
timer("salltoall").stop()
|
||||
self.time_salltoall = timer("salltoall").elapsed(reset=False)
|
||||
if self.overlap_degree == 1:
|
||||
expert_output = no_overlap_moe_forward(
|
||||
dispatched_inputs, self.experts, self.ep_group, self.ep_size, self.num_local_experts, d_model
|
||||
)
|
||||
elif self.overlap_degree > 1 and not self.use_tutel:
|
||||
expert_output = overlap_moe_forward(
|
||||
dispatched_inputs,
|
||||
self.experts,
|
||||
self.overlap_degree,
|
||||
self.ep_group,
|
||||
self.ep_size,
|
||||
self.num_local_experts,
|
||||
d_model,
|
||||
)
|
||||
elif self.overlap_degree > 1 and self.use_tutel:
|
||||
expert_output = tutel_overlap_moe_forward(
|
||||
dispatched_inputs, self.experts, self.overlap_degree, True, self.ep_group
|
||||
)
|
||||
else:
|
||||
assert False, "unsupported moe forward strategy"
|
||||
|
||||
# Re-shape back: gecm -> ecm
|
||||
expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)
|
||||
|
||||
combined_output = einsum("sec,ecm->sm", combine_weights.type_as(inputs[0]), expert_output)
|
||||
combined_output = einsum("sec,ecm->sm", combine_weights.type_as(inputs[0]), expert_output.type_as(inputs[0]))
|
||||
|
||||
out = combined_output.reshape(inputs[0].shape)
|
||||
|
||||
if self.wall_clock_breakdown:
|
||||
timer("moe").stop()
|
||||
self.time_moe = timer("moe").elapsed(reset=False)
|
||||
|
||||
return out
|
||||
|
|
Loading…
Reference in New Issue