[zero] trivial zero optimizer refactoring (#2869)

* Fix mionr grad store interface

* Apply lint
pull/2914/head
YH 2 years ago committed by GitHub
parent dbc01b9c04
commit 7b13f7db18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -6,6 +6,7 @@ from .base_store import BaseStore
class GradientStore(BaseStore):
def __init__(self, *args):
super().__init__(*args)
# bookkeeping data structures
@ -56,9 +57,7 @@ class GradientStore(BaseStore):
else:
self._averaged_gradients[group_id] = [tensor]
def add_average_gradient_by_group(
self, group_id: int, tensor_idx: int, tensor: Tensor
) -> None:
def add_average_gradient_by_group(self, group_id: int, tensor_idx: int, tensor: Tensor) -> None:
"""
Add an average gradient to the list of averaged gradients of a parameter group
@ -81,3 +80,9 @@ class GradientStore(BaseStore):
"""
self._averaged_gradients[group_id] = []
def reset_all_average_gradients(self) -> None:
"""
Reset the bookkeeping data structure for averaged gradients to an empty list
"""
self._averaged_gradients = dict()

@ -416,7 +416,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
:param set_to_none: Whether set the gradient to None. Default value is True.
:type set_to_none: bool
"""
for group_id, param_group in self._fp16_param_groups.items():
for _, param_group in self._fp16_param_groups.items():
for param in param_group:
if set_to_none:
param.grad = None
@ -438,7 +438,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# update loss scale if overflow occurs
if found_inf:
self._grad_store._averaged_gradients = dict()
self._grad_store.reset_all_average_gradients()
self.zero_grad()
return
@ -448,7 +448,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
for group_id in range(self.num_param_groups):
# compute norm
norm_group = compute_norm(gradients=self._grad_store._averaged_gradients[group_id],
norm_group = compute_norm(gradients=self._grad_store.get_averaged_gradients_by_group(group_id),
params=self._param_store.get_fp16_params_by_rank_group(group_id=group_id,
rank=self._local_rank),
dp_group=self._dp_torch_group,
@ -469,8 +469,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
single_grad_partition_groups.append(flat_fp32_avg_grads)
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
self._grad_store._averaged_gradients[group_id] = []
self._grad_store._averaged_gradients[group_id] = []
self._grad_store.reset_average_gradients_by_group(group_id)
# unscale and clip grads
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
@ -546,28 +545,22 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
def _sync_grad(self):
# update param already reduced flag
reduction_states = self._param_store.get_param_reduction_states()
for tensor, state in reduction_states.items():
for tensor, _ in reduction_states.items():
reduction_states[tensor] = False
# accumulate gradient
for group_id in range(self.num_param_groups):
param_group = self._param_store.get_fp16_params_by_rank_group(self._local_rank, group_id)
avg_gradients_group = self._grad_store.get_averaged_gradients_by_group(
group_id
)
avg_gradients_group = self._grad_store.get_averaged_gradients_by_group(group_id)
param_idx = 0
for param in param_group:
if param.grad is not None:
if len(avg_gradients_group) == param_idx:
self._grad_store.append_average_gradient_by_group(
group_id, param.grad
)
self._grad_store.append_average_gradient_by_group(group_id, param.grad)
else:
self._grad_store.add_average_gradient_by_group(
group_id, param_idx, param.grad
)
self._grad_store.add_average_gradient_by_group(group_id, param_idx, param.grad)
param_idx += 1
# the gradients needed are stored in the avg_gradients buffer
@ -594,4 +587,4 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# only need to reduce the gradients
# left in the communication bucket
for reduce_rank in range(self._world_size):
self._run_reduction(reduce_rank)
self._run_reduction(reduce_rank)

Loading…
Cancel
Save