mirror of https://github.com/InternLM/InternLM
auto save
parent
bbb5651582
commit
81f51fd0ff
|
@ -38,10 +38,9 @@ ckpt = dict(
|
||||||
# If you want to initialize your model weights from another model, you must set `auto_resume` to False.
|
# If you want to initialize your model weights from another model, you must set `auto_resume` to False.
|
||||||
# If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None.
|
# If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None.
|
||||||
auto_resume=True,
|
auto_resume=True,
|
||||||
checkpoint_every=CHECKPOINT_EVERY,
|
checkpoint_every=CHECKPOINT_EVERY, # if checkpoint_every is not set, auto checkpoint_every will be executed
|
||||||
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
|
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
|
||||||
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
|
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
|
||||||
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
|
|
||||||
)
|
)
|
||||||
|
|
||||||
TRAIN_FOLDER = None # "/path/to/dataset"
|
TRAIN_FOLDER = None # "/path/to/dataset"
|
||||||
|
|
|
@ -78,6 +78,13 @@ class TrainState:
|
||||||
"last_tgs_50": 0,
|
"last_tgs_50": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# ckpt statistic
|
||||||
|
self.ckpt_statistic = {
|
||||||
|
"total_step": 0,
|
||||||
|
"sum_step": 0,
|
||||||
|
"sum_time": 0,
|
||||||
|
}
|
||||||
|
|
||||||
def init_batch_sampler(self, batch_sampler):
|
def init_batch_sampler(self, batch_sampler):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -169,8 +169,10 @@ def args_sanity_check():
|
||||||
|
|
||||||
# Saving checkpoint args.
|
# Saving checkpoint args.
|
||||||
if ckpt.enable_save_ckpt:
|
if ckpt.enable_save_ckpt:
|
||||||
assert "checkpoint_every" in ckpt, "If enable save checkpoint, must give checkpoint_every in config.data!"
|
# assert "checkpoint_every" in ckpt, "If enable save checkpoint, must give checkpoint_every in config.data!"
|
||||||
assert ckpt.checkpoint_every > 0
|
# assert ckpt.checkpoint_every > 0
|
||||||
|
if "checkpoint_every" not in ckpt or ckpt.checkpoint_every <= 0:
|
||||||
|
ckpt.checkpoint_every = "auto"
|
||||||
assert "save_ckpt_folder" in ckpt, "If enable save checkpoint, must give save_ckpt_folder in config.data!"
|
assert "save_ckpt_folder" in ckpt, "If enable save checkpoint, must give save_ckpt_folder in config.data!"
|
||||||
|
|
||||||
if "async_upload" not in ckpt:
|
if "async_upload" not in ckpt:
|
||||||
|
@ -192,8 +194,13 @@ def args_sanity_check():
|
||||||
if not ckpt.async_upload:
|
if not ckpt.async_upload:
|
||||||
ckpt._add_item("async_upload_tmp_folder", None)
|
ckpt._add_item("async_upload_tmp_folder", None)
|
||||||
|
|
||||||
|
if ckpt.checkpoint_every != "auto":
|
||||||
if "oss_snapshot_freq" not in ckpt:
|
if "oss_snapshot_freq" not in ckpt:
|
||||||
ckpt._add_item("oss_snapshot_freq", float("inf")) # if oss_snapshot_freq not given, we disable.
|
ckpt._add_item("oss_snapshot_freq", "auto")
|
||||||
|
else:
|
||||||
|
ckpt.oss_snapshot_freq = "auto"
|
||||||
|
else:
|
||||||
|
ckpt.oss_snapshot_freq = float("inf")
|
||||||
else:
|
else:
|
||||||
ckpt._add_item("checkpoint_every", float("inf"))
|
ckpt._add_item("checkpoint_every", float("inf"))
|
||||||
ckpt._add_item("oss_snapshot_freq", float("inf"))
|
ckpt._add_item("oss_snapshot_freq", float("inf"))
|
||||||
|
|
|
@ -427,6 +427,44 @@ def record_current_batch_training_metrics(
|
||||||
if gpc.is_no_pp_or_last_stage():
|
if gpc.is_no_pp_or_last_stage():
|
||||||
acc_perplex = metric.get_metric()
|
acc_perplex = metric.get_metric()
|
||||||
|
|
||||||
|
# compute auto save_frequency
|
||||||
|
if success_update and (gpc.config.ckpt.checkpoint_every == "auto" or gpc.config.ckpt.oss_snapshot_freq == "auto"):
|
||||||
|
ckpt_statistic = train_state.ckpt_statistic
|
||||||
|
ckpt_statistic["total_step"] += 1
|
||||||
|
|
||||||
|
# only global rank 0 need to compute save_frequency
|
||||||
|
if gpc.get_global_rank() == 0 and ckpt_statistic["total_step"] < 10 and batch_count >= 5:
|
||||||
|
ckpt_statistic["sum_step"] += 1
|
||||||
|
ckpt_statistic["sum_time"] += time.time() - start_time
|
||||||
|
|
||||||
|
# broadcast and assign save_frequency at the 10th step from start
|
||||||
|
elif ckpt_statistic["total_step"] == 10:
|
||||||
|
|
||||||
|
# compute save_frequency
|
||||||
|
if gpc.get_global_rank() == 0:
|
||||||
|
avg_step_time = ckpt_statistic["sum_time"] / ckpt_statistic["sum_step"]
|
||||||
|
check_time = int(os.getenv("LLM_CKPT_SAVE_TIME", "1200"))
|
||||||
|
save_frequency = torch.tensor(
|
||||||
|
[int(10 * -(-check_time // (avg_step_time * 10)))], device=torch.device("cuda")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
save_frequency = torch.tensor([-1], device=torch.device("cuda"))
|
||||||
|
|
||||||
|
ranks = gpc.get_ranks_in_group(ParallelMode.GLOBAL)
|
||||||
|
dist.broadcast(save_frequency, src=ranks[0], group=gpc.get_group(ParallelMode.GLOBAL))
|
||||||
|
save_frequency = int(save_frequency[0])
|
||||||
|
|
||||||
|
# assign save_frequency
|
||||||
|
# when the "checkpoint_every" is "auto", no snapshot will be implemented
|
||||||
|
# when the "save_frequency" is less than the "checkpoint_every" passed in, no snapshot will be implemented
|
||||||
|
if gpc.config.ckpt.checkpoint_every == "auto":
|
||||||
|
gpc.config.ckpt.checkpoint_every = save_frequency
|
||||||
|
else:
|
||||||
|
if save_frequency < gpc.config.ckpt.checkpoint_every:
|
||||||
|
gpc.config.ckpt.oss_snapshot_freq = save_frequency
|
||||||
|
else:
|
||||||
|
gpc.config.ckpt.oss_snapshot_freq = float("inf")
|
||||||
|
|
||||||
if success_update and gpc.is_rank_for_log():
|
if success_update and gpc.is_rank_for_log():
|
||||||
lr = optimizer.param_groups[0]["lr"]
|
lr = optimizer.param_groups[0]["lr"]
|
||||||
if hasattr(trainer.engine.optimizer, "grad_scaler"):
|
if hasattr(trainer.engine.optimizer, "grad_scaler"):
|
||||||
|
@ -440,6 +478,7 @@ def record_current_batch_training_metrics(
|
||||||
max_samples_in_batch = max([len(b) - 1 for b in batch[0]["cu_seqlens"]])
|
max_samples_in_batch = max([len(b) - 1 for b in batch[0]["cu_seqlens"]])
|
||||||
min_samples_in_batch = min([len(b) - 1 for b in batch[0]["cu_seqlens"]])
|
min_samples_in_batch = min([len(b) - 1 for b in batch[0]["cu_seqlens"]])
|
||||||
time_cost = time.time() - start_time
|
time_cost = time.time() - start_time
|
||||||
|
|
||||||
tk_per_gpu = round(
|
tk_per_gpu = round(
|
||||||
num_tokens_in_batch * gpc.get_world_size(ParallelMode.DATA) / gpc.get_world_size(ParallelMode.GLOBAL),
|
num_tokens_in_batch * gpc.get_world_size(ParallelMode.DATA) / gpc.get_world_size(ParallelMode.GLOBAL),
|
||||||
4,
|
4,
|
||||||
|
@ -507,6 +546,7 @@ def record_current_batch_training_metrics(
|
||||||
"loss_scale": scaler,
|
"loss_scale": scaler,
|
||||||
"grad_norm": grad_norm,
|
"grad_norm": grad_norm,
|
||||||
}
|
}
|
||||||
|
|
||||||
if moe_loss is not None:
|
if moe_loss is not None:
|
||||||
infos["moe_loss"] = moe_loss.item()
|
infos["moe_loss"] = moe_loss.item()
|
||||||
|
|
||||||
|
|
|
@ -971,7 +971,11 @@ now step_count is {train_state.step_count}",
|
||||||
|
|
||||||
def is_now_to_save_ckpt(self, train_state) -> (bool, CheckpointSaveType, bool):
|
def is_now_to_save_ckpt(self, train_state) -> (bool, CheckpointSaveType, bool):
|
||||||
save_ckpts, save_type, now_break = False, CheckpointSaveType.NORMAL_CHECKPOINT, False
|
save_ckpts, save_type, now_break = False, CheckpointSaveType.NORMAL_CHECKPOINT, False
|
||||||
if self.oss_snapshot_freq > 1 and train_state.step_count % self.oss_snapshot_freq == 0:
|
if (
|
||||||
|
self.oss_snapshot_freq != "auto"
|
||||||
|
and self.oss_snapshot_freq > 1
|
||||||
|
and train_state.step_count % self.oss_snapshot_freq == 0
|
||||||
|
):
|
||||||
save_ckpts, save_type = True, CheckpointSaveType.SNAPSHOT_CHECKPOINT
|
save_ckpts, save_type = True, CheckpointSaveType.SNAPSHOT_CHECKPOINT
|
||||||
if train_state.step_count % self.checkpoint_every == 0 or train_state.step_count == train_state.total_steps:
|
if train_state.step_count % self.checkpoint_every == 0 or train_state.step_count == train_state.total_steps:
|
||||||
save_ckpts, save_type = True, CheckpointSaveType.NORMAL_CHECKPOINT
|
save_ckpts, save_type = True, CheckpointSaveType.NORMAL_CHECKPOINT
|
||||||
|
@ -986,6 +990,15 @@ now step_count is {train_state.step_count}",
|
||||||
if not self.enable_save_ckpt:
|
if not self.enable_save_ckpt:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if self.checkpoint_every == "auto":
|
||||||
|
if gpc.config.ckpt.checkpoint_every == "auto":
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
self.checkpoint_every = gpc.config.ckpt.checkpoint_every
|
||||||
|
elif self.oss_snapshot_freq == "auto":
|
||||||
|
if gpc.config.ckpt.oss_snapshot_freq != "auto":
|
||||||
|
self.oss_snapshot_freq = gpc.config.ckpt.oss_snapshot_freq
|
||||||
|
|
||||||
save_ckpts, save_type, now_break = self.is_now_to_save_ckpt(train_state)
|
save_ckpts, save_type, now_break = self.is_now_to_save_ckpt(train_state)
|
||||||
|
|
||||||
if save_ckpts:
|
if save_ckpts:
|
||||||
|
|
Loading…
Reference in New Issue