[moe] remove force_overlap_comm flag and add warning instead

colossalchat
hxwang 4 months ago committed by Hongxin Liu
parent f7c5485ed6
commit 7bedd03739

@ -42,7 +42,6 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
optimizer: Optimizer, optimizer: Optimizer,
model: Module, model: Module,
use_pipeline: bool, use_pipeline: bool,
force_overlap_comm: bool, # force overlap comm
dp_process_group: Optional[ProcessGroup], # the dp pg for comm dp_process_group: Optional[ProcessGroup], # the dp pg for comm
tp_process_group: Optional[ProcessGroup], # if using tp tp_process_group: Optional[ProcessGroup], # if using tp
pp_process_group: Optional[ProcessGroup], # if using pp pp_process_group: Optional[ProcessGroup], # if using pp
@ -65,17 +64,6 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
forced_dtype: Optional[torch.dtype] = None, forced_dtype: Optional[torch.dtype] = None,
overlap_allgather: bool = False, overlap_allgather: bool = False,
): ):
WARN_STR = "Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result."
if not force_overlap_comm and (overlap_communication or partition_grad):
raise RuntimeError(
WARN_STR
+ " If you are not sure about this, set (overlap_communication=False and partition_grad=False) or force_overlap_comm=True"
)
if force_overlap_comm:
overlap_communication = True
warnings.warn(WARN_STR + " Please make sure of this.")
pg_param_list = { pg_param_list = {
dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())), dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
moe_dp_group: list(filter(is_moe_tensor, model.parameters())), moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
@ -116,9 +104,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
Modified from colossalai.booster.plugin.hybrid_parallel_plugin.HybridParallelPlugin Modified from colossalai.booster.plugin.hybrid_parallel_plugin.HybridParallelPlugin
Extra Args: Extra Args:
ep_size (int): The size of expert parallelism ep_size (int): The size of expert parallelism
force_overlap_comm (bool):
For LowLevelZeroOptimizer, it might causes program hang when some experts are routed and overlap_communication is True during training.
This flag is used to force overlap_communication=True. Make sure every expert are routed when you use this.
""" """
def __init__( def __init__(
@ -167,8 +152,15 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
dp_outside: bool = True, dp_outside: bool = True,
overlap_p2p: bool = True, overlap_p2p: bool = True,
overlap_allgather: bool = False, overlap_allgather: bool = False,
force_overlap_comm: bool = False,
) -> None: ) -> None:
if overlap_communication or zero_stage == 2:
overlap_communication = False
zero_stage = 1
warnings.warn(
f"overlap_communication and zero_stage are set to False and 1 because "
f"ZeRO-2 or comm overlap cause program hang when some experts are not routed. "
)
assert ( assert (
dist.get_world_size() % (tp_size * pp_size) == 0 dist.get_world_size() % (tp_size * pp_size) == 0
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}" ), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
@ -326,7 +318,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
) )
self.max_norm = max_norm self.max_norm = max_norm
self.force_overlap_comm = force_overlap_comm
def get_checkpoint_io(self) -> MoECheckpointIO: def get_checkpoint_io(self) -> MoECheckpointIO:
return MoECheckpointIO( return MoECheckpointIO(
@ -421,7 +412,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
optimizer, optimizer,
model, model,
use_pipeline=self.enable_pipeline_parallelism, use_pipeline=self.enable_pipeline_parallelism,
force_overlap_comm=self.force_overlap_comm,
param_info=param_info, param_info=param_info,
dp_process_group=dp_group, dp_process_group=dp_group,
tp_process_group=self.tp_group, tp_process_group=self.tp_group,

Loading…
Cancel
Save