mirror of https://github.com/InternLM/InternLM
73 lines
3.1 KiB
Python
73 lines
3.1 KiB
Python
import torch
|
|
|
|
from .communication import moe_all_to_all, moe_stream_acquire, moe_stream_release
|
|
|
|
|
|
def no_overlap_moe_forward(inputs, expert_fn, ep_group, ep_size, num_local_experts, d_model):
|
|
"""
|
|
Preform moe forward computation sequentially.
|
|
For example:
|
|
alltoall(d)---->expert_fn(d)--->alltoall(d)
|
|
"""
|
|
|
|
inputs = moe_all_to_all.apply(ep_group, inputs)
|
|
|
|
# Re-shape after all-to-all: ecm -> gecm
|
|
inputs = inputs.reshape(ep_size, num_local_experts, -1, d_model)
|
|
expert_output = expert_fn(inputs)
|
|
|
|
expert_output = moe_all_to_all.apply(ep_group, expert_output)
|
|
|
|
return expert_output
|
|
|
|
|
|
def overlap_moe_forward(inputs, expert_fn, a2a_ffn_overlap_degree, ep_group, ep_size, num_local_experts, d_model):
|
|
"""
|
|
Split the input based on a2a_ffn_overlap_degree and then execute the alltoall and experts function
|
|
on different stream to overlap the communication and computation cost.
|
|
For example:
|
|
communication stream: alltoall(d[0])---->alltoall(d[1])---->alltoall(d[0])---->alltoall(d[1])
|
|
computation stream: expert_fn(d[0]) ----> expert_fn(d[1])
|
|
|
|
"""
|
|
|
|
# inputs shape: (e,c,m). split the inputs on 'c' dimension
|
|
input_chunks = inputs.chunk(a2a_ffn_overlap_degree, dim=1)
|
|
|
|
expert_inputs = [None for _ in range(a2a_ffn_overlap_degree)]
|
|
expert_outputs = [None for _ in range(a2a_ffn_overlap_degree)]
|
|
|
|
ready_events = [torch.cuda.Event() for _ in range(a2a_ffn_overlap_degree)]
|
|
alltoall_stream = [torch.cuda.Stream(torch.cuda.current_device()) for _ in range(a2a_ffn_overlap_degree)]
|
|
experts_stream = [torch.cuda.Stream(torch.cuda.current_device()) for _ in range(a2a_ffn_overlap_degree)]
|
|
|
|
# NOTE: async alltoall seems unable to improve the performance
|
|
# first all2all, execute on alltoall streams
|
|
for i, input_split in enumerate(input_chunks):
|
|
moe_stream_release.apply(torch.cuda.default_stream(), ready_events[i])
|
|
|
|
moe_stream_acquire.apply(alltoall_stream[i], ready_events[i])
|
|
expert_inputs[i] = moe_all_to_all.apply(ep_group, input_split)
|
|
moe_stream_release.apply(alltoall_stream[i], ready_events[i])
|
|
|
|
# expert function, execute on experts stream
|
|
for i in range(a2a_ffn_overlap_degree):
|
|
moe_stream_acquire.apply(experts_stream[i], ready_events[i])
|
|
# Re-shape after all-to-all: ecm -> gecm
|
|
expert_inputs[i] = expert_inputs[i].reshape(ep_size, num_local_experts, -1, d_model)
|
|
expert_outputs[i] = expert_fn(expert_inputs[i])
|
|
moe_stream_release.apply(experts_stream[i], ready_events[i])
|
|
|
|
# second all2all, execute on alltoall streams
|
|
for i in range(a2a_ffn_overlap_degree):
|
|
moe_stream_acquire.apply(alltoall_stream[i], ready_events[i])
|
|
expert_outputs[i] = moe_all_to_all.apply(ep_group, expert_outputs[i])
|
|
moe_stream_release.apply(alltoall_stream[i], ready_events[i])
|
|
|
|
moe_stream_acquire.apply(torch.cuda.default_stream(), ready_events[i])
|
|
|
|
# expert_outputs shape: (g, e,c,m). cat the outputs on 'c' dimension
|
|
expert_output_gathered = torch.cat(expert_outputs, dim=2)
|
|
|
|
return expert_output_gathered
|