[chore] handle non member group

moe_sp
hxwang 2024-07-05 07:03:45 +00:00
parent 4fc6f9aa98
commit 000456bf94
No known key found for this signature in database
GPG Key ID: 0EC383D418F0B9F8
1 changed files with 4 additions and 2 deletions

View File

@ -7,6 +7,7 @@ from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import GroupMember
def prod(nums: List[int]) -> int: def prod(nums: List[int]) -> int:
@ -47,7 +48,7 @@ class ProcessGroupMesh:
self._shape = size self._shape = size
self._rank = dist.get_rank() self._rank = dist.get_rank()
self._coord = ProcessGroupMesh.unravel(self._rank, self._shape) self._coord = ProcessGroupMesh.unravel(self._rank, self._shape)
self._ranks_to_group: Dict[Tuple[int, ...], ProcessGroup] = {} self._ranks_to_group: Dict[Tuple[int, ...], Union[ProcessGroup, GroupMember]] = {}
self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {} self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {}
def destroy_mesh_process_groups(self): def destroy_mesh_process_groups(self):
@ -150,6 +151,7 @@ class ProcessGroupMesh:
if tuple(ranks_in_group) not in self._ranks_to_group: if tuple(ranks_in_group) not in self._ranks_to_group:
group = dist.new_group(ranks_in_group, backend=backend) group = dist.new_group(ranks_in_group, backend=backend)
self._ranks_to_group[tuple(ranks_in_group)] = group self._ranks_to_group[tuple(ranks_in_group)] = group
if group is not GroupMember.NON_GROUP_MEMBER:
self._group_to_ranks[group] = tuple(ranks_in_group) self._group_to_ranks[group] = tuple(ranks_in_group)
return self._ranks_to_group[tuple(ranks_in_group)] return self._ranks_to_group[tuple(ranks_in_group)]