mirror of https://github.com/hpcaitech/ColossalAI
[ddp] ColoDDP uses bucket all-reduce (#1177)
* add reducer * update colo ddp with reducer * polish unit test * polish unit testpull/1182/head
parent
7487215b95
commit
6b2f2ab9bb
|
@ -12,6 +12,7 @@ from typing import Dict, Iterable, List, Optional
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from colossalai.tensor.colo_parameter import ColoParameter
|
from colossalai.tensor.colo_parameter import ColoParameter
|
||||||
|
from .reducer import Reducer
|
||||||
try:
|
try:
|
||||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
|
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -61,7 +62,9 @@ class ColoDDP(torch.nn.Module):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
module: torch.nn.Module,
|
module: torch.nn.Module,
|
||||||
process_group: Optional[dist.ProcessGroup] = None,
|
process_group: Optional[dist.ProcessGroup] = None,
|
||||||
cpu_process_group: Optional[dist.ProcessGroup] = None) -> None:
|
cpu_process_group: Optional[dist.ProcessGroup] = None,
|
||||||
|
bucket_cap_mb: int = 25,
|
||||||
|
rebuild_bucket: bool = True) -> None:
|
||||||
assert not isinstance(module, ColoDDP)
|
assert not isinstance(module, ColoDDP)
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.module = module
|
self.module = module
|
||||||
|
@ -69,6 +72,8 @@ class ColoDDP(torch.nn.Module):
|
||||||
self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
|
self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
|
||||||
self.cpu_process_group = cpu_process_group or gpc.get_cpu_group(ParallelMode.DATA)
|
self.cpu_process_group = cpu_process_group or gpc.get_cpu_group(ParallelMode.DATA)
|
||||||
self.dp_world_size = self.process_group.size()
|
self.dp_world_size = self.process_group.size()
|
||||||
|
self.reducer = Reducer(bucket_cap_mb)
|
||||||
|
self.rebuild_bucket = rebuild_bucket
|
||||||
for p in module.parameters():
|
for p in module.parameters():
|
||||||
if getattr(p, '_ddp_to_ignore', False):
|
if getattr(p, '_ddp_to_ignore', False):
|
||||||
continue
|
continue
|
||||||
|
@ -87,7 +92,11 @@ class ColoDDP(torch.nn.Module):
|
||||||
|
|
||||||
def backward(self, loss: torch.Tensor):
|
def backward(self, loss: torch.Tensor):
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
with torch.cuda.stream(self.comm_stream):
|
||||||
|
self.reducer.flush()
|
||||||
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
||||||
|
if self.rebuild_bucket:
|
||||||
|
self.reducer.free()
|
||||||
for p in self.module.parameters():
|
for p in self.module.parameters():
|
||||||
if getattr(p, '_ddp_to_ignore', False):
|
if getattr(p, '_ddp_to_ignore', False):
|
||||||
continue
|
continue
|
||||||
|
@ -102,8 +111,9 @@ class ColoDDP(torch.nn.Module):
|
||||||
grad = grad / self.dp_world_size
|
grad = grad / self.dp_world_size
|
||||||
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
||||||
with torch.cuda.stream(self.comm_stream):
|
with torch.cuda.stream(self.comm_stream):
|
||||||
dist.all_reduce(grad, group=self.process_group)
|
self.reducer.all_reduce_async(grad,
|
||||||
ColoDDP._save_grad(p, grad)
|
group=self.process_group,
|
||||||
|
callback_fn=partial(self._save_grad, p))
|
||||||
grad.record_stream(self.comm_stream)
|
grad.record_stream(self.comm_stream)
|
||||||
else:
|
else:
|
||||||
ColoDDP._save_grad(p, grad)
|
ColoDDP._save_grad(p, grad)
|
||||||
|
|
|
@ -0,0 +1,116 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import functools
|
||||||
|
from typing import Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
|
||||||
|
class Bucket:
|
||||||
|
|
||||||
|
def __init__(self, size: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup):
|
||||||
|
self.buffer = torch.zeros(size, dtype=dtype, device=device)
|
||||||
|
self.group = group
|
||||||
|
self.offset = 0
|
||||||
|
self.callbacks: List[Callable] = []
|
||||||
|
|
||||||
|
def flush(self) -> None:
|
||||||
|
"""Flush content of the bucket."""
|
||||||
|
if self.offset == 0:
|
||||||
|
assert len(self.callbacks) == 0
|
||||||
|
return
|
||||||
|
# reduce-scatter bucket
|
||||||
|
dist.all_reduce(self.buffer[:self.offset], group=self.group)
|
||||||
|
|
||||||
|
# execute post-reduction callbacks
|
||||||
|
for callback_fn in self.callbacks:
|
||||||
|
callback_fn()
|
||||||
|
# reuse input bucket but allocate a fresh output shard
|
||||||
|
self.offset = 0
|
||||||
|
self.callbacks.clear()
|
||||||
|
self.buffer = torch.zeros_like(self.buffer)
|
||||||
|
|
||||||
|
def alloc(self) -> None:
|
||||||
|
|
||||||
|
if self.buffer.storage().size() == 0:
|
||||||
|
self.buffer.storage().resize_(self.buffer.numel())
|
||||||
|
|
||||||
|
def free(self) -> None:
|
||||||
|
|
||||||
|
assert self.offset == 0 and self.callbacks == [], "Incorrect call of teardown"
|
||||||
|
self.buffer.storage().resize_(0)
|
||||||
|
|
||||||
|
def append(self, tensor: Tensor, callback_fn: Callable):
|
||||||
|
tensor_size = tensor.numel()
|
||||||
|
offset = self.offset
|
||||||
|
self.buffer[offset:offset + tensor_size].copy_(tensor.flatten())
|
||||||
|
self.offset += tensor_size
|
||||||
|
|
||||||
|
# callback will be given the reduced result
|
||||||
|
if callback_fn is not None:
|
||||||
|
result_view = self.buffer[offset:offset + tensor_size].view(tensor.shape)
|
||||||
|
self.callbacks.append(functools.partial(callback_fn, result_view))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def avail_size(self) -> int:
|
||||||
|
return self.buffer.size(0) - self.offset
|
||||||
|
|
||||||
|
|
||||||
|
class Reducer:
|
||||||
|
|
||||||
|
def __init__(self, bucket_size_mb: int = 25):
|
||||||
|
self.bucket_size_mb = bucket_size_mb
|
||||||
|
self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {}
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def all_reduce_async(
|
||||||
|
self,
|
||||||
|
tensor: Tensor,
|
||||||
|
group: ProcessGroup,
|
||||||
|
callback_fn: Optional[Callable] = None,
|
||||||
|
) -> None:
|
||||||
|
bucket_size = self._get_bucket_size(tensor.element_size())
|
||||||
|
|
||||||
|
if tensor.numel() >= bucket_size:
|
||||||
|
dist.all_reduce(tensor, group=group)
|
||||||
|
if callback_fn is not None:
|
||||||
|
callback_fn(tensor)
|
||||||
|
return
|
||||||
|
|
||||||
|
bucket = self._get_bucket(tensor, group)
|
||||||
|
if tensor.numel() > bucket.avail_size:
|
||||||
|
# not enough space remaining in bucket, flush it now
|
||||||
|
bucket.flush()
|
||||||
|
bucket.append(tensor, callback_fn)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def flush(self) -> None:
|
||||||
|
for bucket in self.buckets.values():
|
||||||
|
bucket.flush()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def free(self) -> None:
|
||||||
|
for bucket in self.buckets.values():
|
||||||
|
bucket.free()
|
||||||
|
|
||||||
|
@functools.lru_cache()
|
||||||
|
def _get_bucket_size(self, element_size: int) -> int:
|
||||||
|
if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing.
|
||||||
|
return 0
|
||||||
|
MB = 1024 * 1024
|
||||||
|
bucket_size = self.bucket_size_mb * MB / element_size
|
||||||
|
return int(bucket_size)
|
||||||
|
|
||||||
|
def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket:
|
||||||
|
key = (tensor.dtype, tensor.device, group)
|
||||||
|
if key not in self.buckets:
|
||||||
|
bucket_size = self._get_bucket_size(tensor.element_size())
|
||||||
|
self.buckets[key] = Bucket(bucket_size, tensor.dtype, tensor.device, group)
|
||||||
|
self.buckets[key].alloc()
|
||||||
|
return self.buckets[key]
|
|
@ -0,0 +1,55 @@
|
||||||
|
import pytest
|
||||||
|
import colossalai
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
|
from colossalai.utils.cuda import get_current_device
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
|
from colossalai.tensor import ChunkManager
|
||||||
|
from functools import partial
|
||||||
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
|
from colossalai.nn.parallel import ZeroDDP, ColoDDP
|
||||||
|
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||||
|
from typing import Callable
|
||||||
|
from collections import OrderedDict
|
||||||
|
from colossalai.nn.parallel.reducer import Reducer
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.distributed.distributed_c10d import _get_default_group
|
||||||
|
|
||||||
|
REDUCE_CNT = 0
|
||||||
|
|
||||||
|
|
||||||
|
def check_eq(grad, grad_clone):
|
||||||
|
global REDUCE_CNT
|
||||||
|
print(f'Rank{dist.get_rank()} check {REDUCE_CNT}')
|
||||||
|
REDUCE_CNT += 1
|
||||||
|
assert torch.allclose(grad, grad_clone)
|
||||||
|
|
||||||
|
|
||||||
|
def run_reducer():
|
||||||
|
grads = [torch.rand(64, i + 1, device=get_current_device()) for i in range(10)]
|
||||||
|
grads_clone = [g.clone().detach() for g in grads]
|
||||||
|
for g in grads:
|
||||||
|
dist.all_reduce(g)
|
||||||
|
reducer = Reducer(bucket_size_mb=1)
|
||||||
|
for g, g_clone in zip(grads, grads_clone):
|
||||||
|
reducer.all_reduce_async(g_clone, _get_default_group(), partial(check_eq, g))
|
||||||
|
reducer.flush()
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
run_reducer()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@pytest.mark.parametrize('world_size', [1, 2])
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_reducer(world_size):
|
||||||
|
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_reducer(2)
|
Loading…
Reference in New Issue