mirror of https://github.com/hpcaitech/ColossalAI
[zero] trivial zero optimizer refactoring (#2869)
* Fix mionr grad store interface * Apply lintpull/2914/head
parent
dbc01b9c04
commit
7b13f7db18
|
@ -6,6 +6,7 @@ from .base_store import BaseStore
|
|||
|
||||
|
||||
class GradientStore(BaseStore):
|
||||
|
||||
def __init__(self, *args):
|
||||
super().__init__(*args)
|
||||
# bookkeeping data structures
|
||||
|
@ -56,9 +57,7 @@ class GradientStore(BaseStore):
|
|||
else:
|
||||
self._averaged_gradients[group_id] = [tensor]
|
||||
|
||||
def add_average_gradient_by_group(
|
||||
self, group_id: int, tensor_idx: int, tensor: Tensor
|
||||
) -> None:
|
||||
def add_average_gradient_by_group(self, group_id: int, tensor_idx: int, tensor: Tensor) -> None:
|
||||
"""
|
||||
Add an average gradient to the list of averaged gradients of a parameter group
|
||||
|
||||
|
@ -81,3 +80,9 @@ class GradientStore(BaseStore):
|
|||
"""
|
||||
|
||||
self._averaged_gradients[group_id] = []
|
||||
|
||||
def reset_all_average_gradients(self) -> None:
|
||||
"""
|
||||
Reset the bookkeeping data structure for averaged gradients to an empty list
|
||||
"""
|
||||
self._averaged_gradients = dict()
|
||||
|
|
|
@ -416,7 +416,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
:param set_to_none: Whether set the gradient to None. Default value is True.
|
||||
:type set_to_none: bool
|
||||
"""
|
||||
for group_id, param_group in self._fp16_param_groups.items():
|
||||
for _, param_group in self._fp16_param_groups.items():
|
||||
for param in param_group:
|
||||
if set_to_none:
|
||||
param.grad = None
|
||||
|
@ -438,7 +438,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
|
||||
# update loss scale if overflow occurs
|
||||
if found_inf:
|
||||
self._grad_store._averaged_gradients = dict()
|
||||
self._grad_store.reset_all_average_gradients()
|
||||
self.zero_grad()
|
||||
return
|
||||
|
||||
|
@ -448,7 +448,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
|
||||
for group_id in range(self.num_param_groups):
|
||||
# compute norm
|
||||
norm_group = compute_norm(gradients=self._grad_store._averaged_gradients[group_id],
|
||||
norm_group = compute_norm(gradients=self._grad_store.get_averaged_gradients_by_group(group_id),
|
||||
params=self._param_store.get_fp16_params_by_rank_group(group_id=group_id,
|
||||
rank=self._local_rank),
|
||||
dp_group=self._dp_torch_group,
|
||||
|
@ -469,8 +469,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
single_grad_partition_groups.append(flat_fp32_avg_grads)
|
||||
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
|
||||
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
|
||||
self._grad_store._averaged_gradients[group_id] = []
|
||||
self._grad_store._averaged_gradients[group_id] = []
|
||||
self._grad_store.reset_average_gradients_by_group(group_id)
|
||||
|
||||
# unscale and clip grads
|
||||
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
|
||||
|
@ -546,28 +545,22 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
def _sync_grad(self):
|
||||
# update param already reduced flag
|
||||
reduction_states = self._param_store.get_param_reduction_states()
|
||||
for tensor, state in reduction_states.items():
|
||||
for tensor, _ in reduction_states.items():
|
||||
reduction_states[tensor] = False
|
||||
|
||||
# accumulate gradient
|
||||
for group_id in range(self.num_param_groups):
|
||||
param_group = self._param_store.get_fp16_params_by_rank_group(self._local_rank, group_id)
|
||||
|
||||
avg_gradients_group = self._grad_store.get_averaged_gradients_by_group(
|
||||
group_id
|
||||
)
|
||||
avg_gradients_group = self._grad_store.get_averaged_gradients_by_group(group_id)
|
||||
|
||||
param_idx = 0
|
||||
for param in param_group:
|
||||
if param.grad is not None:
|
||||
if len(avg_gradients_group) == param_idx:
|
||||
self._grad_store.append_average_gradient_by_group(
|
||||
group_id, param.grad
|
||||
)
|
||||
self._grad_store.append_average_gradient_by_group(group_id, param.grad)
|
||||
else:
|
||||
self._grad_store.add_average_gradient_by_group(
|
||||
group_id, param_idx, param.grad
|
||||
)
|
||||
self._grad_store.add_average_gradient_by_group(group_id, param_idx, param.grad)
|
||||
param_idx += 1
|
||||
|
||||
# the gradients needed are stored in the avg_gradients buffer
|
||||
|
|
Loading…
Reference in New Issue