mirror of https://github.com/hpcaitech/ColossalAI
[moe] implement submesh initialization
parent
5ed5e8cfba
commit
e28e05345b
|
@ -1,6 +1,7 @@
|
||||||
import warnings
|
import warnings
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import Callable, Optional, OrderedDict, Tuple
|
from typing import Callable, Optional, OrderedDict, Tuple
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -64,6 +65,14 @@ class MoeHybridParallelZeroOptimizer(HybridParallelZeroOptimizer):
|
||||||
overlap_communication = True
|
overlap_communication = True
|
||||||
warnings.warn(WARN_STR + " Please make sure of this.")
|
warnings.warn(WARN_STR + " Please make sure of this.")
|
||||||
|
|
||||||
|
self.param_info = param_info
|
||||||
|
self.stage_manager = model.stage_manager
|
||||||
|
self.shared_params = model.shared_params
|
||||||
|
self.dp_pg = dp_process_group
|
||||||
|
|
||||||
|
if use_pipeline:
|
||||||
|
reinitialize_optimizer(optimizer, model)
|
||||||
|
|
||||||
pg_param_list = {
|
pg_param_list = {
|
||||||
dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
|
dp_process_group: list(filter(lambda p: not is_moe_tensor(p), model.parameters())),
|
||||||
moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
|
moe_dp_group: list(filter(is_moe_tensor, model.parameters())),
|
||||||
|
@ -116,17 +125,16 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
world_size = dist.get_world_size()
|
world_size = dist.get_world_size()
|
||||||
|
self.moe_dp_size = world_size // (self.pp_size * ep_size * moe_tp_size)
|
||||||
self.moe_dp_size = world_size // (ep_size * moe_tp_size)
|
|
||||||
self.ep_size = ep_size
|
self.ep_size = ep_size
|
||||||
self.moe_tp_size = moe_tp_size
|
self.moe_tp_size = moe_tp_size
|
||||||
|
|
||||||
self.moe_pg_mesh = ProcessGroupMesh(self.moe_dp_size, self.ep_size, self.moe_tp_size)
|
if self.pp_size * self.moe_dp_size * self.ep_size * self.moe_tp_size != world_size:
|
||||||
self.moe_dp_axis, self.ep_axis, self.moe_tp_axis = 0, 1, 2
|
raise ValueError(
|
||||||
|
f"world_size={world_size} is not divisible by pp_size={self.pp_size} * moe_dp_size={self.moe_dp_size} * ep_size={self.ep_size} * moe_tp_size={self.moe_tp_size}"
|
||||||
|
)
|
||||||
|
|
||||||
self.moe_dp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_dp_axis)
|
self._init_moe_param_comm()
|
||||||
self.ep_group = self.moe_pg_mesh.get_group_along_axis(self.ep_axis)
|
|
||||||
self.moe_tp_group = self.moe_pg_mesh.get_group_along_axis(self.moe_tp_axis)
|
|
||||||
|
|
||||||
self.logger.info(f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}", ranks=[0])
|
self.logger.info(f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=}", ranks=[0])
|
||||||
|
|
||||||
|
@ -136,6 +144,52 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
|
|
||||||
self.force_overlap_comm = force_overlap_comm
|
self.force_overlap_comm = force_overlap_comm
|
||||||
|
|
||||||
|
def _init_moe_param_comm(self):
|
||||||
|
self.moe_dp_group = None
|
||||||
|
self.ep_group = None
|
||||||
|
self.moe_tp_group = None
|
||||||
|
|
||||||
|
# create submesh for ep, moe_dp, moe_tp
|
||||||
|
ranks_by_pp_stage = self.pg_mesh.get_group_along_axis(
|
||||||
|
[self.dp_axis, self.tp_axis, self.sp_axis], return_ranks_by_group=True
|
||||||
|
)
|
||||||
|
|
||||||
|
global_rank = self.pg_mesh.rank
|
||||||
|
pp_rank = self.pg_mesh.coordinate(self.pp_axis)
|
||||||
|
|
||||||
|
# create groups from submesh
|
||||||
|
for stage_idx, stage_rank in enumerate(ranks_by_pp_stage):
|
||||||
|
# axis 0 is dp, axis 1 is tp, axis 2 is sp
|
||||||
|
submesh = np.array(stage_rank).reshape(self.moe_dp_size, self.ep_size, self.moe_tp_size)
|
||||||
|
|
||||||
|
# hardcode here since we only have 3 axis
|
||||||
|
# moe_dp_group
|
||||||
|
for ep_idx in range(self.ep_size):
|
||||||
|
for moe_tp_idx in range(self.moe_tp_size):
|
||||||
|
moe_dp_ranks = submesh[:, ep_idx, moe_tp_idx].flatten().tolist()
|
||||||
|
group = dist.new_group(moe_dp_ranks)
|
||||||
|
if pp_rank == stage_idx and global_rank in moe_dp_ranks:
|
||||||
|
assert self.moe_dp_group is None
|
||||||
|
self.moe_dp_group = group
|
||||||
|
# ep_group
|
||||||
|
for moe_dp_idx in range(self.moe_dp_size):
|
||||||
|
for moe_tp_idx in range(self.moe_tp_size):
|
||||||
|
ep_ranks = submesh[moe_dp_idx, :, moe_tp_idx].flatten().tolist()
|
||||||
|
group = dist.new_group(ep_ranks)
|
||||||
|
if pp_rank == stage_idx and global_rank in ep_ranks:
|
||||||
|
assert self.ep_group is None
|
||||||
|
self.ep_group = group
|
||||||
|
# moe_tp_group
|
||||||
|
for moe_dp_idx in range(self.moe_dp_size):
|
||||||
|
for ep_idx in range(self.ep_size):
|
||||||
|
moe_tp_ranks = submesh[moe_dp_idx, ep_idx, :].flatten().tolist()
|
||||||
|
group = dist.new_group(moe_tp_ranks)
|
||||||
|
if pp_rank == stage_idx and global_rank in moe_tp_ranks:
|
||||||
|
assert self.moe_tp_group is None
|
||||||
|
self.moe_tp_group = group
|
||||||
|
|
||||||
|
self.logger.info(f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)}")
|
||||||
|
|
||||||
def get_checkpoint_io(self) -> MoECheckpointIO:
|
def get_checkpoint_io(self) -> MoECheckpointIO:
|
||||||
return MoECheckpointIO(
|
return MoECheckpointIO(
|
||||||
self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage
|
self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage
|
||||||
|
|
|
@ -209,13 +209,15 @@ class ProcessGroupMesh:
|
||||||
axis: Union[int, List[int]],
|
axis: Union[int, List[int]],
|
||||||
indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None,
|
indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None,
|
||||||
backend: Optional[str] = None,
|
backend: Optional[str] = None,
|
||||||
) -> ProcessGroup:
|
return_ranks_by_group: bool = False
|
||||||
|
) -> Union[ProcessGroup, List[Tuple[int, ...]]]:
|
||||||
"""Create all process groups along the given axis, and return the one which the current process belongs to.
|
"""Create all process groups along the given axis, and return the one which the current process belongs to.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
axis (int): Axis along which the process groups are created.
|
axis (int): Axis along which the process groups are created.
|
||||||
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
|
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
|
||||||
backend (Optional[str], optional): Backend of the process group. Defaults to None.
|
backend (Optional[str], optional): Backend of the process group. Defaults to None.
|
||||||
|
return_ranks_by_group (bool): Whether to return all ranks by group for creating submesh. Defaults to False.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ProcessGroup: The process group along the given axis which the current process belongs to.
|
ProcessGroup: The process group along the given axis which the current process belongs to.
|
||||||
|
@ -235,25 +237,35 @@ class ProcessGroupMesh:
|
||||||
# the choices on the axis are reduced to 1, since it's determined by `indices_at_axis`
|
# the choices on the axis are reduced to 1, since it's determined by `indices_at_axis`
|
||||||
for ax in axis:
|
for ax in axis:
|
||||||
reduced_shape[ax] = 1
|
reduced_shape[ax] = 1
|
||||||
target_group = None
|
if return_ranks_by_group:
|
||||||
# use Cartesian product to generate all combinations of coordinates
|
ranks_by_group = []
|
||||||
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
|
# use Cartesian product to generate all combinations of coordinates
|
||||||
coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
|
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
|
||||||
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
|
coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
|
||||||
group = self._get_group(ranks_in_group, backend=backend)
|
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
|
||||||
if self._rank in ranks_in_group:
|
ranks_by_group.append(ranks_in_group)
|
||||||
target_group = group
|
return ranks_by_group
|
||||||
return target_group
|
else:
|
||||||
|
target_group = None
|
||||||
|
# use Cartesian product to generate all combinations of coordinates
|
||||||
|
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
|
||||||
|
coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
|
||||||
|
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
|
||||||
|
group = self._get_group(ranks_in_group, backend=backend)
|
||||||
|
if self._rank in ranks_in_group:
|
||||||
|
target_group = group
|
||||||
|
return target_group
|
||||||
|
|
||||||
def get_group_along_axis(
|
def get_group_along_axis(
|
||||||
self, axis: Union[int, List[int]], indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
|
self, axis: Union[int, List[int]], indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None, return_ranks_by_group: bool = False
|
||||||
) -> ProcessGroup:
|
) -> Union[ProcessGroup, List[Tuple[int, ...]]]:
|
||||||
"""Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created.
|
"""Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
axis (int or list of int): Axes along which the process groups are created.
|
axis (int or list of int): Axes along which the process groups are created.
|
||||||
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
|
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
|
||||||
backend (Optional[str], optional): Backend of the process group. Defaults to None.
|
backend (Optional[str], optional): Backend of the process group. Defaults to None.
|
||||||
|
return_ranks_by_group (bool): Whether to return all ranks by group for creating submesh. Defaults to False.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ProcessGroup: The process group along the given axis which the current process belongs to.
|
ProcessGroup: The process group along the given axis which the current process belongs to.
|
||||||
|
@ -267,6 +279,10 @@ class ProcessGroupMesh:
|
||||||
|
|
||||||
coords_in_group = ProcessGroupMesh.get_coords_along_axis(self._coord, axis, indices_at_axis)
|
coords_in_group = ProcessGroupMesh.get_coords_along_axis(self._coord, axis, indices_at_axis)
|
||||||
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
|
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
|
||||||
|
|
||||||
|
if return_ranks_by_group:
|
||||||
|
return self.create_group_along_axis(axis, indices_at_axis, backend=backend, return_ranks_by_group=True)
|
||||||
|
|
||||||
if ranks_in_group not in self._ranks_to_group:
|
if ranks_in_group not in self._ranks_to_group:
|
||||||
# no need to cache it explicitly, since it will be cached in `create_group_along_axis`
|
# no need to cache it explicitly, since it will be cached in `create_group_along_axis`
|
||||||
return self.create_group_along_axis(axis, indices_at_axis, backend=backend)
|
return self.create_group_along_axis(axis, indices_at_axis, backend=backend)
|
||||||
|
|
|
@ -29,10 +29,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
|
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
|
||||||
model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.Adam
|
model_fn, loss_fn, test_config, pluggin_cls=MoeHybridParallelPlugin, optim_class=torch.optim.Adam
|
||||||
)
|
)
|
||||||
with torch.autograd.set_detect_anomaly(True):
|
|
||||||
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
||||||
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
||||||
)
|
)
|
||||||
|
|
||||||
stage_manager = booster.plugin.stage_manager
|
stage_manager = booster.plugin.stage_manager
|
||||||
tp_group = booster.plugin.tp_group
|
tp_group = booster.plugin.tp_group
|
||||||
|
@ -115,8 +115,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"tp_size": 1,
|
"tp_size": 1,
|
||||||
"pp_size": 1,
|
"pp_size": 2,
|
||||||
"ep_size": 1,
|
"num_microbatches": 2,
|
||||||
|
"ep_size": 2,
|
||||||
"zero_stage": 1,
|
"zero_stage": 1,
|
||||||
"overlap_communication": False,
|
"overlap_communication": False,
|
||||||
"precision": "fp32",
|
"precision": "fp32",
|
||||||
|
@ -125,7 +126,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
"tp_size": 1,
|
"tp_size": 1,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
"num_microbatches": 2,
|
"num_microbatches": 2,
|
||||||
"ep_size": 1,
|
"ep_size": 2,
|
||||||
"zero_stage": 1,
|
"zero_stage": 1,
|
||||||
"overlap_communication": False,
|
"overlap_communication": False,
|
||||||
"precision": "fp32",
|
"precision": "fp32",
|
||||||
|
@ -134,7 +135,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 2,
|
"pp_size": 2,
|
||||||
"num_microbatches": 2,
|
"num_microbatches": 2,
|
||||||
"ep_size": 1,
|
"ep_size": 2,
|
||||||
"zero_stage": 1,
|
"zero_stage": 1,
|
||||||
"overlap_communication": False,
|
"overlap_communication": False,
|
||||||
"precision": "fp32",
|
"precision": "fp32",
|
||||||
|
|
Loading…
Reference in New Issue