2022-11-11 01:26:40 +00:00
|
|
|
from typing import List
|
|
|
|
|
|
|
|
from torch import Tensor
|
2023-06-30 07:30:50 +00:00
|
|
|
from torch._utils import _flatten_dense_tensors
|
2022-11-11 01:26:40 +00:00
|
|
|
|
|
|
|
from .base_store import BaseStore
|
|
|
|
|
|
|
|
|
|
|
|
class GradientStore(BaseStore):
|
2023-02-27 06:04:53 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
def __init__(self, *args, partition_grad: bool = False):
|
2022-11-11 01:26:40 +00:00
|
|
|
super().__init__(*args)
|
|
|
|
"""
|
2023-06-30 07:30:50 +00:00
|
|
|
self._grads_of_params mapping the paramater and its gradient slices
|
|
|
|
data structure:
|
|
|
|
{
|
|
|
|
group_id:{
|
|
|
|
param_id: [grad_rank0, grad_rank1, ...]
|
|
|
|
}
|
|
|
|
}
|
2022-11-11 01:26:40 +00:00
|
|
|
"""
|
2023-06-30 07:30:50 +00:00
|
|
|
self._grads_of_params = dict()
|
|
|
|
# for zero2, it's `param_id: [grad_local_rank]`
|
|
|
|
self._working_index = 0 if partition_grad else self._local_rank
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List:
|
|
|
|
"""Return list of gradient slices of a specific parameter
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
Args:
|
|
|
|
group_id (int): The index of a parameter group
|
|
|
|
param_id (int): The id of a parameter
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
Returns:
|
|
|
|
List: the list of gradient slices of a parameter.
|
2022-11-11 01:26:40 +00:00
|
|
|
"""
|
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
if group_id in self._grads_of_params:
|
|
|
|
if param_id in self._grads_of_params[group_id]:
|
|
|
|
return self._grads_of_params[group_id][param_id]
|
|
|
|
# the param has no grad, for instance, in layer drop
|
|
|
|
return []
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
def append_gradients_by_param_id(self, grad: Tensor, group_id: int, param_id: int):
|
|
|
|
"""Append a gradient slice to the parameter's gradient slice list
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
Args:
|
|
|
|
grad (Tensor): The gradient slice to append to list
|
|
|
|
group_id (int): The index of a parameter group
|
|
|
|
param_id (int): The id of a parameter
|
2022-11-11 01:26:40 +00:00
|
|
|
"""
|
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
if group_id not in self._grads_of_params:
|
|
|
|
self._grads_of_params[group_id] = dict()
|
|
|
|
if param_id not in self._grads_of_params[group_id]:
|
|
|
|
self._grads_of_params[group_id][param_id] = [grad]
|
2022-11-11 01:26:40 +00:00
|
|
|
else:
|
2023-06-30 07:30:50 +00:00
|
|
|
self._grads_of_params[group_id][param_id].append(grad)
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
def add_gradients_by_param_id(self, grad: Tensor, grad_idx: int, group_id: int, param_id: int):
|
|
|
|
"""For old gradient accumulation, not in use now.
|
|
|
|
Add a gradient slice on an existing slice of the parameter's gradient
|
|
|
|
|
|
|
|
Args:
|
|
|
|
grad (Tensor): The split gradient to append to list
|
|
|
|
grad_idx (int): The index of the existing slice
|
|
|
|
group_id (int): The index of a parameter group
|
|
|
|
param_id (int): The id of a parameter
|
2023-02-15 14:27:58 +00:00
|
|
|
"""
|
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
self._grads_of_params[group_id][param_id][grad_idx].add_(grad)
|
2023-02-15 14:27:58 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
def get_working_grads_by_group_id(self, group_id: int) -> List:
|
|
|
|
"""Return list of working gradient slices in the group
|
2023-02-15 14:27:58 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
Args:
|
|
|
|
group_id (int): The index of a parameter group
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
Returns:
|
|
|
|
List: the list working gradient slices in the group
|
2022-11-11 01:26:40 +00:00
|
|
|
"""
|
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
grad_list = []
|
|
|
|
for param_grads in self._grads_of_params[group_id].values():
|
|
|
|
grad_list.append(param_grads[self._working_index])
|
2023-02-27 06:04:53 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
return grad_list
|
|
|
|
|
|
|
|
def reset_grads_by_group_id(self, group_id: int):
|
|
|
|
self._grads_of_params[group_id] = dict()
|
|
|
|
|
|
|
|
def reset_all_gradients(self):
|
|
|
|
self._grads_of_params = dict()
|