fix(writer): fix tensorboard resume bug (#229)

pull/231/head
Guoteng 2023-08-24 17:38:39 +08:00 committed by GitHub
parent 04c02a61b2
commit 2acb278e1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 42 additions and 17 deletions

View File

@ -38,6 +38,11 @@ class TrainState:
# Total step count
self.total_steps: int = config.data.total_steps
# resume tensorboard folder, need load from checkpoint or set manually.
self.resume_tb_folder = config.resume_tb_folder
self.tensorboard_folder = config.tensorboard_folder
def init_batch_sampler(self, train_dl):
# Copy of the batch sampler from the DataLoader
self.batch_sampler = train_dl.batch_sampler.copy()
@ -76,6 +81,9 @@ class TrainState:
self.batch_sampler = train_dl.batch_sampler.copy()
self.batch_sampler_iter = iter(self.batch_sampler)
# resume tensorboard from older tensorboard_folder
self.resume_tb_folder = other_stuffs.get("tensorboard_folder", None)
def state_dict(self):
return {
"batch_count": self.batch_count,
@ -83,6 +91,7 @@ class TrainState:
"num_consumed_tokens": self.num_consumed_tokens,
"inf_nan_skip_batches": self.inf_nan_skip_batches,
"step_count": self.step_count,
"tensorboard_folder": self.tensorboard_folder,
}

View File

@ -200,6 +200,10 @@ and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
"resume_tb_folder", os.environ["resume_tb_folder"] if "resume_tb_folder" in os.environ else None
)
if gpc.is_rank_for_log():
logger.info(f"tensorboard_folder: {gpc.config.tensorboard_folder}")
logger.info(f"resume_tb_folder: {gpc.config.resume_tb_folder}")
# cudnn
torch.backends.cudnn.benchmark = gpc.config.get("cudnn_benchmark", False)
torch.backends.cudnn.deterministic = gpc.config.get("cudnn_deterministic", False)

View File

@ -38,10 +38,21 @@ def init_tb_writer(
tb_folder = tensorboard_folder
if gpc.get_global_rank() == 0:
# If we don't load ckpt, 'resume_tb_folder' is set as the tensorboard
# dir of the last task by 'make_launch_script.sh'.
# If we load ckpt, 'resume_tb_folder' will be overwritten as the
# reloaded 'train_state.resume_tb_folder'.s
if resume_tb_folder is not None:
logger.info(f"Try mv tensorboard logs: {resume_tb_folder} to {tb_folder}...")
os.system(f"cp -r {resume_tb_folder}/* {tb_folder}/")
os.system(f"chmod -R +w {tb_folder}/")
assert len(resume_tb_folder) > 0 and resume_tb_folder != "/"
if not os.path.exists(resume_tb_folder):
logger.error(
f"Can't found resume_tb_folder{resume_tb_folder}, \
please make sure this folder is located at local file system."
)
else:
logger.info(f"Try mv tensorboard logs: {resume_tb_folder} to {tb_folder}... ")
os.system(f"cp -r {resume_tb_folder}/* {tb_folder}/")
os.system(f"chmod -R +w {tb_folder}/")
else:
logger.info(f"Login tensorboard logs to: {tb_folder}")

View File

@ -96,20 +96,6 @@ def main(args):
# initialize customed llm logger
uniscale_logger = initialize_llm_logger(start_time=current_time)
# initialize customed llm writer
with open(args.config, "r") as f:
config_lines = f.readlines()
writer = Writer(
job_name=gpc.config.JOB_NAME,
launch_time=current_time,
file_name=get_parallel_log_file_name(),
tensorboard_folder=gpc.config.tensorboard_folder,
resume_tb_folder=gpc.config.resume_tb_folder,
config=config_lines,
logger=logger,
enable_tb=gpc.config.enable_tb,
)
# initialize and resume train state
train_state = TrainState(gpc.config)
@ -139,6 +125,21 @@ def main(args):
# Loading other persistent training states.
ckpt_manager.try_resume_training(lr_scheduler, optimizer, lr, train_state, train_dl)
# initialize customed llm writer
with open(args.config, "r") as f:
config_lines = f.readlines()
writer = Writer(
job_name=gpc.config.JOB_NAME,
launch_time=current_time,
file_name=get_parallel_log_file_name(),
tensorboard_folder=gpc.config.tensorboard_folder,
resume_tb_folder=train_state.resume_tb_folder, # resume from ckpt.
step_count=train_state.step_count, # resume from ckpt.
config=config_lines,
logger=logger,
enable_tb=gpc.config.enable_tb,
)
# initialize metric for calculating accuracy and perplexity
metric = AccPerplex(
device=torch.cuda.current_device(),