feat: more tgs (#310)

* feat:more tgs

* feat:add more tgs

* feat:more tgs
pull/314/head
jiaxingli 2023-09-15 18:56:11 +08:00 committed by GitHub
parent 607f691e16
commit 794a484666
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 72 additions and 6 deletions

View File

@ -4,6 +4,7 @@
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
import json
from collections import deque
from typing import Iterable, Optional
from internlm.core.engine import Engine
@ -58,6 +59,24 @@ class TrainState:
if batch_sampler:
self.init_batch_sampler(batch_sampler)
# tgs statistic
self.tgs_statistic = {
"sum_step": 0,
"sum_tg": 0,
"sum_time": 0,
"sum_last_tg_10": 0,
"sum_last_time_10": 0,
"sum_last_tg_50": 0,
"sum_last_time_50": 0,
"SMA_tg_50": 0,
"SMA_time_50": 0,
"SMA_tg_50_list": deque(),
"SMA_time_50_list": deque(),
"sum_tgs": 0,
"last_tgs_10": 0,
"last_tgs_50": 0,
}
def init_batch_sampler(self, batch_sampler):
"""
Args:

View File

@ -372,9 +372,52 @@ def record_current_batch_training_metrics(
max_length_in_batch = max([(b[1:] - b[:-1]).max().item() for b in batch[0]["cu_seqlens"]])
max_samples_in_batch = max([len(b) - 1 for b in batch[0]["cu_seqlens"]])
min_samples_in_batch = min([len(b) - 1 for b in batch[0]["cu_seqlens"]])
tk_per_gpu = 0
time_cost = time.time() - start_time
tk_per_gpu = round(
num_tokens_in_batch * gpc.get_world_size(ParallelMode.DATA) / gpc.get_world_size(ParallelMode.GLOBAL),
4,
)
tgs_statistic = train_state.tgs_statistic
tgs_statistic["sum_step"] += 1
tgs_statistic["sum_tg"] += tk_per_gpu
tgs_statistic["sum_time"] += time_cost
tgs_statistic["sum_last_tg_10"] += tk_per_gpu
tgs_statistic["sum_last_time_10"] += time_cost
tgs_statistic["sum_last_tg_50"] += tk_per_gpu
tgs_statistic["sum_last_time_50"] += time_cost
tgs_statistic["SMA_tg_50"] += tk_per_gpu
tgs_statistic["SMA_time_50"] += time_cost
tgs_statistic["SMA_tg_50_list"].append(tk_per_gpu)
tgs_statistic["SMA_time_50_list"].append(time_cost)
if tgs_statistic["sum_step"] > 50:
tgs_statistic["SMA_tg_50"] -= tgs_statistic["SMA_tg_50_list"][0]
tgs_statistic["SMA_time_50"] -= tgs_statistic["SMA_time_50_list"][0]
tgs_statistic["SMA_tg_50_list"].popleft()
tgs_statistic["SMA_time_50_list"].popleft()
last_tgs_1 = round(tk_per_gpu / time_cost, 2)
tgs_statistic["sum_tgs"] += last_tgs_1
if tgs_statistic["sum_step"] % 10 == 0:
tgs_statistic["last_tgs_10"] = round(tgs_statistic["sum_last_tg_10"] / tgs_statistic["sum_last_time_10"], 2)
tgs_statistic["sum_last_tg_10"] = 0
tgs_statistic["sum_last_time_10"] = 0
if tgs_statistic["sum_step"] % 50 == 0:
tgs_statistic["last_tgs_50"] = round(tgs_statistic["sum_last_tg_50"] / tgs_statistic["sum_last_time_50"], 2)
tgs_statistic["sum_last_tg_50"] = 0
tgs_statistic["sum_last_time_50"] = 0
last_tgs_10 = tgs_statistic["last_tgs_10"]
last_tgs_50 = tgs_statistic["last_tgs_50"]
tgs_all = round(tgs_statistic["sum_tg"] / tgs_statistic["sum_time"], 2)
tgs_avg = round(tgs_statistic["sum_tgs"] / tgs_statistic["sum_step"], 2)
tgs_SMA = round(tgs_statistic["SMA_tg_50"] / tgs_statistic["SMA_time_50"], 2)
tflops = get_tflops_func((time.time() - start_time))
tgs_origin = round(
num_tokens_in_batch
* gpc.get_world_size(ParallelMode.DATA)
/ gpc.get_world_size(ParallelMode.GLOBAL)
@ -382,13 +425,17 @@ def record_current_batch_training_metrics(
2,
)
tflops = get_tflops_func((time.time() - start_time))
infos = {
"tflops": tflops,
"step": batch_count,
"loss": loss.item(),
"tgs (tokens/gpu/second)": tk_per_gpu,
"tgs (tokens/gpu/second)": tgs_origin,
"tgs/last_tgs_1": last_tgs_1,
"tgs/tgs_all": tgs_all,
"tgs/tgs_avg": tgs_avg,
"tgs/tgs_SMA": tgs_SMA,
"tgs/last_tgs_10": last_tgs_10,
"tgs/last_tgs_50": last_tgs_50,
"lr": lr,
"loss_scale": scaler,
"grad_norm": grad_norm,
@ -428,7 +475,7 @@ def record_current_batch_training_metrics(
"num_consumed_tokens": train_state.num_consumed_tokens,
"loss": loss.item(),
"flops": tflops,
"tgs": tk_per_gpu,
"tgs": last_tgs_1,
"acc": acc_perplex["acc"],
"perplexity": acc_perplex["perplexity"],
"fwd_bwd_time": fwd_bwd_time,