From 2f9bce6686d1415a83d5726dc5ff02222c742582 Mon Sep 17 00:00:00 2001 From: botbw Date: Thu, 11 Jul 2024 05:50:20 +0000 Subject: [PATCH] [moe] implement submesh initialization --- .../plugin/moe_hybrid_parallel_plugin.py | 68 +++++++++++++++++-- colossalai/cluster/process_group_mesh.py | 40 +++++++---- .../test_model/test_shard_mixtral.py | 17 ++--- 3 files changed, 98 insertions(+), 27 deletions(-) diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index a02deb80d..f689fe988 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -1,6 +1,7 @@ import warnings from types import MethodType from typing import Callable, Optional, OrderedDict, Tuple +import numpy as np import torch import torch.distributed as dist @@ -64,6 +65,14 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer): overlap_communication = True warnings.warn(WARN_STR + " Please make sure of this.") + self.param_info = param_info + self.stage_manager = model.stage_manager + self.shared_params = model.shared_params + self.dp_pg = dp_process_group + + if use_pipeline: + reinitialize_optimizer(optimizer, model) + pg_param_list = { dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())), moe_dp_group: list(filter(is_moe_tensor, model.parameters())), @@ -116,17 +125,16 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): raise NotImplementedError world_size = dist.get_world_size() - - self.moe_dp_size = world_size // (ep_size * moe_tp_size) + self.moe_dp_size = world_size // (self.pp_size * ep_size * moe_tp_size) self.ep_size = ep_size self.moe_tp_size = moe_tp_size - self.moe_pg_mesh = ProcessGroupMesh(self.moe_dp_size, self.ep_size, self.moe_tp_size) - self.moe_dp_axis, self.ep_axis, self.moe_tp_axis = 0, 1, 2 + if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size != world_size: + raise ValueError( + f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}" + ) - self.moe_dp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_dp_axis) - self.ep_group = self.moe_pg_mesh.get_group_along_axis(self.ep_axis) - self.moe_tp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_tp_axis) + self._init_moe_param_comm() self.logger.info(f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}", ranks=[0]) @@ -136,6 +144,52 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): self.force_overlap_comm = force_overlap_comm + def _init_moe_param_comm(self): + self.moe_dp_group = None + self.ep_group = None + self.moe_tp_group = None + + # create submesh for ep, moe_dp, moe_tp + ranks_by_pp_stage = self.pg_mesh.get_group_along_axis( + [self.dp_axis, self.tp_axis, self.sp_axis], return_ranks_by_group=True + ) + + global_rank = self.pg_mesh.rank + pp_rank = self.pg_mesh.coordinate(self.pp_axis) + + # create groups from submesh + for stage_idx, stage_rank in enumerate(ranks_by_pp_stage): + # axis 0 is dp, axis 1 is tp, axis 2 is sp + submesh = np.array(stage_rank).reshape(self.moe_dp_size, self.ep_size, self.moe_tp_size) + + # hardcode here since we only have 3 axis + # moe_dp_group + for ep_idx in range(self.ep_size): + for moe_tp_idx in range(self.moe_tp_size): + moe_dp_ranks = submesh[:, ep_idx, moe_tp_idx].flatten().tolist() + group = dist.new_group(moe_dp_ranks) + if pp_rank == stage_idx and global_rank in moe_dp_ranks: + assert self.moe_dp_group is None + self.moe_dp_group = group + # ep_group + for moe_dp_idx in range(self.moe_dp_size): + for moe_tp_idx in range(self.moe_tp_size): + ep_ranks = submesh[moe_dp_idx, :, moe_tp_idx].flatten().tolist() + group = dist.new_group(ep_ranks) + if pp_rank == stage_idx and global_rank in ep_ranks: + assert self.ep_group is None + self.ep_group = group + # moe_tp_group + for moe_dp_idx in range(self.moe_dp_size): + for ep_idx in range(self.ep_size): + moe_tp_ranks = submesh[moe_dp_idx, ep_idx, :].flatten().tolist() + group = dist.new_group(moe_tp_ranks) + if pp_rank == stage_idx and global_rank in moe_tp_ranks: + assert self.moe_tp_group is None + self.moe_tp_group = group + + self.logger.info(f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}") + def get_checkpoint_io(self) -> MoECheckpointIO: return MoECheckpointIO( self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py index c09c7a2cc..ee9e2d71d 100644 --- a/colossalai/cluster/process_group_mesh.py +++ b/colossalai/cluster/process_group_mesh.py @@ -209,13 +209,15 @@ class ProcessGroupMesh: axis: Union[int, List[int]], indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None, backend: Optional[str] = None, - ) -> ProcessGroup: + return_ranks_by_group: bool = False + ) -> Union[ProcessGroup, List[Tuple[int, ...]]]: """Create all process groups along the given axis, and return the one which the current process belongs to. Args: axis (int): Axis along which the process groups are created. indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None. backend (Optional[str], optional): Backend of the process group. Defaults to None. + return_ranks_by_group (bool): Whether to return all ranks by group for creating submesh. Defaults to False. Returns: ProcessGroup: The process group along the given axis which the current process belongs to. @@ -235,25 +237,35 @@ class ProcessGroupMesh: # the choices on the axis are reduced to 1, since it's determined by `indices_at_axis` for ax in axis: reduced_shape[ax] = 1 - target_group = None - # use Cartesian product to generate all combinations of coordinates - for base_coord in itertools.product(*[range(s) for s in reduced_shape]): - coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis) - ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group]) - group = self._get_group(ranks_in_group, backend=backend) - if self._rank in ranks_in_group: - target_group = group - return target_group + if return_ranks_by_group: + ranks_by_group = [] + # use Cartesian product to generate all combinations of coordinates + for base_coord in itertools.product(*[range(s) for s in reduced_shape]): + coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis) + ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group]) + ranks_by_group.append(ranks_in_group) + return ranks_by_group + else: + target_group = None + # use Cartesian product to generate all combinations of coordinates + for base_coord in itertools.product(*[range(s) for s in reduced_shape]): + coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis) + ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group]) + group = self._get_group(ranks_in_group, backend=backend) + if self._rank in ranks_in_group: + target_group = group + return target_group def get_group_along_axis( - self, axis: Union[int, List[int]], indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None - ) -> ProcessGroup: + self, axis: Union[int, List[int]], indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None, return_ranks_by_group: bool = False + ) -> Union[ProcessGroup, List[Tuple[int, ...]]]: """Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created. Args: axis (int or list of int): Axes along which the process groups are created. indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None. backend (Optional[str], optional): Backend of the process group. Defaults to None. + return_ranks_by_group (bool): Whether to return all ranks by group for creating submesh. Defaults to False. Returns: ProcessGroup: The process group along the given axis which the current process belongs to. @@ -267,6 +279,10 @@ class ProcessGroupMesh: coords_in_group = ProcessGroupMesh.get_coords_along_axis(self._coord, axis, indices_at_axis) ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group]) + + if return_ranks_by_group: + return self.create_group_along_axis(axis, indices_at_axis, backend=backend, return_ranks_by_group=True) + if ranks_in_group not in self._ranks_to_group: # no need to cache it explicitly, since it will be cached in `create_group_along_axis` return self.create_group_along_axis(axis, indices_at_axis, backend=backend) diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py index 123e590c9..f268d1686 100644 --- a/tests/test_shardformer/test_model/test_shard_mixtral.py +++ b/tests/test_shardformer/test_model/test_shard_mixtral.py @@ -29,10 +29,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.Adam ) - with torch.autograd.set_detect_anomaly(True): - org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( - org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster - ) + + org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( + org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster + ) stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group @@ -115,8 +115,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, [ { "tp_size": 1, - "pp_size": 1, - "ep_size": 1, + "pp_size": 2, + "num_microbatches": 2, + "ep_size": 2, "zero_stage": 1, "overlap_communication": False, "precision": "fp32", @@ -125,7 +126,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 1, "pp_size": 2, "num_microbatches": 2, - "ep_size": 1, + "ep_size": 2, "zero_stage": 1, "overlap_communication": False, "precision": "fp32", @@ -134,7 +135,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 2, "pp_size": 2, "num_microbatches": 2, - "ep_size": 1, + "ep_size": 2, "zero_stage": 1, "overlap_communication": False, "precision": "fp32",