mirror of https://github.com/InternLM/InternLM
feat(train/training_internlm.py): remove abnormal tgs when calculating avg tgs
parent
229cc5c68c
commit
4e99a7fdbc
|
@ -576,4 +576,8 @@ def record_current_batch_training_metrics(
|
|||
tgs_list.append(tgs_origin)
|
||||
if batch_count == gpc.config.data.total_steps - 1:
|
||||
print(tgs_list, flush=True)
|
||||
avg_tgs = sum(tgs_list) / len(tgs_list)
|
||||
for tgs in tgs_list.copy():
|
||||
if abs(tgs - avg_tgs) > 1000:
|
||||
tgs_list.remove(tgs)
|
||||
print(f"avg_tgs: {sum(tgs_list)/len(tgs_list)}", flush=True)
|
||||
|
|
Loading…
Reference in New Issue