from typing import List

from torch import Tensor

from .base_store import BaseStore


class GradientStore(BaseStore):

    def __init__(self, *args):
        super().__init__(*args)
        # bookkeeping data structures
        self._averaged_gradients = dict()

        # for backward reduction hooks
        self._grad_acc_objs = []

    def append_accumulate_grad_object(self, obj):
        """
        Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not
        be attached successfully.

        :param obj: An object of :class:`AccumulateGrad` class
        :type obj: :class:`AccumulateGrad`
        """

        self._grad_acc_objs.append(obj)

    def get_averaged_gradients_by_group(self, group_id: int) -> List[Tensor]:
        """
        Return average gradients of a parameter group

        :param group_id: The index of parameter group
        :type group_id: int

        :return: Return the list of averaged gradients of a parameter group. Each element is a gradient, not a parameter.
        :rtype: List[torch.Tensor]
        """
        if group_id not in self._averaged_gradients:
            self._averaged_gradients[group_id] = []

        return self._averaged_gradients[group_id]

    def append_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None:
        """
        Append an average gradient to the list of averaged gradients of a parameter group

        :param group_id: The index of a parameter group
        :param tensor: A :class:`torch.Tensor` object
        :type group_id: int
        :type tensor: torch.Tensor

        """

        if group_id in self._averaged_gradients:
            self._averaged_gradients[group_id].append(tensor)
        else:
            self._averaged_gradients[group_id] = [tensor]

    def add_average_gradient_by_group(self, group_id: int, tensor_idx: int, tensor: Tensor) -> None:
        """
        Add an average gradient to the list of averaged gradients of a parameter group

        :param group_id: The index of a parameter group
        :param tensor_idx: The index of a tensor in the list of averaged gradients
        :param tensor: A :class:`torch.Tensor` object
        :type group_id: int
        :type tensor_idx: int
        :type tensor: torch.Tensor

        """
        self._averaged_gradients[group_id][tensor_idx].add_(tensor)

    def reset_average_gradients_by_group(self, group_id: int) -> None:
        """
        Reset the bookkeeping data structure for averaged gradients to an empty list

        :param group_id: The index of a parameter group
        :type group_id: int
        """

        self._averaged_gradients[group_id] = []

    def reset_all_average_gradients(self) -> None:
        """
        Reset the bookkeeping data structure for averaged gradients to an empty list
        """
        self._averaged_gradients = dict()