mirror of https://github.com/InternLM/InternLM
parent
607f691e16
commit
794a484666
|
@ -4,6 +4,7 @@
|
|||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
|
||||
|
||||
import json
|
||||
from collections import deque
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from internlm.core.engine import Engine
|
||||
|
@ -58,6 +59,24 @@ class TrainState:
|
|||
if batch_sampler:
|
||||
self.init_batch_sampler(batch_sampler)
|
||||
|
||||
# tgs statistic
|
||||
self.tgs_statistic = {
|
||||
"sum_step": 0,
|
||||
"sum_tg": 0,
|
||||
"sum_time": 0,
|
||||
"sum_last_tg_10": 0,
|
||||
"sum_last_time_10": 0,
|
||||
"sum_last_tg_50": 0,
|
||||
"sum_last_time_50": 0,
|
||||
"SMA_tg_50": 0,
|
||||
"SMA_time_50": 0,
|
||||
"SMA_tg_50_list": deque(),
|
||||
"SMA_time_50_list": deque(),
|
||||
"sum_tgs": 0,
|
||||
"last_tgs_10": 0,
|
||||
"last_tgs_50": 0,
|
||||
}
|
||||
|
||||
def init_batch_sampler(self, batch_sampler):
|
||||
"""
|
||||
Args:
|
||||
|
|
|
@ -372,9 +372,52 @@ def record_current_batch_training_metrics(
|
|||
max_length_in_batch = max([(b[1:] - b[:-1]).max().item() for b in batch[0]["cu_seqlens"]])
|
||||
max_samples_in_batch = max([len(b) - 1 for b in batch[0]["cu_seqlens"]])
|
||||
min_samples_in_batch = min([len(b) - 1 for b in batch[0]["cu_seqlens"]])
|
||||
|
||||
tk_per_gpu = 0
|
||||
time_cost = time.time() - start_time
|
||||
tk_per_gpu = round(
|
||||
num_tokens_in_batch * gpc.get_world_size(ParallelMode.DATA) / gpc.get_world_size(ParallelMode.GLOBAL),
|
||||
4,
|
||||
)
|
||||
tgs_statistic = train_state.tgs_statistic
|
||||
tgs_statistic["sum_step"] += 1
|
||||
tgs_statistic["sum_tg"] += tk_per_gpu
|
||||
tgs_statistic["sum_time"] += time_cost
|
||||
tgs_statistic["sum_last_tg_10"] += tk_per_gpu
|
||||
tgs_statistic["sum_last_time_10"] += time_cost
|
||||
tgs_statistic["sum_last_tg_50"] += tk_per_gpu
|
||||
tgs_statistic["sum_last_time_50"] += time_cost
|
||||
tgs_statistic["SMA_tg_50"] += tk_per_gpu
|
||||
tgs_statistic["SMA_time_50"] += time_cost
|
||||
tgs_statistic["SMA_tg_50_list"].append(tk_per_gpu)
|
||||
tgs_statistic["SMA_time_50_list"].append(time_cost)
|
||||
if tgs_statistic["sum_step"] > 50:
|
||||
tgs_statistic["SMA_tg_50"] -= tgs_statistic["SMA_tg_50_list"][0]
|
||||
tgs_statistic["SMA_time_50"] -= tgs_statistic["SMA_time_50_list"][0]
|
||||
tgs_statistic["SMA_tg_50_list"].popleft()
|
||||
tgs_statistic["SMA_time_50_list"].popleft()
|
||||
|
||||
last_tgs_1 = round(tk_per_gpu / time_cost, 2)
|
||||
tgs_statistic["sum_tgs"] += last_tgs_1
|
||||
|
||||
if tgs_statistic["sum_step"] % 10 == 0:
|
||||
tgs_statistic["last_tgs_10"] = round(tgs_statistic["sum_last_tg_10"] / tgs_statistic["sum_last_time_10"], 2)
|
||||
tgs_statistic["sum_last_tg_10"] = 0
|
||||
tgs_statistic["sum_last_time_10"] = 0
|
||||
|
||||
if tgs_statistic["sum_step"] % 50 == 0:
|
||||
tgs_statistic["last_tgs_50"] = round(tgs_statistic["sum_last_tg_50"] / tgs_statistic["sum_last_time_50"], 2)
|
||||
tgs_statistic["sum_last_tg_50"] = 0
|
||||
tgs_statistic["sum_last_time_50"] = 0
|
||||
|
||||
last_tgs_10 = tgs_statistic["last_tgs_10"]
|
||||
last_tgs_50 = tgs_statistic["last_tgs_50"]
|
||||
|
||||
tgs_all = round(tgs_statistic["sum_tg"] / tgs_statistic["sum_time"], 2)
|
||||
tgs_avg = round(tgs_statistic["sum_tgs"] / tgs_statistic["sum_step"], 2)
|
||||
tgs_SMA = round(tgs_statistic["SMA_tg_50"] / tgs_statistic["SMA_time_50"], 2)
|
||||
|
||||
tflops = get_tflops_func((time.time() - start_time))
|
||||
|
||||
tgs_origin = round(
|
||||
num_tokens_in_batch
|
||||
* gpc.get_world_size(ParallelMode.DATA)
|
||||
/ gpc.get_world_size(ParallelMode.GLOBAL)
|
||||
|
@ -382,13 +425,17 @@ def record_current_batch_training_metrics(
|
|||
2,
|
||||
)
|
||||
|
||||
tflops = get_tflops_func((time.time() - start_time))
|
||||
|
||||
infos = {
|
||||
"tflops": tflops,
|
||||
"step": batch_count,
|
||||
"loss": loss.item(),
|
||||
"tgs (tokens/gpu/second)": tk_per_gpu,
|
||||
"tgs (tokens/gpu/second)": tgs_origin,
|
||||
"tgs/last_tgs_1": last_tgs_1,
|
||||
"tgs/tgs_all": tgs_all,
|
||||
"tgs/tgs_avg": tgs_avg,
|
||||
"tgs/tgs_SMA": tgs_SMA,
|
||||
"tgs/last_tgs_10": last_tgs_10,
|
||||
"tgs/last_tgs_50": last_tgs_50,
|
||||
"lr": lr,
|
||||
"loss_scale": scaler,
|
||||
"grad_norm": grad_norm,
|
||||
|
@ -428,7 +475,7 @@ def record_current_batch_training_metrics(
|
|||
"num_consumed_tokens": train_state.num_consumed_tokens,
|
||||
"loss": loss.item(),
|
||||
"flops": tflops,
|
||||
"tgs": tk_per_gpu,
|
||||
"tgs": last_tgs_1,
|
||||
"acc": acc_perplex["acc"],
|
||||
"perplexity": acc_perplex["perplexity"],
|
||||
"fwd_bwd_time": fwd_bwd_time,
|
||||
|
|
Loading…
Reference in New Issue