From 9ac5ab3101c2aa3d08d6b42aae8bab6bb7f11957 Mon Sep 17 00:00:00 2001 From: JiaoPL Date: Thu, 26 Oct 2023 10:07:43 +0800 Subject: [PATCH] fix layer norm with pp --- internlm/train/training_internlm.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 42377f6..21925a0 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -539,8 +539,9 @@ def record_current_batch_training_metrics( for group_name, layer_group in param_norms.items(): if layer_group: for layer_name, param_group in layer_group.items(): - title = f"param_norm/{group_name}/{layer_name}" - writer.add_scalars(key=title, value=param_group, step=train_state.step_count) + for param_name, param_value in param_group.items(): + title = f"param_norm/{group_name}/{layer_name}/{param_name}" + writer.add_scalar(key=title, value=param_value, step=train_state.step_count) for group_name, value in layer_zero_grad_count.items(): if value: title = f"laye_zero_grad/{group_name}" @@ -548,8 +549,9 @@ def record_current_batch_training_metrics( for group_name, layer_group in param_zero_grad_count.items(): if layer_group: for layer_name, param_group in layer_group.items(): - title = f"param_zero_grad/{group_name}/{layer_name}" - writer.add_scalars(key=title, value=param_group, step=train_state.step_count) + for param_name, param_value in param_group.items(): + title = f"param_zero_grad/{group_name}/{layer_name}/{param_name}" + writer.add_scalar(key=title, value=param_value, step=train_state.step_count) del grad_norm["layer_norms"] del grad_norm["param_norms"] del grad_norm["layer_zero_grad"]