mirror of https://github.com/hpcaitech/ColossalAI
201 lines
8.2 KiB
Python
201 lines
8.2 KiB
Python
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||
|
#
|
||
|
# This source code is licensed under the BSD license found in the
|
||
|
# LICENSE file in the root directory of this source tree.
|
||
|
|
||
|
import functools
|
||
|
import os
|
||
|
from typing import Callable, Dict, List, Optional, Tuple
|
||
|
|
||
|
import torch
|
||
|
import torch.distributed as dist
|
||
|
from torch import Tensor
|
||
|
from torch.distributed import ProcessGroup
|
||
|
|
||
|
# TODO: Remove the toggle-enable_nccl_base_collectives when github open issue #801 is resolved.
|
||
|
if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0":
|
||
|
enable_nccl_base_collectives = False
|
||
|
else:
|
||
|
enable_nccl_base_collectives = True
|
||
|
|
||
|
|
||
|
class Bucket:
|
||
|
|
||
|
def __init__(self, shard_size: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup):
|
||
|
self.buffer = torch.zeros((group.size(), shard_size), dtype=dtype, device=device)
|
||
|
self.group = group
|
||
|
self.offset = 0
|
||
|
self.callbacks: List[Callable] = []
|
||
|
self.output_shard = torch.zeros_like(self.buffer[0])
|
||
|
|
||
|
def flush(self) -> None:
|
||
|
"""Flush content of the bucket."""
|
||
|
if self.offset == 0:
|
||
|
assert len(self.callbacks) == 0
|
||
|
return
|
||
|
# reduce-scatter bucket
|
||
|
if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives:
|
||
|
dist._reduce_scatter_base(self.output_shard[:self.offset],
|
||
|
self.buffer[:, :self.offset].contiguous(),
|
||
|
group=self.group)
|
||
|
else:
|
||
|
dist.reduce_scatter(self.output_shard[:self.offset],
|
||
|
list(self.buffer[:, :self.offset].unbind(0)),
|
||
|
group=self.group)
|
||
|
# execute post-reduction callbacks
|
||
|
for callback_fn in self.callbacks:
|
||
|
callback_fn()
|
||
|
# reuse input bucket but allocate a fresh output shard
|
||
|
self.buffer[:, :self.offset].zero_()
|
||
|
self.offset = 0
|
||
|
self.callbacks.clear()
|
||
|
self.output_shard = torch.zeros_like(self.buffer[0])
|
||
|
|
||
|
def alloc(self) -> None:
|
||
|
"""Setup the buffers if they are not allocated.
|
||
|
|
||
|
Using ``setup`` and ``teardown``, we can ensure that the bucket
|
||
|
buffers are only allocated during the backward pass, hence saving more
|
||
|
memory to other parts of the training process, such as the forward pass
|
||
|
for activation memory.
|
||
|
"""
|
||
|
for tensor in [self.buffer, self.output_shard]:
|
||
|
if tensor.storage().size() == 0:
|
||
|
tensor.storage().resize_(tensor.size().numel())
|
||
|
|
||
|
def free(self) -> None:
|
||
|
"""Tear down the bucket by freeing the memory"""
|
||
|
assert self.offset == 0 and self.callbacks == [], "Incorrect call of teardown"
|
||
|
for tensor in [self.buffer, self.output_shard]:
|
||
|
tensor.storage().resize_(0)
|
||
|
|
||
|
def append(self, tensor_list: List[Tensor], callback_fn: Callable):
|
||
|
# copy data from input_list into bucket
|
||
|
tensor_size = tensor_list[0].numel()
|
||
|
stacked_input = torch.stack(tensor_list).view(self.group.size(), tensor_size)
|
||
|
offset = self.offset
|
||
|
self.buffer[:, offset:offset + tensor_size].copy_(stacked_input)
|
||
|
self.offset += tensor_size
|
||
|
|
||
|
# callback will be given the reduced result
|
||
|
if callback_fn is not None:
|
||
|
result_view = self.output_shard[offset:offset + tensor_size].view_as(tensor_list[0])
|
||
|
self.callbacks.append(functools.partial(callback_fn, result_view))
|
||
|
|
||
|
|
||
|
class ReduceScatterBucketer:
|
||
|
"""
|
||
|
Helper for bucketing multiple reduce-scatter operations on small tensors
|
||
|
into larger reduce-scatter ops to improve communication efficiency.
|
||
|
|
||
|
Usage::
|
||
|
|
||
|
bucketer = ReduceScatterBucketer()
|
||
|
bucketer.reduce_scatter_async(
|
||
|
small_tensors, callback_fn=lambda result: print("small")
|
||
|
)
|
||
|
bucketer.reduce_scatter_async(
|
||
|
big_tensors, callback_fn=lambda result: print("big")
|
||
|
)
|
||
|
bucketer.reduce_scatter_async(
|
||
|
more_small_tensors, callback_fn=lambda result: print("small2")
|
||
|
)
|
||
|
bucketer.flush() # callbacks only guaranteed to be called after flush()
|
||
|
# Example output (note that it is out of order, due to bucketing):
|
||
|
# big
|
||
|
# small
|
||
|
# small2
|
||
|
|
||
|
Args:
|
||
|
bucket_size_mb (int, Optional): bucket size for communicating. Buckets
|
||
|
are sub-divided based on world_size. Values <= 0 disable bucketing.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, bucket_size_mb: int = 25):
|
||
|
self.bucket_size_mb = bucket_size_mb
|
||
|
self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {}
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def reduce_scatter_async(
|
||
|
self,
|
||
|
input_list: List[Tensor],
|
||
|
group: ProcessGroup,
|
||
|
callback_fn: Optional[Callable] = None,
|
||
|
) -> None:
|
||
|
"""
|
||
|
Reduce-scatter a list of tensors asynchronously, so smaller reductions
|
||
|
can be bucketed together. The given callback (``callback_fn``) will be
|
||
|
called with the reduced result at some later time. Call ``flush()`` to
|
||
|
force all queued ops and callbacks to be executed.
|
||
|
|
||
|
Note that large inputs will be reduced immediately, and this function
|
||
|
may also flush the relevant bucket to make room for ``input_list``.
|
||
|
|
||
|
Args:
|
||
|
input_list (List[Tensor]): list of tensors to reduce-scatter. List
|
||
|
should contain ``group.size()`` tensors and each tensor should
|
||
|
have identical shape, dtype and device.
|
||
|
group (ProcessGroup): process group for reduction
|
||
|
callback_fn (Callable, Optional): callback function to call after
|
||
|
the reduction executes. Function will be called with a single
|
||
|
argument corresponding to the reduced result.
|
||
|
"""
|
||
|
world_size = group.size()
|
||
|
|
||
|
assert (len(input_list) == world_size
|
||
|
), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})"
|
||
|
|
||
|
first_input = input_list[0]
|
||
|
first_input_size = first_input.numel()
|
||
|
|
||
|
bucket_shard_size = self._get_shard_size(first_input.element_size(), world_size)
|
||
|
if first_input_size > bucket_shard_size:
|
||
|
# TODO: investigate how to avoid using torch.cat (because it seems to be slow for CPU tensors)
|
||
|
# input is too big to fit in the bucket, reduce-scatter directly
|
||
|
output = torch.zeros_like(input_list[0])
|
||
|
if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives:
|
||
|
input_flattened = torch.cat(input_list)
|
||
|
dist._reduce_scatter_base(output, input_flattened, group=group)
|
||
|
else:
|
||
|
# fallback
|
||
|
dist.reduce_scatter(output, input_list, group=group)
|
||
|
if callback_fn is not None:
|
||
|
callback_fn(output)
|
||
|
return
|
||
|
|
||
|
bucket = self._get_bucket(first_input, group)
|
||
|
if first_input_size > bucket.buffer.size(1) - bucket.offset:
|
||
|
# not enough space remaining in bucket, flush it now
|
||
|
bucket.flush()
|
||
|
bucket.append(input_list, callback_fn)
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def flush(self) -> None:
|
||
|
"""Reduce-scatter any partial buckets."""
|
||
|
for bucket in self.buckets.values():
|
||
|
bucket.flush()
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def free(self) -> None:
|
||
|
"""Free buffers from all buckets."""
|
||
|
for bucket in self.buckets.values():
|
||
|
bucket.free()
|
||
|
|
||
|
@functools.lru_cache()
|
||
|
def _get_shard_size(self, element_size: int, num_shards: int) -> int:
|
||
|
if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing.
|
||
|
return 0
|
||
|
MB = 1024 * 1024
|
||
|
bucket_size = self.bucket_size_mb * MB / element_size
|
||
|
return int(bucket_size // num_shards)
|
||
|
|
||
|
def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket:
|
||
|
key = (tensor.dtype, tensor.device, group)
|
||
|
if key not in self.buckets:
|
||
|
# buckets are divided into world_size pieces, bucket.data shaped (world_size, shard_size)
|
||
|
world_size = group.size()
|
||
|
shard_size = self._get_shard_size(tensor.element_size(), world_size)
|
||
|
self.buckets[key] = Bucket(shard_size, tensor.dtype, tensor.device, group)
|
||
|
self.buckets[key].alloc()
|
||
|
return self.buckets[key]
|