|
|
|
@ -18,7 +18,6 @@ class ProcessGroup:
|
|
|
|
|
def __init__(self,
|
|
|
|
|
rank: Optional[int] = None,
|
|
|
|
|
ranks: Optional[List[int]] = None,
|
|
|
|
|
backend: str = 'nccl',
|
|
|
|
|
tp_degree: Optional[int] = None,
|
|
|
|
|
dp_degree: Optional[int] = None) -> None:
|
|
|
|
|
assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized"
|
|
|
|
@ -32,7 +31,6 @@ class ProcessGroup:
|
|
|
|
|
else:
|
|
|
|
|
self._rank_list = ranks
|
|
|
|
|
|
|
|
|
|
self._backend = backend
|
|
|
|
|
self._world_size = len(self._rank_list)
|
|
|
|
|
|
|
|
|
|
if dp_degree is None and tp_degree is None:
|
|
|
|
@ -59,16 +57,26 @@ class ProcessGroup:
|
|
|
|
|
if rank_id // self._tp_degree == self._rank // self._tp_degree:
|
|
|
|
|
self._tp_rank_list.append(rank_id)
|
|
|
|
|
|
|
|
|
|
assert backend == 'nccl'
|
|
|
|
|
self._tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list)
|
|
|
|
|
self._dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list)
|
|
|
|
|
self._tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list, backend='nccl')
|
|
|
|
|
self._dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list, backend='nccl')
|
|
|
|
|
|
|
|
|
|
self.logger = get_dist_logger('ProcessGroup')
|
|
|
|
|
self.logger.info(f'{self._rank} initialize TP group on {self._tp_rank_list} DP group pn {self._dp_rank_list}')
|
|
|
|
|
self.logger.info(
|
|
|
|
|
f'{self._rank} NCCL initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}')
|
|
|
|
|
|
|
|
|
|
self._has_cpu_groups = False
|
|
|
|
|
|
|
|
|
|
def set_cpu_groups(self):
|
|
|
|
|
if self.has_cpu_groups:
|
|
|
|
|
return
|
|
|
|
|
self.logger.info(
|
|
|
|
|
f'{self._rank} Gloo initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}')
|
|
|
|
|
self._cpu_tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list, backend='gloo')
|
|
|
|
|
self._cpu_dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list, backend='gloo')
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def backend(self):
|
|
|
|
|
return self._backend
|
|
|
|
|
def has_cpu_groups(self):
|
|
|
|
|
return self._has_cpu_groups
|
|
|
|
|
|
|
|
|
|
def __eq__(self, obj: 'ProcessGroup') -> bool:
|
|
|
|
|
if not isinstance(obj, ProcessGroup):
|
|
|
|
@ -81,8 +89,6 @@ class ProcessGroup:
|
|
|
|
|
assert False
|
|
|
|
|
if self._dp_rank_list != obj._dp_rank_list:
|
|
|
|
|
assert False
|
|
|
|
|
if self._backend != obj._backend:
|
|
|
|
|
assert False
|
|
|
|
|
if self._tp_degree != obj._tp_degree:
|
|
|
|
|
return False
|
|
|
|
|
if self._dp_degree != obj._dp_degree:
|
|
|
|
@ -112,3 +118,9 @@ class ProcessGroup:
|
|
|
|
|
|
|
|
|
|
def tp_process_group(self):
|
|
|
|
|
return self._tp_process_group
|
|
|
|
|
|
|
|
|
|
def cpu_dp_process_group(self):
|
|
|
|
|
return self._cpu_dp_process_group
|
|
|
|
|
|
|
|
|
|
def cpu_tp_process_group(self):
|
|
|
|
|
return self._cpu_tp_process_group
|
|
|
|
|