mirror of https://github.com/hpcaitech/ColossalAI
[tensor] improve robustness of class 'ProcessGroup' (#1223)
parent
15d988f954
commit
280a81243d
|
@ -10,29 +10,17 @@ class PyTorchProcessGroupDict(metaclass=SingletonMeta):
|
||||||
# distributed settings
|
# distributed settings
|
||||||
self.dict = {}
|
self.dict = {}
|
||||||
|
|
||||||
def get(self, rank: int, world_size: int, tp_degree: int, dp_degree: int, backend: str = 'nccl'):
|
def get(self, rank_list: List[int], backend: str = 'nccl'):
|
||||||
key = (tp_degree, dp_degree, backend)
|
"""Reuse Pytorch ProcessGroup when such a group is initialized
|
||||||
if key in self.dict:
|
"""
|
||||||
return self.dict[key]
|
rank_tuple = tuple(rank_list)
|
||||||
else:
|
# we need to convert the passed list to a tuple
|
||||||
self.logger = get_dist_logger('PyTorchProcessGroupDict')
|
# since List is unhashable
|
||||||
_tp_rank_list = []
|
pg_key = (backend, rank_tuple)
|
||||||
_dp_rank_list = []
|
|
||||||
|
|
||||||
for rank_id in range(world_size):
|
if pg_key not in self.dict:
|
||||||
# rank_id and self._rank in the same tp group
|
self.dict[pg_key] = torch.distributed.new_group(ranks=rank_list, backend=backend)
|
||||||
if rank_id % tp_degree == rank % tp_degree:
|
return self.dict[pg_key]
|
||||||
_dp_rank_list.append(rank_id)
|
|
||||||
if rank_id // tp_degree == rank // tp_degree:
|
|
||||||
_tp_rank_list.append(rank_id)
|
|
||||||
|
|
||||||
_tp_process_group = torch.distributed.new_group(ranks=_tp_rank_list, backend=backend)
|
|
||||||
_dp_process_group = torch.distributed.new_group(ranks=_dp_rank_list, backend=backend)
|
|
||||||
self.logger.info(
|
|
||||||
f'rank {rank} initialize process group on {backend}, dp ranks: {_dp_rank_list} tp ranks: {_tp_rank_list}'
|
|
||||||
)
|
|
||||||
self.dict[key] = _tp_rank_list, _tp_process_group, _dp_rank_list, _dp_process_group
|
|
||||||
return _tp_rank_list, _tp_process_group, _dp_rank_list, _dp_process_group
|
|
||||||
|
|
||||||
|
|
||||||
PYTORCHPGDICT_ = PyTorchProcessGroupDict()
|
PYTORCHPGDICT_ = PyTorchProcessGroupDict()
|
||||||
|
@ -50,7 +38,6 @@ class ProcessGroup:
|
||||||
dp_degree: Optional[int], data parallelism degree, default None means len(ranks)
|
dp_degree: Optional[int], data parallelism degree, default None means len(ranks)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
#TODO(haichen) fix me! ranks now must start from 0,1,2,3...
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
rank: Optional[int] = None,
|
rank: Optional[int] = None,
|
||||||
ranks: Optional[List[int]] = None,
|
ranks: Optional[List[int]] = None,
|
||||||
|
@ -69,37 +56,57 @@ class ProcessGroup:
|
||||||
self._rank_list = list(range(torch.distributed.get_world_size()))
|
self._rank_list = list(range(torch.distributed.get_world_size()))
|
||||||
else:
|
else:
|
||||||
self._rank_list = ranks
|
self._rank_list = ranks
|
||||||
|
self._rank_list.sort() # ensure that the list is in order
|
||||||
|
|
||||||
|
self._rank_idx = self._rank_list.index(self._rank)
|
||||||
self._world_size = len(self._rank_list)
|
self._world_size = len(self._rank_list)
|
||||||
|
|
||||||
if dp_degree is None and tp_degree is None:
|
if dp_degree is None and tp_degree is None:
|
||||||
self._dp_degree = self._world_size
|
self._dp_degree = self._world_size
|
||||||
self._tp_degree = 1
|
self._tp_degree = 1
|
||||||
|
elif dp_degree and not tp_degree:
|
||||||
if dp_degree and not tp_degree:
|
|
||||||
self._dp_degree = dp_degree
|
self._dp_degree = dp_degree
|
||||||
assert self._world_size % self._dp_degree == 0, f"DP degree {dp_degree} should be divisible by {self._world_size} hen DP degree is None"
|
assert self._world_size % self._dp_degree == 0, f"DP degree {dp_degree} should be divisible by {self._world_size} hen DP degree is None"
|
||||||
self._tp_degree = self._world_size // dp_degree
|
self._tp_degree = self._world_size // dp_degree
|
||||||
|
elif not dp_degree and tp_degree:
|
||||||
if not dp_degree and tp_degree:
|
|
||||||
self._tp_degree = tp_degree
|
self._tp_degree = tp_degree
|
||||||
assert self._world_size % self._tp_degree == 0, f"TP degree {tp_degree} should be divisible by {self._world_size} when DP degree is None"
|
assert self._world_size % self._tp_degree == 0, f"TP degree {tp_degree} should be divisible by {self._world_size} when DP degree is None"
|
||||||
self._dp_degree = self._world_size // tp_degree
|
self._dp_degree = self._world_size // tp_degree
|
||||||
|
else:
|
||||||
|
self._dp_degree = dp_degree
|
||||||
|
self._tp_degree = tp_degree
|
||||||
|
assert self._dp_degree * self._tp_degree == self._world_size, \
|
||||||
|
f"the world size {self._world_size} should equals to the product of DP degree {self._dp_degree}" \
|
||||||
|
f"and TP degree {self._tp_degree}"
|
||||||
|
|
||||||
|
self._tp_rank_list = []
|
||||||
|
self._dp_rank_list = []
|
||||||
|
|
||||||
|
for idx, rank_id in enumerate(self._rank_list):
|
||||||
|
# idx and self._rank_idx in the same tp group
|
||||||
|
if idx % self._tp_degree == self._rank_idx % self._tp_degree:
|
||||||
|
self._dp_rank_list.append(rank_id)
|
||||||
|
if idx // self._tp_degree == self._rank_idx // self._tp_degree:
|
||||||
|
self._tp_rank_list.append(rank_id)
|
||||||
|
|
||||||
|
self._tp_process_group = PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
|
||||||
|
self._dp_process_group = PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
|
||||||
|
|
||||||
|
self.logger = get_dist_logger('ProcessGroup')
|
||||||
|
self.logger.info(
|
||||||
|
f'{self._rank} NCCL initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}')
|
||||||
|
|
||||||
self._tp_rank_list, self._tp_process_group, self._dp_rank_list, self._dp_process_group = PYTORCHPGDICT_.get(
|
|
||||||
self._rank, self._world_size, self._tp_degree, self._dp_degree, 'nccl')
|
|
||||||
self._has_cpu_groups = False
|
self._has_cpu_groups = False
|
||||||
|
self._cpu_dp_process_group = None
|
||||||
|
self._cpu_tp_process_group = None
|
||||||
|
|
||||||
def set_cpu_groups(self):
|
def set_cpu_groups(self):
|
||||||
if self.has_cpu_groups:
|
if self.has_cpu_groups:
|
||||||
return
|
return
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f'{self._rank} Gloo initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}')
|
f'{self._rank} Gloo initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}')
|
||||||
self._cpu_tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list, backend='gloo')
|
self._cpu_tp_process_group = PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
|
||||||
self._cpu_dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list, backend='gloo')
|
self._cpu_dp_process_group = PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
|
||||||
|
|
||||||
_, self._cpu_tp_process_group, _, self._cpu_dp_process_group = PYTORCHPGDICT_.get(
|
|
||||||
self._rank, self._world_size, self._tp_degree, self._dp_degree, 'gloo')
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def has_cpu_groups(self):
|
def has_cpu_groups(self):
|
||||||
|
|
Loading…
Reference in New Issue