From 29dd401071020cc0efa6e4978b3550c6d6693c81 Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Thu, 24 Aug 2023 17:46:27 +0800 Subject: [PATCH] fix(train.py): fix overflow grad norm error (#230) --- internlm/solver/optimizer/hybrid_zero_optim.py | 2 +- train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index ffe56dd..8bdeccf 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -585,7 +585,7 @@ class HybridZeroOptimizer(BaseOptimizer): ) self._grad_store._averaged_gradients = dict() self.zero_grad() - return False, None + return False, norms # copy the grad of fp16 param to fp32 param single_grad_partition_groups = [] diff --git a/train.py b/train.py index a224f67..2ac6de6 100644 --- a/train.py +++ b/train.py @@ -235,7 +235,7 @@ def main(args): train_state.step_count += 1 else: train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully. - if -99.0 in grad_norm_groups and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case + if -1 in grad_norm_groups and gpc.is_rank_for_log(): # -1 encodes a specific failure case logger.warning(f"Warning: skip parameter update at step {batch_count}.") send_alert_message( address=gpc.config.alert_address,