mirror of https://github.com/hpcaitech/ColossalAI
parent
cb01c0d5ce
commit
5b4c12381b
|
@ -209,15 +209,13 @@ class ProcessGroupMesh:
|
||||||
axis: Union[int, List[int]],
|
axis: Union[int, List[int]],
|
||||||
indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None,
|
indices_at_axis: Optional[Union[List[int], List[List[int]]]] = None,
|
||||||
backend: Optional[str] = None,
|
backend: Optional[str] = None,
|
||||||
return_ranks_by_group: bool = False,
|
) -> ProcessGroup:
|
||||||
) -> Union[ProcessGroup, List[Tuple[int, ...]]]:
|
|
||||||
"""Create all process groups along the given axis, and return the one which the current process belongs to.
|
"""Create all process groups along the given axis, and return the one which the current process belongs to.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
axis (int): Axis along which the process groups are created.
|
axis (int): Axis along which the process groups are created.
|
||||||
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
|
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
|
||||||
backend (Optional[str], optional): Backend of the process group. Defaults to None.
|
backend (Optional[str], optional): Backend of the process group. Defaults to None.
|
||||||
return_ranks_by_group (bool): Whether to return all ranks by group for creating submesh. Defaults to False.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ProcessGroup: The process group along the given axis which the current process belongs to.
|
ProcessGroup: The process group along the given axis which the current process belongs to.
|
||||||
|
@ -237,15 +235,6 @@ class ProcessGroupMesh:
|
||||||
# the choices on the axis are reduced to 1, since it's determined by `indices_at_axis`
|
# the choices on the axis are reduced to 1, since it's determined by `indices_at_axis`
|
||||||
for ax in axis:
|
for ax in axis:
|
||||||
reduced_shape[ax] = 1
|
reduced_shape[ax] = 1
|
||||||
if return_ranks_by_group:
|
|
||||||
ranks_by_group = []
|
|
||||||
# use Cartesian product to generate all combinations of coordinates
|
|
||||||
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
|
|
||||||
coords_in_group = ProcessGroupMesh.get_coords_along_axis(base_coord, axis, indices_at_axis)
|
|
||||||
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
|
|
||||||
ranks_by_group.append(ranks_in_group)
|
|
||||||
return ranks_by_group
|
|
||||||
else:
|
|
||||||
target_group = None
|
target_group = None
|
||||||
# use Cartesian product to generate all combinations of coordinates
|
# use Cartesian product to generate all combinations of coordinates
|
||||||
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
|
for base_coord in itertools.product(*[range(s) for s in reduced_shape]):
|
||||||
|
@ -257,19 +246,14 @@ class ProcessGroupMesh:
|
||||||
return target_group
|
return target_group
|
||||||
|
|
||||||
def get_group_along_axis(
|
def get_group_along_axis(
|
||||||
self,
|
self, axis: Union[int, List[int]], indices_at_axis: Optional[List[int]] = None, backend: Optional[str] = None
|
||||||
axis: Union[int, List[int]],
|
) -> ProcessGroup:
|
||||||
indices_at_axis: Optional[List[int]] = None,
|
|
||||||
backend: Optional[str] = None,
|
|
||||||
return_ranks_by_group: bool = False,
|
|
||||||
) -> Union[ProcessGroup, List[Tuple[int, ...]]]:
|
|
||||||
"""Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created.
|
"""Get the process group along the given axis which the current process belongs to. If the process group doesn't exist, it will be created.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
axis (int or list of int): Axes along which the process groups are created.
|
axis (int or list of int): Axes along which the process groups are created.
|
||||||
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
|
indices_at_axis (Optional[List[int]], optional): Indices at the axis. Defaults to None.
|
||||||
backend (Optional[str], optional): Backend of the process group. Defaults to None.
|
backend (Optional[str], optional): Backend of the process group. Defaults to None.
|
||||||
return_ranks_by_group (bool): Whether to return all ranks by group for creating submesh. Defaults to False.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ProcessGroup: The process group along the given axis which the current process belongs to.
|
ProcessGroup: The process group along the given axis which the current process belongs to.
|
||||||
|
@ -283,10 +267,6 @@ class ProcessGroupMesh:
|
||||||
|
|
||||||
coords_in_group = ProcessGroupMesh.get_coords_along_axis(self._coord, axis, indices_at_axis)
|
coords_in_group = ProcessGroupMesh.get_coords_along_axis(self._coord, axis, indices_at_axis)
|
||||||
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
|
ranks_in_group = tuple([ProcessGroupMesh.ravel(coord, self._shape) for coord in coords_in_group])
|
||||||
|
|
||||||
if return_ranks_by_group:
|
|
||||||
return self.create_group_along_axis(axis, indices_at_axis, backend=backend, return_ranks_by_group=True)
|
|
||||||
|
|
||||||
if ranks_in_group not in self._ranks_to_group:
|
if ranks_in_group not in self._ranks_to_group:
|
||||||
# no need to cache it explicitly, since it will be cached in `create_group_along_axis`
|
# no need to cache it explicitly, since it will be cached in `create_group_along_axis`
|
||||||
return self.create_group_along_axis(axis, indices_at_axis, backend=backend)
|
return self.create_group_along_axis(axis, indices_at_axis, backend=backend)
|
||||||
|
|
Loading…
Reference in New Issue