From 280a81243dcd634d7926f4e8d85266064027211b Mon Sep 17 00:00:00 2001 From: HELSON Date: Thu, 7 Jul 2022 13:55:24 +0800 Subject: [PATCH] [tensor] improve robustness of class 'ProcessGroup' (#1223) --- colossalai/tensor/process_group.py | 75 ++++++++++++++++-------------- 1 file changed, 41 insertions(+), 34 deletions(-) diff --git a/colossalai/tensor/process_group.py b/colossalai/tensor/process_group.py index 5328f058f..3c959395c 100644 --- a/colossalai/tensor/process_group.py +++ b/colossalai/tensor/process_group.py @@ -10,29 +10,17 @@ class PyTorchProcessGroupDict(metaclass=SingletonMeta): # distributed settings self.dict = {} - def get(self, rank: int, world_size: int, tp_degree: int, dp_degree: int, backend: str = 'nccl'): - key = (tp_degree, dp_degree, backend) - if key in self.dict: - return self.dict[key] - else: - self.logger = get_dist_logger('PyTorchProcessGroupDict') - _tp_rank_list = [] - _dp_rank_list = [] + def get(self, rank_list: List[int], backend: str = 'nccl'): + """Reuse Pytorch ProcessGroup when such a group is initialized + """ + rank_tuple = tuple(rank_list) + # we need to convert the passed list to a tuple + # since List is unhashable + pg_key = (backend, rank_tuple) - for rank_id in range(world_size): - # rank_id and self._rank in the same tp group - if rank_id % tp_degree == rank % tp_degree: - _dp_rank_list.append(rank_id) - if rank_id // tp_degree == rank // tp_degree: - _tp_rank_list.append(rank_id) - - _tp_process_group = torch.distributed.new_group(ranks=_tp_rank_list, backend=backend) - _dp_process_group = torch.distributed.new_group(ranks=_dp_rank_list, backend=backend) - self.logger.info( - f'rank {rank} initialize process group on {backend}, dp ranks: {_dp_rank_list} tp ranks: {_tp_rank_list}' - ) - self.dict[key] = _tp_rank_list, _tp_process_group, _dp_rank_list, _dp_process_group - return _tp_rank_list, _tp_process_group, _dp_rank_list, _dp_process_group + if pg_key not in self.dict: + self.dict[pg_key] = torch.distributed.new_group(ranks=rank_list, backend=backend) + return self.dict[pg_key] PYTORCHPGDICT_ = PyTorchProcessGroupDict() @@ -50,7 +38,6 @@ class ProcessGroup: dp_degree: Optional[int], data parallelism degree, default None means len(ranks) """ - #TODO(haichen) fix me! ranks now must start from 0,1,2,3... def __init__(self, rank: Optional[int] = None, ranks: Optional[List[int]] = None, @@ -69,37 +56,57 @@ class ProcessGroup: self._rank_list = list(range(torch.distributed.get_world_size())) else: self._rank_list = ranks + self._rank_list.sort() # ensure that the list is in order + self._rank_idx = self._rank_list.index(self._rank) self._world_size = len(self._rank_list) if dp_degree is None and tp_degree is None: self._dp_degree = self._world_size self._tp_degree = 1 - - if dp_degree and not tp_degree: + elif dp_degree and not tp_degree: self._dp_degree = dp_degree assert self._world_size % self._dp_degree == 0, f"DP degree {dp_degree} should be divisible by {self._world_size} hen DP degree is None" self._tp_degree = self._world_size // dp_degree - - if not dp_degree and tp_degree: + elif not dp_degree and tp_degree: self._tp_degree = tp_degree assert self._world_size % self._tp_degree == 0, f"TP degree {tp_degree} should be divisible by {self._world_size} when DP degree is None" self._dp_degree = self._world_size // tp_degree + else: + self._dp_degree = dp_degree + self._tp_degree = tp_degree + assert self._dp_degree * self._tp_degree == self._world_size, \ + f"the world size {self._world_size} should equals to the product of DP degree {self._dp_degree}" \ + f"and TP degree {self._tp_degree}" + + self._tp_rank_list = [] + self._dp_rank_list = [] + + for idx, rank_id in enumerate(self._rank_list): + # idx and self._rank_idx in the same tp group + if idx % self._tp_degree == self._rank_idx % self._tp_degree: + self._dp_rank_list.append(rank_id) + if idx // self._tp_degree == self._rank_idx // self._tp_degree: + self._tp_rank_list.append(rank_id) + + self._tp_process_group = PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl') + self._dp_process_group = PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl') + + self.logger = get_dist_logger('ProcessGroup') + self.logger.info( + f'{self._rank} NCCL initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}') - self._tp_rank_list, self._tp_process_group, self._dp_rank_list, self._dp_process_group = PYTORCHPGDICT_.get( - self._rank, self._world_size, self._tp_degree, self._dp_degree, 'nccl') self._has_cpu_groups = False + self._cpu_dp_process_group = None + self._cpu_tp_process_group = None def set_cpu_groups(self): if self.has_cpu_groups: return self.logger.info( f'{self._rank} Gloo initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}') - self._cpu_tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list, backend='gloo') - self._cpu_dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list, backend='gloo') - - _, self._cpu_tp_process_group, _, self._cpu_dp_process_group = PYTORCHPGDICT_.get( - self._rank, self._world_size, self._tp_degree, self._dp_degree, 'gloo') + self._cpu_tp_process_group = PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo') + self._cpu_dp_process_group = PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo') @property def has_cpu_groups(self):