mirror of https://github.com/hpcaitech/ColossalAI
67 lines
2.0 KiB
Python
67 lines
2.0 KiB
Python
|
from typing import List
|
||
|
from torch import Tensor
|
||
|
from .base_store import BaseStore
|
||
|
|
||
|
|
||
|
class GradientStore(BaseStore):
|
||
|
|
||
|
def __init__(self, *args):
|
||
|
super().__init__(*args)
|
||
|
# bookkeeping data structures
|
||
|
self._averaged_gradients = dict()
|
||
|
|
||
|
# for backward reduction hooks
|
||
|
self._grad_acc_objs = []
|
||
|
|
||
|
def add_accumulate_grad_object(self, obj):
|
||
|
"""
|
||
|
Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not
|
||
|
be attached successfully.
|
||
|
|
||
|
:param obj: An object of :class:`AccumulateGrad` class
|
||
|
:type obj: :class:`AccumulateGrad`
|
||
|
"""
|
||
|
|
||
|
self._grad_acc_objs.append(obj)
|
||
|
|
||
|
def get_averaged_gradients_by_group(self, group_id: int) -> List[Tensor]:
|
||
|
"""
|
||
|
Return average gradients of a parameter group
|
||
|
|
||
|
:param group_id: The index of parameter group
|
||
|
:type group_id: int
|
||
|
|
||
|
:return: Return the list of averaged gradients of a parameter group. Each element is a gradient, not a parameter.
|
||
|
:rtype: List[torch.Tensor]
|
||
|
"""
|
||
|
|
||
|
return self._averaged_gradients[group_id]
|
||
|
|
||
|
def add_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None:
|
||
|
"""
|
||
|
Append an average gradient to the list of averaged gradients of a parameter group
|
||
|
|
||
|
:param group_id: The index of a parameter group
|
||
|
:param tensor: A :class:`torch.Tensor` object
|
||
|
:type group_id: int
|
||
|
:type tensor: torch.Tensor
|
||
|
|
||
|
"""
|
||
|
|
||
|
if group_id in self._averaged_gradients:
|
||
|
self._averaged_gradients[group_id].append(tensor)
|
||
|
else:
|
||
|
self._averaged_gradients[group_id] = [tensor]
|
||
|
|
||
|
def reset_average_gradients_by_group(self, group_id: int) -> None:
|
||
|
"""
|
||
|
Reset the bookkeeping data structure for averaged gradients to an empty list
|
||
|
|
||
|
:param group_id: The index of a parameter group
|
||
|
:type group_id: int
|
||
|
"""
|
||
|
|
||
|
self._averaged_gradients[group_id] = []
|
||
|
|
||
|
|