Revert "[moe] implement submesh initialization"

This reverts commit 2f9bce6686.
colossalchat
hxwang 2024-07-25 06:32:02 +00:00 committed by Hongxin Liu
parent cb01c0d5ce
commit 5b4c12381b
1 changed files with 12 additions and 32 deletions

View File

@ -209,15 +209,13 @@ class ProcessGroupMesh:
axis: Union[int, List[int]], axis: Union[int, List[int]],
indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None, indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None,
backend: Optional[str] = None, backend: Optional[str] = None,
return_ranks_by_group: bool = False, ) -> ProcessGroup:
) -> Union[ProcessGroup, List[Tuple[int, ...]]]:
"""Create all process groups along the given axis, and return the one which the current process belongs to. """Create all process groups along the given axis, and return the one which the current process belongs to.
Args: Args:
axis (int): Axis along which the process groups are created. axis (int): Axis along which the process groups are created.
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None. indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
backend (Optional[str], optional): Backend of the process group. Defaults to None. backend (Optional[str], optional): Backend of the process group. Defaults to None.
return_ranks_by_group (bool): Whether to return all ranks by group for creating submesh. Defaults to False.
Returns: Returns:
ProcessGroup: The process group along the given axis which the current process belongs to. ProcessGroup: The process group along the given axis which the current process belongs to.
@ -237,15 +235,6 @@ class ProcessGroupMesh:
# the choices on the axis are reduced to 1, since it's determined by `indices_at_axis` # the choices on the axis are reduced to 1, since it's determined by `indices_at_axis`
for ax in axis: for ax in axis:
reduced_shape[ax] = 1 reduced_shape[ax] = 1
if return_ranks_by_group:
ranks_by_group = []
# use Cartesian product to generate all combinations of coordinates
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
ranks_by_group.append(ranks_in_group)
return ranks_by_group
else:
target_group = None target_group = None
# use Cartesian product to generate all combinations of coordinates # use Cartesian product to generate all combinations of coordinates
for base_coord in itertools.product(*[range(s) for s in reduced_shape]): for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
@ -257,19 +246,14 @@ class ProcessGroupMesh:
return target_group return target_group
def get_group_along_axis( def get_group_along_axis(
self, self, axis: Union[int, List[int]], indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
axis: Union[int, List[int]], ) -> ProcessGroup:
indices_at_axis: Optional[List[int]] = None,
backend: Optional[str] = None,
return_ranks_by_group: bool = False,
) -> Union[ProcessGroup, List[Tuple[int, ...]]]:
"""Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created. """Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created.
Args: Args:
axis (int or list of int): Axes along which the process groups are created. axis (int or list of int): Axes along which the process groups are created.
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None. indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
backend (Optional[str], optional): Backend of the process group. Defaults to None. backend (Optional[str], optional): Backend of the process group. Defaults to None.
return_ranks_by_group (bool): Whether to return all ranks by group for creating submesh. Defaults to False.
Returns: Returns:
ProcessGroup: The process group along the given axis which the current process belongs to. ProcessGroup: The process group along the given axis which the current process belongs to.
@ -283,10 +267,6 @@ class ProcessGroupMesh:
coords_in_group = ProcessGroupMesh.get_coords_along_axis(self._coord, axis, indices_at_axis) coords_in_group = ProcessGroupMesh.get_coords_along_axis(self._coord, axis, indices_at_axis)
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group]) ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
if return_ranks_by_group:
return self.create_group_along_axis(axis, indices_at_axis, backend=backend, return_ranks_by_group=True)
if ranks_in_group not in self._ranks_to_group: if ranks_in_group not in self._ranks_to_group:
# no need to cache it explicitly, since it will be cached in `create_group_along_axis` # no need to cache it explicitly, since it will be cached in `create_group_along_axis`
return self.create_group_along_axis(axis, indices_at_axis, backend=backend) return self.create_group_along_axis(axis, indices_at_axis, backend=backend)