@ -15,6 +15,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
HybridParallelModule ,
HybridParallelNaiveOptimizer ,
HybridParallelPlugin ,
HybridParallelZeroOptimizer ,
get_param_info ,
reinitialize_optimizer ,
)
@ -22,16 +23,18 @@ from colossalai.checkpoint_io import MoECheckpointIO
from colossalai . cluster . process_group_mesh import ProcessGroupMesh
from colossalai . interface import ModelWrapper , OptimizerWrapper
from colossalai . tensor . moe_tensor . api import is_moe_tensor
from colossalai . zero . low_level import LowLevelZeroOptimizer
class MoeHybridParallelZeroOptimizer ( LowLevelZeroOptimizer ) :
class MoeHybridParallelZeroOptimizer ( HybridParallelZeroOptimizer ) :
def __init__ (
self ,
optimizer : Optimizer ,
model : Module ,
use_pipeline : bool ,
force_overlap_comm : bool , # force overlap comm
dp_process_group : ProcessGroup , # dp pg for comm
dp_process_group : Optional [ ProcessGroup ] , # the dp pg for comm
tp_process_group : Optional [ ProcessGroup ] , # if using tp
pp_process_group : Optional [ ProcessGroup ] , # if using pp
moe_dp_group : ProcessGroup , # moe dp pg for comm
param_info : OrderedDict ,
initial_scale : int = 2 * * 16 , # grad scaler config
@ -49,32 +52,28 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
partition_grad : bool = False , # stage 2 flag
cpu_offload : bool = False , # cpu offload
forced_dtype : Optional [ torch . dtype ] = None ,
) :
) :
WARN_STR = " Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result "
if not force_overlap_comm and ( overlap_communication or partition_grad ) :
raise RuntimeError ( WARN_STR + " If you are not sure about this, set (overlap_communication=False and partition_grad=False) or force_overlap_comm=True " )
raise RuntimeError (
WARN_STR
+ " If you are not sure about this, set (overlap_communication=False and partition_grad=False) or force_overlap_comm=True "
)
if force_overlap_comm :
overlap_communication = True
warnings . warn ( WARN_STR + " Please make sure of this. " )
self . param_info = param_info
self . stage_manager = model . stage_manager
self . shared_params = model . shared_params
self . dp_pg = dp_process_group
if use_pipeline :
reinitialize_optimizer ( optimizer , model )
pg_param_list = {
dp_process_group : list ( filter ( lambda p : not is_moe_tensor ( p ) , model . parameters ( ) ) ) ,
moe_dp_group : list ( filter ( is_moe_tensor , model . parameters ( ) ) ) ,
}
super ( ) . __init__ (
model = model ,
optimizer = optimizer ,
pg_to_param_list = pg_param_list ,
use_pipeline = use_pipeline ,
param_info = param_info ,
initial_scale = initial_scale ,
min_scale = min_scale ,
growth_factor = growth_factor ,
@ -89,7 +88,12 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
overlap_communication = overlap_communication ,
partition_grad = partition_grad ,
cpu_offload = cpu_offload ,
# dp_process_group=dp_process_group,
tp_process_group = tp_process_group ,
pp_process_group = pp_process_group ,
forced_dtype = forced_dtype ,
## moe args
pg_to_param_list = pg_param_list ,
)
@ -180,7 +184,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
optimizer , model , use_pipeline = self . enable_pipeline_parallelism , param_info = param_info
)
else :
if not ( self . dp_size > 1 or self . moe_dp_size > 1 ) :
if not ( self . dp_size > 1 or self . moe_dp_size > 1 ) :
warnings . warn (
" Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
" If you do not intend to use cpu_offload, please consider set zero_stage=0. "
@ -192,6 +196,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
force_overlap_comm = self . force_overlap_comm ,
param_info = param_info ,
dp_process_group = self . dp_group ,
tp_process_group = self . tp_group ,
pp_process_group = self . pp_group ,
moe_dp_group = self . moe_dp_group ,
verbose = True ,
clip_grad_norm = self . max_norm ,