# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import functools
from typing import Callable, Dict, List, Optional, Tuple

import torch
import torch.distributed as dist
from torch import Tensor
from torch.distributed import ProcessGroup


class Bucket:

    def __init__(self, size: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup):
        self.buffer = torch.zeros(size, dtype=dtype, device=device)
        self.group = group
        self.offset = 0
        self.callbacks: List[Callable] = []

    def flush(self) -> None:
        """Flush content of the bucket."""
        if self.offset == 0:
            assert len(self.callbacks) == 0
            return
        # reduce-scatter bucket
        dist.all_reduce(self.buffer[:self.offset], group=self.group)

        # execute post-reduction callbacks
        for callback_fn in self.callbacks:
            callback_fn()
        # reuse input bucket but allocate a fresh output shard
        self.offset = 0
        self.callbacks.clear()
        self.buffer = torch.zeros_like(self.buffer)

    def alloc(self) -> None:

        if self.buffer.storage().size() == 0:
            self.buffer.storage().resize_(self.buffer.numel())

    def free(self) -> None:

        assert self.offset == 0 and self.callbacks == [], "Incorrect call of teardown"
        self.buffer.storage().resize_(0)

    def append(self, tensor: Tensor, callback_fn: Callable):
        tensor_size = tensor.numel()
        offset = self.offset
        self.buffer[offset:offset + tensor_size].copy_(tensor.flatten())
        self.offset += tensor_size

        # callback will be given the reduced result
        if callback_fn is not None:
            result_view = self.buffer[offset:offset + tensor_size].view(tensor.shape)
            self.callbacks.append(functools.partial(callback_fn, result_view))

    @property
    def avail_size(self) -> int:
        return self.buffer.size(0) - self.offset


class Reducer:

    def __init__(self, bucket_size_mb: int = 25):
        self.bucket_size_mb = bucket_size_mb
        self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {}

    @torch.no_grad()
    def all_reduce_async(
        self,
        tensor: Tensor,
        group: ProcessGroup,
        callback_fn: Optional[Callable] = None,
    ) -> None:
        bucket_size = self._get_bucket_size(tensor.element_size())

        if tensor.numel() >= bucket_size:
            dist.all_reduce(tensor, group=group)
            if callback_fn is not None:
                callback_fn(tensor)
            return

        bucket = self._get_bucket(tensor, group)
        if tensor.numel() > bucket.avail_size:
            # not enough space remaining in bucket, flush it now
            bucket.flush()
        bucket.append(tensor, callback_fn)

    @torch.no_grad()
    def flush(self) -> None:
        for bucket in self.buckets.values():
            bucket.flush()

    @torch.no_grad()
    def free(self) -> None:
        for bucket in self.buckets.values():
            bucket.free()

    @functools.lru_cache()
    def _get_bucket_size(self, element_size: int) -> int:
        if self.bucket_size_mb <= 0:    # Values <= 0 disable bucketing.
            return 0
        MB = 1024 * 1024
        bucket_size = self.bucket_size_mb * MB / element_size
        return int(bucket_size)

    def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket:
        key = (tensor.dtype, tensor.device, group)
        if key not in self.buckets:
            bucket_size = self._get_bucket_size(tensor.element_size())
            self.buckets[key] = Bucket(bucket_size, tensor.dtype, tensor.device, group)
        self.buckets[key].alloc()
        return self.buckets[key]