mirror of https://github.com/hpcaitech/ColossalAI
[moe] remove force_overlap_comm flag and add warning instead
parent
f7c5485ed6
commit
7bedd03739
|
@ -42,7 +42,6 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
|
||||||
optimizer: Optimizer,
|
optimizer: Optimizer,
|
||||||
model: Module,
|
model: Module,
|
||||||
use_pipeline: bool,
|
use_pipeline: bool,
|
||||||
force_overlap_comm: bool, # force overlap comm
|
|
||||||
dp_process_group: Optional[ProcessGroup], # the dp pg for comm
|
dp_process_group: Optional[ProcessGroup], # the dp pg for comm
|
||||||
tp_process_group: Optional[ProcessGroup], # if using tp
|
tp_process_group: Optional[ProcessGroup], # if using tp
|
||||||
pp_process_group: Optional[ProcessGroup], # if using pp
|
pp_process_group: Optional[ProcessGroup], # if using pp
|
||||||
|
@ -65,17 +64,6 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
|
||||||
forced_dtype: Optional[torch.dtype] = None,
|
forced_dtype: Optional[torch.dtype] = None,
|
||||||
overlap_allgather: bool = False,
|
overlap_allgather: bool = False,
|
||||||
):
|
):
|
||||||
WARN_STR = "Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result."
|
|
||||||
if not force_overlap_comm and (overlap_communication or partition_grad):
|
|
||||||
raise RuntimeError(
|
|
||||||
WARN_STR
|
|
||||||
+ " If you are not sure about this, set (overlap_communication=False and partition_grad=False) or force_overlap_comm=True"
|
|
||||||
)
|
|
||||||
|
|
||||||
if force_overlap_comm:
|
|
||||||
overlap_communication = True
|
|
||||||
warnings.warn(WARN_STR + " Please make sure of this.")
|
|
||||||
|
|
||||||
pg_param_list = {
|
pg_param_list = {
|
||||||
dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
|
dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
|
||||||
moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
|
moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
|
||||||
|
@ -116,9 +104,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
Modified from colossalai.booster.plugin.hybrid_parallel_plugin.HybridParallelPlugin
|
Modified from colossalai.booster.plugin.hybrid_parallel_plugin.HybridParallelPlugin
|
||||||
Extra Args:
|
Extra Args:
|
||||||
ep_size (int): The size of expert parallelism
|
ep_size (int): The size of expert parallelism
|
||||||
force_overlap_comm (bool):
|
|
||||||
For LowLevelZeroOptimizer, it might causes program hang when some experts are routed and overlap_communication is True during training.
|
|
||||||
This flag is used to force overlap_communication=True. Make sure every expert are routed when you use this.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -167,8 +152,15 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
dp_outside: bool = True,
|
dp_outside: bool = True,
|
||||||
overlap_p2p: bool = True,
|
overlap_p2p: bool = True,
|
||||||
overlap_allgather: bool = False,
|
overlap_allgather: bool = False,
|
||||||
force_overlap_comm: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if overlap_communication or zero_stage == 2:
|
||||||
|
overlap_communication = False
|
||||||
|
zero_stage = 1
|
||||||
|
warnings.warn(
|
||||||
|
f"overlap_communication and zero_stage are set to False and 1 because "
|
||||||
|
f"ZeRO-2 or comm overlap cause program hang when some experts are not routed. "
|
||||||
|
)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
dist.get_world_size() % (tp_size * pp_size) == 0
|
dist.get_world_size() % (tp_size * pp_size) == 0
|
||||||
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
), f"World size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
|
||||||
|
@ -326,7 +318,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.max_norm = max_norm
|
self.max_norm = max_norm
|
||||||
self.force_overlap_comm = force_overlap_comm
|
|
||||||
|
|
||||||
def get_checkpoint_io(self) -> MoECheckpointIO:
|
def get_checkpoint_io(self) -> MoECheckpointIO:
|
||||||
return MoECheckpointIO(
|
return MoECheckpointIO(
|
||||||
|
@ -421,7 +412,6 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
optimizer,
|
optimizer,
|
||||||
model,
|
model,
|
||||||
use_pipeline=self.enable_pipeline_parallelism,
|
use_pipeline=self.enable_pipeline_parallelism,
|
||||||
force_overlap_comm=self.force_overlap_comm,
|
|
||||||
param_info=param_info,
|
param_info=param_info,
|
||||||
dp_process_group=dp_group,
|
dp_process_group=dp_group,
|
||||||
tp_process_group=self.tp_group,
|
tp_process_group=self.tp_group,
|
||||||
|
|
Loading…
Reference in New Issue