[tensor] sharded global process group (#1219)

pull/1223/head
Jiarui Fang 2 years ago committed by GitHub
parent db1bef9032
commit 15d988f954
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,41 @@
import torch import torch
from typing import List, Optional from typing import List, Optional
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.context.singleton_meta import SingletonMeta
class PyTorchProcessGroupDict(metaclass=SingletonMeta):
def __init__(self):
# distributed settings
self.dict = {}
def get(self, rank: int, world_size: int, tp_degree: int, dp_degree: int, backend: str = 'nccl'):
key = (tp_degree, dp_degree, backend)
if key in self.dict:
return self.dict[key]
else:
self.logger = get_dist_logger('PyTorchProcessGroupDict')
_tp_rank_list = []
_dp_rank_list = []
for rank_id in range(world_size):
# rank_id and self._rank in the same tp group
if rank_id % tp_degree == rank % tp_degree:
_dp_rank_list.append(rank_id)
if rank_id // tp_degree == rank // tp_degree:
_tp_rank_list.append(rank_id)
_tp_process_group = torch.distributed.new_group(ranks=_tp_rank_list, backend=backend)
_dp_process_group = torch.distributed.new_group(ranks=_dp_rank_list, backend=backend)
self.logger.info(
f'rank {rank} initialize process group on {backend}, dp ranks: {_dp_rank_list} tp ranks: {_tp_rank_list}'
)
self.dict[key] = _tp_rank_list, _tp_process_group, _dp_rank_list, _dp_process_group
return _tp_rank_list, _tp_process_group, _dp_rank_list, _dp_process_group
PYTORCHPGDICT_ = PyTorchProcessGroupDict()
class ProcessGroup: class ProcessGroup:
@ -15,6 +50,7 @@ class ProcessGroup:
dp_degree: Optional[int], data parallelism degree, default None means len(ranks) dp_degree: Optional[int], data parallelism degree, default None means len(ranks)
""" """
#TODO(haichen) fix me! ranks now must start from 0,1,2,3...
def __init__(self, def __init__(self,
rank: Optional[int] = None, rank: Optional[int] = None,
ranks: Optional[List[int]] = None, ranks: Optional[List[int]] = None,
@ -50,23 +86,8 @@ class ProcessGroup:
assert self._world_size % self._tp_degree == 0, f"TP degree {tp_degree} should be divisible by {self._world_size} when DP degree is None" assert self._world_size % self._tp_degree == 0, f"TP degree {tp_degree} should be divisible by {self._world_size} when DP degree is None"
self._dp_degree = self._world_size // tp_degree self._dp_degree = self._world_size // tp_degree
self._tp_rank_list = [] self._tp_rank_list, self._tp_process_group, self._dp_rank_list, self._dp_process_group = PYTORCHPGDICT_.get(
self._dp_rank_list = [] self._rank, self._world_size, self._tp_degree, self._dp_degree, 'nccl')
for rank_id in range(self._world_size):
# rank_id and self._rank in the same tp group
if rank_id % self._tp_degree == self._rank % self._tp_degree:
self._dp_rank_list.append(rank_id)
if rank_id // self._tp_degree == self._rank // self._tp_degree:
self._tp_rank_list.append(rank_id)
self._tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list, backend='nccl')
self._dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list, backend='nccl')
self.logger = get_dist_logger('ProcessGroup')
self.logger.info(
f'{self._rank} NCCL initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}')
self._has_cpu_groups = False self._has_cpu_groups = False
def set_cpu_groups(self): def set_cpu_groups(self):
@ -77,6 +98,9 @@ class ProcessGroup:
self._cpu_tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list, backend='gloo') self._cpu_tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list, backend='gloo')
self._cpu_dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list, backend='gloo') self._cpu_dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list, backend='gloo')
_, self._cpu_tp_process_group, _, self._cpu_dp_process_group = PYTORCHPGDICT_.get(
self._rank, self._world_size, self._tp_degree, self._dp_degree, 'gloo')
@property @property
def has_cpu_groups(self): def has_cpu_groups(self):
return self._has_cpu_groups return self._has_cpu_groups

@ -103,9 +103,9 @@ def run_dist_tests(rank, world_size, port):
_run_view(world_size) _run_view(world_size)
_run_process_group(world_size) _run_process_group(world_size)
_run_tensor_indexing() _run_tensor_indexing()
_run_operand()
# TODO not passed # TODO not passed
# _run_wrapped_tensor_func() # _run_wrapped_tensor_func()
_run_operand()
@pytest.mark.dist @pytest.mark.dist

Loading…
Cancel
Save