diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 3fc3338..01b40ab 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -639,9 +639,9 @@ class HybridZeroOptimizer(BaseOptimizer): groups_param_norms = [] group_param_zero_grad_count = [] group_vocab_norms = [] - batch_count = gpc.config.batch_count + batch_count = gpc.config.get("batch_count") interval_steps = grad_profiling_config.get("interval_steps", 1) - is_profiling = batch_count % interval_steps == 0 + is_profiling = batch_count % interval_steps == 0 if batch_count is not None else False for group_id in range(self.num_param_groups): groups_norms.append(self._compute_norm_with_stage(group_id=group_id)) if is_profiling: