@ -209,13 +209,15 @@ class ProcessGroupMesh:
axis : Union [ int , List [ int ] ] ,
axis : Union [ int , List [ int ] ] ,
indices_at_axis : Optional [ Union [ List [ int ] , List [ List [ int ] ] ] ] = None ,
indices_at_axis : Optional [ Union [ List [ int ] , List [ List [ int ] ] ] ] = None ,
backend : Optional [ str ] = None ,
backend : Optional [ str ] = None ,
) - > ProcessGroup :
return_ranks_by_group : bool = False
) - > Union [ ProcessGroup , List [ Tuple [ int , . . . ] ] ] :
""" Create all process groups along the given axis, and return the one which the current process belongs to.
""" Create all process groups along the given axis, and return the one which the current process belongs to.
Args :
Args :
axis ( int ) : Axis along which the process groups are created .
axis ( int ) : Axis along which the process groups are created .
indices_at_axis ( Optional [ List [ int ] ] , optional ) : Indices at the axis . Defaults to None .
indices_at_axis ( Optional [ List [ int ] ] , optional ) : Indices at the axis . Defaults to None .
backend ( Optional [ str ] , optional ) : Backend of the process group . Defaults to None .
backend ( Optional [ str ] , optional ) : Backend of the process group . Defaults to None .
return_ranks_by_group ( bool ) : Whether to return all ranks by group for creating submesh . Defaults to False .
Returns :
Returns :
ProcessGroup : The process group along the given axis which the current process belongs to .
ProcessGroup : The process group along the given axis which the current process belongs to .
@ -235,25 +237,35 @@ class ProcessGroupMesh:
# the choices on the axis are reduced to 1, since it's determined by `indices_at_axis`
# the choices on the axis are reduced to 1, since it's determined by `indices_at_axis`
for ax in axis :
for ax in axis :
reduced_shape [ ax ] = 1
reduced_shape [ ax ] = 1
target_group = None
if return_ranks_by_group :
# use Cartesian product to generate all combinations of coordinates
ranks_by_group = [ ]
for base_coord in itertools . product ( * [ range ( s ) for s in reduced_shape ] ) :
# use Cartesian product to generate all combinations of coordinates
coords_in_group = ProcessGroupMesh . get_coords_along_axis ( base_coord , axis , indices_at_axis )
for base_coord in itertools . product ( * [ range ( s ) for s in reduced_shape ] ) :
ranks_in_group = tuple ( [ ProcessGroupMesh . ravel ( coord , self . _shape ) for coord in coords_in_group ] )
coords_in_group = ProcessGroupMesh . get_coords_along_axis ( base_coord , axis , indices_at_axis )
group = self . _get_group ( ranks_in_group , backend = backend )
ranks_in_group = tuple ( [ ProcessGroupMesh . ravel ( coord , self . _shape ) for coord in coords_in_group ] )
if self . _rank in ranks_in_group :
ranks_by_group . append ( ranks_in_group )
target_group = group
return ranks_by_group
return target_group
else :
target_group = None
# use Cartesian product to generate all combinations of coordinates
for base_coord in itertools . product ( * [ range ( s ) for s in reduced_shape ] ) :
coords_in_group = ProcessGroupMesh . get_coords_along_axis ( base_coord , axis , indices_at_axis )
ranks_in_group = tuple ( [ ProcessGroupMesh . ravel ( coord , self . _shape ) for coord in coords_in_group ] )
group = self . _get_group ( ranks_in_group , backend = backend )
if self . _rank in ranks_in_group :
target_group = group
return target_group
def get_group_along_axis (
def get_group_along_axis (
self , axis : Union [ int , List [ int ] ] , indices_at_axis : Optional [ List [ int ] ] = None , backend : Optional [ str ] = None
self , axis : Union [ int , List [ int ] ] , indices_at_axis : Optional [ List [ int ] ] = None , backend : Optional [ str ] = None , return_ranks_by_group : bool = False
) - > ProcessGroup :
) - > Union [ ProcessGroup , List [ Tuple [ int , . . . ] ] ] :
""" Get the process group along the given axis which the current process belongs to. If the process group doesn ' t exist, it will be created.
""" Get the process group along the given axis which the current process belongs to. If the process group doesn ' t exist, it will be created.
Args :
Args :
axis ( int or list of int ) : Axes along which the process groups are created .
axis ( int or list of int ) : Axes along which the process groups are created .
indices_at_axis ( Optional [ List [ int ] ] , optional ) : Indices at the axis . Defaults to None .
indices_at_axis ( Optional [ List [ int ] ] , optional ) : Indices at the axis . Defaults to None .
backend ( Optional [ str ] , optional ) : Backend of the process group . Defaults to None .
backend ( Optional [ str ] , optional ) : Backend of the process group . Defaults to None .
return_ranks_by_group ( bool ) : Whether to return all ranks by group for creating submesh . Defaults to False .
Returns :
Returns :
ProcessGroup : The process group along the given axis which the current process belongs to .
ProcessGroup : The process group along the given axis which the current process belongs to .
@ -267,6 +279,10 @@ class ProcessGroupMesh:
coords_in_group = ProcessGroupMesh . get_coords_along_axis ( self . _coord , axis , indices_at_axis )
coords_in_group = ProcessGroupMesh . get_coords_along_axis ( self . _coord , axis , indices_at_axis )
ranks_in_group = tuple ( [ ProcessGroupMesh . ravel ( coord , self . _shape ) for coord in coords_in_group ] )
ranks_in_group = tuple ( [ ProcessGroupMesh . ravel ( coord , self . _shape ) for coord in coords_in_group ] )
if return_ranks_by_group :
return self . create_group_along_axis ( axis , indices_at_axis , backend = backend , return_ranks_by_group = True )
if ranks_in_group not in self . _ranks_to_group :
if ranks_in_group not in self . _ranks_to_group :
# no need to cache it explicitly, since it will be cached in `create_group_along_axis`
# no need to cache it explicitly, since it will be cached in `create_group_along_axis`
return self . create_group_along_axis ( axis , indices_at_axis , backend = backend )
return self . create_group_along_axis ( axis , indices_at_axis , backend = backend )