|
|
|
@ -209,15 +209,13 @@ class ProcessGroupMesh:
|
|
|
|
|
axis: Union[int, List[int]], |
|
|
|
|
indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None, |
|
|
|
|
backend: Optional[str] = None, |
|
|
|
|
return_ranks_by_group: bool = False, |
|
|
|
|
) -> Union[ProcessGroup, List[Tuple[int, ...]]]: |
|
|
|
|
) -> ProcessGroup: |
|
|
|
|
"""Create all process groups along the given axis, and return the one which the current process belongs to. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
axis (int): Axis along which the process groups are created. |
|
|
|
|
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None. |
|
|
|
|
backend (Optional[str], optional): Backend of the process group. Defaults to None. |
|
|
|
|
return_ranks_by_group (bool): Whether to return all ranks by group for creating submesh. Defaults to False. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
ProcessGroup: The process group along the given axis which the current process belongs to. |
|
|
|
@ -237,39 +235,25 @@ class ProcessGroupMesh:
|
|
|
|
|
# the choices on the axis are reduced to 1, since it's determined by `indices_at_axis` |
|
|
|
|
for ax in axis: |
|
|
|
|
reduced_shape[ax] = 1 |
|
|
|
|
if return_ranks_by_group: |
|
|
|
|
ranks_by_group = [] |
|
|
|
|
# use Cartesian product to generate all combinations of coordinates |
|
|
|
|
for base_coord in itertools.product(*[range(s) for s in reduced_shape]): |
|
|
|
|
coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis) |
|
|
|
|
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group]) |
|
|
|
|
ranks_by_group.append(ranks_in_group) |
|
|
|
|
return ranks_by_group |
|
|
|
|
else: |
|
|
|
|
target_group = None |
|
|
|
|
# use Cartesian product to generate all combinations of coordinates |
|
|
|
|
for base_coord in itertools.product(*[range(s) for s in reduced_shape]): |
|
|
|
|
coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis) |
|
|
|
|
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group]) |
|
|
|
|
group = self._get_group(ranks_in_group, backend=backend) |
|
|
|
|
if self._rank in ranks_in_group: |
|
|
|
|
target_group = group |
|
|
|
|
return target_group |
|
|
|
|
target_group = None |
|
|
|
|
# use Cartesian product to generate all combinations of coordinates |
|
|
|
|
for base_coord in itertools.product(*[range(s) for s in reduced_shape]): |
|
|
|
|
coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis) |
|
|
|
|
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group]) |
|
|
|
|
group = self._get_group(ranks_in_group, backend=backend) |
|
|
|
|
if self._rank in ranks_in_group: |
|
|
|
|
target_group = group |
|
|
|
|
return target_group |
|
|
|
|
|
|
|
|
|
def get_group_along_axis( |
|
|
|
|
self, |
|
|
|
|
axis: Union[int, List[int]], |
|
|
|
|
indices_at_axis: Optional[List[int]] = None, |
|
|
|
|
backend: Optional[str] = None, |
|
|
|
|
return_ranks_by_group: bool = False, |
|
|
|
|
) -> Union[ProcessGroup, List[Tuple[int, ...]]]: |
|
|
|
|
self, axis: Union[int, List[int]], indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None |
|
|
|
|
) -> ProcessGroup: |
|
|
|
|
"""Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
axis (int or list of int): Axes along which the process groups are created. |
|
|
|
|
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None. |
|
|
|
|
backend (Optional[str], optional): Backend of the process group. Defaults to None. |
|
|
|
|
return_ranks_by_group (bool): Whether to return all ranks by group for creating submesh. Defaults to False. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
ProcessGroup: The process group along the given axis which the current process belongs to. |
|
|
|
@ -283,10 +267,6 @@ class ProcessGroupMesh:
|
|
|
|
|
|
|
|
|
|
coords_in_group = ProcessGroupMesh.get_coords_along_axis(self._coord, axis, indices_at_axis) |
|
|
|
|
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group]) |
|
|
|
|
|
|
|
|
|
if return_ranks_by_group: |
|
|
|
|
return self.create_group_along_axis(axis, indices_at_axis, backend=backend, return_ranks_by_group=True) |
|
|
|
|
|
|
|
|
|
if ranks_in_group not in self._ranks_to_group: |
|
|
|
|
# no need to cache it explicitly, since it will be cached in `create_group_along_axis` |
|
|
|
|
return self.create_group_along_axis(axis, indices_at_axis, backend=backend) |
|
|
|
|