From b5f25eb32a2fa7bc4e80ffac496a87a9636601bb Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Tue, 5 Jul 2022 14:58:28 +0800 Subject: [PATCH] [Tensor] add cpu group to ddp (#1200) --- colossalai/nn/parallel/data_parallel.py | 14 +++++------ colossalai/tensor/dist_spec_mgr.py | 6 ++--- colossalai/tensor/process_group.py | 32 +++++++++++++++++-------- colossalai/tensor/tensor_spec.py | 4 ---- 4 files changed, 30 insertions(+), 26 deletions(-) diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index a67e4ecde..9de510e3e 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -54,14 +54,11 @@ class ColoDDP(torch.nn.Module): module (torch.nn.Module): Module to apply DDP. process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses. If it's None, the default data parallel group will be used. Defaults to None. - cpu_process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses for those parameters on CPU. - If it's None, the default CPU data parallel group will be used. Defaults to None. """ def __init__(self, module: torch.nn.Module, process_group: ColoProcessGroup, - cpu_process_group: Optional[dist.ProcessGroup] = None, bucket_cap_mb: int = 25, rebuild_bucket: bool = True) -> None: assert not isinstance(module, ColoDDP) @@ -70,8 +67,9 @@ class ColoDDP(torch.nn.Module): self.comm_stream: torch.cuda.Stream = torch.cuda.Stream() assert process_group - self.process_group = process_group.dp_process_group() - self.dp_world_size = self.process_group.size() + self.process_group = process_group + self.dp_world_size = self.process_group.dp_world_size() + self.reducer = Reducer(bucket_cap_mb) self.rebuild_bucket = rebuild_bucket for p in module.parameters(): @@ -112,7 +110,7 @@ class ColoDDP(torch.nn.Module): self.comm_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self.comm_stream): self.reducer.all_reduce_async(grad, - group=self.process_group, + group=self.process_group.dp_process_group(), callback_fn=partial(self._save_grad, p)) grad.record_stream(self.comm_stream) else: @@ -121,8 +119,8 @@ class ColoDDP(torch.nn.Module): else: #TODO(jiaruifang) fixme - raise NotImplementedError - dist.all_reduce(grad, group=self.cpu_process_group) + self.process_group.set_cpu_groups() + dist.all_reduce(grad, group=self.process_group.cpu_dp_process_group()) return grad @staticmethod diff --git a/colossalai/tensor/dist_spec_mgr.py b/colossalai/tensor/dist_spec_mgr.py index 928b0db24..9f9d61c90 100644 --- a/colossalai/tensor/dist_spec_mgr.py +++ b/colossalai/tensor/dist_spec_mgr.py @@ -114,10 +114,8 @@ class DistSpecManager: if world_size == 1: return tensor - assert tensor.device.type == "cuda" and old_dist_spec.process_group.backend == "nccl", \ - "Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll " \ - f"collective function, however, we got {tensor.device.type} device and " \ - f"{old_dist_spec.process_group.backend} backend" + assert tensor.device.type == "cuda", "Currently, only CUDA Tensors are supported for the requested AlltoAll " \ + f"collective function, however, we got {tensor.device.type} device" gather_dim = old_dist_spec.dims[0] scatter_dim = dist_spec.dims[0] diff --git a/colossalai/tensor/process_group.py b/colossalai/tensor/process_group.py index c0141dd90..b2268a241 100644 --- a/colossalai/tensor/process_group.py +++ b/colossalai/tensor/process_group.py @@ -18,7 +18,6 @@ class ProcessGroup: def __init__(self, rank: Optional[int] = None, ranks: Optional[List[int]] = None, - backend: str = 'nccl', tp_degree: Optional[int] = None, dp_degree: Optional[int] = None) -> None: assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized" @@ -32,7 +31,6 @@ class ProcessGroup: else: self._rank_list = ranks - self._backend = backend self._world_size = len(self._rank_list) if dp_degree is None and tp_degree is None: @@ -59,16 +57,26 @@ class ProcessGroup: if rank_id // self._tp_degree == self._rank // self._tp_degree: self._tp_rank_list.append(rank_id) - assert backend == 'nccl' - self._tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list) - self._dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list) + self._tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list, backend='nccl') + self._dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list, backend='nccl') self.logger = get_dist_logger('ProcessGroup') - self.logger.info(f'{self._rank} initialize TP group on {self._tp_rank_list} DP group pn {self._dp_rank_list}') + self.logger.info( + f'{self._rank} NCCL initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}') + + self._has_cpu_groups = False + + def set_cpu_groups(self): + if self.has_cpu_groups: + return + self.logger.info( + f'{self._rank} Gloo initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}') + self._cpu_tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list, backend='gloo') + self._cpu_dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list, backend='gloo') @property - def backend(self): - return self._backend + def has_cpu_groups(self): + return self._has_cpu_groups def __eq__(self, obj: 'ProcessGroup') -> bool: if not isinstance(obj, ProcessGroup): @@ -81,8 +89,6 @@ class ProcessGroup: assert False if self._dp_rank_list != obj._dp_rank_list: assert False - if self._backend != obj._backend: - assert False if self._tp_degree != obj._tp_degree: return False if self._dp_degree != obj._dp_degree: @@ -112,3 +118,9 @@ class ProcessGroup: def tp_process_group(self): return self._tp_process_group + + def cpu_dp_process_group(self): + return self._cpu_dp_process_group + + def cpu_tp_process_group(self): + return self._cpu_tp_process_group diff --git a/colossalai/tensor/tensor_spec.py b/colossalai/tensor/tensor_spec.py index 21d9c3d5c..8d785e3ef 100644 --- a/colossalai/tensor/tensor_spec.py +++ b/colossalai/tensor/tensor_spec.py @@ -17,13 +17,9 @@ class TensorSpec(object): self.compute_spec = compute_spec self.dist_spec = dist_spec - # TODO(jiaruifang) actually need tp process group def get_process_group(self): return self.dist_spec.process_group - def get_process_group_size(self): - return dist.get_world_size(self.dist_spec.process_group.tp_process_group()) - def get_placement(self): return self.dist_spec.placement