@ -42,7 +42,6 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
optimizer : Optimizer ,
model : Module ,
use_pipeline : bool ,
force_overlap_comm : bool , # force overlap comm
dp_process_group : Optional [ ProcessGroup ] , # the dp pg for comm
tp_process_group : Optional [ ProcessGroup ] , # if using tp
pp_process_group : Optional [ ProcessGroup ] , # if using pp
@ -65,17 +64,6 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
forced_dtype : Optional [ torch . dtype ] = None ,
overlap_allgather : bool = False ,
) :
WARN_STR = " Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result. "
if not force_overlap_comm and ( overlap_communication or partition_grad ) :
raise RuntimeError (
WARN_STR
+ " If you are not sure about this, set (overlap_communication=False and partition_grad=False) or force_overlap_comm=True "
)
if force_overlap_comm :
overlap_communication = True
warnings . warn ( WARN_STR + " Please make sure of this. " )
pg_param_list = {
dp_process_group : list ( filter ( lambda p : not is_moe_tensor ( p ) , model . parameters ( ) ) ) ,
moe_dp_group : list ( filter ( is_moe_tensor , model . parameters ( ) ) ) ,
@ -116,9 +104,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
Modified from colossalai . booster . plugin . hybrid_parallel_plugin . HybridParallelPlugin
Extra Args :
ep_size ( int ) : The size of expert parallelism
force_overlap_comm ( bool ) :
For LowLevelZeroOptimizer , it might causes program hang when some experts are routed and overlap_communication is True during training .
This flag is used to force overlap_communication = True . Make sure every expert are routed when you use this .
"""
def __init__ (
@ -167,8 +152,15 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
dp_outside : bool = True ,
overlap_p2p : bool = True ,
overlap_allgather : bool = False ,
force_overlap_comm : bool = False ,
) - > None :
if overlap_communication or zero_stage == 2 :
overlap_communication = False
zero_stage = 1
warnings . warn (
f " overlap_communication and zero_stage are set to False and 1 because "
f " ZeRO-2 or comm overlap cause program hang when some experts are not routed. "
)
assert (
dist . get_world_size ( ) % ( tp_size * pp_size ) == 0
) , f " World size { dist . get_world_size ( ) } is not divisible by tp_size { tp_size } * pp_size { pp_size } "
@ -326,7 +318,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
)
self . max_norm = max_norm
self . force_overlap_comm = force_overlap_comm
def get_checkpoint_io ( self ) - > MoECheckpointIO :
return MoECheckpointIO (
@ -421,7 +412,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
optimizer ,
model ,
use_pipeline = self . enable_pipeline_parallelism ,
force_overlap_comm = self . force_overlap_comm ,
param_info = param_info ,
dp_process_group = dp_group ,
tp_process_group = self . tp_group ,