mirror of https://github.com/InternLM/InternLM
overlap gating further
parent
d74ad7cca7
commit
66d6efd004
|
@ -2,6 +2,48 @@ import torch
|
||||||
|
|
||||||
from .communication import moe_all_to_all, moe_stream_acquire, moe_stream_release
|
from .communication import moe_all_to_all, moe_stream_acquire, moe_stream_release
|
||||||
|
|
||||||
|
# einsum rewrites are on par or more performant
|
||||||
|
# switch can be bubbled up in future
|
||||||
|
USE_EINSUM = True
|
||||||
|
|
||||||
|
|
||||||
|
# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
|
||||||
|
# See https://arxiv.org/pdf/2006.16668.pdf for details.
|
||||||
|
def einsum(rule, a, b):
|
||||||
|
if USE_EINSUM:
|
||||||
|
return torch.einsum(rule, a, b)
|
||||||
|
elif rule == "s,se->se":
|
||||||
|
# [1, s] * [s, e]
|
||||||
|
return a.reshape(a.shape[0], -1) * b
|
||||||
|
elif rule == "se,sc->sec":
|
||||||
|
# [s,e,1] * [s,1,c]
|
||||||
|
return a.unsqueeze(2) * b.unsqueeze(1)
|
||||||
|
elif rule == "se,se->s":
|
||||||
|
# [s,1,e] * [s,e,1]
|
||||||
|
return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1)
|
||||||
|
elif rule == "sec,sm->ecm":
|
||||||
|
# [e*c, s] * [s, m]
|
||||||
|
s = a.shape[0]
|
||||||
|
e = a.shape[1]
|
||||||
|
c = a.shape[2]
|
||||||
|
m = b.shape[1]
|
||||||
|
return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m)
|
||||||
|
elif rule == "sec,ecm->sm":
|
||||||
|
# [s, e*c] * [e*c, m]
|
||||||
|
return torch.matmul(a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1]))
|
||||||
|
elif rule == "ks,ksm->sm":
|
||||||
|
k = b.shape[0]
|
||||||
|
s = b.shape[1]
|
||||||
|
m = b.shape[2]
|
||||||
|
# [k, s] -> [s, k] -> [s, 1, k]
|
||||||
|
a = a.t().unsqueeze(1)
|
||||||
|
# [k,s,m] -> [k, sm] -> [sm, k] -> [s, m, k]
|
||||||
|
b = b.reshape(k, -1).t().reshape(s, m, k)
|
||||||
|
# bmm([s, 1, k], [s, m, k]^t) -> [s, m, 1]
|
||||||
|
return torch.bmm(a, b.transpose(1, 2)).squeeze(2)
|
||||||
|
else:
|
||||||
|
return torch.einsum(rule, a, b)
|
||||||
|
|
||||||
|
|
||||||
def no_overlap_moe_forward(inputs, expert_fn, ep_group, ep_size, num_local_experts, d_model):
|
def no_overlap_moe_forward(inputs, expert_fn, ep_group, ep_size, num_local_experts, d_model):
|
||||||
"""
|
"""
|
||||||
|
@ -21,7 +63,9 @@ def no_overlap_moe_forward(inputs, expert_fn, ep_group, ep_size, num_local_exper
|
||||||
return expert_output
|
return expert_output
|
||||||
|
|
||||||
|
|
||||||
def overlap_moe_forward(inputs, expert_fn, a2a_ffn_overlap_degree, ep_group, ep_size, num_local_experts, d_model):
|
def overlap_moe_forward(
|
||||||
|
reshaped_inputs, gata_fn, expert_fn, a2a_ffn_overlap_degree, ep_group, ep_size, num_local_experts, d_model
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Split the input based on a2a_ffn_overlap_degree and then execute the alltoall and experts function
|
Split the input based on a2a_ffn_overlap_degree and then execute the alltoall and experts function
|
||||||
on different stream to overlap the communication and computation cost.
|
on different stream to overlap the communication and computation cost.
|
||||||
|
@ -31,23 +75,38 @@ def overlap_moe_forward(inputs, expert_fn, a2a_ffn_overlap_degree, ep_group, ep_
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# inputs shape: (e,c,m). split the inputs on 'c' dimension
|
# variables for stream control
|
||||||
input_chunks = inputs.chunk(a2a_ffn_overlap_degree, dim=1)
|
|
||||||
|
|
||||||
expert_inputs = [None for _ in range(a2a_ffn_overlap_degree)]
|
|
||||||
expert_outputs = [None for _ in range(a2a_ffn_overlap_degree)]
|
|
||||||
|
|
||||||
ready_events = [torch.cuda.Event() for _ in range(a2a_ffn_overlap_degree)]
|
ready_events = [torch.cuda.Event() for _ in range(a2a_ffn_overlap_degree)]
|
||||||
alltoall_stream = [torch.cuda.Stream(torch.cuda.current_device()) for _ in range(a2a_ffn_overlap_degree)]
|
alltoall_stream = [torch.cuda.Stream(torch.cuda.current_device()) for _ in range(a2a_ffn_overlap_degree)]
|
||||||
experts_stream = [torch.cuda.Stream(torch.cuda.current_device()) for _ in range(a2a_ffn_overlap_degree)]
|
experts_stream = [torch.cuda.Stream(torch.cuda.current_device()) for _ in range(a2a_ffn_overlap_degree)]
|
||||||
|
|
||||||
# NOTE: async alltoall seems unable to improve the performance
|
# local variables for gating and expert computing
|
||||||
# first all2all, execute on alltoall streams
|
l_aux = torch.tensor(0.0, dtype=reshaped_inputs.dtype, device=reshaped_inputs.device)
|
||||||
|
dispatched_inputs = [None for _ in range(a2a_ffn_overlap_degree)]
|
||||||
|
expert_inputs = [None for _ in range(a2a_ffn_overlap_degree)]
|
||||||
|
expert_outputs = [None for _ in range(a2a_ffn_overlap_degree)]
|
||||||
|
combine_weights = [None for _ in range(a2a_ffn_overlap_degree)]
|
||||||
|
combined_output = [None for _ in range(a2a_ffn_overlap_degree)]
|
||||||
|
|
||||||
|
# (s,d), split by "s" dimension
|
||||||
|
input_chunks = reshaped_inputs.chunk(a2a_ffn_overlap_degree, dim=0)
|
||||||
|
|
||||||
|
# gating computing
|
||||||
for i, input_split in enumerate(input_chunks):
|
for i, input_split in enumerate(input_chunks):
|
||||||
moe_stream_release.apply(torch.cuda.default_stream(), ready_events[i])
|
moe_stream_release.apply(torch.cuda.default_stream(), ready_events[i])
|
||||||
|
|
||||||
|
moe_stream_acquire.apply(experts_stream[i], ready_events[i])
|
||||||
|
cur_l_aux, combine_weights[i], dispatch_mask, exp_counts = gata_fn(input_split)
|
||||||
|
dispatched_inputs[i] = einsum(
|
||||||
|
"sec,sm->ecm", dispatch_mask.type_as(input_split), input_split
|
||||||
|
) # TODO: heavy memory usage due to long sequence length
|
||||||
|
l_aux += cur_l_aux
|
||||||
|
moe_stream_release.apply(experts_stream[i], ready_events[i])
|
||||||
|
|
||||||
|
# NOTE: async alltoall seems unable to improve the performance
|
||||||
|
# first all2all, execute on alltoall streams
|
||||||
moe_stream_acquire.apply(alltoall_stream[i], ready_events[i])
|
moe_stream_acquire.apply(alltoall_stream[i], ready_events[i])
|
||||||
expert_inputs[i] = moe_all_to_all.apply(ep_group, input_split)
|
expert_inputs[i] = moe_all_to_all.apply(ep_group, dispatched_inputs[i])
|
||||||
moe_stream_release.apply(alltoall_stream[i], ready_events[i])
|
moe_stream_release.apply(alltoall_stream[i], ready_events[i])
|
||||||
|
|
||||||
# expert function, execute on experts stream
|
# expert function, execute on experts stream
|
||||||
|
@ -64,9 +123,16 @@ def overlap_moe_forward(inputs, expert_fn, a2a_ffn_overlap_degree, ep_group, ep_
|
||||||
expert_outputs[i] = moe_all_to_all.apply(ep_group, expert_outputs[i])
|
expert_outputs[i] = moe_all_to_all.apply(ep_group, expert_outputs[i])
|
||||||
moe_stream_release.apply(alltoall_stream[i], ready_events[i])
|
moe_stream_release.apply(alltoall_stream[i], ready_events[i])
|
||||||
|
|
||||||
|
for i in range(a2a_ffn_overlap_degree):
|
||||||
|
moe_stream_acquire.apply(experts_stream[i], ready_events[i])
|
||||||
|
# Re-shape back: gecm -> ecm
|
||||||
|
expert_outputs[i] = expert_outputs[i].reshape(ep_size * num_local_experts, -1, d_model)
|
||||||
|
combined_output[i] = einsum(
|
||||||
|
"sec,ecm->sm", combine_weights[i].type_as(input_chunks[0]), expert_outputs[i].type_as(input_chunks[0])
|
||||||
|
)
|
||||||
|
moe_stream_release.apply(experts_stream[i], ready_events[i])
|
||||||
|
|
||||||
moe_stream_acquire.apply(torch.cuda.default_stream(), ready_events[i])
|
moe_stream_acquire.apply(torch.cuda.default_stream(), ready_events[i])
|
||||||
|
|
||||||
# expert_outputs shape: (g, e,c,m). cat the outputs on 'c' dimension
|
combined_output = torch.cat(combined_output)
|
||||||
expert_output_gathered = torch.cat(expert_outputs, dim=2)
|
return combined_output, l_aux / a2a_ffn_overlap_degree, exp_counts
|
||||||
|
|
||||||
return expert_output_gathered
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ from internlm.core.context.parallel_context import ParallelMode
|
||||||
from internlm.core.context.parallel_context import global_context as gpc
|
from internlm.core.context.parallel_context import global_context as gpc
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
|
|
||||||
from .forward_func import no_overlap_moe_forward, overlap_moe_forward
|
from .forward_func import einsum, no_overlap_moe_forward, overlap_moe_forward
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from tutel.impls.overlap import a2a_ffn_overlap_forward as tutel_overlap_moe_forward
|
from tutel.impls.overlap import a2a_ffn_overlap_forward as tutel_overlap_moe_forward
|
||||||
|
@ -72,49 +72,6 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
|
||||||
return gumbel(shape)
|
return gumbel(shape)
|
||||||
|
|
||||||
|
|
||||||
# einsum rewrites are on par or more performant
|
|
||||||
# switch can be bubbled up in future
|
|
||||||
USE_EINSUM = True
|
|
||||||
|
|
||||||
|
|
||||||
# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
|
|
||||||
# See https://arxiv.org/pdf/2006.16668.pdf for details.
|
|
||||||
def einsum(rule, a, b):
|
|
||||||
if USE_EINSUM:
|
|
||||||
return torch.einsum(rule, a, b)
|
|
||||||
elif rule == "s,se->se":
|
|
||||||
# [1, s] * [s, e]
|
|
||||||
return a.reshape(a.shape[0], -1) * b
|
|
||||||
elif rule == "se,sc->sec":
|
|
||||||
# [s,e,1] * [s,1,c]
|
|
||||||
return a.unsqueeze(2) * b.unsqueeze(1)
|
|
||||||
elif rule == "se,se->s":
|
|
||||||
# [s,1,e] * [s,e,1]
|
|
||||||
return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1)
|
|
||||||
elif rule == "sec,sm->ecm":
|
|
||||||
# [e*c, s] * [s, m]
|
|
||||||
s = a.shape[0]
|
|
||||||
e = a.shape[1]
|
|
||||||
c = a.shape[2]
|
|
||||||
m = b.shape[1]
|
|
||||||
return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m)
|
|
||||||
elif rule == "sec,ecm->sm":
|
|
||||||
# [s, e*c] * [e*c, m]
|
|
||||||
return torch.matmul(a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1]))
|
|
||||||
elif rule == "ks,ksm->sm":
|
|
||||||
k = b.shape[0]
|
|
||||||
s = b.shape[1]
|
|
||||||
m = b.shape[2]
|
|
||||||
# [k, s] -> [s, k] -> [s, 1, k]
|
|
||||||
a = a.t().unsqueeze(1)
|
|
||||||
# [k,s,m] -> [k, sm] -> [sm, k] -> [s, m, k]
|
|
||||||
b = b.reshape(k, -1).t().reshape(s, m, k)
|
|
||||||
# bmm([s, 1, k], [s, m, k]^t) -> [s, m, 1]
|
|
||||||
return torch.bmm(a, b.transpose(1, 2)).squeeze(2)
|
|
||||||
else:
|
|
||||||
return torch.einsum(rule, a, b)
|
|
||||||
|
|
||||||
|
|
||||||
# The following functions are extracted and scripted
|
# The following functions are extracted and scripted
|
||||||
# because otherwise during a torch.jit.trace, the non-Tensor
|
# because otherwise during a torch.jit.trace, the non-Tensor
|
||||||
# values used in the calculations get recorded as constants.
|
# values used in the calculations get recorded as constants.
|
||||||
|
@ -405,18 +362,10 @@ class MOELayer(Base):
|
||||||
# group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1
|
# group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1
|
||||||
reshaped_inputs = inputs[0].reshape(-1, d_model)
|
reshaped_inputs = inputs[0].reshape(-1, d_model)
|
||||||
|
|
||||||
self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_inputs, inputs[1])
|
if self.overlap_degree > 1 and not self.use_tutel:
|
||||||
dispatched_inputs = einsum(
|
combined_output, self.l_aux, self.exp_counts = overlap_moe_forward(
|
||||||
"sec,sm->ecm", dispatch_mask.type_as(inputs[0]), reshaped_inputs
|
reshaped_inputs,
|
||||||
) # TODO: heavy memory usage due to long sequence length
|
self.gate,
|
||||||
|
|
||||||
if self.overlap_degree == 1:
|
|
||||||
expert_output = no_overlap_moe_forward(
|
|
||||||
dispatched_inputs, self.experts, self.ep_group, self.ep_size, self.num_local_experts, d_model
|
|
||||||
)
|
|
||||||
elif self.overlap_degree > 1 and not self.use_tutel:
|
|
||||||
expert_output = overlap_moe_forward(
|
|
||||||
dispatched_inputs,
|
|
||||||
self.experts,
|
self.experts,
|
||||||
self.overlap_degree,
|
self.overlap_degree,
|
||||||
self.ep_group,
|
self.ep_group,
|
||||||
|
@ -424,17 +373,29 @@ class MOELayer(Base):
|
||||||
self.num_local_experts,
|
self.num_local_experts,
|
||||||
d_model,
|
d_model,
|
||||||
)
|
)
|
||||||
elif self.overlap_degree > 1 and self.use_tutel:
|
|
||||||
expert_output = tutel_overlap_moe_forward(
|
|
||||||
dispatched_inputs, self.experts, self.overlap_degree, True, self.ep_group
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
assert False, "unsupported moe forward strategy"
|
self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_inputs, inputs[1])
|
||||||
|
dispatched_inputs = einsum(
|
||||||
|
"sec,sm->ecm", dispatch_mask.type_as(inputs[0]), reshaped_inputs
|
||||||
|
) # TODO: heavy memory usage due to long sequence length
|
||||||
|
|
||||||
# Re-shape back: gecm -> ecm
|
if self.overlap_degree == 1:
|
||||||
expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)
|
expert_output = no_overlap_moe_forward(
|
||||||
|
dispatched_inputs, self.experts, self.ep_group, self.ep_size, self.num_local_experts, d_model
|
||||||
|
)
|
||||||
|
elif self.overlap_degree > 1 and self.use_tutel:
|
||||||
|
expert_output = tutel_overlap_moe_forward(
|
||||||
|
dispatched_inputs, self.experts, self.overlap_degree, True, self.ep_group
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert False, "unsupported moe forward strategy"
|
||||||
|
|
||||||
combined_output = einsum("sec,ecm->sm", combine_weights.type_as(inputs[0]), expert_output.type_as(inputs[0]))
|
# Re-shape back: gecm -> ecm
|
||||||
|
expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)
|
||||||
|
|
||||||
|
combined_output = einsum(
|
||||||
|
"sec,ecm->sm", combine_weights.type_as(inputs[0]), expert_output.type_as(inputs[0])
|
||||||
|
)
|
||||||
|
|
||||||
out = combined_output.reshape(inputs[0].shape)
|
out = combined_output.reshape(inputs[0].shape)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue