overlap gating further

pull/506/head
Wenwen Qu 2023-11-23 17:46:32 +08:00 committed by Qu Wenwen
parent d74ad7cca7
commit 66d6efd004
2 changed files with 105 additions and 78 deletions

View File

@ -2,6 +2,48 @@ import torch
from .communication import moe_all_to_all, moe_stream_acquire, moe_stream_release
# einsum rewrites are on par or more performant
# switch can be bubbled up in future
USE_EINSUM = True
# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
# See https://arxiv.org/pdf/2006.16668.pdf for details.
def einsum(rule, a, b):
if USE_EINSUM:
return torch.einsum(rule, a, b)
elif rule == "s,se->se":
# [1, s] * [s, e]
return a.reshape(a.shape[0], -1) * b
elif rule == "se,sc->sec":
# [s,e,1] * [s,1,c]
return a.unsqueeze(2) * b.unsqueeze(1)
elif rule == "se,se->s":
# [s,1,e] * [s,e,1]
return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1)
elif rule == "sec,sm->ecm":
# [e*c, s] * [s, m]
s = a.shape[0]
e = a.shape[1]
c = a.shape[2]
m = b.shape[1]
return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m)
elif rule == "sec,ecm->sm":
# [s, e*c] * [e*c, m]
return torch.matmul(a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1]))
elif rule == "ks,ksm->sm":
k = b.shape[0]
s = b.shape[1]
m = b.shape[2]
# [k, s] -> [s, k] -> [s, 1, k]
a = a.t().unsqueeze(1)
# [k,s,m] -> [k, sm] -> [sm, k] -> [s, m, k]
b = b.reshape(k, -1).t().reshape(s, m, k)
# bmm([s, 1, k], [s, m, k]^t) -> [s, m, 1]
return torch.bmm(a, b.transpose(1, 2)).squeeze(2)
else:
return torch.einsum(rule, a, b)
def no_overlap_moe_forward(inputs, expert_fn, ep_group, ep_size, num_local_experts, d_model):
"""
@ -21,7 +63,9 @@ def no_overlap_moe_forward(inputs, expert_fn, ep_group, ep_size, num_local_exper
return expert_output
def overlap_moe_forward(inputs, expert_fn, a2a_ffn_overlap_degree, ep_group, ep_size, num_local_experts, d_model):
def overlap_moe_forward(
reshaped_inputs, gata_fn, expert_fn, a2a_ffn_overlap_degree, ep_group, ep_size, num_local_experts, d_model
):
"""
Split the input based on a2a_ffn_overlap_degree and then execute the alltoall and experts function
on different stream to overlap the communication and computation cost.
@ -31,23 +75,38 @@ def overlap_moe_forward(inputs, expert_fn, a2a_ffn_overlap_degree, ep_group, ep_
"""
# inputs shape: (e,c,m). split the inputs on 'c' dimension
input_chunks = inputs.chunk(a2a_ffn_overlap_degree, dim=1)
expert_inputs = [None for _ in range(a2a_ffn_overlap_degree)]
expert_outputs = [None for _ in range(a2a_ffn_overlap_degree)]
# variables for stream control
ready_events = [torch.cuda.Event() for _ in range(a2a_ffn_overlap_degree)]
alltoall_stream = [torch.cuda.Stream(torch.cuda.current_device()) for _ in range(a2a_ffn_overlap_degree)]
experts_stream = [torch.cuda.Stream(torch.cuda.current_device()) for _ in range(a2a_ffn_overlap_degree)]
# NOTE: async alltoall seems unable to improve the performance
# first all2all, execute on alltoall streams
# local variables for gating and expert computing
l_aux = torch.tensor(0.0, dtype=reshaped_inputs.dtype, device=reshaped_inputs.device)
dispatched_inputs = [None for _ in range(a2a_ffn_overlap_degree)]
expert_inputs = [None for _ in range(a2a_ffn_overlap_degree)]
expert_outputs = [None for _ in range(a2a_ffn_overlap_degree)]
combine_weights = [None for _ in range(a2a_ffn_overlap_degree)]
combined_output = [None for _ in range(a2a_ffn_overlap_degree)]
# (s,d), split by "s" dimension
input_chunks = reshaped_inputs.chunk(a2a_ffn_overlap_degree, dim=0)
# gating computing
for i, input_split in enumerate(input_chunks):
moe_stream_release.apply(torch.cuda.default_stream(), ready_events[i])
moe_stream_acquire.apply(experts_stream[i], ready_events[i])
cur_l_aux, combine_weights[i], dispatch_mask, exp_counts = gata_fn(input_split)
dispatched_inputs[i] = einsum(
"sec,sm->ecm", dispatch_mask.type_as(input_split), input_split
) # TODO: heavy memory usage due to long sequence length
l_aux += cur_l_aux
moe_stream_release.apply(experts_stream[i], ready_events[i])
# NOTE: async alltoall seems unable to improve the performance
# first all2all, execute on alltoall streams
moe_stream_acquire.apply(alltoall_stream[i], ready_events[i])
expert_inputs[i] = moe_all_to_all.apply(ep_group, input_split)
expert_inputs[i] = moe_all_to_all.apply(ep_group, dispatched_inputs[i])
moe_stream_release.apply(alltoall_stream[i], ready_events[i])
# expert function, execute on experts stream
@ -64,9 +123,16 @@ def overlap_moe_forward(inputs, expert_fn, a2a_ffn_overlap_degree, ep_group, ep_
expert_outputs[i] = moe_all_to_all.apply(ep_group, expert_outputs[i])
moe_stream_release.apply(alltoall_stream[i], ready_events[i])
for i in range(a2a_ffn_overlap_degree):
moe_stream_acquire.apply(experts_stream[i], ready_events[i])
# Re-shape back: gecm -> ecm
expert_outputs[i] = expert_outputs[i].reshape(ep_size * num_local_experts, -1, d_model)
combined_output[i] = einsum(
"sec,ecm->sm", combine_weights[i].type_as(input_chunks[0]), expert_outputs[i].type_as(input_chunks[0])
)
moe_stream_release.apply(experts_stream[i], ready_events[i])
moe_stream_acquire.apply(torch.cuda.default_stream(), ready_events[i])
# expert_outputs shape: (g, e,c,m). cat the outputs on 'c' dimension
expert_output_gathered = torch.cat(expert_outputs, dim=2)
return expert_output_gathered
combined_output = torch.cat(combined_output)
return combined_output, l_aux / a2a_ffn_overlap_degree, exp_counts

View File

@ -16,7 +16,7 @@ from internlm.core.context.parallel_context import ParallelMode
from internlm.core.context.parallel_context import global_context as gpc
from internlm.utils.logger import get_logger
from .forward_func import no_overlap_moe_forward, overlap_moe_forward
from .forward_func import einsum, no_overlap_moe_forward, overlap_moe_forward
try:
from tutel.impls.overlap import a2a_ffn_overlap_forward as tutel_overlap_moe_forward
@ -72,49 +72,6 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
return gumbel(shape)
# einsum rewrites are on par or more performant
# switch can be bubbled up in future
USE_EINSUM = True
# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
# See https://arxiv.org/pdf/2006.16668.pdf for details.
def einsum(rule, a, b):
if USE_EINSUM:
return torch.einsum(rule, a, b)
elif rule == "s,se->se":
# [1, s] * [s, e]
return a.reshape(a.shape[0], -1) * b
elif rule == "se,sc->sec":
# [s,e,1] * [s,1,c]
return a.unsqueeze(2) * b.unsqueeze(1)
elif rule == "se,se->s":
# [s,1,e] * [s,e,1]
return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1)
elif rule == "sec,sm->ecm":
# [e*c, s] * [s, m]
s = a.shape[0]
e = a.shape[1]
c = a.shape[2]
m = b.shape[1]
return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m)
elif rule == "sec,ecm->sm":
# [s, e*c] * [e*c, m]
return torch.matmul(a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1]))
elif rule == "ks,ksm->sm":
k = b.shape[0]
s = b.shape[1]
m = b.shape[2]
# [k, s] -> [s, k] -> [s, 1, k]
a = a.t().unsqueeze(1)
# [k,s,m] -> [k, sm] -> [sm, k] -> [s, m, k]
b = b.reshape(k, -1).t().reshape(s, m, k)
# bmm([s, 1, k], [s, m, k]^t) -> [s, m, 1]
return torch.bmm(a, b.transpose(1, 2)).squeeze(2)
else:
return torch.einsum(rule, a, b)
# The following functions are extracted and scripted
# because otherwise during a torch.jit.trace, the non-Tensor
# values used in the calculations get recorded as constants.
@ -405,6 +362,18 @@ class MOELayer(Base):
# group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1
reshaped_inputs = inputs[0].reshape(-1, d_model)
if self.overlap_degree > 1 and not self.use_tutel:
combined_output, self.l_aux, self.exp_counts = overlap_moe_forward(
reshaped_inputs,
self.gate,
self.experts,
self.overlap_degree,
self.ep_group,
self.ep_size,
self.num_local_experts,
d_model,
)
else:
self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_inputs, inputs[1])
dispatched_inputs = einsum(
"sec,sm->ecm", dispatch_mask.type_as(inputs[0]), reshaped_inputs
@ -414,16 +383,6 @@ class MOELayer(Base):
expert_output = no_overlap_moe_forward(
dispatched_inputs, self.experts, self.ep_group, self.ep_size, self.num_local_experts, d_model
)
elif self.overlap_degree > 1 and not self.use_tutel:
expert_output = overlap_moe_forward(
dispatched_inputs,
self.experts,
self.overlap_degree,
self.ep_group,
self.ep_size,
self.num_local_experts,
d_model,
)
elif self.overlap_degree > 1 and self.use_tutel:
expert_output = tutel_overlap_moe_forward(
dispatched_inputs, self.experts, self.overlap_degree, True, self.ep_group
@ -434,7 +393,9 @@ class MOELayer(Base):
# Re-shape back: gecm -> ecm
expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)
combined_output = einsum("sec,ecm->sm", combine_weights.type_as(inputs[0]), expert_output.type_as(inputs[0]))
combined_output = einsum(
"sec,ecm->sm", combine_weights.type_as(inputs[0]), expert_output.type_as(inputs[0])
)
out = combined_output.reshape(inputs[0].shape)