mirror of https://github.com/hpcaitech/ColossalAI
[moe] remove force_overlap_comm flag and add warning instead
parent
f7c5485ed6
commit
7bedd03739
|
@ -42,7 +42,6 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
|
|||
optimizer: Optimizer,
|
||||
model: Module,
|
||||
use_pipeline: bool,
|
||||
force_overlap_comm: bool, # force overlap comm
|
||||
dp_process_group: Optional[ProcessGroup], # the dp pg for comm
|
||||
tp_process_group: Optional[ProcessGroup], # if using tp
|
||||
pp_process_group: Optional[ProcessGroup], # if using pp
|
||||
|
@ -65,17 +64,6 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
|
|||
forced_dtype: Optional[torch.dtype] = None,
|
||||
overlap_allgather: bool = False,
|
||||
):
|
||||
WARN_STR = "Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result."
|
||||
if not force_overlap_comm and (overlap_communication or partition_grad):
|
||||
raise RuntimeError(
|
||||
WARN_STR
|
||||
+ " If you are not sure about this, set (overlap_communication=False and partition_grad=False) or force_overlap_comm=True"
|
||||
)
|
||||
|
||||
if force_overlap_comm:
|
||||
overlap_communication = True
|
||||
warnings.warn(WARN_STR + " Please make sure of this.")
|
||||
|
||||
pg_param_list = {
|
||||
dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
|
||||
moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
|
||||
|
@ -116,9 +104,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
Modified from colossalai.booster.plugin.hybrid_parallel_plugin.HybridParallelPlugin
|
||||
Extra Args:
|
||||
ep_size (int): The size of expert parallelism
|
||||
force_overlap_comm (bool):
|
||||
For LowLevelZeroOptimizer, it might causes program hang when some experts are routed and overlap_communication is True during training.
|
||||
This flag is used to force overlap_communication=True. Make sure every expert are routed when you use this.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -167,8 +152,15 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
dp_outside: bool = True,
|
||||
overlap_p2p: bool = True,
|
||||
overlap_allgather: bool = False,
|
||||
force_overlap_comm: bool = False,
|
||||
) -> None:
|
||||
if overlap_communication or zero_stage == 2:
|
||||
overlap_communication = False
|
||||
zero_stage = 1
|
||||
warnings.warn(
|
||||
f"overlap_communication and zero_stage are set to False and 1 because "
|
||||
f"ZeRO-2 or comm overlap cause program hang when some experts are not routed. "
|
||||
)
|
||||
|
||||
assert (
|
||||
dist.get_world_size() % (tp_size * pp_size) == 0
|
||||
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
||||
|
@ -326,7 +318,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
)
|
||||
|
||||
self.max_norm = max_norm
|
||||
self.force_overlap_comm = force_overlap_comm
|
||||
|
||||
def get_checkpoint_io(self) -> MoECheckpointIO:
|
||||
return MoECheckpointIO(
|
||||
|
@ -421,7 +412,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
optimizer,
|
||||
model,
|
||||
use_pipeline=self.enable_pipeline_parallelism,
|
||||
force_overlap_comm=self.force_overlap_comm,
|
||||
param_info=param_info,
|
||||
dp_process_group=dp_group,
|
||||
tp_process_group=self.tp_group,
|
||||
|
|
Loading…
Reference in New Issue