diff --git a/colossalai/zero/sharded_optim/bookkeeping/gradient_store.py b/colossalai/zero/sharded_optim/bookkeeping/gradient_store.py index b166752cc..942d7186e 100644 --- a/colossalai/zero/sharded_optim/bookkeeping/gradient_store.py +++ b/colossalai/zero/sharded_optim/bookkeeping/gradient_store.py @@ -6,6 +6,7 @@ from .base_store import BaseStore class GradientStore(BaseStore): + def __init__(self, *args): super().__init__(*args) # bookkeeping data structures @@ -56,9 +57,7 @@ class GradientStore(BaseStore): else: self._averaged_gradients[group_id] = [tensor] - def add_average_gradient_by_group( - self, group_id: int, tensor_idx: int, tensor: Tensor - ) -> None: + def add_average_gradient_by_group(self, group_id: int, tensor_idx: int, tensor: Tensor) -> None: """ Add an average gradient to the list of averaged gradients of a parameter group @@ -81,3 +80,9 @@ class GradientStore(BaseStore): """ self._averaged_gradients[group_id] = [] + + def reset_all_average_gradients(self) -> None: + """ + Reset the bookkeeping data structure for averaged gradients to an empty list + """ + self._averaged_gradients = dict() diff --git a/colossalai/zero/sharded_optim/low_level_optim.py b/colossalai/zero/sharded_optim/low_level_optim.py index f5e03ce28..502b1c4d9 100644 --- a/colossalai/zero/sharded_optim/low_level_optim.py +++ b/colossalai/zero/sharded_optim/low_level_optim.py @@ -416,7 +416,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): :param set_to_none: Whether set the gradient to None. Default value is True. :type set_to_none: bool """ - for group_id, param_group in self._fp16_param_groups.items(): + for _, param_group in self._fp16_param_groups.items(): for param in param_group: if set_to_none: param.grad = None @@ -438,7 +438,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): # update loss scale if overflow occurs if found_inf: - self._grad_store._averaged_gradients = dict() + self._grad_store.reset_all_average_gradients() self.zero_grad() return @@ -448,7 +448,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): for group_id in range(self.num_param_groups): # compute norm - norm_group = compute_norm(gradients=self._grad_store._averaged_gradients[group_id], + norm_group = compute_norm(gradients=self._grad_store.get_averaged_gradients_by_group(group_id), params=self._param_store.get_fp16_params_by_rank_group(group_id=group_id, rank=self._local_rank), dp_group=self._dp_torch_group, @@ -469,8 +469,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): single_grad_partition_groups.append(flat_fp32_avg_grads) device = self._fp32_flat_param_groups_of_current_rank[group_id].device self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device) - self._grad_store._averaged_gradients[group_id] = [] - self._grad_store._averaged_gradients[group_id] = [] + self._grad_store.reset_average_gradients_by_group(group_id) # unscale and clip grads global_norm = calculate_global_norm_from_list(norm_list=norm_groups) @@ -546,28 +545,22 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): def _sync_grad(self): # update param already reduced flag reduction_states = self._param_store.get_param_reduction_states() - for tensor, state in reduction_states.items(): + for tensor, _ in reduction_states.items(): reduction_states[tensor] = False # accumulate gradient for group_id in range(self.num_param_groups): param_group = self._param_store.get_fp16_params_by_rank_group(self._local_rank, group_id) - avg_gradients_group = self._grad_store.get_averaged_gradients_by_group( - group_id - ) + avg_gradients_group = self._grad_store.get_averaged_gradients_by_group(group_id) param_idx = 0 for param in param_group: if param.grad is not None: if len(avg_gradients_group) == param_idx: - self._grad_store.append_average_gradient_by_group( - group_id, param.grad - ) + self._grad_store.append_average_gradient_by_group(group_id, param.grad) else: - self._grad_store.add_average_gradient_by_group( - group_id, param_idx, param.grad - ) + self._grad_store.add_average_gradient_by_group(group_id, param_idx, param.grad) param_idx += 1 # the gradients needed are stored in the avg_gradients buffer @@ -594,4 +587,4 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): # only need to reduce the gradients # left in the communication bucket for reduce_rank in range(self._world_size): - self._run_reduction(reduce_rank) \ No newline at end of file + self._run_reduction(reduce_rank)