diff --git a/colossalai/tensor/process_group.py b/colossalai/tensor/process_group.py index ebc8713b8..5328f058f 100644 --- a/colossalai/tensor/process_group.py +++ b/colossalai/tensor/process_group.py @@ -1,6 +1,41 @@ import torch from typing import List, Optional from colossalai.logging import get_dist_logger +from colossalai.context.singleton_meta import SingletonMeta + + +class PyTorchProcessGroupDict(metaclass=SingletonMeta): + + def __init__(self): + # distributed settings + self.dict = {} + + def get(self, rank: int, world_size: int, tp_degree: int, dp_degree: int, backend: str = 'nccl'): + key = (tp_degree, dp_degree, backend) + if key in self.dict: + return self.dict[key] + else: + self.logger = get_dist_logger('PyTorchProcessGroupDict') + _tp_rank_list = [] + _dp_rank_list = [] + + for rank_id in range(world_size): + # rank_id and self._rank in the same tp group + if rank_id % tp_degree == rank % tp_degree: + _dp_rank_list.append(rank_id) + if rank_id // tp_degree == rank // tp_degree: + _tp_rank_list.append(rank_id) + + _tp_process_group = torch.distributed.new_group(ranks=_tp_rank_list, backend=backend) + _dp_process_group = torch.distributed.new_group(ranks=_dp_rank_list, backend=backend) + self.logger.info( + f'rank {rank} initialize process group on {backend}, dp ranks: {_dp_rank_list} tp ranks: {_tp_rank_list}' + ) + self.dict[key] = _tp_rank_list, _tp_process_group, _dp_rank_list, _dp_process_group + return _tp_rank_list, _tp_process_group, _dp_rank_list, _dp_process_group + + +PYTORCHPGDICT_ = PyTorchProcessGroupDict() class ProcessGroup: @@ -15,6 +50,7 @@ class ProcessGroup: dp_degree: Optional[int], data parallelism degree, default None means len(ranks) """ + #TODO(haichen) fix me! ranks now must start from 0,1,2,3... def __init__(self, rank: Optional[int] = None, ranks: Optional[List[int]] = None, @@ -50,23 +86,8 @@ class ProcessGroup: assert self._world_size % self._tp_degree == 0, f"TP degree {tp_degree} should be divisible by {self._world_size} when DP degree is None" self._dp_degree = self._world_size // tp_degree - self._tp_rank_list = [] - self._dp_rank_list = [] - - for rank_id in range(self._world_size): - # rank_id and self._rank in the same tp group - if rank_id % self._tp_degree == self._rank % self._tp_degree: - self._dp_rank_list.append(rank_id) - if rank_id // self._tp_degree == self._rank // self._tp_degree: - self._tp_rank_list.append(rank_id) - - self._tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list, backend='nccl') - self._dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list, backend='nccl') - - self.logger = get_dist_logger('ProcessGroup') - self.logger.info( - f'{self._rank} NCCL initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}') - + self._tp_rank_list, self._tp_process_group, self._dp_rank_list, self._dp_process_group = PYTORCHPGDICT_.get( + self._rank, self._world_size, self._tp_degree, self._dp_degree, 'nccl') self._has_cpu_groups = False def set_cpu_groups(self): @@ -77,6 +98,9 @@ class ProcessGroup: self._cpu_tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list, backend='gloo') self._cpu_dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list, backend='gloo') + _, self._cpu_tp_process_group, _, self._cpu_dp_process_group = PYTORCHPGDICT_.get( + self._rank, self._world_size, self._tp_degree, self._dp_degree, 'gloo') + @property def has_cpu_groups(self): return self._has_cpu_groups diff --git a/tests/test_tensor/test_tensor.py b/tests/test_tensor/test_tensor.py index b144b00d3..9ed267301 100644 --- a/tests/test_tensor/test_tensor.py +++ b/tests/test_tensor/test_tensor.py @@ -103,9 +103,9 @@ def run_dist_tests(rank, world_size, port): _run_view(world_size) _run_process_group(world_size) _run_tensor_indexing() + _run_operand() # TODO not passed # _run_wrapped_tensor_func() - _run_operand() @pytest.mark.dist