diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 46121e4..e423ea6 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -485,7 +485,8 @@ def record_current_batch_training_metrics( "perplexity": acc_perplex["perplexity"], "fwd_bwd_time": fwd_bwd_time, } - panel_metrics["moe_loss"] = moe_loss.item() + if moe_loss is not None: + panel_metrics["moe_loss"] = moe_loss.item() for norm_key, norm_value in grad_norm.items(): panel_metrics[norm_key] = norm_value