solve hang when parallel mode = pp + dp

moe_sp
haze188 2024-07-11 02:12:44 +00:00 committed by hxwang
parent 0210bead8c
commit a613edd517
No known key found for this signature in database
GPG Key ID: 0EC383D418F0B9F8
3 changed files with 57 additions and 34 deletions

View File

@ -27,6 +27,7 @@ from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
@ -1068,8 +1069,10 @@ class HybridParallelPlugin(PipelinePluginBase):
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size) self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
self.logger.info(f"{type(self).__name__}: {self.pp_size=} {self.dp_size=} {self.tp_size=} {self.sp_size=}", ranks=[0]) self.logger.info(
f"{type(self).__name__}: {self.pp_size=} {self.dp_size=} {self.tp_size=} {self.sp_size=}", ranks=[0]
)
self.stage_manager = None self.stage_manager = None
self.schedule = None self.schedule = None
self.custom_policy = custom_policy self.custom_policy = custom_policy

View File

@ -15,6 +15,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
HybridParallelModule, HybridParallelModule,
HybridParallelNaiveOptimizer, HybridParallelNaiveOptimizer,
HybridParallelPlugin, HybridParallelPlugin,
HybridParallelZeroOptimizer,
get_param_info, get_param_info,
reinitialize_optimizer, reinitialize_optimizer,
) )
@ -22,16 +23,18 @@ from colossalai.checkpoint_io import MoECheckpointIO
from colossalai.cluster.process_group_mesh import ProcessGroupMesh from colossalai.cluster.process_group_mesh import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.tensor.moe_tensor.api import is_moe_tensor from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.zero.low_level import LowLevelZeroOptimizer
class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
def __init__( def __init__(
self, self,
optimizer: Optimizer, optimizer: Optimizer,
model: Module, model: Module,
use_pipeline: bool, use_pipeline: bool,
force_overlap_comm: bool, # force overlap comm force_overlap_comm: bool, # force overlap comm
dp_process_group: ProcessGroup, # dp pg for comm dp_process_group: Optional[ProcessGroup], # the dp pg for comm
tp_process_group: Optional[ProcessGroup], # if using tp
pp_process_group: Optional[ProcessGroup], # if using pp
moe_dp_group: ProcessGroup, # moe dp pg for comm moe_dp_group: ProcessGroup, # moe dp pg for comm
param_info: OrderedDict, param_info: OrderedDict,
initial_scale: int = 2**16, # grad scaler config initial_scale: int = 2**16, # grad scaler config
@ -49,32 +52,28 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
partition_grad: bool = False, # stage 2 flag partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload cpu_offload: bool = False, # cpu offload
forced_dtype: Optional[torch.dtype] = None, forced_dtype: Optional[torch.dtype] = None,
): ):
WARN_STR = "Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result" WARN_STR = "Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result"
if not force_overlap_comm and (overlap_communication or partition_grad): if not force_overlap_comm and (overlap_communication or partition_grad):
raise RuntimeError(WARN_STR + " If you are not sure about this, set (overlap_communication=False and partition_grad=False) or force_overlap_comm=True") raise RuntimeError(
WARN_STR
+ " If you are not sure about this, set (overlap_communication=False and partition_grad=False) or force_overlap_comm=True"
)
if force_overlap_comm: if force_overlap_comm:
overlap_communication = True overlap_communication = True
warnings.warn(WARN_STR + " Please make sure of this.") warnings.warn(WARN_STR + " Please make sure of this.")
self.param_info = param_info
self.stage_manager = model.stage_manager
self.shared_params = model.shared_params
self.dp_pg = dp_process_group
if use_pipeline:
reinitialize_optimizer(optimizer, model)
pg_param_list = { pg_param_list = {
dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())), dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
moe_dp_group: list(filter(is_moe_tensor, model.parameters())), moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
} }
super().__init__( super().__init__(
model=model,
optimizer=optimizer, optimizer=optimizer,
pg_to_param_list=pg_param_list, use_pipeline=use_pipeline,
param_info=param_info,
initial_scale=initial_scale, initial_scale=initial_scale,
min_scale=min_scale, min_scale=min_scale,
growth_factor=growth_factor, growth_factor=growth_factor,
@ -89,7 +88,12 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
overlap_communication=overlap_communication, overlap_communication=overlap_communication,
partition_grad=partition_grad, partition_grad=partition_grad,
cpu_offload=cpu_offload, cpu_offload=cpu_offload,
# dp_process_group=dp_process_group,
tp_process_group=tp_process_group,
pp_process_group=pp_process_group,
forced_dtype=forced_dtype, forced_dtype=forced_dtype,
## moe args
pg_to_param_list=pg_param_list,
) )
@ -180,7 +184,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
) )
else: else:
if not(self.dp_size > 1 or self.moe_dp_size > 1): if not (self.dp_size > 1 or self.moe_dp_size > 1):
warnings.warn( warnings.warn(
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. " "Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
"If you do not intend to use cpu_offload, please consider set zero_stage=0." "If you do not intend to use cpu_offload, please consider set zero_stage=0."
@ -192,6 +196,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
force_overlap_comm=self.force_overlap_comm, force_overlap_comm=self.force_overlap_comm,
param_info=param_info, param_info=param_info,
dp_process_group=self.dp_group, dp_process_group=self.dp_group,
tp_process_group=self.tp_group,
pp_process_group=self.pp_group,
moe_dp_group=self.moe_dp_group, moe_dp_group=self.moe_dp_group,
verbose=True, verbose=True,
clip_grad_norm=self.max_norm, clip_grad_norm=self.max_norm,

View File

@ -117,23 +117,35 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 1, "tp_size": 1,
"pp_size": 1, "pp_size": 1,
"ep_size": 1, "ep_size": 1,
"zero_stage": 2, "zero_stage": 1,
"overlap_communication": False,
"precision": "fp32", "precision": "fp32",
}, # [dp(2) + pp(2)] + [moe_dp(4)] }, # [dp(4)] + [moe_dp(4)]
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 2,
"ep_size": 1,
"zero_stage": 1,
"overlap_communication": False,
"precision": "fp32",
}, # [dp(2) + pp(2)] + [moe_pp(2)]
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 2,
"ep_size": 1,
"zero_stage": 1,
"overlap_communication": False,
"precision": "fp32",
}, # [pp(2) + tp(2)] + [pp(2), replicate(2)] pass
# { # {
# "tp_size": 1, # "tp_size": 1,
# "pp_size": 2, # "pp_size": 2,
# "num_microbatches": 2, # "num_microbatches": 2,
# "ep_size": 1, # "ep_size": 2,
# "zero_stage": 1,
# "precision": "fp32",
# }, # [dp(2) + pp(2)] + [moe_dp(4)]
# {
# "tp_size": 1,
# "pp_size": 2,
# "num_microbatches": 2,
# "ep_size": 4,
# "zero_stage": 1, # "zero_stage": 1,
# "overlap_communication": False,
# "precision": "fp32", # "precision": "fp32",
# }, # [dp(2) + pp(2)] + [ep(4))] # }, # [dp(2) + pp(2)] + [ep(4))]
# { # {
@ -141,13 +153,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# "pp_size": 1, # "pp_size": 1,
# "ep_size": 2, # "ep_size": 2,
# "zero_stage": 0, # "zero_stage": 0,
# "overlap_communication": False,
# "precision": "fp32", # "precision": "fp32",
# }, # [dp(4)] + [ep(2) + moe_tp(2)] # }, # [dp(4)] + [ep(2) + moe_tp(2)]
# { # {
# "tp_size": 1, # "tp_size": 1,
# "pp_size": 1, # "pp_size": 1,
# "ep_size": 4, # "ep_size": 4,
# "zero_stage": 0, # "overlap_communication": False,
# "zero_stage": 0,
# "precision": "fp32" # "precision": "fp32"
# }, # full dp for non-moe and full ep for moe # }, # full dp for non-moe and full ep for moe
], ],