|
|
|
@ -7,6 +7,7 @@ from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
|
import numpy as np
|
|
|
|
|
import torch.distributed as dist
|
|
|
|
|
from torch.distributed import ProcessGroup
|
|
|
|
|
from torch.distributed.distributed_c10d import GroupMember
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prod(nums: List[int]) -> int:
|
|
|
|
@ -47,7 +48,7 @@ class ProcessGroupMesh:
|
|
|
|
|
self._shape = size
|
|
|
|
|
self._rank = dist.get_rank()
|
|
|
|
|
self._coord = ProcessGroupMesh.unravel(self._rank, self._shape)
|
|
|
|
|
self._ranks_to_group: Dict[Tuple[int, ...], ProcessGroup] = {}
|
|
|
|
|
self._ranks_to_group: Dict[Tuple[int, ...], Union[ProcessGroup, GroupMember]] = {}
|
|
|
|
|
self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {}
|
|
|
|
|
|
|
|
|
|
def destroy_mesh_process_groups(self):
|
|
|
|
@ -150,7 +151,8 @@ class ProcessGroupMesh:
|
|
|
|
|
if tuple(ranks_in_group) not in self._ranks_to_group:
|
|
|
|
|
group = dist.new_group(ranks_in_group, backend=backend)
|
|
|
|
|
self._ranks_to_group[tuple(ranks_in_group)] = group
|
|
|
|
|
self._group_to_ranks[group] = tuple(ranks_in_group)
|
|
|
|
|
if group is not GroupMember.NON_GROUP_MEMBER:
|
|
|
|
|
self._group_to_ranks[group] = tuple(ranks_in_group)
|
|
|
|
|
return self._ranks_to_group[tuple(ranks_in_group)]
|
|
|
|
|
|
|
|
|
|
def get_ranks_in_group(self, group: ProcessGroup) -> List[int]:
|
|
|
|
|