mirror of https://github.com/InternLM/InternLM
feat(train/training_internlm.py): remove abnormal tgs when calculating avg tgs
parent
229cc5c68c
commit
4e99a7fdbc
|
@ -576,4 +576,8 @@ def record_current_batch_training_metrics(
|
||||||
tgs_list.append(tgs_origin)
|
tgs_list.append(tgs_origin)
|
||||||
if batch_count == gpc.config.data.total_steps - 1:
|
if batch_count == gpc.config.data.total_steps - 1:
|
||||||
print(tgs_list, flush=True)
|
print(tgs_list, flush=True)
|
||||||
|
avg_tgs = sum(tgs_list) / len(tgs_list)
|
||||||
|
for tgs in tgs_list.copy():
|
||||||
|
if abs(tgs - avg_tgs) > 1000:
|
||||||
|
tgs_list.remove(tgs)
|
||||||
print(f"avg_tgs: {sum(tgs_list)/len(tgs_list)}", flush=True)
|
print(f"avg_tgs: {sum(tgs_list)/len(tgs_list)}", flush=True)
|
||||||
|
|
Loading…
Reference in New Issue