From a45a91bb843cf0b10b8b014a6ef35e695871f91b Mon Sep 17 00:00:00 2001 From: Guoteng <32697156+SolenoidWGT@users.noreply.github.com> Date: Fri, 11 Aug 2023 17:08:01 +0800 Subject: [PATCH] feat(ckpt): add auto ckpt load and singal quit (#189) Co-authored-by: wangguoteng.p --- internlm/initialize/launch.py | 6 + internlm/utils/model_checkpoint.py | 321 ++++++++++++++++++++++++----- internlm/utils/storage_manager.py | 27 ++- train.py | 74 ++----- 4 files changed, 309 insertions(+), 119 deletions(-) diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 33b5d15..2447190 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -138,6 +138,12 @@ def args_sanity_check(): if "async_upload_tmp_folder" not in gpc.config.ckpt: gpc.config.ckpt._add_item("async_upload_tmp_folder", "/dev/shm/internlm_tmp_ckpt/") + if "stop_file_path" not in gpc.config.ckpt: + gpc.config._add_item("stop_file_path", None) + + if "load_given_ckpt" not in gpc.config.ckpt: + gpc.config._add_item("load_given_ckpt", False) + if "snapshot_ckpt_folder" not in gpc.config.ckpt: gpc.config.ckpt._add_item("snapshot_ckpt_folder", os.path.join(gpc.config.ckpt.save_ckpt_folder), "snapshot") diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 3fe29cc..643fee9 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -2,7 +2,9 @@ # -*- encoding: utf-8 -*- import copy +import fcntl import os +import socket import time from enum import Enum from typing import Dict @@ -12,6 +14,7 @@ import torch from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.trainer import TrainState +from internlm.monitor import send_alert_message from internlm.solver.optimizer import HybridZeroOptimizer from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger @@ -25,8 +28,6 @@ from internlm.utils.storage_manager import ( logger = get_logger(__file__) -quit_signal_handler = None - class CheckpointType(Enum): NORMAL_CHECKPOINT = 1 @@ -167,44 +168,6 @@ def save_optimizer_checkpoint(optim, state_path): llm_save(os.path.join(state_path, fp), states) -def save_checkpoint(folder, model, optimizer, scheduler, train_state: TrainState, model_config: Dict = None): - """ - Save checkpoint to the given folder path. - """ - - start = time.time() - torch.distributed.barrier() - folder = os.path.join(folder, str(train_state.step_count)) - logger.info( - f"Saving checkpoint to `{folder}` at batch count:{train_state.step_count} from rank:{gpc.get_global_rank()}..." - ) - - timer("save-model").start() - save_model_checkpoint(folder=folder, model=model) - timer("save-model").stop() - - timer("save-optimizer").start() - save_optimizer_checkpoint(optim=optimizer, state_path=folder) - timer("save-optimizer").stop() - - if gpc.is_rank_for_log(): - scheduler_states = scheduler.state_dict() - llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states) - - sampler_state = train_state.batch_sampler.state_dict() - llm_save(os.path.join(folder, "sampler.pt"), saved_obj=sampler_state) - llm_save(os.path.join(folder, "context.pt"), saved_obj=train_state.state_dict()) - - if model_config is not None: - llm_save(os.path.join(folder, "model_config.pt"), saved_obj=model_config) - - torch.distributed.barrier() - - if gpc.is_rank_for_log(): - timer.log(["save-model", "save-optimizer"], logger=logger) - logger.info(f"Step: {train_state.step_count}, rank 0 save ckpt use {time.time() - start:.3f} s") - - def load_optimizer_checkpoint(folder, optim): """Load the optimizer state from the local file system or remote object storage Service (OSS). @@ -304,19 +267,12 @@ def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train logger.info(f"reload load_scheduler:{lr_scheduler}") -class CheckpointSaveManager: +class CheckpointManager: """StorageManagerContext""" - def __init__( - self, - ckpt_config, - model, - optimizer, - lr_scheduler, - model_config, - ) -> None: + def __init__(self, ckpt_config, model, model_config, feishu_address=None) -> None: """ - CheckpointSaveManager is used to decide when to store ckpt. If it is an asynchronous + CheckpointManager is used to decide when to store ckpt. If it is an asynchronous upload mode, you must call wait_async_upload_finish at the end of the program to wait for the asynchronous ckpt upload to complete. @@ -332,26 +288,85 @@ class CheckpointSaveManager: self.save_ckpt_folder = ckpt_config.save_ckpt_folder self.snapshot_ckpt_folder = ckpt_config.snapshot_ckpt_folder self.oss_snapshot_freq: int = ckpt_config.oss_snapshot_freq + self.stop_file_path = ckpt_config.stop_file_path + self.load_model_only_folder = ckpt_config.load_model_only_folder + self.feishu_address = feishu_address self.storage_manager = get_storage_manager() self.snapshot_counter = 0 + self.load_optimizer = gpc.config.ckpt.load_optimizer self.model = model - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler self.model_config = model_config + if self.stop_file_path and gpc.get_global_rank() == 0: + dir_path = os.path.dirname(self.stop_file_path) + if dir_path != "" and not os.path.exists(dir_path): + os.makedirs(dir_path) + with open(self.stop_file_path, "w", encoding="utf-8") as f: + f.write("0") + + if not ckpt_config.load_given_ckpt: + latest_ckpt_path = self.query_lastest_ckpt() + self.load_ckpt_folder = latest_ckpt_path if latest_ckpt_path is not None else ckpt_config.load_ckpt_folder + else: + self.load_ckpt_folder = ckpt_config.load_ckpt_folder + + def quit_signal_handler(self, train_state) -> bool: + """ + Exit signal detection function, if we write the exit step in the 'QUIT_FILE_PATH' file, + all ranks will save ckpt and exit. + Negative integer step means save ckpt. + Positive integer step means save ckpt and quit. + + Args: + train_state (TrainState): + Returns: + bool: whether to quit. + """ + if self.stop_file_path is None: + logger.warning("no set stop_file_path") + return + + now_break, now_save_ckpt, save_type = False, False, CheckpointType.NORMAL_CHECKPOINT + with open(self.stop_file_path, "a+", encoding="utf-8") as f: + fcntl.flock(f, fcntl.LOCK_EX) + f.seek(0) + msg = f.read() + fcntl.flock(f, fcntl.LOCK_UN) + action_step = int(msg) + + if action_step < 0 and abs(action_step) == train_state.step_count: + now_save_ckpt = True + + if action_step > 0 and action_step == train_state.step_count: + now_break, now_save_ckpt = True, True + + if action_step != 0 and gpc.is_rank_for_log(): + msg = "Stop" if action_step > 0 else "Save" + action_step = abs(action_step) + if train_state.step_count <= action_step: + if self.feishu_address: + send_alert_message( + address=self.feishu_address, + message=f"training will {msg} at step_count {action_step}!\ +now step_count is {train_state.step_count}", + ) + + return now_break, now_save_ckpt, save_type + def try_save_checkpoint(self, train_state): if not self.enable_save_ckpt: - return + return False save_ckpts, save_type = False, CheckpointType.NORMAL_CHECKPOINT if self.oss_snapshot_freq > 1 and train_state.step_count % self.oss_snapshot_freq == 0: save_ckpts, save_type = True, CheckpointType.SNAPSHOT_CHECKPOINT if train_state.step_count % self.checkpoint_every == 0: save_ckpts, save_type = True, CheckpointType.NORMAL_CHECKPOINT + now_break, singal_save_ckpts, singal_save_type = self.quit_signal_handler(train_state) if save_ckpts is False: - if quit_signal_handler is not None: - save_ckpts, save_type = quit_signal_handler(train_state) + save_ckpts = singal_save_ckpts + save_type = singal_save_type if save_ckpts: # Wait for the previous round of asynchronous upload storage to complete. @@ -361,9 +376,9 @@ class CheckpointSaveManager: self.snapshot_counter = (self.snapshot_counter + 1) % 2 save_ckpt_folder = os.path.join(self.snapshot_ckpt_folder, f"{self.snapshot_counter}") else: - save_ckpt_folder = self.save_ckpt_folder + save_ckpt_folder = os.path.join(self.save_ckpt_folder, str(train_state.step_count)) - save_checkpoint( + self.save_checkpoint( folder=save_ckpt_folder, model=self.model, optimizer=self.optimizer, @@ -372,7 +387,199 @@ class CheckpointSaveManager: model_config=self.model_config, ) + return now_break + def wait_async_upload_finish(self): """wait for all checkpoint uploads to be completed""" self.storage_manager.wait() torch.distributed.barrier() + + def query_latest_snapshot_step_boto3(self): + """query_latest_snapshot_step_boto3 + Returns: + Tuple(str, int): path of latest ckpt and ckpt step, if not found, None will return. + """ + ckpt_list = self.storage_manager.get_fns(self.save_ckpt_folder) + if len(ckpt_list) == 0: + return None, None + + max_normal_step = 0 + ckpt_list = list(map(lambda a: int(a.strip("/")) if a.strip("/").isdigit() else 0, ckpt_list)) + ckpt_list.sort(reverse=True) + for ckpt in ckpt_list: + fns_list = self.storage_manager.get_fns(os.path.join(self.save_ckpt_folder, str(ckpt))) + for fn in fns_list: + if fn.endswith(".step"): + max_normal_step = ckpt + break + if max_normal_step != 0: + break + + max_normal_step = ckpt_list[0] + load_normal_ckpt_path = os.path.join(self.save_ckpt_folder, str(max_normal_step)) + + snapshot_path_0 = os.path.join(self.save_ckpt_folder, "snapshot", "0") + snapshot_path_1 = os.path.join(self.save_ckpt_folder, "snapshot", "1") + ckpt_list_1 = self.storage_manager.get_fns(snapshot_path_0) + ckpt_list_2 = self.storage_manager.get_fns(snapshot_path_1) + max_step_0, max_step_1 = 0, 0 + for ckpt in ckpt_list_1: + ckpt = ckpt.strip("/") + if ckpt.endswith(".step"): + max_step_0 = max(max_step_0, int(ckpt.split(".")[0])) + for ckpt in ckpt_list_2: + ckpt = ckpt.strip("/") + if ckpt.endswith(".step"): + max_step_1 = max(max_step_1, int(ckpt.split(".")[0])) + + snap_load_path = snapshot_path_0 if max_step_0 > max_step_1 else snapshot_path_1 + snap_step = max(max_step_0, max_step_1) + load_path = snap_load_path if snap_step > max_normal_step else load_normal_ckpt_path + load_step = max(snap_step, max_normal_step) + return load_path, load_step + + def query_latest_snapshot_step_local(self): + max_step, max_step_path = 0, None + for root, _, files in os.walk(self.save_ckpt_folder, followlinks=True): + for fn in files: + fn = fn.strip("/") + if fn.endswith(".step"): + # We assume that both normal ckpt and snapshot ckpt will store the '.step' file + # as an integrity flag. + step = int(fn.rsplit(".", maxsplit=1)[0]) + if max_step < step: + max_step = step + max_step_path = root + + return max_step_path, max_step + + def query_lastest_ckpt(self): + latest_checkpoint = None + # Training was automatically restarted by the process, forcing the latest snapshot to be read. + if self.save_ckpt_folder: + if self.save_ckpt_folder.startswith("boto3"): + latest_checkpoint, step = self.query_latest_snapshot_step_boto3() + elif self.save_ckpt_folder.startswith("local"): + latest_checkpoint, step = self.query_latest_snapshot_step_local() + else: + latest_checkpoint, step = None, 0 + + if latest_checkpoint is not None: + if gpc.is_rank_for_log(): + logger.info(f"Found latest ckpt : {latest_checkpoint}, step: {step}") + send_alert_message( + address=self.feishu_address, + message=f"Auto restart resume from ckpt-path: '{latest_checkpoint}', step : {step}", + ) + else: + if gpc.is_rank_for_log(): + send_alert_message( + address=self.feishu_address, + message=f"Can't find snapshot checkpoint, use default load-ckpt path: {latest_checkpoint}", + ) + + return latest_checkpoint + + def try_load_model(self, current_time=""): + model_load_path = None + + if self.load_ckpt_folder and self.load_model_only_folder: + raise ValueError( + "Error, try to use both load_ckpt_folder and load_model_only_folder paths, \ +if you only need to load model weights (for example starting an SFT task for the first time), \ +set load_model_only_folder path, if you need to resume training from ckpt, \ +set load_ckpt_folder or use default value \ +(if is the default value, internlm will try to load the latest ckpt from save_ckpt_folder)" + ) + + if self.load_ckpt_folder: + logger.info( + f"===========Resume training from `{self.load_ckpt_folder}` {current_time} on host:" + f"{socket.gethostname()}===========" + ) + model_load_path = self.load_ckpt_folder + elif self.load_model_only_folder: + logger.info( + f"===========SFT training from `{self.load_model_only_folder}` {current_time} on host:" + f"{socket.gethostname()}===========" + ) + model_load_path = self.load_model_only_folder + else: + logger.info( + f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()}," + f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)}," + f"dp={gpc.get_local_rank(ParallelMode.DATA)}===========" + ) + + # Loading model weights must be done before zero is initialized. + if model_load_path is not None: + load_model_checkpoint(folder=model_load_path, model=self.model) + + def try_resume_traing(self, lr_scheduler, optimizer, lr, train_state, train_dl): + """Attempt to restore the training state of the last ckpt. + + Args: + lr_scheduler (_LRScheduler): lr_scheduler object. + optimizer (Optimizer): optimizer object. + lr (float): learning rate. + train_state (dict): traing states. + train_dl (DataLoader): traning dataloader object + """ + if self.load_ckpt_folder is not None: + # load lr scheduler states. + load_scheduler(self.load_ckpt_folder, lr_scheduler, optimizer, lr, train_state) + # load training states. + load_context(self.load_ckpt_folder, train_dl, train_state) + # load dataloader sampler states. + load_sampler(self.load_ckpt_folder, train_dl.batch_sampler) + # load optimzier states. + if self.load_optimizer: + load_optimizer_checkpoint(self.load_ckpt_folder, optimizer) + + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + + def save_checkpoint(self, folder, model, optimizer, scheduler, train_state: TrainState, model_config: Dict = None): + """ + Save checkpoint to the given folder path. + """ + + start = time.time() + self.set_save_folder(folder, train_state.step_count) + torch.distributed.barrier() + if gpc.is_rank_for_log(): + logger.info(f"Saving checkpoint to `{folder}` at batch count:{train_state.step_count}...") + + timer("save-model").start() + save_model_checkpoint(folder=folder, model=model) + timer("save-model").stop() + + timer("save-optimizer").start() + save_optimizer_checkpoint(optim=optimizer, state_path=folder) + timer("save-optimizer").stop() + + if gpc.is_rank_for_log(): + scheduler_states = scheduler.state_dict() + llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states) + + sampler_state = train_state.batch_sampler.state_dict() + llm_save(os.path.join(folder, "sampler.pt"), saved_obj=sampler_state) + llm_save(os.path.join(folder, "context.pt"), saved_obj=train_state.state_dict()) + + if model_config is not None: + llm_save(os.path.join(folder, "model_config.pt"), saved_obj=model_config) + + torch.distributed.barrier() + + if gpc.is_rank_for_log(): + timer.log(["save-model", "save-optimizer"], logger=logger) + logger.info(f"Step: {train_state.step_count}, rank 0 save ckpt use {time.time() - start:.3f} s") + if self.storage_manager.async_mode is False: + llm_save( + os.path.join(folder, f"{train_state.step_count}.step"), + saved_obj=dict({"step": train_state.step_count}), + ) + + def set_save_folder(self, folder, step): + self.storage_manager.latest_save_folder = folder + self.storage_manager.latest_save_step = step diff --git a/internlm/utils/storage_manager.py b/internlm/utils/storage_manager.py index 481bd28..6984ddc 100644 --- a/internlm/utils/storage_manager.py +++ b/internlm/utils/storage_manager.py @@ -234,13 +234,13 @@ class Boto3Client(StorageClient): """ paginator = handler.client.get_paginator("list_objects_v2") pages = paginator.paginate(Bucket=bucket_name, Prefix=fp) - folder_name_list = [] for page in pages: - for obj in page["Contents"]: - fp: str = obj["Key"] - folder_name_list.append(fp.rsplit("/", maxsplit=1)[1]) - return folder_name_list + if "Contents" in page: + for obj in page["Contents"]: + pth: str = obj["Key"] + folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0]) + return list(set(folder_name_list)) @staticmethod def async_upload_fileobj(handler, bucket_name: str, fp: str, local_nvme_path: str): @@ -391,6 +391,11 @@ class StorageManager(metaclass=SingletonMeta): self.tmp_local_folder = tmp_local_folde self.async_mode = async_mode self.has_warning = False + self._async_loop = None + self._thread_pool = None + self.latest_save_folder = None + self.latest_save_step = 0 + self.async_task_peeding = False if enable_save and self.async_mode: self._async_loop = asyncio.new_event_loop() @@ -485,6 +490,7 @@ class StorageManager(metaclass=SingletonMeta): torch.save(saved_obj, f, pickle_protocol=pickle.HIGHEST_PROTOCOL) self.async_executor(meta.async_upload_fn, *unpack_meta(meta)) os.chmod(tmp_step_file, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) + self.async_task_peeding = True else: meta.client.sync_upload_fileobj(*unpack_meta(meta), *args, saved_obj=saved_obj, **kwargs) self.upload_count += 1 @@ -560,6 +566,9 @@ class StorageManager(metaclass=SingletonMeta): if not self.async_mode: return + if not self.async_task_peeding: + return + if self._async_loop: self._async_loop.run_until_complete(self._sync_tasks()) @@ -578,10 +587,16 @@ class StorageManager(metaclass=SingletonMeta): self._del_tmp_folder() self._exception_list.clear() self._to_be_del_files.clear() + self.async_task_peeding = False if gpc.is_rank_for_log(): - logger.info("all async uploads succeeded!") self.upload_count += 1 + if self.async_mode: + self.save( + os.path.join(self.latest_save_folder, f"{self.latest_save_step}.step"), + saved_obj=dict({"step": self.latest_save_step}), + async_upload=False, + ) storage_manager: StorageManager = None diff --git a/train.py b/train.py index 7e87d05..f1e6fd0 100644 --- a/train.py +++ b/train.py @@ -45,14 +45,7 @@ from internlm.utils.common import ( from internlm.utils.evaluation import evaluate_on_val_dls, switch_sequence_parallel_mode from internlm.utils.logger import get_logger, initialize_uniscale_logger from internlm.utils.megatron_timers import megatron_timer as timer -from internlm.utils.model_checkpoint import ( - CheckpointSaveManager, - load_context, - load_model_checkpoint, - load_optimizer_checkpoint, - load_sampler, - load_scheduler, -) +from internlm.utils.model_checkpoint import CheckpointManager from internlm.utils.parallel import ( get_parallel_log_file_name, is_no_pp_or_last_stage, @@ -428,13 +421,9 @@ def main(args): skip_batches = gpc.config.data.skip_batches total_steps = gpc.config.data.total_steps valid_every = gpc.config.data.valid_every - load_optimizer = gpc.config.ckpt.load_optimizer label_smoothing = gpc.config.loss.label_smoothing lr = gpc.config.adam.lr - load_model_only_folder = gpc.config.ckpt.get("load_model_only_folder", None) - load_resume_ckpt_folder = gpc.config.ckpt.get("load_ckpt_folder", None) - get_tflops_func = partial( get_megatron_flops, checkpoint=gpc.config.model.checkpoint, @@ -470,32 +459,19 @@ def main(args): enable_tb=gpc.config.enable_tb, ) - model_load_path = None - if load_resume_ckpt_folder is not None: - logger.info( - f"===========Resume training from `{load_resume_ckpt_folder}` {current_time} on host:" - f"{socket.gethostname()}===========" - ) - model_load_path = load_resume_ckpt_folder - elif load_model_only_folder is not None: - logger.info( - f"===========SFT training from `{load_model_only_folder}` {current_time} on host:" - f"{socket.gethostname()}===========" - ) - model_load_path = load_model_only_folder - else: - logger.info( - f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()}," - f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)}," - f"dp={gpc.get_local_rank(ParallelMode.DATA)}===========" - ) - # initialize and resume train state train_state = TrainState(gpc.config) # initialize model model = initialize_model() + ckpt_manager = CheckpointManager( + ckpt_config=gpc.config.ckpt, + model=model, + model_config=gpc.config.model, + feishu_address=gpc.config.alert_address, + ) + # initialize loss function criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing) @@ -505,30 +481,12 @@ def main(args): train_state.init_batch_sampler(train_dl) # Loading model weights must be done before zero is initialized. - if model_load_path is not None: - load_model_checkpoint(folder=model_load_path, model=model) + ckpt_manager.try_load_model(current_time) optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model) # Loading other persistent training states. - if load_resume_ckpt_folder is not None: - # load lr scheduler states. - load_scheduler(load_resume_ckpt_folder, lr_scheduler, optimizer, lr, train_state) - # load training states. - load_context(load_resume_ckpt_folder, train_dl, train_state) - # load dataloader sampler states. - load_sampler(load_resume_ckpt_folder, train_dl.batch_sampler) - # load optimzier states. - if load_optimizer: - load_optimizer_checkpoint(load_resume_ckpt_folder, optimizer) - - ckpt_save_manager = CheckpointSaveManager( - ckpt_config=gpc.config.ckpt, - model=model, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - model_config=gpc.config.model, - ) + ckpt_manager.try_resume_traing(lr_scheduler, optimizer, lr, train_state, train_dl) # initialize metric for calculating accuracy and perplexity metric = AccPerplex( @@ -649,9 +607,11 @@ def main(args): # checkpoint the training states in specific steps, which is determined by the args "checkpoint_every" # # save batch sampler that tracks the true consumed samples - ckpt_save_manager.try_save_checkpoint(train_state) + now_break = ckpt_manager.try_save_checkpoint(train_state) + if now_break: + break - ckpt_save_manager.wait_async_upload_finish() + ckpt_manager.wait_async_upload_finish() if __name__ == "__main__": @@ -667,8 +627,10 @@ if __name__ == "__main__": try: main(args) except Exception: + format_trace = "" + for line in traceback.format_exc().split("\n")[-10:]: + format_trace += "\n" + line logger.error( - f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}", - exc_info=traceback.format_exc(), + f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}, trace:{format_trace}", ) mm.monitor_exception(alert_address=gpc.config.alert_address, excp_info=traceback.format_exc())