mirror of https://github.com/InternLM/InternLM
feat(utils/writer.py): support tensorboard writer (#63)
* feat(utils/writer.py): support tensorboard writer * feat(utils/writer.py): add class comment --------- Co-authored-by: 黄婷 <huangting3@CN0014010744M.local>pull/120/head
parent
c7287e2584
commit
0d3d27cdf4
|
@ -127,6 +127,14 @@ def args_sanity_check():
|
|||
logger.info(f"save_ckpt_folder: {gpc.config.ckpt.save_ckpt_folder}")
|
||||
logger.info(f"checkpoint_every: {gpc.config.ckpt.checkpoint_every}")
|
||||
|
||||
# tensorboard writer config
|
||||
if "enable_tb" not in gpc.config:
|
||||
gpc.config._add_item("enable_tb", True)
|
||||
if "tensorboard_folder" not in gpc.config:
|
||||
gpc.config._add_item("tensorboard_folder", None)
|
||||
if "resume_tb_folder" not in gpc.config:
|
||||
gpc.config._add_item("resume_tb_folder", None)
|
||||
|
||||
# cudnn
|
||||
torch.backends.cudnn.benchmark = gpc.config.get("cudnn_benchmark", False)
|
||||
torch.backends.cudnn.deterministic = gpc.config.get("cudnn_deterministic", False)
|
||||
|
|
|
@ -0,0 +1,140 @@
|
|||
import logging
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
import traceback
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
|
||||
|
||||
def get_tb_log_file_name():
|
||||
if gpc.is_rank_for_log():
|
||||
tb_prefix = "main_" # Indicates a rank with more output information
|
||||
else:
|
||||
tb_prefix = ""
|
||||
|
||||
tb_log_file_name = (
|
||||
f"{tb_prefix}dp={gpc.get_local_rank(ParallelMode.DATA)}_"
|
||||
f"tp={gpc.get_local_rank(ParallelMode.TENSOR)}_pp={gpc.get_local_rank(ParallelMode.PIPELINE)}"
|
||||
)
|
||||
return tb_log_file_name
|
||||
|
||||
|
||||
def copy_ignore_folder(source_path, target_path):
|
||||
os.system(f"cp -r {source_path}/* {target_path}/")
|
||||
|
||||
|
||||
def tb_save_run_info(writer, config_lines, global_step=0):
|
||||
writer.add_text(tag="cmd", text_string=" ".join(sys.argv[:]), global_step=global_step)
|
||||
lines = []
|
||||
for line in config_lines:
|
||||
if line.strip().startswith("#"):
|
||||
continue
|
||||
lines.append(line)
|
||||
writer.add_text(tag="config", text_string="\n".join(lines), global_step=global_step)
|
||||
|
||||
|
||||
def init_tb_writer(
|
||||
launch_time,
|
||||
tensorboard_folder: str,
|
||||
resume_tb_folder: str,
|
||||
step_count: int,
|
||||
config: str,
|
||||
logger: logging.Logger,
|
||||
):
|
||||
tb_log_file_name = get_tb_log_file_name()
|
||||
if not tensorboard_folder:
|
||||
tb_folder = os.path.join(gpc.config.JOB_NAME, launch_time)
|
||||
else:
|
||||
tb_folder = tensorboard_folder
|
||||
|
||||
if gpc.get_global_rank() == 0:
|
||||
if resume_tb_folder is not None:
|
||||
logger.info(f"Try mv tensorboard logs: {resume_tb_folder} to {tb_folder}...")
|
||||
copy_ignore_folder(resume_tb_folder, tb_folder)
|
||||
else:
|
||||
logger.info(f"Login tensorboard logs to: {tb_folder}")
|
||||
|
||||
tb_logdir = os.path.join(tb_folder, tb_log_file_name)
|
||||
writer = SummaryWriter(log_dir=tb_logdir, max_queue=5, purge_step=step_count, flush_secs=3)
|
||||
writer.add_text(tag="job_name", text_string=gpc.config.JOB_NAME, global_step=step_count)
|
||||
writer.add_text(tag="tensorboard_folder", text_string=tb_logdir, global_step=step_count)
|
||||
|
||||
torch.distributed.broadcast_object_list([tb_folder], src=0)
|
||||
else:
|
||||
objects = [None]
|
||||
torch.distributed.broadcast_object_list(objects, src=0)
|
||||
tb_folder = objects[0]
|
||||
tb_logdir = os.path.join(tb_folder, tb_log_file_name)
|
||||
writer = SummaryWriter(log_dir=tb_logdir, max_queue=5, purge_step=step_count, flush_secs=3)
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
tb_save_run_info(
|
||||
writer=writer,
|
||||
config_lines=config,
|
||||
global_step=step_count,
|
||||
)
|
||||
|
||||
writer.add_text(
|
||||
tag=f"mapping_{tb_log_file_name}",
|
||||
text_string=f"file_path={tb_logdir} hostname={socket.gethostname()} device={torch.cuda.current_device()}",
|
||||
global_step=step_count,
|
||||
)
|
||||
writer.add_scaler = partial(writer.add_scalar, new_style=True)
|
||||
|
||||
return writer, tb_logdir
|
||||
|
||||
|
||||
class Writer:
|
||||
"""
|
||||
Customed writer based on tensorboard for recording training metrics.
|
||||
|
||||
Args:
|
||||
launch_time (str): A string representing the launch time of the training.
|
||||
tensorboard_folder (str): A string representing the folder for saving tensorboard logs.
|
||||
resume_tb_folder (str): A string representing the folder for resuming tensorboard logs.
|
||||
step_count (int): An integer representing the step count of the training.
|
||||
config (str): A string representing the configuration of the training.
|
||||
logger (logging.Logger): A logging.Logger object for logging information during training.
|
||||
enable_tb (bool): A boolean indicating whether to enable the tensorboard writer.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
launch_time: str,
|
||||
tensorboard_folder: str = None,
|
||||
resume_tb_folder: str = None,
|
||||
step_count: int = 0,
|
||||
config: str = None,
|
||||
logger: logging.Logger = None,
|
||||
enable_tb: bool = True,
|
||||
) -> None:
|
||||
self.enable_tb = enable_tb
|
||||
self.tb_writer, self.tb_logdir = init_tb_writer(
|
||||
launch_time=launch_time,
|
||||
tensorboard_folder=tensorboard_folder,
|
||||
resume_tb_folder=resume_tb_folder,
|
||||
step_count=step_count,
|
||||
config=config,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
def add_scalar(self, key, value, step):
|
||||
try:
|
||||
if self.enable_tb and self.tb_writer is not None:
|
||||
self.tb_writer.add_scalar(tag=key, scalar_value=value, global_step=step)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
def add_text(self, key, value, step):
|
||||
try:
|
||||
if self.enable_tb and self.tb_writer is not None:
|
||||
self.tb_writer.add_text(tag=key, text_string=value, global_step=step)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
26
train.py
26
train.py
|
@ -54,6 +54,7 @@ from internlm.utils.parallel import (
|
|||
sync_model_param_within_tp,
|
||||
)
|
||||
from internlm.utils.registry import MODEL_INITIALIZER
|
||||
from internlm.utils.writer import Writer
|
||||
|
||||
# global llm logger
|
||||
logger = get_logger(__file__)
|
||||
|
@ -246,6 +247,7 @@ def initialize_optimizer(model: nn.Module):
|
|||
def record_current_batch_training_metrics(
|
||||
get_tflops_func,
|
||||
logger,
|
||||
writer,
|
||||
success_update,
|
||||
batch_count,
|
||||
batch,
|
||||
|
@ -307,12 +309,13 @@ def record_current_batch_training_metrics(
|
|||
infos["smallest_batch"] = min_samples_in_batch
|
||||
infos["adam_beta2"] = beta2_scheduler.get_beta2()
|
||||
|
||||
line = ""
|
||||
for k, v in infos.items():
|
||||
line += f"{k}={v},"
|
||||
|
||||
fwd_bwd_time = round(timer("fwd-bwd").elapsed(), 2)
|
||||
line += f"fwd_bwd_time={fwd_bwd_time}"
|
||||
infos["fwd_bwd_time"] = fwd_bwd_time
|
||||
|
||||
line = ""
|
||||
for key, value in infos.items():
|
||||
line += f"{key}={value},"
|
||||
writer.add_scalar(key=key, value=value, step=train_state.step_count)
|
||||
|
||||
logger.info(line)
|
||||
|
||||
|
@ -355,6 +358,18 @@ def main(args):
|
|||
dist.broadcast_object_list(objs, src=0)
|
||||
current_time = objs[0]
|
||||
|
||||
# initialize customed llm writer
|
||||
with open(args.config, "r") as f:
|
||||
config_lines = f.readlines()
|
||||
writer = Writer(
|
||||
launch_time=current_time,
|
||||
tensorboard_folder=gpc.config.tensorboard_folder,
|
||||
resume_tb_folder=gpc.config.resume_tb_folder,
|
||||
config=config_lines,
|
||||
logger=logger,
|
||||
enable_tb=gpc.config.enable_tb,
|
||||
)
|
||||
|
||||
model_load_path = None
|
||||
if load_resume_ckpt_folder is not None:
|
||||
logger.info(
|
||||
|
@ -469,6 +484,7 @@ def main(args):
|
|||
record_current_batch_training_metrics(
|
||||
get_tflops_func=get_tflops_func,
|
||||
logger=logger,
|
||||
writer=writer,
|
||||
success_update=success_update,
|
||||
batch_count=batch_count,
|
||||
batch=batch,
|
||||
|
|
Loading…
Reference in New Issue