mirror of https://github.com/hpcaitech/ColossalAI
[Tensor] add cpu group to ddp (#1200)
parent
f7878f465c
commit
b5f25eb32a
|
@ -54,14 +54,11 @@ class ColoDDP(torch.nn.Module):
|
|||
module (torch.nn.Module): Module to apply DDP.
|
||||
process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses.
|
||||
If it's None, the default data parallel group will be used. Defaults to None.
|
||||
cpu_process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses for those parameters on CPU.
|
||||
If it's None, the default CPU data parallel group will be used. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
module: torch.nn.Module,
|
||||
process_group: ColoProcessGroup,
|
||||
cpu_process_group: Optional[dist.ProcessGroup] = None,
|
||||
bucket_cap_mb: int = 25,
|
||||
rebuild_bucket: bool = True) -> None:
|
||||
assert not isinstance(module, ColoDDP)
|
||||
|
@ -70,8 +67,9 @@ class ColoDDP(torch.nn.Module):
|
|||
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
|
||||
assert process_group
|
||||
|
||||
self.process_group = process_group.dp_process_group()
|
||||
self.dp_world_size = self.process_group.size()
|
||||
self.process_group = process_group
|
||||
self.dp_world_size = self.process_group.dp_world_size()
|
||||
|
||||
self.reducer = Reducer(bucket_cap_mb)
|
||||
self.rebuild_bucket = rebuild_bucket
|
||||
for p in module.parameters():
|
||||
|
@ -112,7 +110,7 @@ class ColoDDP(torch.nn.Module):
|
|||
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(self.comm_stream):
|
||||
self.reducer.all_reduce_async(grad,
|
||||
group=self.process_group,
|
||||
group=self.process_group.dp_process_group(),
|
||||
callback_fn=partial(self._save_grad, p))
|
||||
grad.record_stream(self.comm_stream)
|
||||
else:
|
||||
|
@ -121,8 +119,8 @@ class ColoDDP(torch.nn.Module):
|
|||
|
||||
else:
|
||||
#TODO(jiaruifang) fixme
|
||||
raise NotImplementedError
|
||||
dist.all_reduce(grad, group=self.cpu_process_group)
|
||||
self.process_group.set_cpu_groups()
|
||||
dist.all_reduce(grad, group=self.process_group.cpu_dp_process_group())
|
||||
return grad
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -114,10 +114,8 @@ class DistSpecManager:
|
|||
if world_size == 1:
|
||||
return tensor
|
||||
|
||||
assert tensor.device.type == "cuda" and old_dist_spec.process_group.backend == "nccl", \
|
||||
"Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll " \
|
||||
f"collective function, however, we got {tensor.device.type} device and " \
|
||||
f"{old_dist_spec.process_group.backend} backend"
|
||||
assert tensor.device.type == "cuda", "Currently, only CUDA Tensors are supported for the requested AlltoAll " \
|
||||
f"collective function, however, we got {tensor.device.type} device"
|
||||
|
||||
gather_dim = old_dist_spec.dims[0]
|
||||
scatter_dim = dist_spec.dims[0]
|
||||
|
|
|
@ -18,7 +18,6 @@ class ProcessGroup:
|
|||
def __init__(self,
|
||||
rank: Optional[int] = None,
|
||||
ranks: Optional[List[int]] = None,
|
||||
backend: str = 'nccl',
|
||||
tp_degree: Optional[int] = None,
|
||||
dp_degree: Optional[int] = None) -> None:
|
||||
assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized"
|
||||
|
@ -32,7 +31,6 @@ class ProcessGroup:
|
|||
else:
|
||||
self._rank_list = ranks
|
||||
|
||||
self._backend = backend
|
||||
self._world_size = len(self._rank_list)
|
||||
|
||||
if dp_degree is None and tp_degree is None:
|
||||
|
@ -59,16 +57,26 @@ class ProcessGroup:
|
|||
if rank_id // self._tp_degree == self._rank // self._tp_degree:
|
||||
self._tp_rank_list.append(rank_id)
|
||||
|
||||
assert backend == 'nccl'
|
||||
self._tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list)
|
||||
self._dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list)
|
||||
self._tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list, backend='nccl')
|
||||
self._dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list, backend='nccl')
|
||||
|
||||
self.logger = get_dist_logger('ProcessGroup')
|
||||
self.logger.info(f'{self._rank} initialize TP group on {self._tp_rank_list} DP group pn {self._dp_rank_list}')
|
||||
self.logger.info(
|
||||
f'{self._rank} NCCL initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}')
|
||||
|
||||
self._has_cpu_groups = False
|
||||
|
||||
def set_cpu_groups(self):
|
||||
if self.has_cpu_groups:
|
||||
return
|
||||
self.logger.info(
|
||||
f'{self._rank} Gloo initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}')
|
||||
self._cpu_tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list, backend='gloo')
|
||||
self._cpu_dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list, backend='gloo')
|
||||
|
||||
@property
|
||||
def backend(self):
|
||||
return self._backend
|
||||
def has_cpu_groups(self):
|
||||
return self._has_cpu_groups
|
||||
|
||||
def __eq__(self, obj: 'ProcessGroup') -> bool:
|
||||
if not isinstance(obj, ProcessGroup):
|
||||
|
@ -81,8 +89,6 @@ class ProcessGroup:
|
|||
assert False
|
||||
if self._dp_rank_list != obj._dp_rank_list:
|
||||
assert False
|
||||
if self._backend != obj._backend:
|
||||
assert False
|
||||
if self._tp_degree != obj._tp_degree:
|
||||
return False
|
||||
if self._dp_degree != obj._dp_degree:
|
||||
|
@ -112,3 +118,9 @@ class ProcessGroup:
|
|||
|
||||
def tp_process_group(self):
|
||||
return self._tp_process_group
|
||||
|
||||
def cpu_dp_process_group(self):
|
||||
return self._cpu_dp_process_group
|
||||
|
||||
def cpu_tp_process_group(self):
|
||||
return self._cpu_tp_process_group
|
||||
|
|
|
@ -17,13 +17,9 @@ class TensorSpec(object):
|
|||
self.compute_spec = compute_spec
|
||||
self.dist_spec = dist_spec
|
||||
|
||||
# TODO(jiaruifang) actually need tp process group
|
||||
def get_process_group(self):
|
||||
return self.dist_spec.process_group
|
||||
|
||||
def get_process_group_size(self):
|
||||
return dist.get_world_size(self.dist_spec.process_group.tp_process_group())
|
||||
|
||||
def get_placement(self):
|
||||
return self.dist_spec.placement
|
||||
|
||||
|
|
Loading…
Reference in New Issue