#!/usr/bin/env python

import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors

from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER
from ._base_gradient_handler import BaseGradientHandler
from ...context.parallel_mode import ParallelMode


@GRADIENT_HANDLER.register_module
class DataParallelGradientHandler(BaseGradientHandler):
    """A helper class to handle all-reduce operations in a data parallel group.
    A all-reduce collective communication will be operated in 
    :func:`handle_gradient` among a data parallel group.
    For better performance, it bucketizes the gradients of all parameters that are 
    the same type to improve the efficiency of communication.
    """

    def handle_gradient(self):
        """A method running a all-reduce operation in a data parallel group.
        """
        # TODO: add memory buffer
        if gpc.data_parallel_size > 1:
            # bucketize and all-reduce
            buckets = {}
            # Pack the buckets.
            for param in self._model.parameters():
                if param.requires_grad and param.grad is not None:
                    tp = param.data.type()
                    if tp not in buckets:
                        buckets[tp] = []
                    buckets[tp].append(param)
                    # param.main_grad = param.grad

            # For each bucket, all-reduce and copy all-reduced grads.
            for tp in buckets:
                bucket = buckets[tp]
                grads = [param.grad.data for param in bucket]
                coalesced = _flatten_dense_tensors(grads)
                coalesced /= gpc.get_world_size(ParallelMode.DATA)

                dist.all_reduce(
                    coalesced, group=gpc.get_group(ParallelMode.DATA))
                for buf, synced in zip(grads, _unflatten_dense_tensors(
                        coalesced, grads)):
                    buf.copy_(synced)