@ -1,6 +1,7 @@
import warnings
from types import MethodType
from typing import Callable , Optional , OrderedDict , Tuple
import numpy as np
import torch
import torch . distributed as dist
@ -64,6 +65,14 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
overlap_communication = True
warnings . warn ( WARN_STR + " Please make sure of this. " )
self . param_info = param_info
self . stage_manager = model . stage_manager
self . shared_params = model . shared_params
self . dp_pg = dp_process_group
if use_pipeline :
reinitialize_optimizer ( optimizer , model )
pg_param_list = {
dp_process_group : list ( filter ( lambda p : not is_moe_tensor ( p ) , model . parameters ( ) ) ) ,
moe_dp_group : list ( filter ( is_moe_tensor , model . parameters ( ) ) ) ,
@ -116,17 +125,16 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
raise NotImplementedError
world_size = dist . get_world_size ( )
self . moe_dp_size = world_size / / ( ep_size * moe_tp_size )
self . moe_dp_size = world_size / / ( self . pp_size * ep_size * moe_tp_size )
self . ep_size = ep_size
self . moe_tp_size = moe_tp_size
self . moe_pg_mesh = ProcessGroupMesh ( self . moe_dp_size , self . ep_size , self . moe_tp_size )
self . moe_dp_axis , self . ep_axis , self . moe_tp_axis = 0 , 1 , 2
if self . pp_size * self . moe_dp_size * self . ep_size * self . moe_tp_size != world_size :
raise ValueError (
f " world_size= { world_size } is not divisible by pp_size= { self . pp_size } * moe_dp_size= { self . moe_dp_size } * ep_size= { self . ep_size } * moe_tp_size= { self . moe_tp_size } "
)
self . moe_dp_group = self . moe_pg_mesh . get_group_along_axis ( self . moe_dp_axis )
self . ep_group = self . moe_pg_mesh . get_group_along_axis ( self . ep_axis )
self . moe_tp_group = self . moe_pg_mesh . get_group_along_axis ( self . moe_tp_axis )
self . _init_moe_param_comm ( )
self . logger . info ( f " { type ( self ) . __name__ } : { self . ep_size =} { self . moe_dp_size =} { self . moe_tp_size =} " , ranks = [ 0 ] )
@ -136,6 +144,52 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self . force_overlap_comm = force_overlap_comm
def _init_moe_param_comm ( self ) :
self . moe_dp_group = None
self . ep_group = None
self . moe_tp_group = None
# create submesh for ep, moe_dp, moe_tp
ranks_by_pp_stage = self . pg_mesh . get_group_along_axis (
[ self . dp_axis , self . tp_axis , self . sp_axis ] , return_ranks_by_group = True
)
global_rank = self . pg_mesh . rank
pp_rank = self . pg_mesh . coordinate ( self . pp_axis )
# create groups from submesh
for stage_idx , stage_rank in enumerate ( ranks_by_pp_stage ) :
# axis 0 is dp, axis 1 is tp, axis 2 is sp
submesh = np . array ( stage_rank ) . reshape ( self . moe_dp_size , self . ep_size , self . moe_tp_size )
# hardcode here since we only have 3 axis
# moe_dp_group
for ep_idx in range ( self . ep_size ) :
for moe_tp_idx in range ( self . moe_tp_size ) :
moe_dp_ranks = submesh [ : , ep_idx , moe_tp_idx ] . flatten ( ) . tolist ( )
group = dist . new_group ( moe_dp_ranks )
if pp_rank == stage_idx and global_rank in moe_dp_ranks :
assert self . moe_dp_group is None
self . moe_dp_group = group
# ep_group
for moe_dp_idx in range ( self . moe_dp_size ) :
for moe_tp_idx in range ( self . moe_tp_size ) :
ep_ranks = submesh [ moe_dp_idx , : , moe_tp_idx ] . flatten ( ) . tolist ( )
group = dist . new_group ( ep_ranks )
if pp_rank == stage_idx and global_rank in ep_ranks :
assert self . ep_group is None
self . ep_group = group
# moe_tp_group
for moe_dp_idx in range ( self . moe_dp_size ) :
for ep_idx in range ( self . ep_size ) :
moe_tp_ranks = submesh [ moe_dp_idx , ep_idx , : ] . flatten ( ) . tolist ( )
group = dist . new_group ( moe_tp_ranks )
if pp_rank == stage_idx and global_rank in moe_tp_ranks :
assert self . moe_tp_group is None
self . moe_tp_group = group
self . logger . info ( f " rank { dist . get_rank ( ) } moe_dp_group { dist . get_process_group_ranks ( self . moe_dp_group ) } ep_group { dist . get_process_group_ranks ( self . ep_group ) } moe_tp_group { dist . get_process_group_ranks ( self . moe_tp_group ) } " )
def get_checkpoint_io ( self ) - > MoECheckpointIO :
return MoECheckpointIO (
self . dp_group , self . pp_group , self . tp_group , self . ep_group , self . moe_dp_group , self . zero_stage