mirror of https://github.com/hpcaitech/ColossalAI
[chore] handle non member group
parent
4fc6f9aa98
commit
000456bf94
|
@ -7,6 +7,7 @@ from typing import Dict, List, Optional, Tuple, Union
|
|||
import numpy as np
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed.distributed_c10d import GroupMember
|
||||
|
||||
|
||||
def prod(nums: List[int]) -> int:
|
||||
|
@ -47,7 +48,7 @@ class ProcessGroupMesh:
|
|||
self._shape = size
|
||||
self._rank = dist.get_rank()
|
||||
self._coord = ProcessGroupMesh.unravel(self._rank, self._shape)
|
||||
self._ranks_to_group: Dict[Tuple[int, ...], ProcessGroup] = {}
|
||||
self._ranks_to_group: Dict[Tuple[int, ...], Union[ProcessGroup, GroupMember]] = {}
|
||||
self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {}
|
||||
|
||||
def destroy_mesh_process_groups(self):
|
||||
|
@ -150,6 +151,7 @@ class ProcessGroupMesh:
|
|||
if tuple(ranks_in_group) not in self._ranks_to_group:
|
||||
group = dist.new_group(ranks_in_group, backend=backend)
|
||||
self._ranks_to_group[tuple(ranks_in_group)] = group
|
||||
if group is not GroupMember.NON_GROUP_MEMBER:
|
||||
self._group_to_ranks[group] = tuple(ranks_in_group)
|
||||
return self._ranks_to_group[tuple(ranks_in_group)]
|
||||
|
||||
|
|
Loading…
Reference in New Issue