mirror of https://github.com/hpcaitech/ColossalAI
[moe] implement submesh initialization
parent
5ed5e8cfba
commit
e28e05345b
|
@ -1,6 +1,7 @@
|
|||
import warnings
|
||||
from types import MethodType
|
||||
from typing import Callable, Optional, OrderedDict, Tuple
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -64,6 +65,14 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
|
|||
overlap_communication = True
|
||||
warnings.warn(WARN_STR + " Please make sure of this.")
|
||||
|
||||
self.param_info = param_info
|
||||
self.stage_manager = model.stage_manager
|
||||
self.shared_params = model.shared_params
|
||||
self.dp_pg = dp_process_group
|
||||
|
||||
if use_pipeline:
|
||||
reinitialize_optimizer(optimizer, model)
|
||||
|
||||
pg_param_list = {
|
||||
dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
|
||||
moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
|
||||
|
@ -116,17 +125,16 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
raise NotImplementedError
|
||||
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
self.moe_dp_size = world_size // (ep_size * moe_tp_size)
|
||||
self.moe_dp_size = world_size // (self.pp_size * ep_size * moe_tp_size)
|
||||
self.ep_size = ep_size
|
||||
self.moe_tp_size = moe_tp_size
|
||||
|
||||
self.moe_pg_mesh = ProcessGroupMesh(self.moe_dp_size, self.ep_size, self.moe_tp_size)
|
||||
self.moe_dp_axis, self.ep_axis, self.moe_tp_axis = 0, 1, 2
|
||||
if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size != world_size:
|
||||
raise ValueError(
|
||||
f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}"
|
||||
)
|
||||
|
||||
self.moe_dp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_dp_axis)
|
||||
self.ep_group = self.moe_pg_mesh.get_group_along_axis(self.ep_axis)
|
||||
self.moe_tp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_tp_axis)
|
||||
self._init_moe_param_comm()
|
||||
|
||||
self.logger.info(f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}", ranks=[0])
|
||||
|
||||
|
@ -136,6 +144,52 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
|
||||
self.force_overlap_comm = force_overlap_comm
|
||||
|
||||
def _init_moe_param_comm(self):
|
||||
self.moe_dp_group = None
|
||||
self.ep_group = None
|
||||
self.moe_tp_group = None
|
||||
|
||||
# create submesh for ep, moe_dp, moe_tp
|
||||
ranks_by_pp_stage = self.pg_mesh.get_group_along_axis(
|
||||
[self.dp_axis, self.tp_axis, self.sp_axis], return_ranks_by_group=True
|
||||
)
|
||||
|
||||
global_rank = self.pg_mesh.rank
|
||||
pp_rank = self.pg_mesh.coordinate(self.pp_axis)
|
||||
|
||||
# create groups from submesh
|
||||
for stage_idx, stage_rank in enumerate(ranks_by_pp_stage):
|
||||
# axis 0 is dp, axis 1 is tp, axis 2 is sp
|
||||
submesh = np.array(stage_rank).reshape(self.moe_dp_size, self.ep_size, self.moe_tp_size)
|
||||
|
||||
# hardcode here since we only have 3 axis
|
||||
# moe_dp_group
|
||||
for ep_idx in range(self.ep_size):
|
||||
for moe_tp_idx in range(self.moe_tp_size):
|
||||
moe_dp_ranks = submesh[:, ep_idx, moe_tp_idx].flatten().tolist()
|
||||
group = dist.new_group(moe_dp_ranks)
|
||||
if pp_rank == stage_idx and global_rank in moe_dp_ranks:
|
||||
assert self.moe_dp_group is None
|
||||
self.moe_dp_group = group
|
||||
# ep_group
|
||||
for moe_dp_idx in range(self.moe_dp_size):
|
||||
for moe_tp_idx in range(self.moe_tp_size):
|
||||
ep_ranks = submesh[moe_dp_idx, :, moe_tp_idx].flatten().tolist()
|
||||
group = dist.new_group(ep_ranks)
|
||||
if pp_rank == stage_idx and global_rank in ep_ranks:
|
||||
assert self.ep_group is None
|
||||
self.ep_group = group
|
||||
# moe_tp_group
|
||||
for moe_dp_idx in range(self.moe_dp_size):
|
||||
for ep_idx in range(self.ep_size):
|
||||
moe_tp_ranks = submesh[moe_dp_idx, ep_idx, :].flatten().tolist()
|
||||
group = dist.new_group(moe_tp_ranks)
|
||||
if pp_rank == stage_idx and global_rank in moe_tp_ranks:
|
||||
assert self.moe_tp_group is None
|
||||
self.moe_tp_group = group
|
||||
|
||||
self.logger.info(f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}")
|
||||
|
||||
def get_checkpoint_io(self) -> MoECheckpointIO:
|
||||
return MoECheckpointIO(
|
||||
self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage
|
||||
|
|
|
@ -209,13 +209,15 @@ class ProcessGroupMesh:
|
|||
axis: Union[int, List[int]],
|
||||
indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None,
|
||||
backend: Optional[str] = None,
|
||||
) -> ProcessGroup:
|
||||
return_ranks_by_group: bool = False
|
||||
) -> Union[ProcessGroup, List[Tuple[int, ...]]]:
|
||||
"""Create all process groups along the given axis, and return the one which the current process belongs to.
|
||||
|
||||
Args:
|
||||
axis (int): Axis along which the process groups are created.
|
||||
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
|
||||
backend (Optional[str], optional): Backend of the process group. Defaults to None.
|
||||
return_ranks_by_group (bool): Whether to return all ranks by group for creating submesh. Defaults to False.
|
||||
|
||||
Returns:
|
||||
ProcessGroup: The process group along the given axis which the current process belongs to.
|
||||
|
@ -235,25 +237,35 @@ class ProcessGroupMesh:
|
|||
# the choices on the axis are reduced to 1, since it's determined by `indices_at_axis`
|
||||
for ax in axis:
|
||||
reduced_shape[ax] = 1
|
||||
target_group = None
|
||||
# use Cartesian product to generate all combinations of coordinates
|
||||
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
|
||||
coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
|
||||
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
|
||||
group = self._get_group(ranks_in_group, backend=backend)
|
||||
if self._rank in ranks_in_group:
|
||||
target_group = group
|
||||
return target_group
|
||||
if return_ranks_by_group:
|
||||
ranks_by_group = []
|
||||
# use Cartesian product to generate all combinations of coordinates
|
||||
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
|
||||
coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
|
||||
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
|
||||
ranks_by_group.append(ranks_in_group)
|
||||
return ranks_by_group
|
||||
else:
|
||||
target_group = None
|
||||
# use Cartesian product to generate all combinations of coordinates
|
||||
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
|
||||
coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
|
||||
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
|
||||
group = self._get_group(ranks_in_group, backend=backend)
|
||||
if self._rank in ranks_in_group:
|
||||
target_group = group
|
||||
return target_group
|
||||
|
||||
def get_group_along_axis(
|
||||
self, axis: Union[int, List[int]], indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
|
||||
) -> ProcessGroup:
|
||||
self, axis: Union[int, List[int]], indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None, return_ranks_by_group: bool = False
|
||||
) -> Union[ProcessGroup, List[Tuple[int, ...]]]:
|
||||
"""Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created.
|
||||
|
||||
Args:
|
||||
axis (int or list of int): Axes along which the process groups are created.
|
||||
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
|
||||
backend (Optional[str], optional): Backend of the process group. Defaults to None.
|
||||
return_ranks_by_group (bool): Whether to return all ranks by group for creating submesh. Defaults to False.
|
||||
|
||||
Returns:
|
||||
ProcessGroup: The process group along the given axis which the current process belongs to.
|
||||
|
@ -267,6 +279,10 @@ class ProcessGroupMesh:
|
|||
|
||||
coords_in_group = ProcessGroupMesh.get_coords_along_axis(self._coord, axis, indices_at_axis)
|
||||
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
|
||||
|
||||
if return_ranks_by_group:
|
||||
return self.create_group_along_axis(axis, indices_at_axis, backend=backend, return_ranks_by_group=True)
|
||||
|
||||
if ranks_in_group not in self._ranks_to_group:
|
||||
# no need to cache it explicitly, since it will be cached in `create_group_along_axis`
|
||||
return self.create_group_along_axis(axis, indices_at_axis, backend=backend)
|
||||
|
|
|
@ -29,10 +29,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
|
||||
model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.Adam
|
||||
)
|
||||
with torch.autograd.set_detect_anomaly(True):
|
||||
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
||||
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
||||
)
|
||||
|
||||
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
||||
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
||||
)
|
||||
|
||||
stage_manager = booster.plugin.stage_manager
|
||||
tp_group = booster.plugin.tp_group
|
||||
|
@ -115,8 +115,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
[
|
||||
{
|
||||
"tp_size": 1,
|
||||
"pp_size": 1,
|
||||
"ep_size": 1,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"ep_size": 2,
|
||||
"zero_stage": 1,
|
||||
"overlap_communication": False,
|
||||
"precision": "fp32",
|
||||
|
@ -125,7 +126,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
"tp_size": 1,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"ep_size": 1,
|
||||
"ep_size": 2,
|
||||
"zero_stage": 1,
|
||||
"overlap_communication": False,
|
||||
"precision": "fp32",
|
||||
|
@ -134,7 +135,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
"tp_size": 2,
|
||||
"pp_size": 2,
|
||||
"num_microbatches": 2,
|
||||
"ep_size": 1,
|
||||
"ep_size": 2,
|
||||
"zero_stage": 1,
|
||||
"overlap_communication": False,
|
||||
"precision": "fp32",
|
||||
|
|
Loading…
Reference in New Issue