@ -1,9 +1,8 @@
import warnings
from collections import defaultdict
from types import MethodType
from typing import Callable , Optional , OrderedDict , Tuple
from typing import Callable , List , Optional , OrderedDict , Tuple
import numpy as np
import torch
import torch . distributed as dist
from torch . distributed import ProcessGroup
@ -13,6 +12,8 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch . utils . data import DataLoader
from colossalai . booster . plugin . hybrid_parallel_plugin import (
PRECISION_TORCH_TYPE ,
SUPPORT_SP_MODE ,
HybridParallelAMPOptimizer ,
HybridParallelModule ,
HybridParallelNaiveOptimizer ,
@ -22,9 +23,16 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
reinitialize_optimizer ,
)
from colossalai . checkpoint_io import MoECheckpointIO
from colossalai . cluster . process_group_mesh import ProcessGroupMesh
from colossalai . interface import ModelWrapper , OptimizerWrapper
from colossalai . interface . optimizer import DistributedOptim
from colossalai . nn . optimizer import cast_to_distributed
from colossalai . pipeline . schedule . interleaved_pp import InterleavedSchedule
from colossalai . pipeline . schedule . one_f_one_b import OneForwardOneBackwardSchedule
from colossalai . pipeline . stage_manager import PipelineStageManager
from colossalai . shardformer . policies . base_policy import Policy
from colossalai . shardformer . shard . grad_ckpt_config import GradientCheckpointConfig
from colossalai . shardformer . shard . shard_config import ShardConfig
from colossalai . tensor . moe_tensor . api import is_moe_tensor
@ -57,7 +65,7 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
forced_dtype : Optional [ torch . dtype ] = None ,
overlap_allgather : bool = False ,
) :
WARN_STR = " Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result "
WARN_STR = " Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result. "
if not force_overlap_comm and ( overlap_communication or partition_grad ) :
raise RuntimeError (
WARN_STR
@ -105,129 +113,218 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
class MoeHybridParallelPlugin ( HybridParallelPlugin ) :
"""
TODO : add docstring
Modified from colossalai . booster . plugin . hybrid_parallel_plugin . HybridParallelPlugin
Extra Args :
ep_size ( int ) : The size of tensor parallelism . Tensor parallelism will not be used when tp_size is set to 1.
force_overlap_comm ( bool ) : For LowLevelZeroOptimizer , it might causes program hang when some experts are routed and overlap_communication is True during training . This flag is used to force overlap_communication = True .
"""
def __init__ ( self , ep_size : int , moe_tp_size : int = 1 , force_overlap_comm = False , * args , * * kwargs ) - > None :
if " overlap_communication " not in kwargs :
kwargs [ " overlap_communication " ] = False # default by true in super class
super ( ) . __init__ ( * args , * * kwargs )
if ep_size < = 1 :
raise ValueError ( " Use HybridParallelPlugin when ep_size <= 1 " )
def __init__ (
self ,
tp_size : int ,
pp_size : int ,
ep_size : int ,
sp_size : int = None ,
precision : str = " fp16 " ,
zero_stage : int = 0 ,
enable_all_optimization : bool = False ,
enable_fused_normalization : bool = False ,
enable_flash_attention : bool = False ,
enable_jit_fused : bool = False ,
enable_sequence_parallelism : bool = False ,
sequence_parallelism_mode : str = None ,
enable_sequence_overlap : bool = False ,
parallel_output : bool = True ,
num_microbatches : Optional [ int ] = None ,
microbatch_size : Optional [ int ] = None ,
initial_scale : float = 2 * * 16 ,
min_scale : float = 1 ,
growth_factor : float = 2 ,
backoff_factor : float = 0.5 ,
growth_interval : int = 1000 ,
hysteresis : int = 2 ,
max_scale : float = 2 * * 32 ,
max_norm : float = 0 ,
broadcast_buffers : bool = True ,
ddp_bucket_cap_mb : int = 25 ,
find_unused_parameters : bool = False ,
check_reduction : bool = False ,
gradient_as_bucket_view : bool = False ,
static_graph : bool = False ,
zero_bucket_size_in_m : int = 12 ,
cpu_offload : bool = False ,
communication_dtype : Optional [ torch . dtype ] = None ,
overlap_communication : bool = True ,
custom_policy : Policy = None ,
pp_style : str = " 1f1b " ,
num_model_chunks : int = 1 ,
num_layers_per_stage : Optional [ List [ int ] ] = None ,
gradient_checkpoint_config : Optional [ GradientCheckpointConfig ] = None ,
enable_metadata_cache : bool = True ,
make_vocab_size_divisible_by : int = 64 ,
dp_outside : bool = True ,
overlap_p2p : bool = True ,
overlap_allgather : bool = False ,
force_overlap_comm : bool = False ,
) - > None :
assert (
dist . get_world_size ( ) % ( tp_size * pp_size ) == 0
) , f " World size { dist . get_world_size ( ) } is not divisible by tp_size { tp_size } * pp_size { pp_size } "
if enable_sequence_parallelism :
self . sequence_parallelism_mode = (
sequence_parallelism_mode if sequence_parallelism_mode is not None else " all_to_all "
)
assert (
self . sequence_parallelism_mode in SUPPORT_SP_MODE
) , f " Sequence parallelism mode { self . sequence_parallelism_mode } is not in the supported list { SUPPORT_SP_MODE } "
if self . sequence_parallelism_mode in [ " split_gather " , " ring " ] :
assert (
tp_size > 1
) , f " Sequence parallelism mode { self . sequence_parallelism_mode } must be enabled when using tensor parallelism "
if sp_size != 1 :
warnings . warn (
f " The sp_size will be the same as tp_size in sequence parallelism mode { self . sequence_parallelism_mode } , will ignore the given sequence parallelism size. "
)
self . sp_size = 1
self . dp_size = dist . get_world_size ( ) / / ( tp_size * pp_size )
elif self . sequence_parallelism_mode in [ " all_to_all " ] :
self . sp_size = 1 if sp_size is None else sp_size
self . dp_size = dist . get_world_size ( ) / / ( self . sp_size * pp_size * tp_size )
else :
self . dp_size = dist . get_world_size ( ) / / ( tp_size * pp_size )
assert (
sp_size == 1 or sp_size is None
) , f " You should not set sp_size when sequence parallelism is not enabled. "
self . sp_size = 1
assert self . dp_size % ep_size == 0 , f " dp_size should be divisible by ep_size, { self . dp_size =} { ep_size =} "
self . moe_dp_size = self . dp_size / / ep_size
self . ep_size = ep_size
self . moe_tp_size = moe_tp_size
self . _init_moe_param_comm ( )
self . use_ddp = ( self . dp_size > 1 and self . pp_size == 1 and self . zero_stage == 0 ) or (
self . dp_size == 1
and self . pp_size == 1
and self . enable_sequence_parallelism
and self . sequence_parallelism_mode == " all_to_all "
)
self . tp_size = tp_size
self . pp_size = pp_size
self . precision = precision
self . zero_stage = zero_stage
self . cpu_offload = cpu_offload
self . enable_all_optimization = enable_all_optimization
self . enable_fused_normalization = enable_fused_normalization
self . enable_flash_attention = enable_flash_attention
self . enable_jit_fused = enable_jit_fused
self . enable_sequence_parallelism = enable_sequence_parallelism
if dp_outside :
self . dp_axis , self . pp_axis , self . tp_axis , self . sp_axis = 0 , 1 , 2 , 3
self . pg_mesh = ProcessGroupMesh ( self . dp_size , self . pp_size , self . tp_size , self . sp_size )
self . moe_dp_axis , self . ep_axis = 0 , 1
self . moe_pg_mesh = ProcessGroupMesh (
self . moe_dp_size , self . ep_size , self . pp_size , self . tp_size , self . sp_size
)
else :
self . pp_axis , self . dp_axis , self . tp_axis , self . sp_axis = 0 , 1 , 2 , 3
self . pg_mesh = ProcessGroupMesh ( self . pp_size , self . dp_size , self . tp_size , self . sp_size )
self . moe_dp_axis , self . ep_axis = 1 , 2
self . moe_pg_mesh = ProcessGroupMesh (
self . pp_size , self . moe_dp_size , self . ep_size , self . tp_size , self . sp_size
)
if self . use_ddp :
warnings . warn (
f " Will have to check all params are used in pytorch DDP since not all experts are always activated "
self . stage_manager = None
self . schedule = None
self . custom_policy = custom_policy
assert zero_stage in ( 0 , 1 , 2 )
if self . pp_size > 1 :
assert pp_style in [ " 1f1b " , " interleaved " ] , " Unsupported pipeline parallelism style "
assert pp_style == " interleaved " or num_model_chunks == 1 , " num_model_chunks must be 1 when using 1f1b "
assert (
num_microbatches is not None or microbatch_size is not None
) , " num_microbatches or microbatch_size must be specified when using pipeline parallelism "
assert (
self . zero_stage < = 1
) , " To avoid prohibitive gradient synchronization costs, zero stage must be 0 or 1 when using pipeline parallelism "
self . stage_manager = PipelineStageManager (
self . pg_mesh ,
pipeline_axis = self . pp_axis ,
enable_interleave = pp_style == " interleaved " ,
num_model_chunks = num_model_chunks ,
num_layers_per_stage = num_layers_per_stage ,
)
self . ddp_config [ " find_unused_parameters " ] = True
if dist . get_process_group_ranks ( self . dp_group ) != dist . get_process_group_ranks ( self . moe_dp_group ) :
# TODO it might make sense to support non-moe with tp on but moe with tp off
raise ValueError (
f " if ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group { dist . get_process_group_ranks ( self . dp_group ) } and moe_dp_group { dist . get_process_group_ranks ( self . moe_dp_group ) } , you might want to use HybridParallelPlugin or set zero_stage > 0 "
if pp_style == " interleaved " :
assert num_model_chunks > 1 , " number of model chunks must be > 1 when using interleaved "
self . schedule = InterleavedSchedule (
stage_manager = self . stage_manager ,
num_model_chunks = num_model_chunks ,
num_microbatch = num_microbatches ,
microbatch_size = microbatch_size ,
enable_metadata_cache = enable_metadata_cache ,
overlap_p2p = overlap_p2p ,
)
# set param group in shard config
self . shard_config . ep_group = self . ep_group
self . shard_config . moe_dp_group = self . moe_dp_group
self . shard_config . moe_tp_group = self . moe_tp_group
self . force_overlap_comm = force_overlap_comm
def _init_moe_param_comm ( self ) :
world_size = dist . get_world_size ( )
if self . enable_sequence_parallelism :
if self . sequence_parallelism_mode == " all_to_all " :
# if sequence parallelism is enabled, ep_group reuses sp_group
if self . ep_size != self . sp_size :
raise ValueError (
f " ep_size= { self . ep_size } should be equal to sp_size= { self . sp_size } or turned off when sequence parallelism is enabled "
)
# since we are reusing sp_group, moe_dp_group will be derived as dp_group
self . moe_dp_size = self . dp_size
self . moe_dp_group = self . dp_group
self . dp_sp_group = self . pg_mesh . create_group_along_axis ( [ self . dp_axis , self . sp_axis ] )
self . ep_group = self . sp_group
self . moe_tp_group = self . tp_group
else :
raise NotImplementedError (
f " sequence_parallelism_mode= { self . sequence_parallelism_mode } is not supported "
elif pp_style == " 1f1b " :
self . schedule = OneForwardOneBackwardSchedule (
stage_manager = self . stage_manager ,
num_microbatches = num_microbatches ,
microbatch_size = microbatch_size ,
enable_metadata_cache = enable_metadata_cache ,
)
else :
raise NotImplementedError ( )
self . tp_group = self . pg_mesh . get_group_along_axis ( self . tp_axis )
self . dp_group = self . pg_mesh . get_group_along_axis ( self . dp_axis )
self . pp_group = self . pg_mesh . get_group_along_axis ( self . pp_axis )
self . moe_dp_group = self . moe_pg_mesh . get_group_along_axis ( self . moe_dp_axis )
self . ep_group = self . moe_pg_mesh . get_group_along_axis ( self . ep_axis )
if self . enable_sequence_parallelism and self . sequence_parallelism_mode in [ " split_gather " , " ring " ] :
self . sp_group = self . pg_mesh . get_group_along_axis ( self . tp_axis )
else :
self . moe_dp_size = world_size / / ( self . pp_size * self . ep_size * self . moe_tp_size )
if self . moe_dp_size * self . pp_size * self . ep_size * self . moe_tp_size != world_size :
raise ValueError (
f " world_size= { world_size } is not divisible by pp_size= { self . pp_size } * moe_dp_size= { self . moe_dp_size } * ep_size= { self . ep_size } * moe_tp_size= { self . moe_tp_size } "
)
self . moe_dp_group = None
self . ep_group = None
self . moe_tp_group = None
self . dp_sp_group = self . dp_group
# create submesh for ep, moe_dp, moe_tp
ranks_by_pp_stage = self . pg_mesh . get_group_along_axis (
[ self . dp_axis , self . tp_axis , self . sp_axis ] , return_ranks_by_group = True
)
global_rank = self . pg_mesh . rank
pp_rank = self . pg_mesh . coordinate ( self . pp_axis )
self . sp_group = self . pg_mesh . get_group_along_axis ( self . sp_axis )
self . shard_config = ShardConfig (
tensor_parallel_process_group = self . tp_group ,
sequence_parallel_process_group = self . sp_group ,
ep_group = self . ep_group ,
moe_dp_group = self . moe_dp_group ,
pipeline_stage_manager = self . stage_manager ,
enable_tensor_parallelism = self . tp_size > 1 ,
enable_all_optimization = self . enable_all_optimization ,
enable_fused_normalization = self . enable_fused_normalization ,
enable_flash_attention = self . enable_flash_attention ,
enable_jit_fused = self . enable_jit_fused ,
enable_sequence_parallelism = enable_sequence_parallelism ,
sequence_parallelism_mode = sequence_parallelism_mode ,
enable_sequence_overlap = enable_sequence_overlap ,
parallel_output = parallel_output ,
make_vocab_size_divisible_by = make_vocab_size_divisible_by ,
gradient_checkpoint_config = gradient_checkpoint_config ,
)
self . amp_config = dict (
initial_scale = initial_scale ,
growth_factor = growth_factor ,
backoff_factor = backoff_factor ,
growth_interval = growth_interval ,
hysteresis = hysteresis ,
min_scale = min_scale ,
max_scale = max_scale ,
)
# create groups from submesh
for stage_idx , stage_rank in enumerate ( ranks_by_pp_stage ) :
# axis 0 is moe_dp, axis 1 is ep, axis 2 is moe_tp
submesh = np . array ( stage_rank ) . reshape ( self . moe_dp_size , self . ep_size , self . moe_tp_size )
self . ddp_config = dict (
broadcast_buffers = broadcast_buffers ,
bucket_cap_mb = ddp_bucket_cap_mb ,
find_unused_parameters = find_unused_parameters ,
check_reduction = check_reduction ,
gradient_as_bucket_view = gradient_as_bucket_view ,
static_graph = static_graph ,
)
# hardcode here since we only have 3 axis
# moe_dp_group
for ep_idx in range ( self . ep_size ) :
for moe_tp_idx in range ( self . moe_tp_size ) :
moe_dp_ranks = submesh [ : , ep_idx , moe_tp_idx ] . flatten ( ) . tolist ( )
group = dist . new_group ( moe_dp_ranks )
if pp_rank == stage_idx and global_rank in moe_dp_ranks :
assert self . moe_dp_group is None
self . moe_dp_group = group
# ep_group
for moe_dp_idx in range ( self . moe_dp_size ) :
for moe_tp_idx in range ( self . moe_tp_size ) :
ep_ranks = submesh [ moe_dp_idx , : , moe_tp_idx ] . flatten ( ) . tolist ( )
group = dist . new_group ( ep_ranks )
if pp_rank == stage_idx and global_rank in ep_ranks :
assert self . ep_group is None
self . ep_group = group
# moe_tp_group
for moe_dp_idx in range ( self . moe_dp_size ) :
for ep_idx in range ( self . ep_size ) :
moe_tp_ranks = submesh [ moe_dp_idx , ep_idx , : ] . flatten ( ) . tolist ( )
group = dist . new_group ( moe_tp_ranks )
if pp_rank == stage_idx and global_rank in moe_tp_ranks :
assert self . moe_tp_group is None
self . moe_tp_group = group
self . zero_config = dict (
reduce_bucket_size = zero_bucket_size_in_m * 1024 * 1024 ,
communication_dtype = communication_dtype ,
overlap_communication = overlap_communication ,
cpu_offload = cpu_offload ,
partition_grad = ( self . zero_stage == 2 ) ,
forced_dtype = PRECISION_TORCH_TYPE [ precision ] ,
overlap_allgather = overlap_allgather ,
)
if dist . get_process_group_ranks ( self . tp_group ) != dist . get_process_group_ranks ( self . moe_tp_group ) :
# NOTE: different tp settings between moe and non moe param are complex to handle
# we simply reuse tp_group as moe_tp_group, this implies that dp_size == moe_dp_size * ep_size
raise NotImplementedError (
f " Only support shared tp group between moe and non moe params, but found non-moe tp { dist . get_process_group_ranks ( self . tp_group ) } , moe tp { dist . get_process_group_ranks ( self . moe_tp_group ) } , please make sure tp_size == moe_tp_size "
)
self . max_norm = max_norm
self . force_overlap_comm = force_overlap_comm
def get_checkpoint_io ( self ) - > MoECheckpointIO :
return MoECheckpointIO (
@ -249,14 +346,37 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
optimizer = cast_to_distributed ( optimizer )
if not isinstance ( model , ModelWrapper ) :
use_ddp = ( self . dp_size > 1 and self . pp_size == 1 and self . zero_stage == 0 ) or (
self . dp_size == 1
and self . pp_size == 1
and self . enable_sequence_parallelism
and self . sequence_parallelism_mode == " all_to_all "
)
if use_ddp :
warnings . warn (
f " Will have to check all params are used in pytorch DDP since not all experts are always activated "
)
self . ddp_config [ " find_unused_parameters " ] = True
if dist . get_process_group_ranks ( self . dp_group ) != dist . get_process_group_ranks ( self . moe_dp_group ) :
raise ValueError (
f " if pytorch ddp is used, dp_group and moe_dp_group are expected to be the same since DDP can only reduce grad across a single group, but found dp_group { dist . get_process_group_ranks ( self . dp_group ) } and moe_dp_group { dist . get_process_group_ranks ( self . moe_dp_group ) } , you might want to use HybridParallelPlugin (i.e. set ep_size = 1) or set zero_stage > 0 "
)
# sync gradients across DP * SP ranks
if self . enable_sequence_parallelism and self . sequence_parallelism_mode == " all_to_all " :
dp_group = self . pg_mesh . create_group_along_axis ( [ self . dp_axis , self . sp_axis ] )
else :
dp_group = self . dp_group
model = HybridParallelModule (
module = model ,
precision = self . precision ,
shard_config = self . shard_config ,
dp_group = self . dp_sp_group ,
dp_group = dp_group ,
tp_group = self . tp_group ,
sp_group = self . sp_group ,
use_ddp = self . use_ddp ,
use_ddp = use_ddp ,
ddp_config = self . ddp_config ,
custom_policy = self . custom_policy ,
)
@ -301,7 +421,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
use_pipeline = self . enable_pipeline_parallelism ,
force_overlap_comm = self . force_overlap_comm ,
param_info = param_info ,
dp_process_group = self . dp_s p_group ,
dp_process_group = dp_group ,
tp_process_group = self . tp_group ,
pp_process_group = self . pp_group ,
moe_dp_group = self . moe_dp_group ,