fix layer norm with pp

pull/449/head
JiaoPL 2023-10-26 10:07:43 +08:00
parent e900a1e45f
commit 9ac5ab3101
1 changed files with 6 additions and 4 deletions

View File

@ -539,8 +539,9 @@ def record_current_batch_training_metrics(
for group_name, layer_group in param_norms.items():
if layer_group:
for layer_name, param_group in layer_group.items():
title = f"param_norm/{group_name}/{layer_name}"
writer.add_scalars(key=title, value=param_group, step=train_state.step_count)
for param_name, param_value in param_group.items():
title = f"param_norm/{group_name}/{layer_name}/{param_name}"
writer.add_scalar(key=title, value=param_value, step=train_state.step_count)
for group_name, value in layer_zero_grad_count.items():
if value:
title = f"laye_zero_grad/{group_name}"
@ -548,8 +549,9 @@ def record_current_batch_training_metrics(
for group_name, layer_group in param_zero_grad_count.items():
if layer_group:
for layer_name, param_group in layer_group.items():
title = f"param_zero_grad/{group_name}/{layer_name}"
writer.add_scalars(key=title, value=param_group, step=train_state.step_count)
for param_name, param_value in param_group.items():
title = f"param_zero_grad/{group_name}/{layer_name}/{param_name}"
writer.add_scalar(key=title, value=param_value, step=train_state.step_count)
del grad_norm["layer_norms"]
del grad_norm["param_norms"]
del grad_norm["layer_zero_grad"]