from typing import List from torch import Tensor from .base_store import BaseStore class GradientStore(BaseStore): def __init__(self, *args): super().__init__(*args) # bookkeeping data structures self._averaged_gradients = dict() # for backward reduction hooks self._grad_acc_objs = [] def append_accumulate_grad_object(self, obj): """ Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not be attached successfully. :param obj: An object of :class:`AccumulateGrad` class :type obj: :class:`AccumulateGrad` """ self._grad_acc_objs.append(obj) def get_averaged_gradients_by_group(self, group_id: int) -> List[Tensor]: """ Return average gradients of a parameter group :param group_id: The index of parameter group :type group_id: int :return: Return the list of averaged gradients of a parameter group. Each element is a gradient, not a parameter. :rtype: List[torch.Tensor] """ if group_id not in self._averaged_gradients: self._averaged_gradients[group_id] = [] return self._averaged_gradients[group_id] def append_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None: """ Append an average gradient to the list of averaged gradients of a parameter group :param group_id: The index of a parameter group :param tensor: A :class:`torch.Tensor` object :type group_id: int :type tensor: torch.Tensor """ if group_id in self._averaged_gradients: self._averaged_gradients[group_id].append(tensor) else: self._averaged_gradients[group_id] = [tensor] def add_average_gradient_by_group( self, group_id: int, tensor_idx: int, tensor: Tensor ) -> None: """ Add an average gradient to the list of averaged gradients of a parameter group :param group_id: The index of a parameter group :param tensor_idx: The index of a tensor in the list of averaged gradients :param tensor: A :class:`torch.Tensor` object :type group_id: int :type tensor_idx: int :type tensor: torch.Tensor """ self._averaged_gradients[group_id][tensor_idx].add_(tensor) def reset_average_gradients_by_group(self, group_id: int) -> None: """ Reset the bookkeeping data structure for averaged gradients to an empty list :param group_id: The index of a parameter group :type group_id: int """ self._averaged_gradients[group_id] = []