feat(train/training_internlm.py): remove abnormal tgs when calculating avg tgs

pull/407/head
huangting4201 2023-10-17 11:30:44 +08:00
parent 229cc5c68c
commit 4e99a7fdbc
1 changed files with 4 additions and 0 deletions

View File

@ -576,4 +576,8 @@ def record_current_batch_training_metrics(
tgs_list.append(tgs_origin)
if batch_count == gpc.config.data.total_steps - 1:
print(tgs_list, flush=True)
avg_tgs = sum(tgs_list) / len(tgs_list)
for tgs in tgs_list.copy():
if abs(tgs - avg_tgs) > 1000:
tgs_list.remove(tgs)
print(f"avg_tgs: {sum(tgs_list)/len(tgs_list)}", flush=True)