mirror of https://github.com/InternLM/InternLM
fix(writer): fix tensorboard resume bug (#229)
parent
04c02a61b2
commit
2acb278e1f
|
@ -38,6 +38,11 @@ class TrainState:
|
||||||
# Total step count
|
# Total step count
|
||||||
self.total_steps: int = config.data.total_steps
|
self.total_steps: int = config.data.total_steps
|
||||||
|
|
||||||
|
# resume tensorboard folder, need load from checkpoint or set manually.
|
||||||
|
self.resume_tb_folder = config.resume_tb_folder
|
||||||
|
|
||||||
|
self.tensorboard_folder = config.tensorboard_folder
|
||||||
|
|
||||||
def init_batch_sampler(self, train_dl):
|
def init_batch_sampler(self, train_dl):
|
||||||
# Copy of the batch sampler from the DataLoader
|
# Copy of the batch sampler from the DataLoader
|
||||||
self.batch_sampler = train_dl.batch_sampler.copy()
|
self.batch_sampler = train_dl.batch_sampler.copy()
|
||||||
|
@ -76,6 +81,9 @@ class TrainState:
|
||||||
self.batch_sampler = train_dl.batch_sampler.copy()
|
self.batch_sampler = train_dl.batch_sampler.copy()
|
||||||
self.batch_sampler_iter = iter(self.batch_sampler)
|
self.batch_sampler_iter = iter(self.batch_sampler)
|
||||||
|
|
||||||
|
# resume tensorboard from older tensorboard_folder
|
||||||
|
self.resume_tb_folder = other_stuffs.get("tensorboard_folder", None)
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return {
|
return {
|
||||||
"batch_count": self.batch_count,
|
"batch_count": self.batch_count,
|
||||||
|
@ -83,6 +91,7 @@ class TrainState:
|
||||||
"num_consumed_tokens": self.num_consumed_tokens,
|
"num_consumed_tokens": self.num_consumed_tokens,
|
||||||
"inf_nan_skip_batches": self.inf_nan_skip_batches,
|
"inf_nan_skip_batches": self.inf_nan_skip_batches,
|
||||||
"step_count": self.step_count,
|
"step_count": self.step_count,
|
||||||
|
"tensorboard_folder": self.tensorboard_folder,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -200,6 +200,10 @@ and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
|
||||||
"resume_tb_folder", os.environ["resume_tb_folder"] if "resume_tb_folder" in os.environ else None
|
"resume_tb_folder", os.environ["resume_tb_folder"] if "resume_tb_folder" in os.environ else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if gpc.is_rank_for_log():
|
||||||
|
logger.info(f"tensorboard_folder: {gpc.config.tensorboard_folder}")
|
||||||
|
logger.info(f"resume_tb_folder: {gpc.config.resume_tb_folder}")
|
||||||
|
|
||||||
# cudnn
|
# cudnn
|
||||||
torch.backends.cudnn.benchmark = gpc.config.get("cudnn_benchmark", False)
|
torch.backends.cudnn.benchmark = gpc.config.get("cudnn_benchmark", False)
|
||||||
torch.backends.cudnn.deterministic = gpc.config.get("cudnn_deterministic", False)
|
torch.backends.cudnn.deterministic = gpc.config.get("cudnn_deterministic", False)
|
||||||
|
|
|
@ -38,10 +38,21 @@ def init_tb_writer(
|
||||||
tb_folder = tensorboard_folder
|
tb_folder = tensorboard_folder
|
||||||
|
|
||||||
if gpc.get_global_rank() == 0:
|
if gpc.get_global_rank() == 0:
|
||||||
|
# If we don't load ckpt, 'resume_tb_folder' is set as the tensorboard
|
||||||
|
# dir of the last task by 'make_launch_script.sh'.
|
||||||
|
# If we load ckpt, 'resume_tb_folder' will be overwritten as the
|
||||||
|
# reloaded 'train_state.resume_tb_folder'.s
|
||||||
if resume_tb_folder is not None:
|
if resume_tb_folder is not None:
|
||||||
logger.info(f"Try mv tensorboard logs: {resume_tb_folder} to {tb_folder}...")
|
assert len(resume_tb_folder) > 0 and resume_tb_folder != "/"
|
||||||
os.system(f"cp -r {resume_tb_folder}/* {tb_folder}/")
|
if not os.path.exists(resume_tb_folder):
|
||||||
os.system(f"chmod -R +w {tb_folder}/")
|
logger.error(
|
||||||
|
f"Can't found resume_tb_folder{resume_tb_folder}, \
|
||||||
|
please make sure this folder is located at local file system."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(f"Try mv tensorboard logs: {resume_tb_folder} to {tb_folder}... ")
|
||||||
|
os.system(f"cp -r {resume_tb_folder}/* {tb_folder}/")
|
||||||
|
os.system(f"chmod -R +w {tb_folder}/")
|
||||||
else:
|
else:
|
||||||
logger.info(f"Login tensorboard logs to: {tb_folder}")
|
logger.info(f"Login tensorboard logs to: {tb_folder}")
|
||||||
|
|
||||||
|
|
29
train.py
29
train.py
|
@ -96,20 +96,6 @@ def main(args):
|
||||||
# initialize customed llm logger
|
# initialize customed llm logger
|
||||||
uniscale_logger = initialize_llm_logger(start_time=current_time)
|
uniscale_logger = initialize_llm_logger(start_time=current_time)
|
||||||
|
|
||||||
# initialize customed llm writer
|
|
||||||
with open(args.config, "r") as f:
|
|
||||||
config_lines = f.readlines()
|
|
||||||
writer = Writer(
|
|
||||||
job_name=gpc.config.JOB_NAME,
|
|
||||||
launch_time=current_time,
|
|
||||||
file_name=get_parallel_log_file_name(),
|
|
||||||
tensorboard_folder=gpc.config.tensorboard_folder,
|
|
||||||
resume_tb_folder=gpc.config.resume_tb_folder,
|
|
||||||
config=config_lines,
|
|
||||||
logger=logger,
|
|
||||||
enable_tb=gpc.config.enable_tb,
|
|
||||||
)
|
|
||||||
|
|
||||||
# initialize and resume train state
|
# initialize and resume train state
|
||||||
train_state = TrainState(gpc.config)
|
train_state = TrainState(gpc.config)
|
||||||
|
|
||||||
|
@ -139,6 +125,21 @@ def main(args):
|
||||||
# Loading other persistent training states.
|
# Loading other persistent training states.
|
||||||
ckpt_manager.try_resume_training(lr_scheduler, optimizer, lr, train_state, train_dl)
|
ckpt_manager.try_resume_training(lr_scheduler, optimizer, lr, train_state, train_dl)
|
||||||
|
|
||||||
|
# initialize customed llm writer
|
||||||
|
with open(args.config, "r") as f:
|
||||||
|
config_lines = f.readlines()
|
||||||
|
writer = Writer(
|
||||||
|
job_name=gpc.config.JOB_NAME,
|
||||||
|
launch_time=current_time,
|
||||||
|
file_name=get_parallel_log_file_name(),
|
||||||
|
tensorboard_folder=gpc.config.tensorboard_folder,
|
||||||
|
resume_tb_folder=train_state.resume_tb_folder, # resume from ckpt.
|
||||||
|
step_count=train_state.step_count, # resume from ckpt.
|
||||||
|
config=config_lines,
|
||||||
|
logger=logger,
|
||||||
|
enable_tb=gpc.config.enable_tb,
|
||||||
|
)
|
||||||
|
|
||||||
# initialize metric for calculating accuracy and perplexity
|
# initialize metric for calculating accuracy and perplexity
|
||||||
metric = AccPerplex(
|
metric = AccPerplex(
|
||||||
device=torch.cuda.current_device(),
|
device=torch.cuda.current_device(),
|
||||||
|
|
Loading…
Reference in New Issue