diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py index b6aff0d72..7f1ef9fce 100644 --- a/colossalai/cluster/process_group_mesh.py +++ b/colossalai/cluster/process_group_mesh.py @@ -7,6 +7,7 @@ from typing import Dict, List, Optional, Tuple, Union import numpy as np import torch.distributed as dist from torch.distributed import ProcessGroup +from torch.distributed.distributed_c10d import GroupMember def prod(nums: List[int]) -> int: @@ -47,7 +48,7 @@ class ProcessGroupMesh: self._shape = size self._rank = dist.get_rank() self._coord = ProcessGroupMesh.unravel(self._rank, self._shape) - self._ranks_to_group: Dict[Tuple[int, ...], ProcessGroup] = {} + self._ranks_to_group: Dict[Tuple[int, ...], Union[ProcessGroup, GroupMember]] = {} self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {} def destroy_mesh_process_groups(self): @@ -150,7 +151,8 @@ class ProcessGroupMesh: if tuple(ranks_in_group) not in self._ranks_to_group: group = dist.new_group(ranks_in_group, backend=backend) self._ranks_to_group[tuple(ranks_in_group)] = group - self._group_to_ranks[group] = tuple(ranks_in_group) + if group is not GroupMember.NON_GROUP_MEMBER: + self._group_to_ranks[group] = tuple(ranks_in_group) return self._ranks_to_group[tuple(ranks_in_group)] def get_ranks_in_group(self, group: ProcessGroup) -> List[int]: