[chore] handle non member group

moe_sp
hxwang 5 months ago
parent 4fc6f9aa98
commit 000456bf94
No known key found for this signature in database
GPG Key ID: 0EC383D418F0B9F8

@ -7,6 +7,7 @@ from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import GroupMember
def prod(nums: List[int]) -> int:
@ -47,7 +48,7 @@ class ProcessGroupMesh:
self._shape = size
self._rank = dist.get_rank()
self._coord = ProcessGroupMesh.unravel(self._rank, self._shape)
self._ranks_to_group: Dict[Tuple[int, ...], ProcessGroup] = {}
self._ranks_to_group: Dict[Tuple[int, ...], Union[ProcessGroup, GroupMember]] = {}
self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {}
def destroy_mesh_process_groups(self):
@ -150,7 +151,8 @@ class ProcessGroupMesh:
if tuple(ranks_in_group) not in self._ranks_to_group:
group = dist.new_group(ranks_in_group, backend=backend)
self._ranks_to_group[tuple(ranks_in_group)] = group
self._group_to_ranks[group] = tuple(ranks_in_group)
if group is not GroupMember.NON_GROUP_MEMBER:
self._group_to_ranks[group] = tuple(ranks_in_group)
return self._ranks_to_group[tuple(ranks_in_group)]
def get_ranks_in_group(self, group: ProcessGroup) -> List[int]:

Loading…
Cancel
Save