diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py index a9d341efa..dc96708f0 100644 --- a/colossalai/cluster/process_group_mesh.py +++ b/colossalai/cluster/process_group_mesh.py @@ -209,15 +209,13 @@ class ProcessGroupMesh: axis: Union[int, List[int]], indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None, backend: Optional[str] = None, - return_ranks_by_group: bool = False, - ) -> Union[ProcessGroup, List[Tuple[int, ...]]]: + ) -> ProcessGroup: """Create all process groups along the given axis, and return the one which the current process belongs to. Args: axis (int): Axis along which the process groups are created. indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None. backend (Optional[str], optional): Backend of the process group. Defaults to None. - return_ranks_by_group (bool): Whether to return all ranks by group for creating submesh. Defaults to False. Returns: ProcessGroup: The process group along the given axis which the current process belongs to. @@ -237,39 +235,25 @@ class ProcessGroupMesh: # the choices on the axis are reduced to 1, since it's determined by `indices_at_axis` for ax in axis: reduced_shape[ax] = 1 - if return_ranks_by_group: - ranks_by_group = [] - # use Cartesian product to generate all combinations of coordinates - for base_coord in itertools.product(*[range(s) for s in reduced_shape]): - coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis) - ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group]) - ranks_by_group.append(ranks_in_group) - return ranks_by_group - else: - target_group = None - # use Cartesian product to generate all combinations of coordinates - for base_coord in itertools.product(*[range(s) for s in reduced_shape]): - coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis) - ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group]) - group = self._get_group(ranks_in_group, backend=backend) - if self._rank in ranks_in_group: - target_group = group - return target_group + target_group = None + # use Cartesian product to generate all combinations of coordinates + for base_coord in itertools.product(*[range(s) for s in reduced_shape]): + coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis) + ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group]) + group = self._get_group(ranks_in_group, backend=backend) + if self._rank in ranks_in_group: + target_group = group + return target_group def get_group_along_axis( - self, - axis: Union[int, List[int]], - indices_at_axis: Optional[List[int]] = None, - backend: Optional[str] = None, - return_ranks_by_group: bool = False, - ) -> Union[ProcessGroup, List[Tuple[int, ...]]]: + self, axis: Union[int, List[int]], indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None + ) -> ProcessGroup: """Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created. Args: axis (int or list of int): Axes along which the process groups are created. indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None. backend (Optional[str], optional): Backend of the process group. Defaults to None. - return_ranks_by_group (bool): Whether to return all ranks by group for creating submesh. Defaults to False. Returns: ProcessGroup: The process group along the given axis which the current process belongs to. @@ -283,10 +267,6 @@ class ProcessGroupMesh: coords_in_group = ProcessGroupMesh.get_coords_along_axis(self._coord, axis, indices_at_axis) ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group]) - - if return_ranks_by_group: - return self.create_group_along_axis(axis, indices_at_axis, backend=backend, return_ranks_by_group=True) - if ranks_in_group not in self._ranks_to_group: # no need to cache it explicitly, since it will be cached in `create_group_along_axis` return self.create_group_along_axis(axis, indices_at_axis, backend=backend)