pull/546/head
lijiaxing 2023-12-20 17:22:21 +08:00
parent 55ef29df80
commit 55c7dd513d
4 changed files with 19 additions and 32 deletions

View File

@ -41,6 +41,7 @@ ckpt = dict(
checkpoint_every=CHECKPOINT_EVERY, # if checkpoint_every is not set, auto checkpoint_every will be executed checkpoint_every=CHECKPOINT_EVERY, # if checkpoint_every is not set, auto checkpoint_every will be executed
async_upload=True, # async ckpt upload. (only work for boto3 ckpt) async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
auto_save_time=1200, # time to control oss_snapshot_freq. If not set, 1200 seconds by default.
) )
TRAIN_FOLDER = None # "/path/to/dataset" TRAIN_FOLDER = None # "/path/to/dataset"

View File

@ -169,10 +169,8 @@ def args_sanity_check():
# Saving checkpoint args. # Saving checkpoint args.
if ckpt.enable_save_ckpt: if ckpt.enable_save_ckpt:
# assert "checkpoint_every" in ckpt, "If enable save checkpoint, must give checkpoint_every in config.data!" assert "checkpoint_every" in ckpt, "If enable save checkpoint, must give checkpoint_every in config.data!"
# assert ckpt.checkpoint_every > 0 assert ckpt.checkpoint_every > 0
if "checkpoint_every" not in ckpt or ckpt.checkpoint_every <= 0:
ckpt.checkpoint_every = "auto"
assert "save_ckpt_folder" in ckpt, "If enable save checkpoint, must give save_ckpt_folder in config.data!" assert "save_ckpt_folder" in ckpt, "If enable save checkpoint, must give save_ckpt_folder in config.data!"
if "async_upload" not in ckpt: if "async_upload" not in ckpt:
@ -194,13 +192,14 @@ def args_sanity_check():
if not ckpt.async_upload: if not ckpt.async_upload:
ckpt._add_item("async_upload_tmp_folder", None) ckpt._add_item("async_upload_tmp_folder", None)
if ckpt.checkpoint_every != "auto": if "oss_snapshot_freq" not in ckpt:
if "oss_snapshot_freq" not in ckpt: ckpt._add_item("oss_snapshot_freq", -1)
ckpt._add_item("oss_snapshot_freq", "auto")
else:
ckpt.oss_snapshot_freq = "auto"
else: else:
ckpt.oss_snapshot_freq = float("inf") ckpt.oss_snapshot_freq = -1
if "auto_save_time" not in ckpt:
ckpt._add_item("auto_save_time", 1200)
else: else:
ckpt._add_item("checkpoint_every", float("inf")) ckpt._add_item("checkpoint_every", float("inf"))
ckpt._add_item("oss_snapshot_freq", float("inf")) ckpt._add_item("oss_snapshot_freq", float("inf"))

View File

@ -428,7 +428,7 @@ def record_current_batch_training_metrics(
acc_perplex = metric.get_metric() acc_perplex = metric.get_metric()
# compute auto save_frequency # compute auto save_frequency
if success_update and (gpc.config.ckpt.checkpoint_every == "auto" or gpc.config.ckpt.oss_snapshot_freq == "auto"): if success_update and gpc.config.ckpt.oss_snapshot_freq <= 0:
ckpt_statistic = train_state.ckpt_statistic ckpt_statistic = train_state.ckpt_statistic
ckpt_statistic["total_step"] += 1 ckpt_statistic["total_step"] += 1
@ -443,7 +443,7 @@ def record_current_batch_training_metrics(
# compute save_frequency # compute save_frequency
if gpc.get_global_rank() == 0: if gpc.get_global_rank() == 0:
avg_step_time = ckpt_statistic["sum_time"] / ckpt_statistic["sum_step"] avg_step_time = ckpt_statistic["sum_time"] / ckpt_statistic["sum_step"]
check_time = int(os.getenv("LLM_CKPT_SAVE_TIME", "1200")) check_time = gpc.config.ckpt.auto_save_time
save_frequency = torch.tensor( save_frequency = torch.tensor(
[int(10 * -(-check_time // (avg_step_time * 10)))], device=torch.device("cuda") [int(10 * -(-check_time // (avg_step_time * 10)))], device=torch.device("cuda")
) )
@ -455,15 +455,11 @@ def record_current_batch_training_metrics(
save_frequency = int(save_frequency[0]) save_frequency = int(save_frequency[0])
# assign save_frequency # assign save_frequency
# when the "checkpoint_every" is "auto", no snapshot will be performed # when the "save_frequency" is larger than the "checkpoint_every" passed in, no snapshot will be performed
# when the "save_frequency" is less than the "checkpoint_every" passed in, no snapshot will be performed if save_frequency < gpc.config.ckpt.checkpoint_every:
if gpc.config.ckpt.checkpoint_every == "auto": gpc.config.ckpt.oss_snapshot_freq = save_frequency
gpc.config.ckpt.checkpoint_every = save_frequency
else: else:
if save_frequency < gpc.config.ckpt.checkpoint_every: gpc.config.ckpt.oss_snapshot_freq = float("inf")
gpc.config.ckpt.oss_snapshot_freq = save_frequency
else:
gpc.config.ckpt.oss_snapshot_freq = float("inf")
if success_update and gpc.is_rank_for_log(): if success_update and gpc.is_rank_for_log():
lr = optimizer.param_groups[0]["lr"] lr = optimizer.param_groups[0]["lr"]

View File

@ -971,11 +971,7 @@ now step_count is {train_state.step_count}",
def is_now_to_save_ckpt(self, train_state) -> (bool, CheckpointSaveType, bool): def is_now_to_save_ckpt(self, train_state) -> (bool, CheckpointSaveType, bool):
save_ckpts, save_type, now_break = False, CheckpointSaveType.NORMAL_CHECKPOINT, False save_ckpts, save_type, now_break = False, CheckpointSaveType.NORMAL_CHECKPOINT, False
if ( if self.oss_snapshot_freq > 1 and train_state.step_count % self.oss_snapshot_freq == 0:
self.oss_snapshot_freq != "auto"
and self.oss_snapshot_freq > 1
and train_state.step_count % self.oss_snapshot_freq == 0
):
save_ckpts, save_type = True, CheckpointSaveType.SNAPSHOT_CHECKPOINT save_ckpts, save_type = True, CheckpointSaveType.SNAPSHOT_CHECKPOINT
if train_state.step_count % self.checkpoint_every == 0 or train_state.step_count == train_state.total_steps: if train_state.step_count % self.checkpoint_every == 0 or train_state.step_count == train_state.total_steps:
save_ckpts, save_type = True, CheckpointSaveType.NORMAL_CHECKPOINT save_ckpts, save_type = True, CheckpointSaveType.NORMAL_CHECKPOINT
@ -990,13 +986,8 @@ now step_count is {train_state.step_count}",
if not self.enable_save_ckpt: if not self.enable_save_ckpt:
return False return False
if self.checkpoint_every == "auto": if self.oss_snapshot_freq <= 0:
if gpc.config.ckpt.checkpoint_every == "auto": if gpc.config.ckpt.oss_snapshot_freq > 0:
return False
else:
self.checkpoint_every = gpc.config.ckpt.checkpoint_every
elif self.oss_snapshot_freq == "auto":
if gpc.config.ckpt.oss_snapshot_freq != "auto":
self.oss_snapshot_freq = gpc.config.ckpt.oss_snapshot_freq self.oss_snapshot_freq = gpc.config.ckpt.oss_snapshot_freq
save_ckpts, save_type, now_break = self.is_now_to_save_ckpt(train_state) save_ckpts, save_type, now_break = self.is_now_to_save_ckpt(train_state)