diff --git a/internlm/core/trainer.py b/internlm/core/trainer.py index 18a8f6f..6c747aa 100644 --- a/internlm/core/trainer.py +++ b/internlm/core/trainer.py @@ -4,6 +4,7 @@ # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine import json +from collections import deque from typing import Iterable, Optional from internlm.core.engine import Engine @@ -58,6 +59,24 @@ class TrainState: if batch_sampler: self.init_batch_sampler(batch_sampler) + # tgs statistic + self.tgs_statistic = { + "sum_step": 0, + "sum_tg": 0, + "sum_time": 0, + "sum_last_tg_10": 0, + "sum_last_time_10": 0, + "sum_last_tg_50": 0, + "sum_last_time_50": 0, + "SMA_tg_50": 0, + "SMA_time_50": 0, + "SMA_tg_50_list": deque(), + "SMA_time_50_list": deque(), + "sum_tgs": 0, + "last_tgs_10": 0, + "last_tgs_50": 0, + } + def init_batch_sampler(self, batch_sampler): """ Args: diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index a24317e..e08d4ec 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -372,9 +372,52 @@ def record_current_batch_training_metrics( max_length_in_batch = max([(b[1:] - b[:-1]).max().item() for b in batch[0]["cu_seqlens"]]) max_samples_in_batch = max([len(b) - 1 for b in batch[0]["cu_seqlens"]]) min_samples_in_batch = min([len(b) - 1 for b in batch[0]["cu_seqlens"]]) - - tk_per_gpu = 0 + time_cost = time.time() - start_time tk_per_gpu = round( + num_tokens_in_batch * gpc.get_world_size(ParallelMode.DATA) / gpc.get_world_size(ParallelMode.GLOBAL), + 4, + ) + tgs_statistic = train_state.tgs_statistic + tgs_statistic["sum_step"] += 1 + tgs_statistic["sum_tg"] += tk_per_gpu + tgs_statistic["sum_time"] += time_cost + tgs_statistic["sum_last_tg_10"] += tk_per_gpu + tgs_statistic["sum_last_time_10"] += time_cost + tgs_statistic["sum_last_tg_50"] += tk_per_gpu + tgs_statistic["sum_last_time_50"] += time_cost + tgs_statistic["SMA_tg_50"] += tk_per_gpu + tgs_statistic["SMA_time_50"] += time_cost + tgs_statistic["SMA_tg_50_list"].append(tk_per_gpu) + tgs_statistic["SMA_time_50_list"].append(time_cost) + if tgs_statistic["sum_step"] > 50: + tgs_statistic["SMA_tg_50"] -= tgs_statistic["SMA_tg_50_list"][0] + tgs_statistic["SMA_time_50"] -= tgs_statistic["SMA_time_50_list"][0] + tgs_statistic["SMA_tg_50_list"].popleft() + tgs_statistic["SMA_time_50_list"].popleft() + + last_tgs_1 = round(tk_per_gpu / time_cost, 2) + tgs_statistic["sum_tgs"] += last_tgs_1 + + if tgs_statistic["sum_step"] % 10 == 0: + tgs_statistic["last_tgs_10"] = round(tgs_statistic["sum_last_tg_10"] / tgs_statistic["sum_last_time_10"], 2) + tgs_statistic["sum_last_tg_10"] = 0 + tgs_statistic["sum_last_time_10"] = 0 + + if tgs_statistic["sum_step"] % 50 == 0: + tgs_statistic["last_tgs_50"] = round(tgs_statistic["sum_last_tg_50"] / tgs_statistic["sum_last_time_50"], 2) + tgs_statistic["sum_last_tg_50"] = 0 + tgs_statistic["sum_last_time_50"] = 0 + + last_tgs_10 = tgs_statistic["last_tgs_10"] + last_tgs_50 = tgs_statistic["last_tgs_50"] + + tgs_all = round(tgs_statistic["sum_tg"] / tgs_statistic["sum_time"], 2) + tgs_avg = round(tgs_statistic["sum_tgs"] / tgs_statistic["sum_step"], 2) + tgs_SMA = round(tgs_statistic["SMA_tg_50"] / tgs_statistic["SMA_time_50"], 2) + + tflops = get_tflops_func((time.time() - start_time)) + + tgs_origin = round( num_tokens_in_batch * gpc.get_world_size(ParallelMode.DATA) / gpc.get_world_size(ParallelMode.GLOBAL) @@ -382,13 +425,17 @@ def record_current_batch_training_metrics( 2, ) - tflops = get_tflops_func((time.time() - start_time)) - infos = { "tflops": tflops, "step": batch_count, "loss": loss.item(), - "tgs (tokens/gpu/second)": tk_per_gpu, + "tgs (tokens/gpu/second)": tgs_origin, + "tgs/last_tgs_1": last_tgs_1, + "tgs/tgs_all": tgs_all, + "tgs/tgs_avg": tgs_avg, + "tgs/tgs_SMA": tgs_SMA, + "tgs/last_tgs_10": last_tgs_10, + "tgs/last_tgs_50": last_tgs_50, "lr": lr, "loss_scale": scaler, "grad_norm": grad_norm, @@ -428,7 +475,7 @@ def record_current_batch_training_metrics( "num_consumed_tokens": train_state.num_consumed_tokens, "loss": loss.item(), "flops": tflops, - "tgs": tk_per_gpu, + "tgs": last_tgs_1, "acc": acc_perplex["acc"], "perplexity": acc_perplex["perplexity"], "fwd_bwd_time": fwd_bwd_time,