add flash tflops

pull/456/head
yingtongxiong 2023-10-26 20:33:12 +08:00
parent 4d83e1021b
commit 8aefb74e02
2 changed files with 27 additions and 0 deletions

View File

@ -406,11 +406,13 @@ def initialize_llm_profile(profiling: bool = False, start_time: str = None):
tgs_list = [] tgs_list = []
tflops_list = [] tflops_list = []
tflops_list_2 = []
@llm_timeout(func_name="record_current_batch_training_metrics") @llm_timeout(func_name="record_current_batch_training_metrics")
def record_current_batch_training_metrics( def record_current_batch_training_metrics(
get_tflops_func, get_tflops_func,
get_tflops_func_2,
logger, logger,
writer, writer,
success_update, success_update,
@ -495,6 +497,7 @@ def record_current_batch_training_metrics(
tgs_SMA = round(tgs_statistic["SMA_tg_50"] / tgs_statistic["SMA_time_50"], 2) tgs_SMA = round(tgs_statistic["SMA_tg_50"] / tgs_statistic["SMA_time_50"], 2)
tflops = get_tflops_func((time.time() - start_time)) tflops = get_tflops_func((time.time() - start_time))
tflops_2 = get_tflops_func_2((time.time() - start_time))
tgs_origin = round( tgs_origin = round(
num_tokens_in_batch num_tokens_in_batch
@ -506,6 +509,7 @@ def record_current_batch_training_metrics(
infos = { infos = {
"tflops": tflops, "tflops": tflops,
"tflops2": tflops_2,
"step": batch_count, "step": batch_count,
"loss": loss.item() - moe_loss.item() if moe_loss is not None else loss.item(), "loss": loss.item() - moe_loss.item() if moe_loss is not None else loss.item(),
"tgs (tokens/gpu/second)": tgs_origin, "tgs (tokens/gpu/second)": tgs_origin,
@ -599,6 +603,7 @@ def record_current_batch_training_metrics(
if batch_count >= 5: if batch_count >= 5:
tgs_list.append(tgs_origin) tgs_list.append(tgs_origin)
tflops_list.append(tflops) tflops_list.append(tflops)
tflops_list_2.append(tflops_2)
if batch_count == gpc.config.data.total_steps - 1: if batch_count == gpc.config.data.total_steps - 1:
print(tgs_list, flush=True) print(tgs_list, flush=True)
avg_tgs = sum(tgs_list) / len(tgs_list) avg_tgs = sum(tgs_list) / len(tgs_list)
@ -606,9 +611,17 @@ def record_current_batch_training_metrics(
if abs(tgs - avg_tgs) > 400: if abs(tgs - avg_tgs) > 400:
tgs_list.remove(tgs) tgs_list.remove(tgs)
print(f"avg_tgs: {sum(tgs_list)/len(tgs_list)}", flush=True) print(f"avg_tgs: {sum(tgs_list)/len(tgs_list)}", flush=True)
print(tflops_list, flush=True) print(tflops_list, flush=True)
avg_tflops = sum(tflops_list) / len(tflops_list) avg_tflops = sum(tflops_list) / len(tflops_list)
for tf in tflops_list.copy(): for tf in tflops_list.copy():
if abs(tf - avg_tflops) > 10: if abs(tf - avg_tflops) > 10:
tflops_list.remove(tf) tflops_list.remove(tf)
print(f"avg_tflops: {sum(tflops_list)/len(tflops_list)}", flush=True) print(f"avg_tflops: {sum(tflops_list)/len(tflops_list)}", flush=True)
print(tflops_list_2, flush=True)
avg_tflops_2 = sum(tflops_list_2) / len(tflops_list_2)
for tf in tflops_list_2.copy():
if abs(tf - avg_tflops_2) > 10:
tflops_list_2.remove(tf)
print(f"avg_tflops: {sum(tflops_list_2)/len(tflops_list_2)}", flush=True)

View File

@ -33,6 +33,7 @@ from internlm.train import (
from internlm.utils.common import ( from internlm.utils.common import (
BatchSkipper, BatchSkipper,
get_megatron_flops, get_megatron_flops,
get_megatron_flops_2,
launch_time, launch_time,
parse_args, parse_args,
) )
@ -111,6 +112,18 @@ def main(args):
global_world_size=gpc.get_world_size(ParallelMode.GLOBAL), global_world_size=gpc.get_world_size(ParallelMode.GLOBAL),
mlp_ratio=gpc.config.MLP_RATIO, mlp_ratio=gpc.config.MLP_RATIO,
) )
get_tflops_func_2 = partial(
get_megatron_flops_2,
checkpoint=gpc.config.model.checkpoint,
seq_len=gpc.config.SEQ_LEN,
hidden_size=gpc.config.model.hidden_size,
num_layers=gpc.config.model.num_layers,
vocab_size=gpc.config.model.vocab_size,
global_batch_size=gpc.config.data.micro_bsz * gpc.config.data.micro_num * gpc.get_world_size(ParallelMode.DATA),
global_world_size=gpc.get_world_size(ParallelMode.GLOBAL),
mlp_ratio=gpc.config.MLP_RATIO,
)
# get and broadcast current time # get and broadcast current time
current_time = launch_time() current_time = launch_time()
@ -271,6 +284,7 @@ def main(args):
# calculate and record the training metrics, eg. loss, accuracy and so on. # calculate and record the training metrics, eg. loss, accuracy and so on.
record_current_batch_training_metrics( record_current_batch_training_metrics(
get_tflops_func=get_tflops_func, get_tflops_func=get_tflops_func,
get_tflops_func_2=get_tflops_func_2,
logger=logger, logger=logger,
writer=writer, writer=writer,
success_update=success_update, success_update=success_update,