2022-11-11 01:26:40 +00:00
|
|
|
from typing import List
|
|
|
|
|
|
|
|
from torch import Tensor
|
|
|
|
|
|
|
|
from .base_store import BaseStore
|
|
|
|
|
|
|
|
|
|
|
|
class GradientStore(BaseStore):
|
2023-06-30 07:30:50 +00:00
|
|
|
def __init__(self, *args, partition_grad: bool = False):
|
2022-11-11 01:26:40 +00:00
|
|
|
super().__init__(*args)
|
|
|
|
"""
|
2023-06-30 07:30:50 +00:00
|
|
|
self._grads_of_params mapping the paramater and its gradient slices
|
|
|
|
data structure:
|
|
|
|
{
|
|
|
|
group_id:{
|
|
|
|
param_id: [grad_rank0, grad_rank1, ...]
|
|
|
|
}
|
|
|
|
}
|
2022-11-11 01:26:40 +00:00
|
|
|
"""
|
2023-06-30 07:30:50 +00:00
|
|
|
self._grads_of_params = dict()
|
|
|
|
# for zero2, it's `param_id: [grad_local_rank]`
|
|
|
|
self._working_index = 0 if partition_grad else self._local_rank
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-10-12 03:32:37 +00:00
|
|
|
self.grad_to_param_mapping = dict()
|
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List:
|
|
|
|
"""Return list of gradient slices of a specific parameter
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
Args:
|
|
|
|
group_id (int): The index of a parameter group
|
|
|
|
param_id (int): The id of a parameter
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
Returns:
|
|
|
|
List: the list of gradient slices of a parameter.
|
2022-11-11 01:26:40 +00:00
|
|
|
"""
|
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
if group_id in self._grads_of_params:
|
|
|
|
if param_id in self._grads_of_params[group_id]:
|
|
|
|
return self._grads_of_params[group_id][param_id]
|
|
|
|
# the param has no grad, for instance, in layer drop
|
|
|
|
return []
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
def append_gradients_by_param_id(self, grad: Tensor, group_id: int, param_id: int):
|
|
|
|
"""Append a gradient slice to the parameter's gradient slice list
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
Args:
|
|
|
|
grad (Tensor): The gradient slice to append to list
|
|
|
|
group_id (int): The index of a parameter group
|
|
|
|
param_id (int): The id of a parameter
|
2022-11-11 01:26:40 +00:00
|
|
|
"""
|
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
if group_id not in self._grads_of_params:
|
|
|
|
self._grads_of_params[group_id] = dict()
|
|
|
|
if param_id not in self._grads_of_params[group_id]:
|
|
|
|
self._grads_of_params[group_id][param_id] = [grad]
|
2022-11-11 01:26:40 +00:00
|
|
|
else:
|
2023-06-30 07:30:50 +00:00
|
|
|
self._grads_of_params[group_id][param_id].append(grad)
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-10-12 03:32:37 +00:00
|
|
|
self.grad_to_param_mapping[id(grad)] = param_id
|
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
def add_gradients_by_param_id(self, grad: Tensor, grad_idx: int, group_id: int, param_id: int):
|
2023-08-25 05:44:07 +00:00
|
|
|
"""Add a gradient slice on an existing slice of the parameter's gradient
|
|
|
|
Used when no_sync is not activated.
|
2023-06-30 07:30:50 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
grad (Tensor): The split gradient to append to list
|
|
|
|
grad_idx (int): The index of the existing slice
|
|
|
|
group_id (int): The index of a parameter group
|
|
|
|
param_id (int): The id of a parameter
|
2023-02-15 14:27:58 +00:00
|
|
|
"""
|
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
self._grads_of_params[group_id][param_id][grad_idx].add_(grad)
|
2023-02-15 14:27:58 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
def get_working_grads_by_group_id(self, group_id: int) -> List:
|
|
|
|
"""Return list of working gradient slices in the group
|
2023-02-15 14:27:58 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
Args:
|
|
|
|
group_id (int): The index of a parameter group
|
2022-11-11 01:26:40 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
Returns:
|
|
|
|
List: the list working gradient slices in the group
|
2022-11-11 01:26:40 +00:00
|
|
|
"""
|
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
grad_list = []
|
|
|
|
for param_grads in self._grads_of_params[group_id].values():
|
|
|
|
grad_list.append(param_grads[self._working_index])
|
2023-02-27 06:04:53 +00:00
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
return grad_list
|
|
|
|
|
2023-10-12 03:32:37 +00:00
|
|
|
def get_working_grad_by_param_id(self, param_id) -> Tensor:
|
|
|
|
"""
|
|
|
|
Return the working gradient for the specified parameter.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
param_id (int): The index of the parameter.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tensor: The the working gradient slices for the specified param_id.
|
|
|
|
"""
|
|
|
|
|
|
|
|
for group in self._grads_of_params.values():
|
|
|
|
if param_id in group.keys():
|
|
|
|
return group[param_id][self._working_index]
|
|
|
|
|
|
|
|
raise KeyError(f"Working gradient for param_id {param_id} not found.")
|
|
|
|
|
2023-06-30 07:30:50 +00:00
|
|
|
def reset_grads_by_group_id(self, group_id: int):
|
|
|
|
self._grads_of_params[group_id] = dict()
|
|
|
|
|
|
|
|
def reset_all_gradients(self):
|
|
|
|
self._grads_of_params = dict()
|
2023-10-12 03:32:37 +00:00
|
|
|
|
|
|
|
def get_param_id_for_grad(self, grad: Tensor) -> int:
|
|
|
|
"""Return the id of a parameter which the gradient slice belongs to
|
|
|
|
|
|
|
|
Args:
|
|
|
|
grad (Tensor): the gradient slice
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
int: the id of a parameter which the gradient slice belongs to
|
|
|
|
"""
|
|
|
|
|
|
|
|
return self.grad_to_param_mapping[id(grad)]
|