mirror of https://github.com/hpcaitech/ColossalAI
parent
cb01c0d5ce
commit
5b4c12381b
|
@ -209,15 +209,13 @@ class ProcessGroupMesh:
|
|||
axis: Union[int, List[int]],
|
||||
indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None,
|
||||
backend: Optional[str] = None,
|
||||
return_ranks_by_group: bool = False,
|
||||
) -> Union[ProcessGroup, List[Tuple[int, ...]]]:
|
||||
) -> ProcessGroup:
|
||||
"""Create all process groups along the given axis, and return the one which the current process belongs to.
|
||||
|
||||
Args:
|
||||
axis (int): Axis along which the process groups are created.
|
||||
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
|
||||
backend (Optional[str], optional): Backend of the process group. Defaults to None.
|
||||
return_ranks_by_group (bool): Whether to return all ranks by group for creating submesh. Defaults to False.
|
||||
|
||||
Returns:
|
||||
ProcessGroup: The process group along the given axis which the current process belongs to.
|
||||
|
@ -237,15 +235,6 @@ class ProcessGroupMesh:
|
|||
# the choices on the axis are reduced to 1, since it's determined by `indices_at_axis`
|
||||
for ax in axis:
|
||||
reduced_shape[ax] = 1
|
||||
if return_ranks_by_group:
|
||||
ranks_by_group = []
|
||||
# use Cartesian product to generate all combinations of coordinates
|
||||
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
|
||||
coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
|
||||
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
|
||||
ranks_by_group.append(ranks_in_group)
|
||||
return ranks_by_group
|
||||
else:
|
||||
target_group = None
|
||||
# use Cartesian product to generate all combinations of coordinates
|
||||
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
|
||||
|
@ -257,19 +246,14 @@ class ProcessGroupMesh:
|
|||
return target_group
|
||||
|
||||
def get_group_along_axis(
|
||||
self,
|
||||
axis: Union[int, List[int]],
|
||||
indices_at_axis: Optional[List[int]] = None,
|
||||
backend: Optional[str] = None,
|
||||
return_ranks_by_group: bool = False,
|
||||
) -> Union[ProcessGroup, List[Tuple[int, ...]]]:
|
||||
self, axis: Union[int, List[int]], indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
|
||||
) -> ProcessGroup:
|
||||
"""Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created.
|
||||
|
||||
Args:
|
||||
axis (int or list of int): Axes along which the process groups are created.
|
||||
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
|
||||
backend (Optional[str], optional): Backend of the process group. Defaults to None.
|
||||
return_ranks_by_group (bool): Whether to return all ranks by group for creating submesh. Defaults to False.
|
||||
|
||||
Returns:
|
||||
ProcessGroup: The process group along the given axis which the current process belongs to.
|
||||
|
@ -283,10 +267,6 @@ class ProcessGroupMesh:
|
|||
|
||||
coords_in_group = ProcessGroupMesh.get_coords_along_axis(self._coord, axis, indices_at_axis)
|
||||
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
|
||||
|
||||
if return_ranks_by_group:
|
||||
return self.create_group_along_axis(axis, indices_at_axis, backend=backend, return_ranks_by_group=True)
|
||||
|
||||
if ranks_in_group not in self._ranks_to_group:
|
||||
# no need to cache it explicitly, since it will be cached in `create_group_along_axis`
|
||||
return self.create_group_along_axis(axis, indices_at_axis, backend=backend)
|
||||
|
|
Loading…
Reference in New Issue