mirror of https://github.com/hpcaitech/ColossalAI
[Tensor] add cpu group to ddp (#1200)
parent
f7878f465c
commit
b5f25eb32a
|
@ -54,14 +54,11 @@ class ColoDDP(torch.nn.Module):
|
||||||
module (torch.nn.Module): Module to apply DDP.
|
module (torch.nn.Module): Module to apply DDP.
|
||||||
process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses.
|
process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses.
|
||||||
If it's None, the default data parallel group will be used. Defaults to None.
|
If it's None, the default data parallel group will be used. Defaults to None.
|
||||||
cpu_process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses for those parameters on CPU.
|
|
||||||
If it's None, the default CPU data parallel group will be used. Defaults to None.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
module: torch.nn.Module,
|
module: torch.nn.Module,
|
||||||
process_group: ColoProcessGroup,
|
process_group: ColoProcessGroup,
|
||||||
cpu_process_group: Optional[dist.ProcessGroup] = None,
|
|
||||||
bucket_cap_mb: int = 25,
|
bucket_cap_mb: int = 25,
|
||||||
rebuild_bucket: bool = True) -> None:
|
rebuild_bucket: bool = True) -> None:
|
||||||
assert not isinstance(module, ColoDDP)
|
assert not isinstance(module, ColoDDP)
|
||||||
|
@ -70,8 +67,9 @@ class ColoDDP(torch.nn.Module):
|
||||||
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
|
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
|
||||||
assert process_group
|
assert process_group
|
||||||
|
|
||||||
self.process_group = process_group.dp_process_group()
|
self.process_group = process_group
|
||||||
self.dp_world_size = self.process_group.size()
|
self.dp_world_size = self.process_group.dp_world_size()
|
||||||
|
|
||||||
self.reducer = Reducer(bucket_cap_mb)
|
self.reducer = Reducer(bucket_cap_mb)
|
||||||
self.rebuild_bucket = rebuild_bucket
|
self.rebuild_bucket = rebuild_bucket
|
||||||
for p in module.parameters():
|
for p in module.parameters():
|
||||||
|
@ -112,7 +110,7 @@ class ColoDDP(torch.nn.Module):
|
||||||
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
||||||
with torch.cuda.stream(self.comm_stream):
|
with torch.cuda.stream(self.comm_stream):
|
||||||
self.reducer.all_reduce_async(grad,
|
self.reducer.all_reduce_async(grad,
|
||||||
group=self.process_group,
|
group=self.process_group.dp_process_group(),
|
||||||
callback_fn=partial(self._save_grad, p))
|
callback_fn=partial(self._save_grad, p))
|
||||||
grad.record_stream(self.comm_stream)
|
grad.record_stream(self.comm_stream)
|
||||||
else:
|
else:
|
||||||
|
@ -121,8 +119,8 @@ class ColoDDP(torch.nn.Module):
|
||||||
|
|
||||||
else:
|
else:
|
||||||
#TODO(jiaruifang) fixme
|
#TODO(jiaruifang) fixme
|
||||||
raise NotImplementedError
|
self.process_group.set_cpu_groups()
|
||||||
dist.all_reduce(grad, group=self.cpu_process_group)
|
dist.all_reduce(grad, group=self.process_group.cpu_dp_process_group())
|
||||||
return grad
|
return grad
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -114,10 +114,8 @@ class DistSpecManager:
|
||||||
if world_size == 1:
|
if world_size == 1:
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
assert tensor.device.type == "cuda" and old_dist_spec.process_group.backend == "nccl", \
|
assert tensor.device.type == "cuda", "Currently, only CUDA Tensors are supported for the requested AlltoAll " \
|
||||||
"Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll " \
|
f"collective function, however, we got {tensor.device.type} device"
|
||||||
f"collective function, however, we got {tensor.device.type} device and " \
|
|
||||||
f"{old_dist_spec.process_group.backend} backend"
|
|
||||||
|
|
||||||
gather_dim = old_dist_spec.dims[0]
|
gather_dim = old_dist_spec.dims[0]
|
||||||
scatter_dim = dist_spec.dims[0]
|
scatter_dim = dist_spec.dims[0]
|
||||||
|
|
|
@ -18,7 +18,6 @@ class ProcessGroup:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
rank: Optional[int] = None,
|
rank: Optional[int] = None,
|
||||||
ranks: Optional[List[int]] = None,
|
ranks: Optional[List[int]] = None,
|
||||||
backend: str = 'nccl',
|
|
||||||
tp_degree: Optional[int] = None,
|
tp_degree: Optional[int] = None,
|
||||||
dp_degree: Optional[int] = None) -> None:
|
dp_degree: Optional[int] = None) -> None:
|
||||||
assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized"
|
assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized"
|
||||||
|
@ -32,7 +31,6 @@ class ProcessGroup:
|
||||||
else:
|
else:
|
||||||
self._rank_list = ranks
|
self._rank_list = ranks
|
||||||
|
|
||||||
self._backend = backend
|
|
||||||
self._world_size = len(self._rank_list)
|
self._world_size = len(self._rank_list)
|
||||||
|
|
||||||
if dp_degree is None and tp_degree is None:
|
if dp_degree is None and tp_degree is None:
|
||||||
|
@ -59,16 +57,26 @@ class ProcessGroup:
|
||||||
if rank_id // self._tp_degree == self._rank // self._tp_degree:
|
if rank_id // self._tp_degree == self._rank // self._tp_degree:
|
||||||
self._tp_rank_list.append(rank_id)
|
self._tp_rank_list.append(rank_id)
|
||||||
|
|
||||||
assert backend == 'nccl'
|
self._tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list, backend='nccl')
|
||||||
self._tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list)
|
self._dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list, backend='nccl')
|
||||||
self._dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list)
|
|
||||||
|
|
||||||
self.logger = get_dist_logger('ProcessGroup')
|
self.logger = get_dist_logger('ProcessGroup')
|
||||||
self.logger.info(f'{self._rank} initialize TP group on {self._tp_rank_list} DP group pn {self._dp_rank_list}')
|
self.logger.info(
|
||||||
|
f'{self._rank} NCCL initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}')
|
||||||
|
|
||||||
|
self._has_cpu_groups = False
|
||||||
|
|
||||||
|
def set_cpu_groups(self):
|
||||||
|
if self.has_cpu_groups:
|
||||||
|
return
|
||||||
|
self.logger.info(
|
||||||
|
f'{self._rank} Gloo initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}')
|
||||||
|
self._cpu_tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list, backend='gloo')
|
||||||
|
self._cpu_dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list, backend='gloo')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def backend(self):
|
def has_cpu_groups(self):
|
||||||
return self._backend
|
return self._has_cpu_groups
|
||||||
|
|
||||||
def __eq__(self, obj: 'ProcessGroup') -> bool:
|
def __eq__(self, obj: 'ProcessGroup') -> bool:
|
||||||
if not isinstance(obj, ProcessGroup):
|
if not isinstance(obj, ProcessGroup):
|
||||||
|
@ -81,8 +89,6 @@ class ProcessGroup:
|
||||||
assert False
|
assert False
|
||||||
if self._dp_rank_list != obj._dp_rank_list:
|
if self._dp_rank_list != obj._dp_rank_list:
|
||||||
assert False
|
assert False
|
||||||
if self._backend != obj._backend:
|
|
||||||
assert False
|
|
||||||
if self._tp_degree != obj._tp_degree:
|
if self._tp_degree != obj._tp_degree:
|
||||||
return False
|
return False
|
||||||
if self._dp_degree != obj._dp_degree:
|
if self._dp_degree != obj._dp_degree:
|
||||||
|
@ -112,3 +118,9 @@ class ProcessGroup:
|
||||||
|
|
||||||
def tp_process_group(self):
|
def tp_process_group(self):
|
||||||
return self._tp_process_group
|
return self._tp_process_group
|
||||||
|
|
||||||
|
def cpu_dp_process_group(self):
|
||||||
|
return self._cpu_dp_process_group
|
||||||
|
|
||||||
|
def cpu_tp_process_group(self):
|
||||||
|
return self._cpu_tp_process_group
|
||||||
|
|
|
@ -17,13 +17,9 @@ class TensorSpec(object):
|
||||||
self.compute_spec = compute_spec
|
self.compute_spec = compute_spec
|
||||||
self.dist_spec = dist_spec
|
self.dist_spec = dist_spec
|
||||||
|
|
||||||
# TODO(jiaruifang) actually need tp process group
|
|
||||||
def get_process_group(self):
|
def get_process_group(self):
|
||||||
return self.dist_spec.process_group
|
return self.dist_spec.process_group
|
||||||
|
|
||||||
def get_process_group_size(self):
|
|
||||||
return dist.get_world_size(self.dist_spec.process_group.tp_process_group())
|
|
||||||
|
|
||||||
def get_placement(self):
|
def get_placement(self):
|
||||||
return self.dist_spec.placement
|
return self.dist_spec.placement
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue