ColossalAI/colossalai/tensor/process_group.py

312 lines
9.8 KiB
Python
Raw Normal View History

import torch
from typing import List, Optional
from colossalai.logging import get_dist_logger
from colossalai.context.singleton_meta import SingletonMeta
class PyTorchProcessGroupDict(metaclass=SingletonMeta):
def __init__(self):
# distributed settings
self.dict = {}
def get(self, rank_list: List[int], backend: str = 'nccl'):
"""Reuse Pytorch ProcessGroup when such a group is initialized
"""
rank_tuple = tuple(rank_list)
# we need to convert the passed list to a tuple
# since List is unhashable
pg_key = (backend, rank_tuple)
if pg_key not in self.dict:
self.logger = get_dist_logger('ProcessGroup')
self.logger.info(f'NCCL initialize ProcessGroup on {rank_list}', ranks=[0])
self.dict[pg_key] = torch.distributed.new_group(ranks=rank_list, backend=backend)
return self.dict[pg_key]
PYTORCHPGDICT_ = PyTorchProcessGroupDict()
class ProcessGroup:
"""ProcessGroup
Process Group indicates how processes are organized in groups for parallel execution using Tensor Parallelism and Data Parallelism.
NOTE, the ProcessGroup must be used after `torch.distributed.initialize()`
Args:
rank: the global rank of the current process.
ranks: List[int], a list of rank id belongings to this process group.
backend: str, the backend of the process group.
tp_degree: Optional[int], tensor parallelism degree. How many processes are inside a tp process group. default None means 1.
dp_degree: Optional[int], data parallelism degree. How many processes are inside a dp process group. . default None means len(ranks).
"""
def __init__(self,
rank: Optional[int] = None,
ranks: Optional[List[int]] = None,
tp_degree: Optional[int] = None,
dp_degree: Optional[int] = None) -> None:
if not torch.distributed.is_initialized():
self.is_init = False
return
assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized"
if rank is None:
self._rank = torch.distributed.get_rank()
else:
self._rank = rank
if ranks is None:
self._rank_list = list(range(torch.distributed.get_world_size()))
else:
self._rank_list = ranks
self._rank_list.sort() # ensure that the list is in order
self._world_size = len(self._rank_list)
if dp_degree is None and tp_degree is None:
self._dp_degree = self._world_size
self._tp_degree = 1
elif dp_degree and not tp_degree:
self._dp_degree = dp_degree
assert self._world_size % self._dp_degree == 0, f"DP degree {dp_degree} should be divisible by {self._world_size} hen DP degree is None"
self._tp_degree = self._world_size // dp_degree
elif not dp_degree and tp_degree:
self._tp_degree = tp_degree
assert self._world_size % self._tp_degree == 0, f"TP degree {tp_degree} should be divisible by {self._world_size} when DP degree is None"
self._dp_degree = self._world_size // tp_degree
else:
self._dp_degree = dp_degree
self._tp_degree = tp_degree
assert self._dp_degree * self._tp_degree == self._world_size, \
f"the world size {self._world_size} should equals to the product of DP degree {self._dp_degree}" \
f"and TP degree {self._tp_degree}"
self._tp_rank_list = None
self._dp_rank_list = None
for i in range(self._dp_degree):
i_tp_list = [self._rank_list[i * self._tp_degree + j] for j in range(self._tp_degree)]
PYTORCHPGDICT_.get(i_tp_list, 'nccl')
if self._rank in i_tp_list:
self._tp_rank_list = i_tp_list
for j in range(self._tp_degree):
j_dp_list = [self._rank_list[i * self._tp_degree + j] for i in range(self._dp_degree)]
PYTORCHPGDICT_.get(j_dp_list, 'nccl')
if self._rank in j_dp_list:
self._dp_rank_list = j_dp_list
2022-07-05 06:58:28 +00:00
self._has_cpu_groups = False
self.is_init = True
2022-07-05 06:58:28 +00:00
def set_cpu_groups(self):
"""set_cpu_groups
Initialize Pytorch process groups for cpu communications.
"""
2022-07-05 06:58:28 +00:00
if self.has_cpu_groups:
return
for i in range(self._dp_degree):
i_tp_list = [self._rank_list[i * self._tp_degree + j] for j in range(self._tp_degree)]
PYTORCHPGDICT_.get(i_tp_list, 'gloo')
for j in range(self._tp_degree):
j_dp_list = [self._rank_list[i * self._tp_degree + j] for i in range(self._dp_degree)]
PYTORCHPGDICT_.get(j_dp_list, 'gloo')
self._has_cpu_groups = True
@property
def has_cpu_groups(self) -> bool:
"""has_cpu_groups
If cpu groups have been initailized.
Returns:
bool: cpu process groups have been initialized or not.
"""
2022-07-05 06:58:28 +00:00
return self._has_cpu_groups
def __repr__(self):
if self.is_init:
return "ProcessGroup:\n\tRank: {}, World size: {}, DP degree: {}, TP degree: {}\n\tRanks in group: {}".\
format(self._rank, self._world_size, self._dp_degree, self._tp_degree, self._rank_list)
else:
return "ProcessGroup not initialized"
def __eq__(self, obj: 'ProcessGroup') -> bool:
if not isinstance(obj, ProcessGroup):
return False
if self._rank != obj._rank:
return False
if self._rank_list != obj._rank_list:
return False
if self._tp_rank_list != obj._tp_rank_list:
return False
if self._dp_rank_list != obj._dp_rank_list:
return False
if self._tp_degree != obj._tp_degree:
return False
if self._dp_degree != obj._dp_degree:
return False
return True
def rank(self) -> int:
"""rank
The current rank in the global process group.
Returns:
int: the rank number
"""
return self._rank
def ranks_in_group(self) -> List[int]:
"""ranks_in_group
a list of rank number in in the global process group.
Returns:
List[int]: a list of rank number.
"""
return self._rank_list
def world_size(self) -> int:
"""world_size
The world size of the global process group.
Returns:
int: world size
"""
return self._world_size
def tp_rank_list(self) -> List[int]:
"""tp_rank_list
the rank list in the TP process group containing the current rank.
Returns:
List[int]: the list of rank number.
"""
return self._tp_rank_list
def dp_rank_list(self) -> List[int]:
"""dp_rank_list
the rank list in the DP process group containing the current rank.
Returns:
List[int]: the list of rank number.
"""
return self._dp_rank_list
def tp_local_rank(self) -> int:
"""tp_local_rank
The local rank number in the current TP process group.
Returns:
int: tp rank number.
"""
return self._rank % self._tp_degree
def dp_local_rank(self) -> int:
"""dp_local_rank
The local rank number in the current DP process group.
Returns:
int: dp rank number.
"""
return self._rank // self._tp_degree
def dp_world_size(self) -> int:
"""dp_world_size
The world size of the current DP process group.
Returns:
int: dp world size
"""
return len(self._dp_rank_list)
def tp_world_size(self) -> int:
"""tp_world_size
The world size of the current TP process group.
Returns:
int: tp world size
"""
return len(self._tp_rank_list)
def dp_process_group(self):
"""dp_process_group
the pytorch DP process group containing the current rank.
Returns:
`torch._C._distributed_c10d.ProcessGroup`: the pytorch DP process group.
"""
return PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
def tp_process_group(self):
"""tp_process_group
the pytorch TP process group containing the current rank.
Returns:
`torch._C._distributed_c10d.ProcessGroup`: the pytorch TP process group.
"""
return PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
2022-07-05 06:58:28 +00:00
def cpu_dp_process_group(self):
"""cpu_dp_process_group
the pytorch CPU DP process group containing the current rank.
assert failed if cpu process group is not initialized.
Returns:
`torch._C._distributed_c10d.ProcessGroup`: the pytorch DP process group.
"""
assert self._has_cpu_groups
return PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
2022-07-05 06:58:28 +00:00
def cpu_tp_process_group(self):
"""cpu_tp_process_group
the pytorch CPU TP process group containing the current rank.
assert failed if cpu process group is not initialized.
Returns:
`torch._C._distributed_c10d.ProcessGroup`: the pytorch TP process group.
"""
assert self._has_cpu_groups
return PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
def get_ranks_in_dp(self) -> List[int]:
"""get_ranks_in_dp
ranks in current dp process group.
Returns:
List[int]: a list of rank number.
"""
return self._dp_rank_list
def get_ranks_in_tp(self):
"""get_ranks_in_tp
ranks in current tp process group.
Returns:
List[int]: a list of rank number.
"""
return self._tp_rank_list