mirror of https://github.com/hpcaitech/ColossalAI
[ddp] ColoDDP uses bucket all-reduce (#1177)
* add reducer * update colo ddp with reducer * polish unit test * polish unit testpull/1182/head
parent
7487215b95
commit
6b2f2ab9bb
|
@ -12,6 +12,7 @@ from typing import Dict, Iterable, List, Optional
|
|||
from colossalai.logging import get_dist_logger
|
||||
from collections import OrderedDict
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from .reducer import Reducer
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
|
||||
except ImportError:
|
||||
|
@ -61,7 +62,9 @@ class ColoDDP(torch.nn.Module):
|
|||
def __init__(self,
|
||||
module: torch.nn.Module,
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
cpu_process_group: Optional[dist.ProcessGroup] = None) -> None:
|
||||
cpu_process_group: Optional[dist.ProcessGroup] = None,
|
||||
bucket_cap_mb: int = 25,
|
||||
rebuild_bucket: bool = True) -> None:
|
||||
assert not isinstance(module, ColoDDP)
|
||||
super().__init__()
|
||||
self.module = module
|
||||
|
@ -69,6 +72,8 @@ class ColoDDP(torch.nn.Module):
|
|||
self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
|
||||
self.cpu_process_group = cpu_process_group or gpc.get_cpu_group(ParallelMode.DATA)
|
||||
self.dp_world_size = self.process_group.size()
|
||||
self.reducer = Reducer(bucket_cap_mb)
|
||||
self.rebuild_bucket = rebuild_bucket
|
||||
for p in module.parameters():
|
||||
if getattr(p, '_ddp_to_ignore', False):
|
||||
continue
|
||||
|
@ -87,7 +92,11 @@ class ColoDDP(torch.nn.Module):
|
|||
|
||||
def backward(self, loss: torch.Tensor):
|
||||
loss.backward()
|
||||
with torch.cuda.stream(self.comm_stream):
|
||||
self.reducer.flush()
|
||||
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
||||
if self.rebuild_bucket:
|
||||
self.reducer.free()
|
||||
for p in self.module.parameters():
|
||||
if getattr(p, '_ddp_to_ignore', False):
|
||||
continue
|
||||
|
@ -102,8 +111,9 @@ class ColoDDP(torch.nn.Module):
|
|||
grad = grad / self.dp_world_size
|
||||
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(self.comm_stream):
|
||||
dist.all_reduce(grad, group=self.process_group)
|
||||
ColoDDP._save_grad(p, grad)
|
||||
self.reducer.all_reduce_async(grad,
|
||||
group=self.process_group,
|
||||
callback_fn=partial(self._save_grad, p))
|
||||
grad.record_stream(self.comm_stream)
|
||||
else:
|
||||
ColoDDP._save_grad(p, grad)
|
||||
|
|
|
@ -0,0 +1,116 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import functools
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
|
||||
class Bucket:
|
||||
|
||||
def __init__(self, size: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup):
|
||||
self.buffer = torch.zeros(size, dtype=dtype, device=device)
|
||||
self.group = group
|
||||
self.offset = 0
|
||||
self.callbacks: List[Callable] = []
|
||||
|
||||
def flush(self) -> None:
|
||||
"""Flush content of the bucket."""
|
||||
if self.offset == 0:
|
||||
assert len(self.callbacks) == 0
|
||||
return
|
||||
# reduce-scatter bucket
|
||||
dist.all_reduce(self.buffer[:self.offset], group=self.group)
|
||||
|
||||
# execute post-reduction callbacks
|
||||
for callback_fn in self.callbacks:
|
||||
callback_fn()
|
||||
# reuse input bucket but allocate a fresh output shard
|
||||
self.offset = 0
|
||||
self.callbacks.clear()
|
||||
self.buffer = torch.zeros_like(self.buffer)
|
||||
|
||||
def alloc(self) -> None:
|
||||
|
||||
if self.buffer.storage().size() == 0:
|
||||
self.buffer.storage().resize_(self.buffer.numel())
|
||||
|
||||
def free(self) -> None:
|
||||
|
||||
assert self.offset == 0 and self.callbacks == [], "Incorrect call of teardown"
|
||||
self.buffer.storage().resize_(0)
|
||||
|
||||
def append(self, tensor: Tensor, callback_fn: Callable):
|
||||
tensor_size = tensor.numel()
|
||||
offset = self.offset
|
||||
self.buffer[offset:offset + tensor_size].copy_(tensor.flatten())
|
||||
self.offset += tensor_size
|
||||
|
||||
# callback will be given the reduced result
|
||||
if callback_fn is not None:
|
||||
result_view = self.buffer[offset:offset + tensor_size].view(tensor.shape)
|
||||
self.callbacks.append(functools.partial(callback_fn, result_view))
|
||||
|
||||
@property
|
||||
def avail_size(self) -> int:
|
||||
return self.buffer.size(0) - self.offset
|
||||
|
||||
|
||||
class Reducer:
|
||||
|
||||
def __init__(self, bucket_size_mb: int = 25):
|
||||
self.bucket_size_mb = bucket_size_mb
|
||||
self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {}
|
||||
|
||||
@torch.no_grad()
|
||||
def all_reduce_async(
|
||||
self,
|
||||
tensor: Tensor,
|
||||
group: ProcessGroup,
|
||||
callback_fn: Optional[Callable] = None,
|
||||
) -> None:
|
||||
bucket_size = self._get_bucket_size(tensor.element_size())
|
||||
|
||||
if tensor.numel() >= bucket_size:
|
||||
dist.all_reduce(tensor, group=group)
|
||||
if callback_fn is not None:
|
||||
callback_fn(tensor)
|
||||
return
|
||||
|
||||
bucket = self._get_bucket(tensor, group)
|
||||
if tensor.numel() > bucket.avail_size:
|
||||
# not enough space remaining in bucket, flush it now
|
||||
bucket.flush()
|
||||
bucket.append(tensor, callback_fn)
|
||||
|
||||
@torch.no_grad()
|
||||
def flush(self) -> None:
|
||||
for bucket in self.buckets.values():
|
||||
bucket.flush()
|
||||
|
||||
@torch.no_grad()
|
||||
def free(self) -> None:
|
||||
for bucket in self.buckets.values():
|
||||
bucket.free()
|
||||
|
||||
@functools.lru_cache()
|
||||
def _get_bucket_size(self, element_size: int) -> int:
|
||||
if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing.
|
||||
return 0
|
||||
MB = 1024 * 1024
|
||||
bucket_size = self.bucket_size_mb * MB / element_size
|
||||
return int(bucket_size)
|
||||
|
||||
def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket:
|
||||
key = (tensor.dtype, tensor.device, group)
|
||||
if key not in self.buckets:
|
||||
bucket_size = self._get_bucket_size(tensor.element_size())
|
||||
self.buckets[key] = Bucket(bucket_size, tensor.dtype, tensor.device, group)
|
||||
self.buckets[key].alloc()
|
||||
return self.buckets[key]
|
|
@ -0,0 +1,55 @@
|
|||
import pytest
|
||||
import colossalai
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.tensor import ChunkManager
|
||||
from functools import partial
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from colossalai.nn.parallel import ZeroDDP, ColoDDP
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from typing import Callable
|
||||
from collections import OrderedDict
|
||||
from colossalai.nn.parallel.reducer import Reducer
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
|
||||
REDUCE_CNT = 0
|
||||
|
||||
|
||||
def check_eq(grad, grad_clone):
|
||||
global REDUCE_CNT
|
||||
print(f'Rank{dist.get_rank()} check {REDUCE_CNT}')
|
||||
REDUCE_CNT += 1
|
||||
assert torch.allclose(grad, grad_clone)
|
||||
|
||||
|
||||
def run_reducer():
|
||||
grads = [torch.rand(64, i + 1, device=get_current_device()) for i in range(10)]
|
||||
grads_clone = [g.clone().detach() for g in grads]
|
||||
for g in grads:
|
||||
dist.all_reduce(g)
|
||||
reducer = Reducer(bucket_size_mb=1)
|
||||
for g, g_clone in zip(grads, grads_clone):
|
||||
reducer.all_reduce_async(g_clone, _get_default_group(), partial(check_eq, g))
|
||||
reducer.flush()
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_reducer()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_reducer(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_reducer(2)
|
Loading…
Reference in New Issue