mirror of https://github.com/InternLM/InternLM
add flash tflops
parent
4d83e1021b
commit
8aefb74e02
|
@ -406,11 +406,13 @@ def initialize_llm_profile(profiling: bool = False, start_time: str = None):
|
|||
|
||||
tgs_list = []
|
||||
tflops_list = []
|
||||
tflops_list_2 = []
|
||||
|
||||
|
||||
@llm_timeout(func_name="record_current_batch_training_metrics")
|
||||
def record_current_batch_training_metrics(
|
||||
get_tflops_func,
|
||||
get_tflops_func_2,
|
||||
logger,
|
||||
writer,
|
||||
success_update,
|
||||
|
@ -495,6 +497,7 @@ def record_current_batch_training_metrics(
|
|||
tgs_SMA = round(tgs_statistic["SMA_tg_50"] / tgs_statistic["SMA_time_50"], 2)
|
||||
|
||||
tflops = get_tflops_func((time.time() - start_time))
|
||||
tflops_2 = get_tflops_func_2((time.time() - start_time))
|
||||
|
||||
tgs_origin = round(
|
||||
num_tokens_in_batch
|
||||
|
@ -506,6 +509,7 @@ def record_current_batch_training_metrics(
|
|||
|
||||
infos = {
|
||||
"tflops": tflops,
|
||||
"tflops2": tflops_2,
|
||||
"step": batch_count,
|
||||
"loss": loss.item() - moe_loss.item() if moe_loss is not None else loss.item(),
|
||||
"tgs (tokens/gpu/second)": tgs_origin,
|
||||
|
@ -599,6 +603,7 @@ def record_current_batch_training_metrics(
|
|||
if batch_count >= 5:
|
||||
tgs_list.append(tgs_origin)
|
||||
tflops_list.append(tflops)
|
||||
tflops_list_2.append(tflops_2)
|
||||
if batch_count == gpc.config.data.total_steps - 1:
|
||||
print(tgs_list, flush=True)
|
||||
avg_tgs = sum(tgs_list) / len(tgs_list)
|
||||
|
@ -606,9 +611,17 @@ def record_current_batch_training_metrics(
|
|||
if abs(tgs - avg_tgs) > 400:
|
||||
tgs_list.remove(tgs)
|
||||
print(f"avg_tgs: {sum(tgs_list)/len(tgs_list)}", flush=True)
|
||||
|
||||
print(tflops_list, flush=True)
|
||||
avg_tflops = sum(tflops_list) / len(tflops_list)
|
||||
for tf in tflops_list.copy():
|
||||
if abs(tf - avg_tflops) > 10:
|
||||
tflops_list.remove(tf)
|
||||
print(f"avg_tflops: {sum(tflops_list)/len(tflops_list)}", flush=True)
|
||||
|
||||
print(tflops_list_2, flush=True)
|
||||
avg_tflops_2 = sum(tflops_list_2) / len(tflops_list_2)
|
||||
for tf in tflops_list_2.copy():
|
||||
if abs(tf - avg_tflops_2) > 10:
|
||||
tflops_list_2.remove(tf)
|
||||
print(f"avg_tflops: {sum(tflops_list_2)/len(tflops_list_2)}", flush=True)
|
||||
|
|
14
train.py
14
train.py
|
@ -33,6 +33,7 @@ from internlm.train import (
|
|||
from internlm.utils.common import (
|
||||
BatchSkipper,
|
||||
get_megatron_flops,
|
||||
get_megatron_flops_2,
|
||||
launch_time,
|
||||
parse_args,
|
||||
)
|
||||
|
@ -111,6 +112,18 @@ def main(args):
|
|||
global_world_size=gpc.get_world_size(ParallelMode.GLOBAL),
|
||||
mlp_ratio=gpc.config.MLP_RATIO,
|
||||
)
|
||||
|
||||
get_tflops_func_2 = partial(
|
||||
get_megatron_flops_2,
|
||||
checkpoint=gpc.config.model.checkpoint,
|
||||
seq_len=gpc.config.SEQ_LEN,
|
||||
hidden_size=gpc.config.model.hidden_size,
|
||||
num_layers=gpc.config.model.num_layers,
|
||||
vocab_size=gpc.config.model.vocab_size,
|
||||
global_batch_size=gpc.config.data.micro_bsz * gpc.config.data.micro_num * gpc.get_world_size(ParallelMode.DATA),
|
||||
global_world_size=gpc.get_world_size(ParallelMode.GLOBAL),
|
||||
mlp_ratio=gpc.config.MLP_RATIO,
|
||||
)
|
||||
|
||||
# get and broadcast current time
|
||||
current_time = launch_time()
|
||||
|
@ -271,6 +284,7 @@ def main(args):
|
|||
# calculate and record the training metrics, eg. loss, accuracy and so on.
|
||||
record_current_batch_training_metrics(
|
||||
get_tflops_func=get_tflops_func,
|
||||
get_tflops_func_2=get_tflops_func_2,
|
||||
logger=logger,
|
||||
writer=writer,
|
||||
success_update=success_update,
|
||||
|
|
Loading…
Reference in New Issue