diff --git a/internlm/moe/forward_func.py b/internlm/moe/forward_func.py index 13e2ad5..b153895 100644 --- a/internlm/moe/forward_func.py +++ b/internlm/moe/forward_func.py @@ -2,6 +2,48 @@ import torch from .communication import moe_all_to_all, moe_stream_acquire, moe_stream_release +# einsum rewrites are on par or more performant +# switch can be bubbled up in future +USE_EINSUM = True + + +# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity +# See https://arxiv.org/pdf/2006.16668.pdf for details. +def einsum(rule, a, b): + if USE_EINSUM: + return torch.einsum(rule, a, b) + elif rule == "s,se->se": + # [1, s] * [s, e] + return a.reshape(a.shape[0], -1) * b + elif rule == "se,sc->sec": + # [s,e,1] * [s,1,c] + return a.unsqueeze(2) * b.unsqueeze(1) + elif rule == "se,se->s": + # [s,1,e] * [s,e,1] + return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1) + elif rule == "sec,sm->ecm": + # [e*c, s] * [s, m] + s = a.shape[0] + e = a.shape[1] + c = a.shape[2] + m = b.shape[1] + return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m) + elif rule == "sec,ecm->sm": + # [s, e*c] * [e*c, m] + return torch.matmul(a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1])) + elif rule == "ks,ksm->sm": + k = b.shape[0] + s = b.shape[1] + m = b.shape[2] + # [k, s] -> [s, k] -> [s, 1, k] + a = a.t().unsqueeze(1) + # [k,s,m] -> [k, sm] -> [sm, k] -> [s, m, k] + b = b.reshape(k, -1).t().reshape(s, m, k) + # bmm([s, 1, k], [s, m, k]^t) -> [s, m, 1] + return torch.bmm(a, b.transpose(1, 2)).squeeze(2) + else: + return torch.einsum(rule, a, b) + def no_overlap_moe_forward(inputs, expert_fn, ep_group, ep_size, num_local_experts, d_model): """ @@ -21,7 +63,9 @@ def no_overlap_moe_forward(inputs, expert_fn, ep_group, ep_size, num_local_exper return expert_output -def overlap_moe_forward(inputs, expert_fn, a2a_ffn_overlap_degree, ep_group, ep_size, num_local_experts, d_model): +def overlap_moe_forward( + reshaped_inputs, gata_fn, expert_fn, a2a_ffn_overlap_degree, ep_group, ep_size, num_local_experts, d_model +): """ Split the input based on a2a_ffn_overlap_degree and then execute the alltoall and experts function on different stream to overlap the communication and computation cost. @@ -31,23 +75,38 @@ def overlap_moe_forward(inputs, expert_fn, a2a_ffn_overlap_degree, ep_group, ep_ """ - # inputs shape: (e,c,m). split the inputs on 'c' dimension - input_chunks = inputs.chunk(a2a_ffn_overlap_degree, dim=1) - - expert_inputs = [None for _ in range(a2a_ffn_overlap_degree)] - expert_outputs = [None for _ in range(a2a_ffn_overlap_degree)] - + # variables for stream control ready_events = [torch.cuda.Event() for _ in range(a2a_ffn_overlap_degree)] alltoall_stream = [torch.cuda.Stream(torch.cuda.current_device()) for _ in range(a2a_ffn_overlap_degree)] experts_stream = [torch.cuda.Stream(torch.cuda.current_device()) for _ in range(a2a_ffn_overlap_degree)] - # NOTE: async alltoall seems unable to improve the performance - # first all2all, execute on alltoall streams + # local variables for gating and expert computing + l_aux = torch.tensor(0.0, dtype=reshaped_inputs.dtype, device=reshaped_inputs.device) + dispatched_inputs = [None for _ in range(a2a_ffn_overlap_degree)] + expert_inputs = [None for _ in range(a2a_ffn_overlap_degree)] + expert_outputs = [None for _ in range(a2a_ffn_overlap_degree)] + combine_weights = [None for _ in range(a2a_ffn_overlap_degree)] + combined_output = [None for _ in range(a2a_ffn_overlap_degree)] + + # (s,d), split by "s" dimension + input_chunks = reshaped_inputs.chunk(a2a_ffn_overlap_degree, dim=0) + + # gating computing for i, input_split in enumerate(input_chunks): moe_stream_release.apply(torch.cuda.default_stream(), ready_events[i]) + moe_stream_acquire.apply(experts_stream[i], ready_events[i]) + cur_l_aux, combine_weights[i], dispatch_mask, exp_counts = gata_fn(input_split) + dispatched_inputs[i] = einsum( + "sec,sm->ecm", dispatch_mask.type_as(input_split), input_split + ) # TODO: heavy memory usage due to long sequence length + l_aux += cur_l_aux + moe_stream_release.apply(experts_stream[i], ready_events[i]) + + # NOTE: async alltoall seems unable to improve the performance + # first all2all, execute on alltoall streams moe_stream_acquire.apply(alltoall_stream[i], ready_events[i]) - expert_inputs[i] = moe_all_to_all.apply(ep_group, input_split) + expert_inputs[i] = moe_all_to_all.apply(ep_group, dispatched_inputs[i]) moe_stream_release.apply(alltoall_stream[i], ready_events[i]) # expert function, execute on experts stream @@ -64,9 +123,16 @@ def overlap_moe_forward(inputs, expert_fn, a2a_ffn_overlap_degree, ep_group, ep_ expert_outputs[i] = moe_all_to_all.apply(ep_group, expert_outputs[i]) moe_stream_release.apply(alltoall_stream[i], ready_events[i]) + for i in range(a2a_ffn_overlap_degree): + moe_stream_acquire.apply(experts_stream[i], ready_events[i]) + # Re-shape back: gecm -> ecm + expert_outputs[i] = expert_outputs[i].reshape(ep_size * num_local_experts, -1, d_model) + combined_output[i] = einsum( + "sec,ecm->sm", combine_weights[i].type_as(input_chunks[0]), expert_outputs[i].type_as(input_chunks[0]) + ) + moe_stream_release.apply(experts_stream[i], ready_events[i]) + moe_stream_acquire.apply(torch.cuda.default_stream(), ready_events[i]) - # expert_outputs shape: (g, e,c,m). cat the outputs on 'c' dimension - expert_output_gathered = torch.cat(expert_outputs, dim=2) - - return expert_output_gathered + combined_output = torch.cat(combined_output) + return combined_output, l_aux / a2a_ffn_overlap_degree, exp_counts diff --git a/internlm/moe/sharded_moe.py b/internlm/moe/sharded_moe.py index 211c610..dc36115 100644 --- a/internlm/moe/sharded_moe.py +++ b/internlm/moe/sharded_moe.py @@ -16,7 +16,7 @@ from internlm.core.context.parallel_context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc from internlm.utils.logger import get_logger -from .forward_func import no_overlap_moe_forward, overlap_moe_forward +from .forward_func import einsum, no_overlap_moe_forward, overlap_moe_forward try: from tutel.impls.overlap import a2a_ffn_overlap_forward as tutel_overlap_moe_forward @@ -72,49 +72,6 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor: return gumbel(shape) -# einsum rewrites are on par or more performant -# switch can be bubbled up in future -USE_EINSUM = True - - -# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity -# See https://arxiv.org/pdf/2006.16668.pdf for details. -def einsum(rule, a, b): - if USE_EINSUM: - return torch.einsum(rule, a, b) - elif rule == "s,se->se": - # [1, s] * [s, e] - return a.reshape(a.shape[0], -1) * b - elif rule == "se,sc->sec": - # [s,e,1] * [s,1,c] - return a.unsqueeze(2) * b.unsqueeze(1) - elif rule == "se,se->s": - # [s,1,e] * [s,e,1] - return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1) - elif rule == "sec,sm->ecm": - # [e*c, s] * [s, m] - s = a.shape[0] - e = a.shape[1] - c = a.shape[2] - m = b.shape[1] - return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m) - elif rule == "sec,ecm->sm": - # [s, e*c] * [e*c, m] - return torch.matmul(a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1])) - elif rule == "ks,ksm->sm": - k = b.shape[0] - s = b.shape[1] - m = b.shape[2] - # [k, s] -> [s, k] -> [s, 1, k] - a = a.t().unsqueeze(1) - # [k,s,m] -> [k, sm] -> [sm, k] -> [s, m, k] - b = b.reshape(k, -1).t().reshape(s, m, k) - # bmm([s, 1, k], [s, m, k]^t) -> [s, m, 1] - return torch.bmm(a, b.transpose(1, 2)).squeeze(2) - else: - return torch.einsum(rule, a, b) - - # The following functions are extracted and scripted # because otherwise during a torch.jit.trace, the non-Tensor # values used in the calculations get recorded as constants. @@ -405,18 +362,10 @@ class MOELayer(Base): # group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1 reshaped_inputs = inputs[0].reshape(-1, d_model) - self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_inputs, inputs[1]) - dispatched_inputs = einsum( - "sec,sm->ecm", dispatch_mask.type_as(inputs[0]), reshaped_inputs - ) # TODO: heavy memory usage due to long sequence length - - if self.overlap_degree == 1: - expert_output = no_overlap_moe_forward( - dispatched_inputs, self.experts, self.ep_group, self.ep_size, self.num_local_experts, d_model - ) - elif self.overlap_degree > 1 and not self.use_tutel: - expert_output = overlap_moe_forward( - dispatched_inputs, + if self.overlap_degree > 1 and not self.use_tutel: + combined_output, self.l_aux, self.exp_counts = overlap_moe_forward( + reshaped_inputs, + self.gate, self.experts, self.overlap_degree, self.ep_group, @@ -424,17 +373,29 @@ class MOELayer(Base): self.num_local_experts, d_model, ) - elif self.overlap_degree > 1 and self.use_tutel: - expert_output = tutel_overlap_moe_forward( - dispatched_inputs, self.experts, self.overlap_degree, True, self.ep_group - ) else: - assert False, "unsupported moe forward strategy" + self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_inputs, inputs[1]) + dispatched_inputs = einsum( + "sec,sm->ecm", dispatch_mask.type_as(inputs[0]), reshaped_inputs + ) # TODO: heavy memory usage due to long sequence length - # Re-shape back: gecm -> ecm - expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model) + if self.overlap_degree == 1: + expert_output = no_overlap_moe_forward( + dispatched_inputs, self.experts, self.ep_group, self.ep_size, self.num_local_experts, d_model + ) + elif self.overlap_degree > 1 and self.use_tutel: + expert_output = tutel_overlap_moe_forward( + dispatched_inputs, self.experts, self.overlap_degree, True, self.ep_group + ) + else: + assert False, "unsupported moe forward strategy" - combined_output = einsum("sec,ecm->sm", combine_weights.type_as(inputs[0]), expert_output.type_as(inputs[0])) + # Re-shape back: gecm -> ecm + expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model) + + combined_output = einsum( + "sec,ecm->sm", combine_weights.type_as(inputs[0]), expert_output.type_as(inputs[0]) + ) out = combined_output.reshape(inputs[0].shape)