diff --git a/colossalai/zero/zero_optimizer.py b/colossalai/zero/zero_optimizer.py index e5a2f9f90..bc23123b0 100644 --- a/colossalai/zero/zero_optimizer.py +++ b/colossalai/zero/zero_optimizer.py @@ -139,6 +139,7 @@ class ZeroOptimizer(ColossalaiOptimizer): self._update_params_ptr() ret = self.optim.step(*args, **kwargs) self._register_states() + self.zero_grad() self._update_fp16_params() return ret