[moe] implement submesh initialization

colossalchat
botbw 5 months ago committed by Hongxin Liu
parent 5ed5e8cfba
commit e28e05345b

@ -1,6 +1,7 @@
import warnings
from types import MethodType
from typing import Callable, Optional, OrderedDict, Tuple
import numpy as np
import torch
import torch.distributed as dist
@ -64,6 +65,14 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
overlap_communication = True
warnings.warn(WARN_STR + " Please make sure of this.")
self.param_info = param_info
self.stage_manager = model.stage_manager
self.shared_params = model.shared_params
self.dp_pg = dp_process_group
if use_pipeline:
reinitialize_optimizer(optimizer, model)
pg_param_list = {
dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
@ -116,17 +125,16 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
raise NotImplementedError
world_size = dist.get_world_size()
self.moe_dp_size = world_size // (ep_size * moe_tp_size)
self.moe_dp_size = world_size // (self.pp_size * ep_size * moe_tp_size)
self.ep_size = ep_size
self.moe_tp_size = moe_tp_size
self.moe_pg_mesh = ProcessGroupMesh(self.moe_dp_size, self.ep_size, self.moe_tp_size)
self.moe_dp_axis, self.ep_axis, self.moe_tp_axis = 0, 1, 2
if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size != world_size:
raise ValueError(
f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}"
)
self.moe_dp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_dp_axis)
self.ep_group = self.moe_pg_mesh.get_group_along_axis(self.ep_axis)
self.moe_tp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_tp_axis)
self._init_moe_param_comm()
self.logger.info(f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}", ranks=[0])
@ -136,6 +144,52 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.force_overlap_comm = force_overlap_comm
def _init_moe_param_comm(self):
self.moe_dp_group = None
self.ep_group = None
self.moe_tp_group = None
# create submesh for ep, moe_dp, moe_tp
ranks_by_pp_stage = self.pg_mesh.get_group_along_axis(
[self.dp_axis, self.tp_axis, self.sp_axis], return_ranks_by_group=True
)
global_rank = self.pg_mesh.rank
pp_rank = self.pg_mesh.coordinate(self.pp_axis)
# create groups from submesh
for stage_idx, stage_rank in enumerate(ranks_by_pp_stage):
# axis 0 is dp, axis 1 is tp, axis 2 is sp
submesh = np.array(stage_rank).reshape(self.moe_dp_size, self.ep_size, self.moe_tp_size)
# hardcode here since we only have 3 axis
# moe_dp_group
for ep_idx in range(self.ep_size):
for moe_tp_idx in range(self.moe_tp_size):
moe_dp_ranks = submesh[:, ep_idx, moe_tp_idx].flatten().tolist()
group = dist.new_group(moe_dp_ranks)
if pp_rank == stage_idx and global_rank in moe_dp_ranks:
assert self.moe_dp_group is None
self.moe_dp_group = group
# ep_group
for moe_dp_idx in range(self.moe_dp_size):
for moe_tp_idx in range(self.moe_tp_size):
ep_ranks = submesh[moe_dp_idx, :, moe_tp_idx].flatten().tolist()
group = dist.new_group(ep_ranks)
if pp_rank == stage_idx and global_rank in ep_ranks:
assert self.ep_group is None
self.ep_group = group
# moe_tp_group
for moe_dp_idx in range(self.moe_dp_size):
for ep_idx in range(self.ep_size):
moe_tp_ranks = submesh[moe_dp_idx, ep_idx, :].flatten().tolist()
group = dist.new_group(moe_tp_ranks)
if pp_rank == stage_idx and global_rank in moe_tp_ranks:
assert self.moe_tp_group is None
self.moe_tp_group = group
self.logger.info(f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}")
def get_checkpoint_io(self) -> MoECheckpointIO:
return MoECheckpointIO(
self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage

@ -209,13 +209,15 @@ class ProcessGroupMesh:
axis: Union[int, List[int]],
indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None,
backend: Optional[str] = None,
) -> ProcessGroup:
return_ranks_by_group: bool = False
) -> Union[ProcessGroup, List[Tuple[int, ...]]]:
"""Create all process groups along the given axis, and return the one which the current process belongs to.
Args:
axis (int): Axis along which the process groups are created.
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
backend (Optional[str], optional): Backend of the process group. Defaults to None.
return_ranks_by_group (bool): Whether to return all ranks by group for creating submesh. Defaults to False.
Returns:
ProcessGroup: The process group along the given axis which the current process belongs to.
@ -235,6 +237,15 @@ class ProcessGroupMesh:
# the choices on the axis are reduced to 1, since it's determined by `indices_at_axis`
for ax in axis:
reduced_shape[ax] = 1
if return_ranks_by_group:
ranks_by_group = []
# use Cartesian product to generate all combinations of coordinates
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
ranks_by_group.append(ranks_in_group)
return ranks_by_group
else:
target_group = None
# use Cartesian product to generate all combinations of coordinates
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
@ -246,14 +257,15 @@ class ProcessGroupMesh:
return target_group
def get_group_along_axis(
self, axis: Union[int, List[int]], indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
) -> ProcessGroup:
self, axis: Union[int, List[int]], indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None, return_ranks_by_group: bool = False
) -> Union[ProcessGroup, List[Tuple[int, ...]]]:
"""Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created.
Args:
axis (int or list of int): Axes along which the process groups are created.
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
backend (Optional[str], optional): Backend of the process group. Defaults to None.
return_ranks_by_group (bool): Whether to return all ranks by group for creating submesh. Defaults to False.
Returns:
ProcessGroup: The process group along the given axis which the current process belongs to.
@ -267,6 +279,10 @@ class ProcessGroupMesh:
coords_in_group = ProcessGroupMesh.get_coords_along_axis(self._coord, axis, indices_at_axis)
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
if return_ranks_by_group:
return self.create_group_along_axis(axis, indices_at_axis, backend=backend, return_ranks_by_group=True)
if ranks_in_group not in self._ranks_to_group:
# no need to cache it explicitly, since it will be cached in `create_group_along_axis`
return self.create_group_along_axis(axis, indices_at_axis, backend=backend)

@ -29,7 +29,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.Adam
)
with torch.autograd.set_detect_anomaly(True):
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)
@ -115,8 +115,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
[
{
"tp_size": 1,
"pp_size": 1,
"ep_size": 1,
"pp_size": 2,
"num_microbatches": 2,
"ep_size": 2,
"zero_stage": 1,
"overlap_communication": False,
"precision": "fp32",
@ -125,7 +126,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 2,
"ep_size": 1,
"ep_size": 2,
"zero_stage": 1,
"overlap_communication": False,
"precision": "fp32",
@ -134,7 +135,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 2,
"ep_size": 1,
"ep_size": 2,
"zero_stage": 1,
"overlap_communication": False,
"precision": "fp32",

Loading…
Cancel
Save