Browse Source

solve hang when parallel mode = pp + dp

moe_sp
haze188 4 months ago committed by hxwang
parent
commit
a613edd517
No known key found for this signature in database
GPG Key ID: EC383D418F0B9F8
  1. 5
      colossalai/booster/plugin/hybrid_parallel_plugin.py
  2. 34
      colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
  3. 36
      tests/test_shardformer/test_model/test_shard_mixtral.py

5
colossalai/booster/plugin/hybrid_parallel_plugin.py

@ -27,6 +27,7 @@ from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
@ -1068,7 +1069,9 @@ class HybridParallelPlugin(PipelinePluginBase):
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
self.logger.info(f"{type(self).__name__}: {self.pp_size=} {self.dp_size=} {self.tp_size=} {self.sp_size=}", ranks=[0])
self.logger.info(
f"{type(self).__name__}: {self.pp_size=} {self.dp_size=} {self.tp_size=} {self.sp_size=}", ranks=[0]
)
self.stage_manager = None
self.schedule = None

34
colossalai/booster/plugin/moe_hybrid_parallel_plugin.py

@ -15,6 +15,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
HybridParallelModule,
HybridParallelNaiveOptimizer,
HybridParallelPlugin,
HybridParallelZeroOptimizer,
get_param_info,
reinitialize_optimizer,
)
@ -22,16 +23,18 @@ from colossalai.checkpoint_io import MoECheckpointIO
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.tensor.moe_tensor.api import is_moe_tensor
from colossalai.zero.low_level import LowLevelZeroOptimizer
class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
def __init__(
self,
optimizer: Optimizer,
model: Module,
use_pipeline: bool,
force_overlap_comm: bool, # force overlap comm
dp_process_group: ProcessGroup, # dp pg for comm
dp_process_group: Optional[ProcessGroup], # the dp pg for comm
tp_process_group: Optional[ProcessGroup], # if using tp
pp_process_group: Optional[ProcessGroup], # if using pp
moe_dp_group: ProcessGroup, # moe dp pg for comm
param_info: OrderedDict,
initial_scale: int = 2**16, # grad scaler config
@ -50,31 +53,27 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
cpu_offload: bool = False, # cpu offload
forced_dtype: Optional[torch.dtype] = None,
):
WARN_STR = "Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result"
if not force_overlap_comm and (overlap_communication or partition_grad):
raise RuntimeError(WARN_STR + " If you are not sure about this, set (overlap_communication=False and partition_grad=False) or force_overlap_comm=True")
raise RuntimeError(
WARN_STR
+ " If you are not sure about this, set (overlap_communication=False and partition_grad=False) or force_overlap_comm=True"
)
if force_overlap_comm:
overlap_communication = True
warnings.warn(WARN_STR + " Please make sure of this.")
self.param_info = param_info
self.stage_manager = model.stage_manager
self.shared_params = model.shared_params
self.dp_pg = dp_process_group
if use_pipeline:
reinitialize_optimizer(optimizer, model)
pg_param_list = {
dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
}
super().__init__(
model=model,
optimizer=optimizer,
pg_to_param_list=pg_param_list,
use_pipeline=use_pipeline,
param_info=param_info,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
@ -89,7 +88,12 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
overlap_communication=overlap_communication,
partition_grad=partition_grad,
cpu_offload=cpu_offload,
# dp_process_group=dp_process_group,
tp_process_group=tp_process_group,
pp_process_group=pp_process_group,
forced_dtype=forced_dtype,
## moe args
pg_to_param_list=pg_param_list,
)
@ -192,6 +196,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
force_overlap_comm=self.force_overlap_comm,
param_info=param_info,
dp_process_group=self.dp_group,
tp_process_group=self.tp_group,
pp_process_group=self.pp_group,
moe_dp_group=self.moe_dp_group,
verbose=True,
clip_grad_norm=self.max_norm,

36
tests/test_shardformer/test_model/test_shard_mixtral.py

@ -117,23 +117,35 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 1,
"pp_size": 1,
"ep_size": 1,
"zero_stage": 2,
"zero_stage": 1,
"overlap_communication": False,
"precision": "fp32",
}, # [dp(2) + pp(2)] + [moe_dp(4)]
# {
# "tp_size": 1,
# "pp_size": 2,
# "num_microbatches": 2,
# "ep_size": 1,
# "zero_stage": 1,
# "precision": "fp32",
# }, # [dp(2) + pp(2)] + [moe_dp(4)]
}, # [dp(4)] + [moe_dp(4)]
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 2,
"ep_size": 1,
"zero_stage": 1,
"overlap_communication": False,
"precision": "fp32",
}, # [dp(2) + pp(2)] + [moe_pp(2)]
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 2,
"ep_size": 1,
"zero_stage": 1,
"overlap_communication": False,
"precision": "fp32",
}, # [pp(2) + tp(2)] + [pp(2), replicate(2)] pass
# {
# "tp_size": 1,
# "pp_size": 2,
# "num_microbatches": 2,
# "ep_size": 4,
# "ep_size": 2,
# "zero_stage": 1,
# "overlap_communication": False,
# "precision": "fp32",
# }, # [dp(2) + pp(2)] + [ep(4))]
# {
@ -141,12 +153,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# "pp_size": 1,
# "ep_size": 2,
# "zero_stage": 0,
# "overlap_communication": False,
# "precision": "fp32",
# }, # [dp(4)] + [ep(2) + moe_tp(2)]
# {
# "tp_size": 1,
# "pp_size": 1,
# "ep_size": 4,
# "overlap_communication": False,
# "zero_stage": 0,
# "precision": "fp32"
# }, # full dp for non-moe and full ep for moe

Loading…
Cancel
Save