[Tensor] add cpu group to ddp (#1200)

pull/1207/head
Jiarui Fang 2022-07-05 14:58:28 +08:00 committed by GitHub
parent f7878f465c
commit b5f25eb32a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 30 additions and 26 deletions

View File

@ -54,14 +54,11 @@ class ColoDDP(torch.nn.Module):
module (torch.nn.Module): Module to apply DDP. module (torch.nn.Module): Module to apply DDP.
process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses. process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses.
If it's None, the default data parallel group will be used. Defaults to None. If it's None, the default data parallel group will be used. Defaults to None.
cpu_process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses for those parameters on CPU.
If it's None, the default CPU data parallel group will be used. Defaults to None.
""" """
def __init__(self, def __init__(self,
module: torch.nn.Module, module: torch.nn.Module,
process_group: ColoProcessGroup, process_group: ColoProcessGroup,
cpu_process_group: Optional[dist.ProcessGroup] = None,
bucket_cap_mb: int = 25, bucket_cap_mb: int = 25,
rebuild_bucket: bool = True) -> None: rebuild_bucket: bool = True) -> None:
assert not isinstance(module, ColoDDP) assert not isinstance(module, ColoDDP)
@ -70,8 +67,9 @@ class ColoDDP(torch.nn.Module):
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream() self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
assert process_group assert process_group
self.process_group = process_group.dp_process_group() self.process_group = process_group
self.dp_world_size = self.process_group.size() self.dp_world_size = self.process_group.dp_world_size()
self.reducer = Reducer(bucket_cap_mb) self.reducer = Reducer(bucket_cap_mb)
self.rebuild_bucket = rebuild_bucket self.rebuild_bucket = rebuild_bucket
for p in module.parameters(): for p in module.parameters():
@ -112,7 +110,7 @@ class ColoDDP(torch.nn.Module):
self.comm_stream.wait_stream(torch.cuda.current_stream()) self.comm_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.comm_stream): with torch.cuda.stream(self.comm_stream):
self.reducer.all_reduce_async(grad, self.reducer.all_reduce_async(grad,
group=self.process_group, group=self.process_group.dp_process_group(),
callback_fn=partial(self._save_grad, p)) callback_fn=partial(self._save_grad, p))
grad.record_stream(self.comm_stream) grad.record_stream(self.comm_stream)
else: else:
@ -121,8 +119,8 @@ class ColoDDP(torch.nn.Module):
else: else:
#TODO(jiaruifang) fixme #TODO(jiaruifang) fixme
raise NotImplementedError self.process_group.set_cpu_groups()
dist.all_reduce(grad, group=self.cpu_process_group) dist.all_reduce(grad, group=self.process_group.cpu_dp_process_group())
return grad return grad
@staticmethod @staticmethod

View File

@ -114,10 +114,8 @@ class DistSpecManager:
if world_size == 1: if world_size == 1:
return tensor return tensor
assert tensor.device.type == "cuda" and old_dist_spec.process_group.backend == "nccl", \ assert tensor.device.type == "cuda", "Currently, only CUDA Tensors are supported for the requested AlltoAll " \
"Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll " \ f"collective function, however, we got {tensor.device.type} device"
f"collective function, however, we got {tensor.device.type} device and " \
f"{old_dist_spec.process_group.backend} backend"
gather_dim = old_dist_spec.dims[0] gather_dim = old_dist_spec.dims[0]
scatter_dim = dist_spec.dims[0] scatter_dim = dist_spec.dims[0]

View File

@ -18,7 +18,6 @@ class ProcessGroup:
def __init__(self, def __init__(self,
rank: Optional[int] = None, rank: Optional[int] = None,
ranks: Optional[List[int]] = None, ranks: Optional[List[int]] = None,
backend: str = 'nccl',
tp_degree: Optional[int] = None, tp_degree: Optional[int] = None,
dp_degree: Optional[int] = None) -> None: dp_degree: Optional[int] = None) -> None:
assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized" assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized"
@ -32,7 +31,6 @@ class ProcessGroup:
else: else:
self._rank_list = ranks self._rank_list = ranks
self._backend = backend
self._world_size = len(self._rank_list) self._world_size = len(self._rank_list)
if dp_degree is None and tp_degree is None: if dp_degree is None and tp_degree is None:
@ -59,16 +57,26 @@ class ProcessGroup:
if rank_id // self._tp_degree == self._rank // self._tp_degree: if rank_id // self._tp_degree == self._rank // self._tp_degree:
self._tp_rank_list.append(rank_id) self._tp_rank_list.append(rank_id)
assert backend == 'nccl' self._tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list, backend='nccl')
self._tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list) self._dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list, backend='nccl')
self._dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list)
self.logger = get_dist_logger('ProcessGroup') self.logger = get_dist_logger('ProcessGroup')
self.logger.info(f'{self._rank} initialize TP group on {self._tp_rank_list} DP group pn {self._dp_rank_list}') self.logger.info(
f'{self._rank} NCCL initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}')
self._has_cpu_groups = False
def set_cpu_groups(self):
if self.has_cpu_groups:
return
self.logger.info(
f'{self._rank} Gloo initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}')
self._cpu_tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list, backend='gloo')
self._cpu_dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list, backend='gloo')
@property @property
def backend(self): def has_cpu_groups(self):
return self._backend return self._has_cpu_groups
def __eq__(self, obj: 'ProcessGroup') -> bool: def __eq__(self, obj: 'ProcessGroup') -> bool:
if not isinstance(obj, ProcessGroup): if not isinstance(obj, ProcessGroup):
@ -81,8 +89,6 @@ class ProcessGroup:
assert False assert False
if self._dp_rank_list != obj._dp_rank_list: if self._dp_rank_list != obj._dp_rank_list:
assert False assert False
if self._backend != obj._backend:
assert False
if self._tp_degree != obj._tp_degree: if self._tp_degree != obj._tp_degree:
return False return False
if self._dp_degree != obj._dp_degree: if self._dp_degree != obj._dp_degree:
@ -112,3 +118,9 @@ class ProcessGroup:
def tp_process_group(self): def tp_process_group(self):
return self._tp_process_group return self._tp_process_group
def cpu_dp_process_group(self):
return self._cpu_dp_process_group
def cpu_tp_process_group(self):
return self._cpu_tp_process_group

View File

@ -17,13 +17,9 @@ class TensorSpec(object):
self.compute_spec = compute_spec self.compute_spec = compute_spec
self.dist_spec = dist_spec self.dist_spec = dist_spec
# TODO(jiaruifang) actually need tp process group
def get_process_group(self): def get_process_group(self):
return self.dist_spec.process_group return self.dist_spec.process_group
def get_process_group_size(self):
return dist.get_world_size(self.dist_spec.process_group.tp_process_group())
def get_placement(self): def get_placement(self):
return self.dist_spec.placement return self.dist_spec.placement