ColossalAI/colossalai/zero/sharded_optim/bookkeeping/gradient_store.py

89 lines
2.8 KiB
Python
Raw Normal View History

from typing import List
from torch import Tensor
from .base_store import BaseStore
class GradientStore(BaseStore):
def __init__(self, *args):
super().__init__(*args)
# bookkeeping data structures
self._averaged_gradients = dict()
# for backward reduction hooks
self._grad_acc_objs = []
2023-02-15 14:27:58 +00:00
def append_accumulate_grad_object(self, obj):
"""
Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not
be attached successfully.
:param obj: An object of :class:`AccumulateGrad` class
:type obj: :class:`AccumulateGrad`
"""
self._grad_acc_objs.append(obj)
def get_averaged_gradients_by_group(self, group_id: int) -> List[Tensor]:
"""
Return average gradients of a parameter group
:param group_id: The index of parameter group
:type group_id: int
:return: Return the list of averaged gradients of a parameter group. Each element is a gradient, not a parameter.
:rtype: List[torch.Tensor]
"""
2023-02-15 14:27:58 +00:00
if group_id not in self._averaged_gradients:
self._averaged_gradients[group_id] = []
return self._averaged_gradients[group_id]
2023-02-15 14:27:58 +00:00
def append_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None:
"""
Append an average gradient to the list of averaged gradients of a parameter group
:param group_id: The index of a parameter group
:param tensor: A :class:`torch.Tensor` object
:type group_id: int
:type tensor: torch.Tensor
"""
if group_id in self._averaged_gradients:
self._averaged_gradients[group_id].append(tensor)
else:
self._averaged_gradients[group_id] = [tensor]
def add_average_gradient_by_group(self, group_id: int, tensor_idx: int, tensor: Tensor) -> None:
2023-02-15 14:27:58 +00:00
"""
Add an average gradient to the list of averaged gradients of a parameter group
:param group_id: The index of a parameter group
:param tensor_idx: The index of a tensor in the list of averaged gradients
:param tensor: A :class:`torch.Tensor` object
:type group_id: int
:type tensor_idx: int
:type tensor: torch.Tensor
"""
self._averaged_gradients[group_id][tensor_idx].add_(tensor)
def reset_average_gradients_by_group(self, group_id: int) -> None:
"""
Reset the bookkeeping data structure for averaged gradients to an empty list
:param group_id: The index of a parameter group
:type group_id: int
"""
self._averaged_gradients[group_id] = []
def reset_all_average_gradients(self) -> None:
"""
Reset the bookkeeping data structure for averaged gradients to an empty list
"""
self._averaged_gradients = dict()