mirror of https://github.com/hpcaitech/ColossalAI
89 lines
2.8 KiB
Python
89 lines
2.8 KiB
Python
from typing import List
|
|
|
|
from torch import Tensor
|
|
|
|
from .base_store import BaseStore
|
|
|
|
|
|
class GradientStore(BaseStore):
|
|
|
|
def __init__(self, *args):
|
|
super().__init__(*args)
|
|
# bookkeeping data structures
|
|
self._averaged_gradients = dict()
|
|
|
|
# for backward reduction hooks
|
|
self._grad_acc_objs = []
|
|
|
|
def append_accumulate_grad_object(self, obj):
|
|
"""
|
|
Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not
|
|
be attached successfully.
|
|
|
|
:param obj: An object of :class:`AccumulateGrad` class
|
|
:type obj: :class:`AccumulateGrad`
|
|
"""
|
|
|
|
self._grad_acc_objs.append(obj)
|
|
|
|
def get_averaged_gradients_by_group(self, group_id: int) -> List[Tensor]:
|
|
"""
|
|
Return average gradients of a parameter group
|
|
|
|
:param group_id: The index of parameter group
|
|
:type group_id: int
|
|
|
|
:return: Return the list of averaged gradients of a parameter group. Each element is a gradient, not a parameter.
|
|
:rtype: List[torch.Tensor]
|
|
"""
|
|
if group_id not in self._averaged_gradients:
|
|
self._averaged_gradients[group_id] = []
|
|
|
|
return self._averaged_gradients[group_id]
|
|
|
|
def append_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None:
|
|
"""
|
|
Append an average gradient to the list of averaged gradients of a parameter group
|
|
|
|
:param group_id: The index of a parameter group
|
|
:param tensor: A :class:`torch.Tensor` object
|
|
:type group_id: int
|
|
:type tensor: torch.Tensor
|
|
|
|
"""
|
|
|
|
if group_id in self._averaged_gradients:
|
|
self._averaged_gradients[group_id].append(tensor)
|
|
else:
|
|
self._averaged_gradients[group_id] = [tensor]
|
|
|
|
def add_average_gradient_by_group(self, group_id: int, tensor_idx: int, tensor: Tensor) -> None:
|
|
"""
|
|
Add an average gradient to the list of averaged gradients of a parameter group
|
|
|
|
:param group_id: The index of a parameter group
|
|
:param tensor_idx: The index of a tensor in the list of averaged gradients
|
|
:param tensor: A :class:`torch.Tensor` object
|
|
:type group_id: int
|
|
:type tensor_idx: int
|
|
:type tensor: torch.Tensor
|
|
|
|
"""
|
|
self._averaged_gradients[group_id][tensor_idx].add_(tensor)
|
|
|
|
def reset_average_gradients_by_group(self, group_id: int) -> None:
|
|
"""
|
|
Reset the bookkeeping data structure for averaged gradients to an empty list
|
|
|
|
:param group_id: The index of a parameter group
|
|
:type group_id: int
|
|
"""
|
|
|
|
self._averaged_gradients[group_id] = []
|
|
|
|
def reset_all_average_gradients(self) -> None:
|
|
"""
|
|
Reset the bookkeeping data structure for averaged gradients to an empty list
|
|
"""
|
|
self._averaged_gradients = dict()
|