mirror of https://github.com/InternLM/InternLM
add flash tflops
parent
4d83e1021b
commit
8aefb74e02
|
@ -406,11 +406,13 @@ def initialize_llm_profile(profiling: bool = False, start_time: str = None):
|
||||||
|
|
||||||
tgs_list = []
|
tgs_list = []
|
||||||
tflops_list = []
|
tflops_list = []
|
||||||
|
tflops_list_2 = []
|
||||||
|
|
||||||
|
|
||||||
@llm_timeout(func_name="record_current_batch_training_metrics")
|
@llm_timeout(func_name="record_current_batch_training_metrics")
|
||||||
def record_current_batch_training_metrics(
|
def record_current_batch_training_metrics(
|
||||||
get_tflops_func,
|
get_tflops_func,
|
||||||
|
get_tflops_func_2,
|
||||||
logger,
|
logger,
|
||||||
writer,
|
writer,
|
||||||
success_update,
|
success_update,
|
||||||
|
@ -495,6 +497,7 @@ def record_current_batch_training_metrics(
|
||||||
tgs_SMA = round(tgs_statistic["SMA_tg_50"] / tgs_statistic["SMA_time_50"], 2)
|
tgs_SMA = round(tgs_statistic["SMA_tg_50"] / tgs_statistic["SMA_time_50"], 2)
|
||||||
|
|
||||||
tflops = get_tflops_func((time.time() - start_time))
|
tflops = get_tflops_func((time.time() - start_time))
|
||||||
|
tflops_2 = get_tflops_func_2((time.time() - start_time))
|
||||||
|
|
||||||
tgs_origin = round(
|
tgs_origin = round(
|
||||||
num_tokens_in_batch
|
num_tokens_in_batch
|
||||||
|
@ -506,6 +509,7 @@ def record_current_batch_training_metrics(
|
||||||
|
|
||||||
infos = {
|
infos = {
|
||||||
"tflops": tflops,
|
"tflops": tflops,
|
||||||
|
"tflops2": tflops_2,
|
||||||
"step": batch_count,
|
"step": batch_count,
|
||||||
"loss": loss.item() - moe_loss.item() if moe_loss is not None else loss.item(),
|
"loss": loss.item() - moe_loss.item() if moe_loss is not None else loss.item(),
|
||||||
"tgs (tokens/gpu/second)": tgs_origin,
|
"tgs (tokens/gpu/second)": tgs_origin,
|
||||||
|
@ -599,6 +603,7 @@ def record_current_batch_training_metrics(
|
||||||
if batch_count >= 5:
|
if batch_count >= 5:
|
||||||
tgs_list.append(tgs_origin)
|
tgs_list.append(tgs_origin)
|
||||||
tflops_list.append(tflops)
|
tflops_list.append(tflops)
|
||||||
|
tflops_list_2.append(tflops_2)
|
||||||
if batch_count == gpc.config.data.total_steps - 1:
|
if batch_count == gpc.config.data.total_steps - 1:
|
||||||
print(tgs_list, flush=True)
|
print(tgs_list, flush=True)
|
||||||
avg_tgs = sum(tgs_list) / len(tgs_list)
|
avg_tgs = sum(tgs_list) / len(tgs_list)
|
||||||
|
@ -606,9 +611,17 @@ def record_current_batch_training_metrics(
|
||||||
if abs(tgs - avg_tgs) > 400:
|
if abs(tgs - avg_tgs) > 400:
|
||||||
tgs_list.remove(tgs)
|
tgs_list.remove(tgs)
|
||||||
print(f"avg_tgs: {sum(tgs_list)/len(tgs_list)}", flush=True)
|
print(f"avg_tgs: {sum(tgs_list)/len(tgs_list)}", flush=True)
|
||||||
|
|
||||||
print(tflops_list, flush=True)
|
print(tflops_list, flush=True)
|
||||||
avg_tflops = sum(tflops_list) / len(tflops_list)
|
avg_tflops = sum(tflops_list) / len(tflops_list)
|
||||||
for tf in tflops_list.copy():
|
for tf in tflops_list.copy():
|
||||||
if abs(tf - avg_tflops) > 10:
|
if abs(tf - avg_tflops) > 10:
|
||||||
tflops_list.remove(tf)
|
tflops_list.remove(tf)
|
||||||
print(f"avg_tflops: {sum(tflops_list)/len(tflops_list)}", flush=True)
|
print(f"avg_tflops: {sum(tflops_list)/len(tflops_list)}", flush=True)
|
||||||
|
|
||||||
|
print(tflops_list_2, flush=True)
|
||||||
|
avg_tflops_2 = sum(tflops_list_2) / len(tflops_list_2)
|
||||||
|
for tf in tflops_list_2.copy():
|
||||||
|
if abs(tf - avg_tflops_2) > 10:
|
||||||
|
tflops_list_2.remove(tf)
|
||||||
|
print(f"avg_tflops: {sum(tflops_list_2)/len(tflops_list_2)}", flush=True)
|
||||||
|
|
14
train.py
14
train.py
|
@ -33,6 +33,7 @@ from internlm.train import (
|
||||||
from internlm.utils.common import (
|
from internlm.utils.common import (
|
||||||
BatchSkipper,
|
BatchSkipper,
|
||||||
get_megatron_flops,
|
get_megatron_flops,
|
||||||
|
get_megatron_flops_2,
|
||||||
launch_time,
|
launch_time,
|
||||||
parse_args,
|
parse_args,
|
||||||
)
|
)
|
||||||
|
@ -111,6 +112,18 @@ def main(args):
|
||||||
global_world_size=gpc.get_world_size(ParallelMode.GLOBAL),
|
global_world_size=gpc.get_world_size(ParallelMode.GLOBAL),
|
||||||
mlp_ratio=gpc.config.MLP_RATIO,
|
mlp_ratio=gpc.config.MLP_RATIO,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
get_tflops_func_2 = partial(
|
||||||
|
get_megatron_flops_2,
|
||||||
|
checkpoint=gpc.config.model.checkpoint,
|
||||||
|
seq_len=gpc.config.SEQ_LEN,
|
||||||
|
hidden_size=gpc.config.model.hidden_size,
|
||||||
|
num_layers=gpc.config.model.num_layers,
|
||||||
|
vocab_size=gpc.config.model.vocab_size,
|
||||||
|
global_batch_size=gpc.config.data.micro_bsz * gpc.config.data.micro_num * gpc.get_world_size(ParallelMode.DATA),
|
||||||
|
global_world_size=gpc.get_world_size(ParallelMode.GLOBAL),
|
||||||
|
mlp_ratio=gpc.config.MLP_RATIO,
|
||||||
|
)
|
||||||
|
|
||||||
# get and broadcast current time
|
# get and broadcast current time
|
||||||
current_time = launch_time()
|
current_time = launch_time()
|
||||||
|
@ -271,6 +284,7 @@ def main(args):
|
||||||
# calculate and record the training metrics, eg. loss, accuracy and so on.
|
# calculate and record the training metrics, eg. loss, accuracy and so on.
|
||||||
record_current_batch_training_metrics(
|
record_current_batch_training_metrics(
|
||||||
get_tflops_func=get_tflops_func,
|
get_tflops_func=get_tflops_func,
|
||||||
|
get_tflops_func_2=get_tflops_func_2,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
writer=writer,
|
writer=writer,
|
||||||
success_update=success_update,
|
success_update=success_update,
|
||||||
|
|
Loading…
Reference in New Issue