diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 2f53963..19219f7 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -817,6 +817,8 @@ class HybridZeroOptimizer(BaseOptimizer): self.optim.step() # release the fp32 grad for group_id in range(self.num_param_groups): + if not self.param_group_has_params[group_id]: + continue release_param_grad(self._fp32_orig_param_groups_of_current_rank[group_id]) if self._enable_memory_balance and self._memory_balance_role == 1: