[Tensor] add cpu group to ddp (#1200)

pull/1207/head
Jiarui Fang 2022-07-05 14:58:28 +08:00 committed by GitHub
parent f7878f465c
commit b5f25eb32a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 30 additions and 26 deletions

View File

@ -54,14 +54,11 @@ class ColoDDP(torch.nn.Module):
module (torch.nn.Module): Module to apply DDP.
process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses.
If it's None, the default data parallel group will be used. Defaults to None.
cpu_process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses for those parameters on CPU.
If it's None, the default CPU data parallel group will be used. Defaults to None.
"""
def __init__(self,
module: torch.nn.Module,
process_group: ColoProcessGroup,
cpu_process_group: Optional[dist.ProcessGroup] = None,
bucket_cap_mb: int = 25,
rebuild_bucket: bool = True) -> None:
assert not isinstance(module, ColoDDP)
@ -70,8 +67,9 @@ class ColoDDP(torch.nn.Module):
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
assert process_group
self.process_group = process_group.dp_process_group()
self.dp_world_size = self.process_group.size()
self.process_group = process_group
self.dp_world_size = self.process_group.dp_world_size()
self.reducer = Reducer(bucket_cap_mb)
self.rebuild_bucket = rebuild_bucket
for p in module.parameters():
@ -112,7 +110,7 @@ class ColoDDP(torch.nn.Module):
self.comm_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.comm_stream):
self.reducer.all_reduce_async(grad,
group=self.process_group,
group=self.process_group.dp_process_group(),
callback_fn=partial(self._save_grad, p))
grad.record_stream(self.comm_stream)
else:
@ -121,8 +119,8 @@ class ColoDDP(torch.nn.Module):
else:
#TODO(jiaruifang) fixme
raise NotImplementedError
dist.all_reduce(grad, group=self.cpu_process_group)
self.process_group.set_cpu_groups()
dist.all_reduce(grad, group=self.process_group.cpu_dp_process_group())
return grad
@staticmethod

View File

@ -114,10 +114,8 @@ class DistSpecManager:
if world_size == 1:
return tensor
assert tensor.device.type == "cuda" and old_dist_spec.process_group.backend == "nccl", \
"Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll " \
f"collective function, however, we got {tensor.device.type} device and " \
f"{old_dist_spec.process_group.backend} backend"
assert tensor.device.type == "cuda", "Currently, only CUDA Tensors are supported for the requested AlltoAll " \
f"collective function, however, we got {tensor.device.type} device"
gather_dim = old_dist_spec.dims[0]
scatter_dim = dist_spec.dims[0]

View File

@ -18,7 +18,6 @@ class ProcessGroup:
def __init__(self,
rank: Optional[int] = None,
ranks: Optional[List[int]] = None,
backend: str = 'nccl',
tp_degree: Optional[int] = None,
dp_degree: Optional[int] = None) -> None:
assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized"
@ -32,7 +31,6 @@ class ProcessGroup:
else:
self._rank_list = ranks
self._backend = backend
self._world_size = len(self._rank_list)
if dp_degree is None and tp_degree is None:
@ -59,16 +57,26 @@ class ProcessGroup:
if rank_id // self._tp_degree == self._rank // self._tp_degree:
self._tp_rank_list.append(rank_id)
assert backend == 'nccl'
self._tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list)
self._dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list)
self._tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list, backend='nccl')
self._dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list, backend='nccl')
self.logger = get_dist_logger('ProcessGroup')
self.logger.info(f'{self._rank} initialize TP group on {self._tp_rank_list} DP group pn {self._dp_rank_list}')
self.logger.info(
f'{self._rank} NCCL initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}')
self._has_cpu_groups = False
def set_cpu_groups(self):
if self.has_cpu_groups:
return
self.logger.info(
f'{self._rank} Gloo initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}')
self._cpu_tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list, backend='gloo')
self._cpu_dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list, backend='gloo')
@property
def backend(self):
return self._backend
def has_cpu_groups(self):
return self._has_cpu_groups
def __eq__(self, obj: 'ProcessGroup') -> bool:
if not isinstance(obj, ProcessGroup):
@ -81,8 +89,6 @@ class ProcessGroup:
assert False
if self._dp_rank_list != obj._dp_rank_list:
assert False
if self._backend != obj._backend:
assert False
if self._tp_degree != obj._tp_degree:
return False
if self._dp_degree != obj._dp_degree:
@ -112,3 +118,9 @@ class ProcessGroup:
def tp_process_group(self):
return self._tp_process_group
def cpu_dp_process_group(self):
return self._cpu_dp_process_group
def cpu_tp_process_group(self):
return self._cpu_tp_process_group

View File

@ -17,13 +17,9 @@ class TensorSpec(object):
self.compute_spec = compute_spec
self.dist_spec = dist_spec
# TODO(jiaruifang) actually need tp process group
def get_process_group(self):
return self.dist_spec.process_group
def get_process_group_size(self):
return dist.get_world_size(self.dist_spec.process_group.tp_process_group())
def get_placement(self):
return self.dist_spec.placement