Browse Source

[moe] implement submesh initialization

moe_sp
botbw 4 months ago committed by hxwang
parent
commit
2f9bce6686
No known key found for this signature in database
GPG Key ID: EC383D418F0B9F8
  1. 68
      colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
  2. 40
      colossalai/cluster/process_group_mesh.py
  3. 17
      tests/test_shardformer/test_model/test_shard_mixtral.py

68
colossalai/booster/plugin/moe_hybrid_parallel_plugin.py

@ -1,6 +1,7 @@
import warnings import warnings
from types import MethodType from types import MethodType
from typing import Callable, Optional, OrderedDict, Tuple from typing import Callable, Optional, OrderedDict, Tuple
import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -64,6 +65,14 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
overlap_communication = True overlap_communication = True
warnings.warn(WARN_STR + " Please make sure of this.") warnings.warn(WARN_STR + " Please make sure of this.")
self.param_info = param_info
self.stage_manager = model.stage_manager
self.shared_params = model.shared_params
self.dp_pg = dp_process_group
if use_pipeline:
reinitialize_optimizer(optimizer, model)
pg_param_list = { pg_param_list = {
dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())), dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
moe_dp_group: list(filter(is_moe_tensor, model.parameters())), moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
@ -116,17 +125,16 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
raise NotImplementedError raise NotImplementedError
world_size = dist.get_world_size() world_size = dist.get_world_size()
self.moe_dp_size = world_size // (self.pp_size * ep_size * moe_tp_size)
self.moe_dp_size = world_size // (ep_size * moe_tp_size)
self.ep_size = ep_size self.ep_size = ep_size
self.moe_tp_size = moe_tp_size self.moe_tp_size = moe_tp_size
self.moe_pg_mesh = ProcessGroupMesh(self.moe_dp_size, self.ep_size, self.moe_tp_size) if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size != world_size:
self.moe_dp_axis, self.ep_axis, self.moe_tp_axis = 0, 1, 2 raise ValueError(
f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}"
)
self.moe_dp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_dp_axis) self._init_moe_param_comm()
self.ep_group = self.moe_pg_mesh.get_group_along_axis(self.ep_axis)
self.moe_tp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_tp_axis)
self.logger.info(f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}", ranks=[0]) self.logger.info(f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}", ranks=[0])
@ -136,6 +144,52 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.force_overlap_comm = force_overlap_comm self.force_overlap_comm = force_overlap_comm
def _init_moe_param_comm(self):
self.moe_dp_group = None
self.ep_group = None
self.moe_tp_group = None
# create submesh for ep, moe_dp, moe_tp
ranks_by_pp_stage = self.pg_mesh.get_group_along_axis(
[self.dp_axis, self.tp_axis, self.sp_axis], return_ranks_by_group=True
)
global_rank = self.pg_mesh.rank
pp_rank = self.pg_mesh.coordinate(self.pp_axis)
# create groups from submesh
for stage_idx, stage_rank in enumerate(ranks_by_pp_stage):
# axis 0 is dp, axis 1 is tp, axis 2 is sp
submesh = np.array(stage_rank).reshape(self.moe_dp_size, self.ep_size, self.moe_tp_size)
# hardcode here since we only have 3 axis
# moe_dp_group
for ep_idx in range(self.ep_size):
for moe_tp_idx in range(self.moe_tp_size):
moe_dp_ranks = submesh[:, ep_idx, moe_tp_idx].flatten().tolist()
group = dist.new_group(moe_dp_ranks)
if pp_rank == stage_idx and global_rank in moe_dp_ranks:
assert self.moe_dp_group is None
self.moe_dp_group = group
# ep_group
for moe_dp_idx in range(self.moe_dp_size):
for moe_tp_idx in range(self.moe_tp_size):
ep_ranks = submesh[moe_dp_idx, :, moe_tp_idx].flatten().tolist()
group = dist.new_group(ep_ranks)
if pp_rank == stage_idx and global_rank in ep_ranks:
assert self.ep_group is None
self.ep_group = group
# moe_tp_group
for moe_dp_idx in range(self.moe_dp_size):
for ep_idx in range(self.ep_size):
moe_tp_ranks = submesh[moe_dp_idx, ep_idx, :].flatten().tolist()
group = dist.new_group(moe_tp_ranks)
if pp_rank == stage_idx and global_rank in moe_tp_ranks:
assert self.moe_tp_group is None
self.moe_tp_group = group
self.logger.info(f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}")
def get_checkpoint_io(self) -> MoECheckpointIO: def get_checkpoint_io(self) -> MoECheckpointIO:
return MoECheckpointIO( return MoECheckpointIO(
self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage

40
colossalai/cluster/process_group_mesh.py

@ -209,13 +209,15 @@ class ProcessGroupMesh:
axis: Union[int, List[int]], axis: Union[int, List[int]],
indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None, indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None,
backend: Optional[str] = None, backend: Optional[str] = None,
) -> ProcessGroup: return_ranks_by_group: bool = False
) -> Union[ProcessGroup, List[Tuple[int, ...]]]:
"""Create all process groups along the given axis, and return the one which the current process belongs to. """Create all process groups along the given axis, and return the one which the current process belongs to.
Args: Args:
axis (int): Axis along which the process groups are created. axis (int): Axis along which the process groups are created.
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None. indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
backend (Optional[str], optional): Backend of the process group. Defaults to None. backend (Optional[str], optional): Backend of the process group. Defaults to None.
return_ranks_by_group (bool): Whether to return all ranks by group for creating submesh. Defaults to False.
Returns: Returns:
ProcessGroup: The process group along the given axis which the current process belongs to. ProcessGroup: The process group along the given axis which the current process belongs to.
@ -235,25 +237,35 @@ class ProcessGroupMesh:
# the choices on the axis are reduced to 1, since it's determined by `indices_at_axis` # the choices on the axis are reduced to 1, since it's determined by `indices_at_axis`
for ax in axis: for ax in axis:
reduced_shape[ax] = 1 reduced_shape[ax] = 1
target_group = None if return_ranks_by_group:
# use Cartesian product to generate all combinations of coordinates ranks_by_group = []
for base_coord in itertools.product(*[range(s) for s in reduced_shape]): # use Cartesian product to generate all combinations of coordinates
coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis) for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group]) coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
group = self._get_group(ranks_in_group, backend=backend) ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
if self._rank in ranks_in_group: ranks_by_group.append(ranks_in_group)
target_group = group return ranks_by_group
return target_group else:
target_group = None
# use Cartesian product to generate all combinations of coordinates
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
group = self._get_group(ranks_in_group, backend=backend)
if self._rank in ranks_in_group:
target_group = group
return target_group
def get_group_along_axis( def get_group_along_axis(
self, axis: Union[int, List[int]], indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None self, axis: Union[int, List[int]], indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None, return_ranks_by_group: bool = False
) -> ProcessGroup: ) -> Union[ProcessGroup, List[Tuple[int, ...]]]:
"""Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created. """Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created.
Args: Args:
axis (int or list of int): Axes along which the process groups are created. axis (int or list of int): Axes along which the process groups are created.
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None. indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
backend (Optional[str], optional): Backend of the process group. Defaults to None. backend (Optional[str], optional): Backend of the process group. Defaults to None.
return_ranks_by_group (bool): Whether to return all ranks by group for creating submesh. Defaults to False.
Returns: Returns:
ProcessGroup: The process group along the given axis which the current process belongs to. ProcessGroup: The process group along the given axis which the current process belongs to.
@ -267,6 +279,10 @@ class ProcessGroupMesh:
coords_in_group = ProcessGroupMesh.get_coords_along_axis(self._coord, axis, indices_at_axis) coords_in_group = ProcessGroupMesh.get_coords_along_axis(self._coord, axis, indices_at_axis)
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group]) ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
if return_ranks_by_group:
return self.create_group_along_axis(axis, indices_at_axis, backend=backend, return_ranks_by_group=True)
if ranks_in_group not in self._ranks_to_group: if ranks_in_group not in self._ranks_to_group:
# no need to cache it explicitly, since it will be cached in `create_group_along_axis` # no need to cache it explicitly, since it will be cached in `create_group_along_axis`
return self.create_group_along_axis(axis, indices_at_axis, backend=backend) return self.create_group_along_axis(axis, indices_at_axis, backend=backend)

17
tests/test_shardformer/test_model/test_shard_mixtral.py

@ -29,10 +29,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin( org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.Adam model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.Adam
) )
with torch.autograd.set_detect_anomaly(True):
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
) )
stage_manager = booster.plugin.stage_manager stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group tp_group = booster.plugin.tp_group
@ -115,8 +115,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
[ [
{ {
"tp_size": 1, "tp_size": 1,
"pp_size": 1, "pp_size": 2,
"ep_size": 1, "num_microbatches": 2,
"ep_size": 2,
"zero_stage": 1, "zero_stage": 1,
"overlap_communication": False, "overlap_communication": False,
"precision": "fp32", "precision": "fp32",
@ -125,7 +126,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 1, "tp_size": 1,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 2, "num_microbatches": 2,
"ep_size": 1, "ep_size": 2,
"zero_stage": 1, "zero_stage": 1,
"overlap_communication": False, "overlap_communication": False,
"precision": "fp32", "precision": "fp32",
@ -134,7 +135,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 2, "tp_size": 2,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 2, "num_microbatches": 2,
"ep_size": 1, "ep_size": 2,
"zero_stage": 1, "zero_stage": 1,
"overlap_communication": False, "overlap_communication": False,
"precision": "fp32", "precision": "fp32",

Loading…
Cancel
Save