mirror of https://github.com/InternLM/InternLM
overlap gating further
parent
d74ad7cca7
commit
66d6efd004
|
@ -2,6 +2,48 @@ import torch
|
|||
|
||||
from .communication import moe_all_to_all, moe_stream_acquire, moe_stream_release
|
||||
|
||||
# einsum rewrites are on par or more performant
|
||||
# switch can be bubbled up in future
|
||||
USE_EINSUM = True
|
||||
|
||||
|
||||
# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
|
||||
# See https://arxiv.org/pdf/2006.16668.pdf for details.
|
||||
def einsum(rule, a, b):
|
||||
if USE_EINSUM:
|
||||
return torch.einsum(rule, a, b)
|
||||
elif rule == "s,se->se":
|
||||
# [1, s] * [s, e]
|
||||
return a.reshape(a.shape[0], -1) * b
|
||||
elif rule == "se,sc->sec":
|
||||
# [s,e,1] * [s,1,c]
|
||||
return a.unsqueeze(2) * b.unsqueeze(1)
|
||||
elif rule == "se,se->s":
|
||||
# [s,1,e] * [s,e,1]
|
||||
return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1)
|
||||
elif rule == "sec,sm->ecm":
|
||||
# [e*c, s] * [s, m]
|
||||
s = a.shape[0]
|
||||
e = a.shape[1]
|
||||
c = a.shape[2]
|
||||
m = b.shape[1]
|
||||
return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m)
|
||||
elif rule == "sec,ecm->sm":
|
||||
# [s, e*c] * [e*c, m]
|
||||
return torch.matmul(a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1]))
|
||||
elif rule == "ks,ksm->sm":
|
||||
k = b.shape[0]
|
||||
s = b.shape[1]
|
||||
m = b.shape[2]
|
||||
# [k, s] -> [s, k] -> [s, 1, k]
|
||||
a = a.t().unsqueeze(1)
|
||||
# [k,s,m] -> [k, sm] -> [sm, k] -> [s, m, k]
|
||||
b = b.reshape(k, -1).t().reshape(s, m, k)
|
||||
# bmm([s, 1, k], [s, m, k]^t) -> [s, m, 1]
|
||||
return torch.bmm(a, b.transpose(1, 2)).squeeze(2)
|
||||
else:
|
||||
return torch.einsum(rule, a, b)
|
||||
|
||||
|
||||
def no_overlap_moe_forward(inputs, expert_fn, ep_group, ep_size, num_local_experts, d_model):
|
||||
"""
|
||||
|
@ -21,7 +63,9 @@ def no_overlap_moe_forward(inputs, expert_fn, ep_group, ep_size, num_local_exper
|
|||
return expert_output
|
||||
|
||||
|
||||
def overlap_moe_forward(inputs, expert_fn, a2a_ffn_overlap_degree, ep_group, ep_size, num_local_experts, d_model):
|
||||
def overlap_moe_forward(
|
||||
reshaped_inputs, gata_fn, expert_fn, a2a_ffn_overlap_degree, ep_group, ep_size, num_local_experts, d_model
|
||||
):
|
||||
"""
|
||||
Split the input based on a2a_ffn_overlap_degree and then execute the alltoall and experts function
|
||||
on different stream to overlap the communication and computation cost.
|
||||
|
@ -31,23 +75,38 @@ def overlap_moe_forward(inputs, expert_fn, a2a_ffn_overlap_degree, ep_group, ep_
|
|||
|
||||
"""
|
||||
|
||||
# inputs shape: (e,c,m). split the inputs on 'c' dimension
|
||||
input_chunks = inputs.chunk(a2a_ffn_overlap_degree, dim=1)
|
||||
|
||||
expert_inputs = [None for _ in range(a2a_ffn_overlap_degree)]
|
||||
expert_outputs = [None for _ in range(a2a_ffn_overlap_degree)]
|
||||
|
||||
# variables for stream control
|
||||
ready_events = [torch.cuda.Event() for _ in range(a2a_ffn_overlap_degree)]
|
||||
alltoall_stream = [torch.cuda.Stream(torch.cuda.current_device()) for _ in range(a2a_ffn_overlap_degree)]
|
||||
experts_stream = [torch.cuda.Stream(torch.cuda.current_device()) for _ in range(a2a_ffn_overlap_degree)]
|
||||
|
||||
# NOTE: async alltoall seems unable to improve the performance
|
||||
# first all2all, execute on alltoall streams
|
||||
# local variables for gating and expert computing
|
||||
l_aux = torch.tensor(0.0, dtype=reshaped_inputs.dtype, device=reshaped_inputs.device)
|
||||
dispatched_inputs = [None for _ in range(a2a_ffn_overlap_degree)]
|
||||
expert_inputs = [None for _ in range(a2a_ffn_overlap_degree)]
|
||||
expert_outputs = [None for _ in range(a2a_ffn_overlap_degree)]
|
||||
combine_weights = [None for _ in range(a2a_ffn_overlap_degree)]
|
||||
combined_output = [None for _ in range(a2a_ffn_overlap_degree)]
|
||||
|
||||
# (s,d), split by "s" dimension
|
||||
input_chunks = reshaped_inputs.chunk(a2a_ffn_overlap_degree, dim=0)
|
||||
|
||||
# gating computing
|
||||
for i, input_split in enumerate(input_chunks):
|
||||
moe_stream_release.apply(torch.cuda.default_stream(), ready_events[i])
|
||||
|
||||
moe_stream_acquire.apply(experts_stream[i], ready_events[i])
|
||||
cur_l_aux, combine_weights[i], dispatch_mask, exp_counts = gata_fn(input_split)
|
||||
dispatched_inputs[i] = einsum(
|
||||
"sec,sm->ecm", dispatch_mask.type_as(input_split), input_split
|
||||
) # TODO: heavy memory usage due to long sequence length
|
||||
l_aux += cur_l_aux
|
||||
moe_stream_release.apply(experts_stream[i], ready_events[i])
|
||||
|
||||
# NOTE: async alltoall seems unable to improve the performance
|
||||
# first all2all, execute on alltoall streams
|
||||
moe_stream_acquire.apply(alltoall_stream[i], ready_events[i])
|
||||
expert_inputs[i] = moe_all_to_all.apply(ep_group, input_split)
|
||||
expert_inputs[i] = moe_all_to_all.apply(ep_group, dispatched_inputs[i])
|
||||
moe_stream_release.apply(alltoall_stream[i], ready_events[i])
|
||||
|
||||
# expert function, execute on experts stream
|
||||
|
@ -64,9 +123,16 @@ def overlap_moe_forward(inputs, expert_fn, a2a_ffn_overlap_degree, ep_group, ep_
|
|||
expert_outputs[i] = moe_all_to_all.apply(ep_group, expert_outputs[i])
|
||||
moe_stream_release.apply(alltoall_stream[i], ready_events[i])
|
||||
|
||||
for i in range(a2a_ffn_overlap_degree):
|
||||
moe_stream_acquire.apply(experts_stream[i], ready_events[i])
|
||||
# Re-shape back: gecm -> ecm
|
||||
expert_outputs[i] = expert_outputs[i].reshape(ep_size * num_local_experts, -1, d_model)
|
||||
combined_output[i] = einsum(
|
||||
"sec,ecm->sm", combine_weights[i].type_as(input_chunks[0]), expert_outputs[i].type_as(input_chunks[0])
|
||||
)
|
||||
moe_stream_release.apply(experts_stream[i], ready_events[i])
|
||||
|
||||
moe_stream_acquire.apply(torch.cuda.default_stream(), ready_events[i])
|
||||
|
||||
# expert_outputs shape: (g, e,c,m). cat the outputs on 'c' dimension
|
||||
expert_output_gathered = torch.cat(expert_outputs, dim=2)
|
||||
|
||||
return expert_output_gathered
|
||||
combined_output = torch.cat(combined_output)
|
||||
return combined_output, l_aux / a2a_ffn_overlap_degree, exp_counts
|
||||
|
|
|
@ -16,7 +16,7 @@ from internlm.core.context.parallel_context import ParallelMode
|
|||
from internlm.core.context.parallel_context import global_context as gpc
|
||||
from internlm.utils.logger import get_logger
|
||||
|
||||
from .forward_func import no_overlap_moe_forward, overlap_moe_forward
|
||||
from .forward_func import einsum, no_overlap_moe_forward, overlap_moe_forward
|
||||
|
||||
try:
|
||||
from tutel.impls.overlap import a2a_ffn_overlap_forward as tutel_overlap_moe_forward
|
||||
|
@ -72,49 +72,6 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
|
|||
return gumbel(shape)
|
||||
|
||||
|
||||
# einsum rewrites are on par or more performant
|
||||
# switch can be bubbled up in future
|
||||
USE_EINSUM = True
|
||||
|
||||
|
||||
# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
|
||||
# See https://arxiv.org/pdf/2006.16668.pdf for details.
|
||||
def einsum(rule, a, b):
|
||||
if USE_EINSUM:
|
||||
return torch.einsum(rule, a, b)
|
||||
elif rule == "s,se->se":
|
||||
# [1, s] * [s, e]
|
||||
return a.reshape(a.shape[0], -1) * b
|
||||
elif rule == "se,sc->sec":
|
||||
# [s,e,1] * [s,1,c]
|
||||
return a.unsqueeze(2) * b.unsqueeze(1)
|
||||
elif rule == "se,se->s":
|
||||
# [s,1,e] * [s,e,1]
|
||||
return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1)
|
||||
elif rule == "sec,sm->ecm":
|
||||
# [e*c, s] * [s, m]
|
||||
s = a.shape[0]
|
||||
e = a.shape[1]
|
||||
c = a.shape[2]
|
||||
m = b.shape[1]
|
||||
return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m)
|
||||
elif rule == "sec,ecm->sm":
|
||||
# [s, e*c] * [e*c, m]
|
||||
return torch.matmul(a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1]))
|
||||
elif rule == "ks,ksm->sm":
|
||||
k = b.shape[0]
|
||||
s = b.shape[1]
|
||||
m = b.shape[2]
|
||||
# [k, s] -> [s, k] -> [s, 1, k]
|
||||
a = a.t().unsqueeze(1)
|
||||
# [k,s,m] -> [k, sm] -> [sm, k] -> [s, m, k]
|
||||
b = b.reshape(k, -1).t().reshape(s, m, k)
|
||||
# bmm([s, 1, k], [s, m, k]^t) -> [s, m, 1]
|
||||
return torch.bmm(a, b.transpose(1, 2)).squeeze(2)
|
||||
else:
|
||||
return torch.einsum(rule, a, b)
|
||||
|
||||
|
||||
# The following functions are extracted and scripted
|
||||
# because otherwise during a torch.jit.trace, the non-Tensor
|
||||
# values used in the calculations get recorded as constants.
|
||||
|
@ -405,6 +362,18 @@ class MOELayer(Base):
|
|||
# group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1
|
||||
reshaped_inputs = inputs[0].reshape(-1, d_model)
|
||||
|
||||
if self.overlap_degree > 1 and not self.use_tutel:
|
||||
combined_output, self.l_aux, self.exp_counts = overlap_moe_forward(
|
||||
reshaped_inputs,
|
||||
self.gate,
|
||||
self.experts,
|
||||
self.overlap_degree,
|
||||
self.ep_group,
|
||||
self.ep_size,
|
||||
self.num_local_experts,
|
||||
d_model,
|
||||
)
|
||||
else:
|
||||
self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_inputs, inputs[1])
|
||||
dispatched_inputs = einsum(
|
||||
"sec,sm->ecm", dispatch_mask.type_as(inputs[0]), reshaped_inputs
|
||||
|
@ -414,16 +383,6 @@ class MOELayer(Base):
|
|||
expert_output = no_overlap_moe_forward(
|
||||
dispatched_inputs, self.experts, self.ep_group, self.ep_size, self.num_local_experts, d_model
|
||||
)
|
||||
elif self.overlap_degree > 1 and not self.use_tutel:
|
||||
expert_output = overlap_moe_forward(
|
||||
dispatched_inputs,
|
||||
self.experts,
|
||||
self.overlap_degree,
|
||||
self.ep_group,
|
||||
self.ep_size,
|
||||
self.num_local_experts,
|
||||
d_model,
|
||||
)
|
||||
elif self.overlap_degree > 1 and self.use_tutel:
|
||||
expert_output = tutel_overlap_moe_forward(
|
||||
dispatched_inputs, self.experts, self.overlap_degree, True, self.ep_group
|
||||
|
@ -434,7 +393,9 @@ class MOELayer(Base):
|
|||
# Re-shape back: gecm -> ecm
|
||||
expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)
|
||||
|
||||
combined_output = einsum("sec,ecm->sm", combine_weights.type_as(inputs[0]), expert_output.type_as(inputs[0]))
|
||||
combined_output = einsum(
|
||||
"sec,ecm->sm", combine_weights.type_as(inputs[0]), expert_output.type_as(inputs[0])
|
||||
)
|
||||
|
||||
out = combined_output.reshape(inputs[0].shape)
|
||||
|
||||
|
|
Loading…
Reference in New Issue