mirror of https://github.com/InternLM/InternLM
fix(writer): fix tensorboard resume bug (#229)
parent
04c02a61b2
commit
2acb278e1f
|
@ -38,6 +38,11 @@ class TrainState:
|
|||
# Total step count
|
||||
self.total_steps: int = config.data.total_steps
|
||||
|
||||
# resume tensorboard folder, need load from checkpoint or set manually.
|
||||
self.resume_tb_folder = config.resume_tb_folder
|
||||
|
||||
self.tensorboard_folder = config.tensorboard_folder
|
||||
|
||||
def init_batch_sampler(self, train_dl):
|
||||
# Copy of the batch sampler from the DataLoader
|
||||
self.batch_sampler = train_dl.batch_sampler.copy()
|
||||
|
@ -76,6 +81,9 @@ class TrainState:
|
|||
self.batch_sampler = train_dl.batch_sampler.copy()
|
||||
self.batch_sampler_iter = iter(self.batch_sampler)
|
||||
|
||||
# resume tensorboard from older tensorboard_folder
|
||||
self.resume_tb_folder = other_stuffs.get("tensorboard_folder", None)
|
||||
|
||||
def state_dict(self):
|
||||
return {
|
||||
"batch_count": self.batch_count,
|
||||
|
@ -83,6 +91,7 @@ class TrainState:
|
|||
"num_consumed_tokens": self.num_consumed_tokens,
|
||||
"inf_nan_skip_batches": self.inf_nan_skip_batches,
|
||||
"step_count": self.step_count,
|
||||
"tensorboard_folder": self.tensorboard_folder,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -200,6 +200,10 @@ and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
|
|||
"resume_tb_folder", os.environ["resume_tb_folder"] if "resume_tb_folder" in os.environ else None
|
||||
)
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"tensorboard_folder: {gpc.config.tensorboard_folder}")
|
||||
logger.info(f"resume_tb_folder: {gpc.config.resume_tb_folder}")
|
||||
|
||||
# cudnn
|
||||
torch.backends.cudnn.benchmark = gpc.config.get("cudnn_benchmark", False)
|
||||
torch.backends.cudnn.deterministic = gpc.config.get("cudnn_deterministic", False)
|
||||
|
|
|
@ -38,10 +38,21 @@ def init_tb_writer(
|
|||
tb_folder = tensorboard_folder
|
||||
|
||||
if gpc.get_global_rank() == 0:
|
||||
# If we don't load ckpt, 'resume_tb_folder' is set as the tensorboard
|
||||
# dir of the last task by 'make_launch_script.sh'.
|
||||
# If we load ckpt, 'resume_tb_folder' will be overwritten as the
|
||||
# reloaded 'train_state.resume_tb_folder'.s
|
||||
if resume_tb_folder is not None:
|
||||
logger.info(f"Try mv tensorboard logs: {resume_tb_folder} to {tb_folder}...")
|
||||
os.system(f"cp -r {resume_tb_folder}/* {tb_folder}/")
|
||||
os.system(f"chmod -R +w {tb_folder}/")
|
||||
assert len(resume_tb_folder) > 0 and resume_tb_folder != "/"
|
||||
if not os.path.exists(resume_tb_folder):
|
||||
logger.error(
|
||||
f"Can't found resume_tb_folder{resume_tb_folder}, \
|
||||
please make sure this folder is located at local file system."
|
||||
)
|
||||
else:
|
||||
logger.info(f"Try mv tensorboard logs: {resume_tb_folder} to {tb_folder}... ")
|
||||
os.system(f"cp -r {resume_tb_folder}/* {tb_folder}/")
|
||||
os.system(f"chmod -R +w {tb_folder}/")
|
||||
else:
|
||||
logger.info(f"Login tensorboard logs to: {tb_folder}")
|
||||
|
||||
|
|
29
train.py
29
train.py
|
@ -96,20 +96,6 @@ def main(args):
|
|||
# initialize customed llm logger
|
||||
uniscale_logger = initialize_llm_logger(start_time=current_time)
|
||||
|
||||
# initialize customed llm writer
|
||||
with open(args.config, "r") as f:
|
||||
config_lines = f.readlines()
|
||||
writer = Writer(
|
||||
job_name=gpc.config.JOB_NAME,
|
||||
launch_time=current_time,
|
||||
file_name=get_parallel_log_file_name(),
|
||||
tensorboard_folder=gpc.config.tensorboard_folder,
|
||||
resume_tb_folder=gpc.config.resume_tb_folder,
|
||||
config=config_lines,
|
||||
logger=logger,
|
||||
enable_tb=gpc.config.enable_tb,
|
||||
)
|
||||
|
||||
# initialize and resume train state
|
||||
train_state = TrainState(gpc.config)
|
||||
|
||||
|
@ -139,6 +125,21 @@ def main(args):
|
|||
# Loading other persistent training states.
|
||||
ckpt_manager.try_resume_training(lr_scheduler, optimizer, lr, train_state, train_dl)
|
||||
|
||||
# initialize customed llm writer
|
||||
with open(args.config, "r") as f:
|
||||
config_lines = f.readlines()
|
||||
writer = Writer(
|
||||
job_name=gpc.config.JOB_NAME,
|
||||
launch_time=current_time,
|
||||
file_name=get_parallel_log_file_name(),
|
||||
tensorboard_folder=gpc.config.tensorboard_folder,
|
||||
resume_tb_folder=train_state.resume_tb_folder, # resume from ckpt.
|
||||
step_count=train_state.step_count, # resume from ckpt.
|
||||
config=config_lines,
|
||||
logger=logger,
|
||||
enable_tb=gpc.config.enable_tb,
|
||||
)
|
||||
|
||||
# initialize metric for calculating accuracy and perplexity
|
||||
metric = AccPerplex(
|
||||
device=torch.cuda.current_device(),
|
||||
|
|
Loading…
Reference in New Issue