Browse Source

Revert "[moe] implement submesh initialization"

This reverts commit 2f9bce6686.
colossalchat
hxwang 4 months ago committed by Hongxin Liu
parent
commit
5b4c12381b
  1. 44
      colossalai/cluster/process_group_mesh.py

44
colossalai/cluster/process_group_mesh.py

@ -209,15 +209,13 @@ class ProcessGroupMesh:
axis: Union[int, List[int]],
indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None,
backend: Optional[str] = None,
return_ranks_by_group: bool = False,
) -> Union[ProcessGroup, List[Tuple[int, ...]]]:
) -> ProcessGroup:
"""Create all process groups along the given axis, and return the one which the current process belongs to.
Args:
axis (int): Axis along which the process groups are created.
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
backend (Optional[str], optional): Backend of the process group. Defaults to None.
return_ranks_by_group (bool): Whether to return all ranks by group for creating submesh. Defaults to False.
Returns:
ProcessGroup: The process group along the given axis which the current process belongs to.
@ -237,39 +235,25 @@ class ProcessGroupMesh:
# the choices on the axis are reduced to 1, since it's determined by `indices_at_axis`
for ax in axis:
reduced_shape[ax] = 1
if return_ranks_by_group:
ranks_by_group = []
# use Cartesian product to generate all combinations of coordinates
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
ranks_by_group.append(ranks_in_group)
return ranks_by_group
else:
target_group = None
# use Cartesian product to generate all combinations of coordinates
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
group = self._get_group(ranks_in_group, backend=backend)
if self._rank in ranks_in_group:
target_group = group
return target_group
target_group = None
# use Cartesian product to generate all combinations of coordinates
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
group = self._get_group(ranks_in_group, backend=backend)
if self._rank in ranks_in_group:
target_group = group
return target_group
def get_group_along_axis(
self,
axis: Union[int, List[int]],
indices_at_axis: Optional[List[int]] = None,
backend: Optional[str] = None,
return_ranks_by_group: bool = False,
) -> Union[ProcessGroup, List[Tuple[int, ...]]]:
self, axis: Union[int, List[int]], indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
) -> ProcessGroup:
"""Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created.
Args:
axis (int or list of int): Axes along which the process groups are created.
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
backend (Optional[str], optional): Backend of the process group. Defaults to None.
return_ranks_by_group (bool): Whether to return all ranks by group for creating submesh. Defaults to False.
Returns:
ProcessGroup: The process group along the given axis which the current process belongs to.
@ -283,10 +267,6 @@ class ProcessGroupMesh:
coords_in_group = ProcessGroupMesh.get_coords_along_axis(self._coord, axis, indices_at_axis)
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
if return_ranks_by_group:
return self.create_group_along_axis(axis, indices_at_axis, backend=backend, return_ranks_by_group=True)
if ranks_in_group not in self._ranks_to_group:
# no need to cache it explicitly, since it will be cached in `create_group_along_axis`
return self.create_group_along_axis(axis, indices_at_axis, backend=backend)

Loading…
Cancel
Save