From 6b2f2ab9bb410b7edd722dbfe6ee425de1437c31 Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 29 Jun 2022 10:34:13 +0800 Subject: [PATCH] [ddp] ColoDDP uses bucket all-reduce (#1177) * add reducer * update colo ddp with reducer * polish unit test * polish unit test --- colossalai/nn/parallel/data_parallel.py | 16 +++- colossalai/nn/parallel/reducer.py | 116 ++++++++++++++++++++++++ tests/test_ddp/test_reducer.py | 55 +++++++++++ 3 files changed, 184 insertions(+), 3 deletions(-) create mode 100644 colossalai/nn/parallel/reducer.py create mode 100644 tests/test_ddp/test_reducer.py diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index f983a782c..fb98b62f3 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -12,6 +12,7 @@ from typing import Dict, Iterable, List, Optional from colossalai.logging import get_dist_logger from collections import OrderedDict from colossalai.tensor.colo_parameter import ColoParameter +from .reducer import Reducer try: from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys except ImportError: @@ -61,7 +62,9 @@ class ColoDDP(torch.nn.Module): def __init__(self, module: torch.nn.Module, process_group: Optional[dist.ProcessGroup] = None, - cpu_process_group: Optional[dist.ProcessGroup] = None) -> None: + cpu_process_group: Optional[dist.ProcessGroup] = None, + bucket_cap_mb: int = 25, + rebuild_bucket: bool = True) -> None: assert not isinstance(module, ColoDDP) super().__init__() self.module = module @@ -69,6 +72,8 @@ class ColoDDP(torch.nn.Module): self.process_group = process_group or gpc.get_group(ParallelMode.DATA) self.cpu_process_group = cpu_process_group or gpc.get_cpu_group(ParallelMode.DATA) self.dp_world_size = self.process_group.size() + self.reducer = Reducer(bucket_cap_mb) + self.rebuild_bucket = rebuild_bucket for p in module.parameters(): if getattr(p, '_ddp_to_ignore', False): continue @@ -87,7 +92,11 @@ class ColoDDP(torch.nn.Module): def backward(self, loss: torch.Tensor): loss.backward() + with torch.cuda.stream(self.comm_stream): + self.reducer.flush() torch.cuda.current_stream().wait_stream(self.comm_stream) + if self.rebuild_bucket: + self.reducer.free() for p in self.module.parameters(): if getattr(p, '_ddp_to_ignore', False): continue @@ -102,8 +111,9 @@ class ColoDDP(torch.nn.Module): grad = grad / self.dp_world_size self.comm_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self.comm_stream): - dist.all_reduce(grad, group=self.process_group) - ColoDDP._save_grad(p, grad) + self.reducer.all_reduce_async(grad, + group=self.process_group, + callback_fn=partial(self._save_grad, p)) grad.record_stream(self.comm_stream) else: ColoDDP._save_grad(p, grad) diff --git a/colossalai/nn/parallel/reducer.py b/colossalai/nn/parallel/reducer.py new file mode 100644 index 000000000..568705581 --- /dev/null +++ b/colossalai/nn/parallel/reducer.py @@ -0,0 +1,116 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import functools +from typing import Callable, Dict, List, Optional, Tuple + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.distributed import ProcessGroup + + +class Bucket: + + def __init__(self, size: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup): + self.buffer = torch.zeros(size, dtype=dtype, device=device) + self.group = group + self.offset = 0 + self.callbacks: List[Callable] = [] + + def flush(self) -> None: + """Flush content of the bucket.""" + if self.offset == 0: + assert len(self.callbacks) == 0 + return + # reduce-scatter bucket + dist.all_reduce(self.buffer[:self.offset], group=self.group) + + # execute post-reduction callbacks + for callback_fn in self.callbacks: + callback_fn() + # reuse input bucket but allocate a fresh output shard + self.offset = 0 + self.callbacks.clear() + self.buffer = torch.zeros_like(self.buffer) + + def alloc(self) -> None: + + if self.buffer.storage().size() == 0: + self.buffer.storage().resize_(self.buffer.numel()) + + def free(self) -> None: + + assert self.offset == 0 and self.callbacks == [], "Incorrect call of teardown" + self.buffer.storage().resize_(0) + + def append(self, tensor: Tensor, callback_fn: Callable): + tensor_size = tensor.numel() + offset = self.offset + self.buffer[offset:offset + tensor_size].copy_(tensor.flatten()) + self.offset += tensor_size + + # callback will be given the reduced result + if callback_fn is not None: + result_view = self.buffer[offset:offset + tensor_size].view(tensor.shape) + self.callbacks.append(functools.partial(callback_fn, result_view)) + + @property + def avail_size(self) -> int: + return self.buffer.size(0) - self.offset + + +class Reducer: + + def __init__(self, bucket_size_mb: int = 25): + self.bucket_size_mb = bucket_size_mb + self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {} + + @torch.no_grad() + def all_reduce_async( + self, + tensor: Tensor, + group: ProcessGroup, + callback_fn: Optional[Callable] = None, + ) -> None: + bucket_size = self._get_bucket_size(tensor.element_size()) + + if tensor.numel() >= bucket_size: + dist.all_reduce(tensor, group=group) + if callback_fn is not None: + callback_fn(tensor) + return + + bucket = self._get_bucket(tensor, group) + if tensor.numel() > bucket.avail_size: + # not enough space remaining in bucket, flush it now + bucket.flush() + bucket.append(tensor, callback_fn) + + @torch.no_grad() + def flush(self) -> None: + for bucket in self.buckets.values(): + bucket.flush() + + @torch.no_grad() + def free(self) -> None: + for bucket in self.buckets.values(): + bucket.free() + + @functools.lru_cache() + def _get_bucket_size(self, element_size: int) -> int: + if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing. + return 0 + MB = 1024 * 1024 + bucket_size = self.bucket_size_mb * MB / element_size + return int(bucket_size) + + def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket: + key = (tensor.dtype, tensor.device, group) + if key not in self.buckets: + bucket_size = self._get_bucket_size(tensor.element_size()) + self.buckets[key] = Bucket(bucket_size, tensor.dtype, tensor.device, group) + self.buckets[key].alloc() + return self.buckets[key] diff --git a/tests/test_ddp/test_reducer.py b/tests/test_ddp/test_reducer.py new file mode 100644 index 000000000..64faffd3c --- /dev/null +++ b/tests/test_ddp/test_reducer.py @@ -0,0 +1,55 @@ +import pytest +import colossalai +import torch +import torch.multiprocessing as mp +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils.cuda import get_current_device +from colossalai.utils import free_port +from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.tensor import ChunkManager +from functools import partial +from tests.components_to_test.registry import non_distributed_component_funcs +from colossalai.nn.parallel import ZeroDDP, ColoDDP +from colossalai.gemini.gemini_mgr import GeminiManager +from typing import Callable +from collections import OrderedDict +from colossalai.nn.parallel.reducer import Reducer +import torch.distributed as dist +from torch.distributed.distributed_c10d import _get_default_group + +REDUCE_CNT = 0 + + +def check_eq(grad, grad_clone): + global REDUCE_CNT + print(f'Rank{dist.get_rank()} check {REDUCE_CNT}') + REDUCE_CNT += 1 + assert torch.allclose(grad, grad_clone) + + +def run_reducer(): + grads = [torch.rand(64, i + 1, device=get_current_device()) for i in range(10)] + grads_clone = [g.clone().detach() for g in grads] + for g in grads: + dist.all_reduce(g) + reducer = Reducer(bucket_size_mb=1) + for g, g_clone in zip(grads, grads_clone): + reducer.all_reduce_async(g_clone, _get_default_group(), partial(check_eq, g)) + reducer.flush() + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_reducer() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2]) +@rerun_if_address_is_in_use() +def test_reducer(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_reducer(2)