Refact method of grad store (#2687)

pull/2738/head^2
YH 2023-02-15 23:27:58 +09:00 committed by GitHub
parent 43dffdaba5
commit ae86a29e23
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 10 deletions

View File

@ -6,7 +6,6 @@ from .base_store import BaseStore
class GradientStore(BaseStore): class GradientStore(BaseStore):
def __init__(self, *args): def __init__(self, *args):
super().__init__(*args) super().__init__(*args)
# bookkeeping data structures # bookkeeping data structures
@ -15,7 +14,7 @@ class GradientStore(BaseStore):
# for backward reduction hooks # for backward reduction hooks
self._grad_acc_objs = [] self._grad_acc_objs = []
def add_accumulate_grad_object(self, obj): def append_accumulate_grad_object(self, obj):
""" """
Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not
be attached successfully. be attached successfully.
@ -36,10 +35,12 @@ class GradientStore(BaseStore):
:return: Return the list of averaged gradients of a parameter group. Each element is a gradient, not a parameter. :return: Return the list of averaged gradients of a parameter group. Each element is a gradient, not a parameter.
:rtype: List[torch.Tensor] :rtype: List[torch.Tensor]
""" """
if group_id not in self._averaged_gradients:
self._averaged_gradients[group_id] = []
return self._averaged_gradients[group_id] return self._averaged_gradients[group_id]
def add_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None: def append_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None:
""" """
Append an average gradient to the list of averaged gradients of a parameter group Append an average gradient to the list of averaged gradients of a parameter group
@ -55,6 +56,22 @@ class GradientStore(BaseStore):
else: else:
self._averaged_gradients[group_id] = [tensor] self._averaged_gradients[group_id] = [tensor]
def add_average_gradient_by_group(
self, group_id: int, tensor_idx: int, tensor: Tensor
) -> None:
"""
Add an average gradient to the list of averaged gradients of a parameter group
:param group_id: The index of a parameter group
:param tensor_idx: The index of a tensor in the list of averaged gradients
:param tensor: A :class:`torch.Tensor` object
:type group_id: int
:type tensor_idx: int
:type tensor: torch.Tensor
"""
self._averaged_gradients[group_id][tensor_idx].add_(tensor)
def reset_average_gradients_by_group(self, group_id: int) -> None: def reset_average_gradients_by_group(self, group_id: int) -> None:
""" """
Reset the bookkeeping data structure for averaged gradients to an empty list Reset the bookkeeping data structure for averaged gradients to an empty list

View File

@ -550,20 +550,24 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
reduction_states[tensor] = False reduction_states[tensor] = False
# accumulate gradient # accumulate gradient
avg_gradients = self._grad_store._averaged_gradients
for group_id in range(self.num_param_groups): for group_id in range(self.num_param_groups):
param_group = self._param_store.get_fp16_params_by_rank_group(self._local_rank, group_id) param_group = self._param_store.get_fp16_params_by_rank_group(self._local_rank, group_id)
if group_id not in avg_gradients: avg_gradients_group = self._grad_store.get_averaged_gradients_by_group(
avg_gradients[group_id] = [] group_id
)
param_idx = 0 param_idx = 0
for param in param_group: for param in param_group:
if param.grad is not None: if param.grad is not None:
if len(avg_gradients[group_id]) == param_idx: if len(avg_gradients_group) == param_idx:
avg_gradients[group_id].append(param.grad) self._grad_store.append_average_gradient_by_group(
group_id, param.grad
)
else: else:
avg_gradients[group_id][param_idx].add_(param.grad) self._grad_store.add_average_gradient_by_group(
group_id, param_idx, param.grad
)
param_idx += 1 param_idx += 1
# the gradients needed are stored in the avg_gradients buffer # the gradients needed are stored in the avg_gradients buffer
@ -590,4 +594,4 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# only need to reduce the gradients # only need to reduce the gradients
# left in the communication bucket # left in the communication bucket
for reduce_rank in range(self._world_size): for reduce_rank in range(self._world_size):
self._run_reduction(reduce_rank) self._run_reduction(reduce_rank)