mirror of https://github.com/hpcaitech/ColossalAI
solve hang when parallel mode = pp + dp
parent
fe24789eb1
commit
5ed5e8cfba
|
@ -27,6 +27,7 @@ from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
|
|||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||
from colossalai.interface.optimizer import DistributedOptim
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
||||
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
@ -1068,8 +1069,10 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
self.pp_axis, self.dp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size, self.sp_size)
|
||||
|
||||
self.logger.info(f"{type(self).__name__}: {self.pp_size=} {self.dp_size=} {self.tp_size=} {self.sp_size=}", ranks=[0])
|
||||
|
||||
self.logger.info(
|
||||
f"{type(self).__name__}: {self.pp_size=} {self.dp_size=} {self.tp_size=} {self.sp_size=}", ranks=[0]
|
||||
)
|
||||
|
||||
self.stage_manager = None
|
||||
self.schedule = None
|
||||
self.custom_policy = custom_policy
|
||||
|
|
|
@ -15,6 +15,7 @@ from colossalai.booster.plugin.hybrid_parallel_plugin import (
|
|||
HybridParallelModule,
|
||||
HybridParallelNaiveOptimizer,
|
||||
HybridParallelPlugin,
|
||||
HybridParallelZeroOptimizer,
|
||||
get_param_info,
|
||||
reinitialize_optimizer,
|
||||
)
|
||||
|
@ -22,16 +23,18 @@ from colossalai.checkpoint_io import MoECheckpointIO
|
|||
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||
|
||||
class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
|
||||
class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
model: Module,
|
||||
use_pipeline: bool,
|
||||
force_overlap_comm: bool, # force overlap comm
|
||||
dp_process_group: ProcessGroup, # dp pg for comm
|
||||
dp_process_group: Optional[ProcessGroup], # the dp pg for comm
|
||||
tp_process_group: Optional[ProcessGroup], # if using tp
|
||||
pp_process_group: Optional[ProcessGroup], # if using pp
|
||||
moe_dp_group: ProcessGroup, # moe dp pg for comm
|
||||
param_info: OrderedDict,
|
||||
initial_scale: int = 2**16, # grad scaler config
|
||||
|
@ -49,32 +52,28 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
partition_grad: bool = False, # stage 2 flag
|
||||
cpu_offload: bool = False, # cpu offload
|
||||
forced_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
|
||||
):
|
||||
WARN_STR = "Note that you need to make sure every expert are routed (i.e.) every expert has backward, otherwise this might lead to program hang or inconsistent result"
|
||||
if not force_overlap_comm and (overlap_communication or partition_grad):
|
||||
raise RuntimeError(WARN_STR + " If you are not sure about this, set (overlap_communication=False and partition_grad=False) or force_overlap_comm=True")
|
||||
|
||||
raise RuntimeError(
|
||||
WARN_STR
|
||||
+ " If you are not sure about this, set (overlap_communication=False and partition_grad=False) or force_overlap_comm=True"
|
||||
)
|
||||
|
||||
if force_overlap_comm:
|
||||
overlap_communication = True
|
||||
warnings.warn(WARN_STR + " Please make sure of this.")
|
||||
|
||||
self.param_info = param_info
|
||||
self.stage_manager = model.stage_manager
|
||||
self.shared_params = model.shared_params
|
||||
self.dp_pg = dp_process_group
|
||||
|
||||
if use_pipeline:
|
||||
reinitialize_optimizer(optimizer, model)
|
||||
|
||||
pg_param_list = {
|
||||
dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
|
||||
moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
|
||||
}
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
pg_to_param_list=pg_param_list,
|
||||
use_pipeline=use_pipeline,
|
||||
param_info=param_info,
|
||||
initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
|
@ -89,7 +88,12 @@ class MoeHybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
overlap_communication=overlap_communication,
|
||||
partition_grad=partition_grad,
|
||||
cpu_offload=cpu_offload,
|
||||
# dp_process_group=dp_process_group,
|
||||
tp_process_group=tp_process_group,
|
||||
pp_process_group=pp_process_group,
|
||||
forced_dtype=forced_dtype,
|
||||
## moe args
|
||||
pg_to_param_list=pg_param_list,
|
||||
)
|
||||
|
||||
|
||||
|
@ -180,7 +184,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
|
||||
)
|
||||
else:
|
||||
if not(self.dp_size > 1 or self.moe_dp_size > 1):
|
||||
if not (self.dp_size > 1 or self.moe_dp_size > 1):
|
||||
warnings.warn(
|
||||
"Use Zero Optimizer when data parallel size is 1 may introduce unnecessary overhead. "
|
||||
"If you do not intend to use cpu_offload, please consider set zero_stage=0."
|
||||
|
@ -192,6 +196,8 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
force_overlap_comm=self.force_overlap_comm,
|
||||
param_info=param_info,
|
||||
dp_process_group=self.dp_group,
|
||||
tp_process_group=self.tp_group,
|
||||
pp_process_group=self.pp_group,
|
||||
moe_dp_group=self.moe_dp_group,
|
||||
verbose=True,
|
||||
clip_grad_norm=self.max_norm,
|
||||
|
|
|
@ -117,23 +117,35 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
"tp_size": 1,
|
||||
"pp_size": 1,
|
||||
"ep_size": 1,
|
||||
"zero_stage": 2,
|
||||
"zero_stage": 1,
|
||||
"overlap_communication": False,
|
||||
"precision": "fp32",
|
||||
}, # [dp(2) + pp(2)] + [moe_dp(4)]
|
||||
}, # [dp(4)] + [moe_dp(4)]
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"ep_size": 1,
|
||||
"zero_stage": 1,
|
||||
"overlap_communication": False,
|
||||
"precision": "fp32",
|
||||
}, # [dp(2) + pp(2)] + [moe_pp(2)]
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"ep_size": 1,
|
||||
"zero_stage": 1,
|
||||
"overlap_communication": False,
|
||||
"precision": "fp32",
|
||||
}, # [pp(2) + tp(2)] + [pp(2), replicate(2)] pass
|
||||
# {
|
||||
# "tp_size": 1,
|
||||
# "pp_size": 2,
|
||||
# "num_microbatches": 2,
|
||||
# "ep_size": 1,
|
||||
# "zero_stage": 1,
|
||||
# "precision": "fp32",
|
||||
# }, # [dp(2) + pp(2)] + [moe_dp(4)]
|
||||
# {
|
||||
# "tp_size": 1,
|
||||
# "pp_size": 2,
|
||||
# "num_microbatches": 2,
|
||||
# "ep_size": 4,
|
||||
# "ep_size": 2,
|
||||
# "zero_stage": 1,
|
||||
# "overlap_communication": False,
|
||||
# "precision": "fp32",
|
||||
# }, # [dp(2) + pp(2)] + [ep(4))]
|
||||
# {
|
||||
|
@ -141,13 +153,15 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
# "pp_size": 1,
|
||||
# "ep_size": 2,
|
||||
# "zero_stage": 0,
|
||||
# "overlap_communication": False,
|
||||
# "precision": "fp32",
|
||||
# }, # [dp(4)] + [ep(2) + moe_tp(2)]
|
||||
# {
|
||||
# "tp_size": 1,
|
||||
# "pp_size": 1,
|
||||
# "ep_size": 4,
|
||||
# "zero_stage": 0,
|
||||
# "tp_size": 1,
|
||||
# "pp_size": 1,
|
||||
# "ep_size": 4,
|
||||
# "overlap_communication": False,
|
||||
# "zero_stage": 0,
|
||||
# "precision": "fp32"
|
||||
# }, # full dp for non-moe and full ep for moe
|
||||
],
|
||||
|
|
Loading…
Reference in New Issue