InternLM/internlm/moe/forward_func.py

73 lines
3.1 KiB
Python

import torch
from .communication import moe_all_to_all, moe_stream_acquire, moe_stream_release
def no_overlap_moe_forward(inputs, expert_fn, ep_group, ep_size, num_local_experts, d_model):
"""
Preform moe forward computation sequentially.
For example:
alltoall(d)---->expert_fn(d)--->alltoall(d)
"""
inputs = moe_all_to_all.apply(ep_group, inputs)
# Re-shape after all-to-all: ecm -> gecm
inputs = inputs.reshape(ep_size, num_local_experts, -1, d_model)
expert_output = expert_fn(inputs)
expert_output = moe_all_to_all.apply(ep_group, expert_output)
return expert_output
def overlap_moe_forward(inputs, expert_fn, a2a_ffn_overlap_degree, ep_group, ep_size, num_local_experts, d_model):
"""
Split the input based on a2a_ffn_overlap_degree and then execute the alltoall and experts function
on different stream to overlap the communication and computation cost.
For example:
communication stream: alltoall(d[0])---->alltoall(d[1])---->alltoall(d[0])---->alltoall(d[1])
computation stream: expert_fn(d[0]) ----> expert_fn(d[1])
"""
# inputs shape: (e,c,m). split the inputs on 'c' dimension
input_chunks = inputs.chunk(a2a_ffn_overlap_degree, dim=1)
expert_inputs = [None for _ in range(a2a_ffn_overlap_degree)]
expert_outputs = [None for _ in range(a2a_ffn_overlap_degree)]
ready_events = [torch.cuda.Event() for _ in range(a2a_ffn_overlap_degree)]
alltoall_stream = [torch.cuda.Stream(torch.cuda.current_device()) for _ in range(a2a_ffn_overlap_degree)]
experts_stream = [torch.cuda.Stream(torch.cuda.current_device()) for _ in range(a2a_ffn_overlap_degree)]
# NOTE: async alltoall seems unable to improve the performance
# first all2all, execute on alltoall streams
for i, input_split in enumerate(input_chunks):
moe_stream_release.apply(torch.cuda.default_stream(), ready_events[i])
moe_stream_acquire.apply(alltoall_stream[i], ready_events[i])
expert_inputs[i] = moe_all_to_all.apply(ep_group, input_split)
moe_stream_release.apply(alltoall_stream[i], ready_events[i])
# expert function, execute on experts stream
for i in range(a2a_ffn_overlap_degree):
moe_stream_acquire.apply(experts_stream[i], ready_events[i])
# Re-shape after all-to-all: ecm -> gecm
expert_inputs[i] = expert_inputs[i].reshape(ep_size, num_local_experts, -1, d_model)
expert_outputs[i] = expert_fn(expert_inputs[i])
moe_stream_release.apply(experts_stream[i], ready_events[i])
# second all2all, execute on alltoall streams
for i in range(a2a_ffn_overlap_degree):
moe_stream_acquire.apply(alltoall_stream[i], ready_events[i])
expert_outputs[i] = moe_all_to_all.apply(ep_group, expert_outputs[i])
moe_stream_release.apply(alltoall_stream[i], ready_events[i])
moe_stream_acquire.apply(torch.cuda.default_stream(), ready_events[i])
# expert_outputs shape: (g, e,c,m). cat the outputs on 'c' dimension
expert_output_gathered = torch.cat(expert_outputs, dim=2)
return expert_output_gathered