from typing import Iterable

import torch.distributed as dist
import torch.nn as nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors


def bucket_allreduce(param_list: Iterable[nn.Parameter], group=None):
    # get communication world size
    comm_size = dist.get_world_size(group)
    # bucketize and all-reduce
    buckets = {}
    # Pack the buckets.
    for param in param_list:
        if param.requires_grad and param.grad is not None:
            tp = param.data.type()
            if tp not in buckets:
                buckets[tp] = []
            buckets[tp].append(param)

    # For each bucket, all-reduce and copy all-reduced grads.
    for tp in buckets:
        bucket = buckets[tp]
        grads = [param.grad.data for param in bucket]
        coalesced = _flatten_dense_tensors(grads)
        coalesced /= comm_size

        dist.all_reduce(coalesced, group=group)
        for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
            buf.copy_(synced)