From 29779c75f05653b57a04870c674bc08e1057dc65 Mon Sep 17 00:00:00 2001 From: Guoteng <32697156+SolenoidWGT@users.noreply.github.com> Date: Wed, 23 Aug 2023 14:17:45 +0800 Subject: [PATCH] feat(ckpt): add auto ckpt load and singal quit (#216) Co-authored-by: wangguoteng.p --- internlm/initialize/launch.py | 107 +++++---- internlm/utils/model_checkpoint.py | 352 ++++++++++++++++++++++++----- internlm/utils/storage_manager.py | 71 +++--- internlm/utils/writer.py | 7 +- train.py | 68 ++---- 5 files changed, 424 insertions(+), 181 deletions(-) diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 986d1f7..5dff0e6 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -108,67 +108,96 @@ def args_sanity_check(): logger.info(f"valid_every: {data.valid_every}") # processing the checkpoint config - if "enable_save_ckpt" not in gpc.config.ckpt: - gpc.config.ckpt._add_item("enable_save_ckpt", False) + ckpt = gpc.config.ckpt + if "enable_save_ckpt" not in ckpt: + ckpt._add_item("enable_save_ckpt", False) - if "checkpoint_every" not in gpc.config.ckpt or gpc.config.ckpt.checkpoint_every <= 0: - gpc.config.ckpt._add_item("checkpoint_every", float("inf")) + # Saving checkpoint args. + if ckpt.enable_save_ckpt: + assert "checkpoint_every" in ckpt, "If enable save checkpoint, must give checkpoint_every in config.data!" + assert ckpt.checkpoint_every > 0 + assert "save_ckpt_folder" in ckpt, "If enable save checkpoint, must give save_ckpt_folder in config.data!" - if "load_optimizer" not in gpc.config.ckpt: - gpc.config.ckpt._add_item("load_optimizer", True) + if "async_upload" not in ckpt: + ckpt._add_item("async_upload", False) # async defalut is False. + else: + if ckpt.async_upload: + assert "save_ckpt_folder" in ckpt + if "boto3:" not in ckpt.save_ckpt_folder: + if gpc.is_rank_for_log(): + logger.warning( + "Storing ckpt on file system does not support asynchronous storage, will use sync save!" + ) + ckpt.async_upload = False + else: + if "async_upload_tmp_folder" not in ckpt: + ckpt._add_item("async_upload_tmp_folder", "/dev/shm/internlm_tmp_ckpt/") - if "save_ckpt_folder" not in gpc.config.ckpt: - gpc.config.ckpt._add_item("save_ckpt_folder", None) + if not ckpt.async_upload: + ckpt._add_item("async_upload_tmp_folder", None) - if "load_ckpt_folder" not in gpc.config.ckpt: - gpc.config.ckpt._add_item("load_ckpt_folder", None) + if "snapshot_ckpt_folder" not in ckpt: + ckpt._add_item("snapshot_ckpt_folder", os.path.join(ckpt.save_ckpt_folder, "snapshot")) - if "load_model_only_folder" not in gpc.config.ckpt: - gpc.config.ckpt._add_item("load_model_only_folder", None) + if "oss_snapshot_freq" not in ckpt: + ckpt._add_item("oss_snapshot_freq", float("inf")) # if oss_snapshot_freq not given, we disable. + else: + ckpt._add_item("checkpoint_every", float("inf")) + ckpt._add_item("oss_snapshot_freq", float("inf")) + ckpt._add_item("save_ckpt_folder", None) + ckpt._add_item("async_upload", False) + ckpt._add_item("async_upload_tmp_folder", None) + ckpt._add_item("snapshot_ckpt_folder", None) + ckpt._add_item("snapshot_ckpt_folder", None) - if "async_upload" not in gpc.config.ckpt: - gpc.config.ckpt._add_item("async_upload", False) + # Loading checkpoint args. + if "load_model_only_folder" not in ckpt: + ckpt._add_item("load_model_only_folder", None) - if "async_upload_tmp_folder" not in gpc.config.ckpt: - gpc.config.ckpt._add_item("async_upload_tmp_folder", "/dev/shm/internlm_tmp_ckpt/") + if "load_ckpt_folder" not in ckpt: + ckpt._add_item("load_ckpt_folder", None) - if gpc.config.ckpt.async_upload: - assert "save_ckpt_folder" in gpc.config.ckpt - if "boto3:" not in gpc.config.ckpt.save_ckpt_folder: - if gpc.is_rank_for_log(): - logger.warning("Storing ckpt on file system does not support asynchronous storage, will use sync save!") - gpc.config.ckpt.async_upload = False + if "load_optimizer" not in ckpt: + ckpt._add_item("load_optimizer", True) - if "snapshot_ckpt_folder" not in gpc.config.ckpt: - gpc.config.ckpt._add_item("snapshot_ckpt_folder", os.path.join(gpc.config.ckpt.save_ckpt_folder, "snapshot")) + if "stop_file_path" not in ckpt: + ckpt._add_item("stop_file_path", None) - if "oss_snapshot_freq" not in gpc.config.ckpt and gpc.config.ckpt.checkpoint_every != float("inf"): - gpc.config.ckpt._add_item("oss_snapshot_freq", gpc.config.ckpt.checkpoint_every / 2) - assert gpc.config.ckpt.oss_snapshot_freq > 0 + if "load_given_ckpt" not in ckpt: + # If 'load_given_ckpt' is not given, we set it to False, so internlm can have opportunity + # to auto-load latest checkpoint. + ckpt._add_item("load_given_ckpt", False) - assert not ( - gpc.config.ckpt.load_ckpt_folder is not None and gpc.config.ckpt.load_model_only_folder is not None - ), "'load_ckpt_folder' and 'load_model_only_folder' cannot be set at the same time." + if ckpt.load_given_ckpt: + # Priority: load_given_ckpt(True) > latest_checkpoint > load_model_only_folder + if ckpt.load_ckpt_folder and ckpt.load_model_only_folder: + logger.warning( + "Detect 'load_ckpt_folder' and 'load_model_only_folder' set at the same time, \ +and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'" + ) + ckpt.load_model_only_folder = None if gpc.is_rank_for_log(): logger.info("+" * 15 + " Ckpt Info " + "+" * 15) # pylint: disable=W1201 - logger.info(f"is enable save ckpt: {gpc.config.ckpt.enable_save_ckpt}") - logger.info(f"save_ckpt_folder: {gpc.config.ckpt.save_ckpt_folder}") - logger.info(f"checkpoint_every: {gpc.config.ckpt.checkpoint_every}") - logger.info(f"async_upload: {gpc.config.ckpt.async_upload}") - if gpc.config.ckpt.async_upload: - logger.info(f"async_upload_tmp_folder: {gpc.config.ckpt.async_upload_tmp_folder}") + logger.info(f"is enable save ckpt: {ckpt.enable_save_ckpt}") + logger.info(f"save_ckpt_folder: {ckpt.save_ckpt_folder}") + logger.info(f"checkpoint_every: {ckpt.checkpoint_every}") + logger.info(f"load_given_ckpt: {ckpt.load_given_ckpt}") # initialization storage manager - init_storage_manager(gpc.config.ckpt) + init_storage_manager(ckpt) # tensorboard writer config if "enable_tb" not in gpc.config: gpc.config._add_item("enable_tb", True) if "tensorboard_folder" not in gpc.config: - gpc.config._add_item("tensorboard_folder", None) + gpc.config._add_item( + "tensorboard_folder", os.environ["tensorboard_folder"] if "tensorboard_folder" in os.environ else None + ) if "resume_tb_folder" not in gpc.config: - gpc.config._add_item("resume_tb_folder", None) + gpc.config._add_item( + "resume_tb_folder", os.environ["resume_tb_folder"] if "resume_tb_folder" in os.environ else None + ) # cudnn torch.backends.cudnn.benchmark = gpc.config.get("cudnn_benchmark", False) diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 3fe29cc..3dca7c5 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -2,7 +2,9 @@ # -*- encoding: utf-8 -*- import copy +import fcntl import os +import socket import time from enum import Enum from typing import Dict @@ -12,6 +14,7 @@ import torch from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.trainer import TrainState +from internlm.monitor import send_alert_message from internlm.solver.optimizer import HybridZeroOptimizer from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger @@ -25,8 +28,6 @@ from internlm.utils.storage_manager import ( logger = get_logger(__file__) -quit_signal_handler = None - class CheckpointType(Enum): NORMAL_CHECKPOINT = 1 @@ -167,44 +168,6 @@ def save_optimizer_checkpoint(optim, state_path): llm_save(os.path.join(state_path, fp), states) -def save_checkpoint(folder, model, optimizer, scheduler, train_state: TrainState, model_config: Dict = None): - """ - Save checkpoint to the given folder path. - """ - - start = time.time() - torch.distributed.barrier() - folder = os.path.join(folder, str(train_state.step_count)) - logger.info( - f"Saving checkpoint to `{folder}` at batch count:{train_state.step_count} from rank:{gpc.get_global_rank()}..." - ) - - timer("save-model").start() - save_model_checkpoint(folder=folder, model=model) - timer("save-model").stop() - - timer("save-optimizer").start() - save_optimizer_checkpoint(optim=optimizer, state_path=folder) - timer("save-optimizer").stop() - - if gpc.is_rank_for_log(): - scheduler_states = scheduler.state_dict() - llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states) - - sampler_state = train_state.batch_sampler.state_dict() - llm_save(os.path.join(folder, "sampler.pt"), saved_obj=sampler_state) - llm_save(os.path.join(folder, "context.pt"), saved_obj=train_state.state_dict()) - - if model_config is not None: - llm_save(os.path.join(folder, "model_config.pt"), saved_obj=model_config) - - torch.distributed.barrier() - - if gpc.is_rank_for_log(): - timer.log(["save-model", "save-optimizer"], logger=logger) - logger.info(f"Step: {train_state.step_count}, rank 0 save ckpt use {time.time() - start:.3f} s") - - def load_optimizer_checkpoint(folder, optim): """Load the optimizer state from the local file system or remote object storage Service (OSS). @@ -304,19 +267,12 @@ def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train logger.info(f"reload load_scheduler:{lr_scheduler}") -class CheckpointSaveManager: +class CheckpointManager: """StorageManagerContext""" - def __init__( - self, - ckpt_config, - model, - optimizer, - lr_scheduler, - model_config, - ) -> None: + def __init__(self, ckpt_config, model, model_config, feishu_address=None) -> None: """ - CheckpointSaveManager is used to decide when to store ckpt. If it is an asynchronous + CheckpointManager is used to decide when to store ckpt. If it is an asynchronous upload mode, you must call wait_async_upload_finish at the end of the program to wait for the asynchronous ckpt upload to complete. @@ -332,26 +288,95 @@ class CheckpointSaveManager: self.save_ckpt_folder = ckpt_config.save_ckpt_folder self.snapshot_ckpt_folder = ckpt_config.snapshot_ckpt_folder self.oss_snapshot_freq: int = ckpt_config.oss_snapshot_freq + self.stop_file_path = ckpt_config.stop_file_path + self.load_model_only_folder = ckpt_config.load_model_only_folder + self.feishu_address = feishu_address self.storage_manager = get_storage_manager() self.snapshot_counter = 0 + self.load_optimizer = gpc.config.ckpt.load_optimizer self.model = model - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler self.model_config = model_config + if self.stop_file_path and gpc.get_global_rank() == 0: + dir_path = os.path.dirname(self.stop_file_path) + if dir_path != "" and not os.path.exists(dir_path): + os.makedirs(dir_path) + with open(self.stop_file_path, "w", encoding="utf-8") as f: + f.write("0") + + if ckpt_config.load_given_ckpt is False: + # Priority: load_given_ckpt(True) > latest_checkpoint > load_model_only_folder + latest_ckpt_path = self.query_lastest_ckpt() + if latest_ckpt_path: + self.load_ckpt_folder = latest_ckpt_path + else: + # At this time, we have to load model init weights and train from step 0. + self.load_ckpt_folder = self.load_model_only_folder + else: + self.load_ckpt_folder = ckpt_config.load_ckpt_folder + + if gpc.is_rank_for_log(): + logger.info(f"load_ckpt_folder will set to :'{self.load_ckpt_folder}'") + if self.stop_file_path is None: + logger.warning("no set stop_file_path, quit_signal_handler is disable") + + def quit_signal_handler(self, train_state) -> bool: + """ + Exit signal detection function, if we write the exit step in the 'QUIT_FILE_PATH' file, + all ranks will save ckpt and exit. + Negative integer step means save ckpt. + Positive integer step means save ckpt and quit. + + Args: + train_state (TrainState): + Returns: + bool: whether to quit. + """ + now_break, now_save_ckpt, save_type = False, False, CheckpointType.NORMAL_CHECKPOINT + + if self.stop_file_path is None: + return now_break, now_save_ckpt, save_type + + with open(self.stop_file_path, "a+", encoding="utf-8") as f: + fcntl.flock(f, fcntl.LOCK_EX) + f.seek(0) + msg = f.read() + fcntl.flock(f, fcntl.LOCK_UN) + action_step = int(msg) + + if action_step < 0 and abs(action_step) == train_state.step_count: + now_save_ckpt = True + + if action_step > 0 and action_step == train_state.step_count: + now_break, now_save_ckpt = True, True + + if action_step != 0 and gpc.is_rank_for_log(): + msg = "Stop" if action_step > 0 else "Save" + action_step = abs(action_step) + if train_state.step_count <= action_step: + if self.feishu_address: + send_alert_message( + address=self.feishu_address, + message=f"training will {msg} at step_count {action_step}!\ +now step_count is {train_state.step_count}", + ) + + return now_break, now_save_ckpt, save_type + def try_save_checkpoint(self, train_state): if not self.enable_save_ckpt: - return + return False save_ckpts, save_type = False, CheckpointType.NORMAL_CHECKPOINT if self.oss_snapshot_freq > 1 and train_state.step_count % self.oss_snapshot_freq == 0: save_ckpts, save_type = True, CheckpointType.SNAPSHOT_CHECKPOINT if train_state.step_count % self.checkpoint_every == 0: save_ckpts, save_type = True, CheckpointType.NORMAL_CHECKPOINT + now_break, singal_save_ckpts, singal_save_type = self.quit_signal_handler(train_state) if save_ckpts is False: - if quit_signal_handler is not None: - save_ckpts, save_type = quit_signal_handler(train_state) + save_ckpts = singal_save_ckpts + save_type = singal_save_type if save_ckpts: # Wait for the previous round of asynchronous upload storage to complete. @@ -361,9 +386,9 @@ class CheckpointSaveManager: self.snapshot_counter = (self.snapshot_counter + 1) % 2 save_ckpt_folder = os.path.join(self.snapshot_ckpt_folder, f"{self.snapshot_counter}") else: - save_ckpt_folder = self.save_ckpt_folder + save_ckpt_folder = os.path.join(self.save_ckpt_folder, str(train_state.step_count)) - save_checkpoint( + self.save_checkpoint( folder=save_ckpt_folder, model=self.model, optimizer=self.optimizer, @@ -372,7 +397,220 @@ class CheckpointSaveManager: model_config=self.model_config, ) + return now_break + def wait_async_upload_finish(self): """wait for all checkpoint uploads to be completed""" self.storage_manager.wait() torch.distributed.barrier() + + def query_latest_snapshot_step_boto3(self): + """query_latest_snapshot_step_boto3 + Returns: + Tuple(str, int): path of latest ckpt and ckpt step, if not found, None will return. + """ + ckpt_list = self.storage_manager.get_fns(self.save_ckpt_folder) + if len(ckpt_list) == 0: + return None, None + + max_normal_step = 0 + ckpt_list = list(map(lambda a: int(a.strip("/")) if a.strip("/").isdigit() else 0, ckpt_list)) + ckpt_list.sort(reverse=True) + for ckpt in ckpt_list: + fns_list = self.storage_manager.get_fns(os.path.join(self.save_ckpt_folder, str(ckpt))) + for fn in fns_list: + if fn.endswith(".step"): + max_normal_step = ckpt + break + if max_normal_step != 0: + break + + max_normal_step = ckpt_list[0] + load_normal_ckpt_path = os.path.join(self.save_ckpt_folder, str(max_normal_step)) + + snapshot_path_0 = os.path.join(self.save_ckpt_folder, "snapshot", "0") + snapshot_path_1 = os.path.join(self.save_ckpt_folder, "snapshot", "1") + ckpt_list_1 = self.storage_manager.get_fns(snapshot_path_0) + ckpt_list_2 = self.storage_manager.get_fns(snapshot_path_1) + max_step_0, max_step_1 = 0, 0 + for ckpt in ckpt_list_1: + ckpt = ckpt.strip("/") + if ckpt.endswith(".step"): + max_step_0 = max(max_step_0, int(ckpt.split(".")[0])) + for ckpt in ckpt_list_2: + ckpt = ckpt.strip("/") + if ckpt.endswith(".step"): + max_step_1 = max(max_step_1, int(ckpt.split(".")[0])) + + snap_load_path = snapshot_path_0 if max_step_0 > max_step_1 else snapshot_path_1 + snap_step = max(max_step_0, max_step_1) + load_path = snap_load_path if snap_step > max_normal_step else load_normal_ckpt_path + load_step = max(snap_step, max_normal_step) + return load_path, load_step + + def query_latest_snapshot_step_local(self): + max_step, max_step_path = 0, None + for root, _, files in os.walk(self.save_ckpt_folder, followlinks=True): + for fn in files: + fn = fn.strip("/") + if fn.endswith(".step"): + # We assume that both normal ckpt and snapshot ckpt will store the '.step' file + # as an integrity flag. + step = int(fn.rsplit(".", maxsplit=1)[0]) + if max_step < step: + max_step = step + max_step_path = root + + return max_step_path, max_step + + def query_lastest_ckpt(self): + latest_checkpoint = None + # Training was automatically restarted by the process, forcing the latest snapshot to be read. + if self.save_ckpt_folder: + if self.save_ckpt_folder.startswith("boto3"): + latest_checkpoint, step = self.query_latest_snapshot_step_boto3() + elif self.save_ckpt_folder.startswith("local"): + latest_checkpoint, step = self.query_latest_snapshot_step_local() + else: + latest_checkpoint, step = None, 0 + + if latest_checkpoint is not None: + if gpc.is_rank_for_log(): + logger.info(f"Found latest ckpt : {latest_checkpoint}, step: {step}") + send_alert_message( + address=self.feishu_address, + message=f"Auto restart resume from ckpt-path: '{latest_checkpoint}', step : {step}", + ) + else: + if gpc.is_rank_for_log(): + send_alert_message( + address=self.feishu_address, + message=f"Can't find snapshot checkpoint, use default load-ckpt path: {latest_checkpoint}", + ) + + return latest_checkpoint + + def try_load_model(self, current_time=""): + model_load_path = None + + if self.load_ckpt_folder and self.load_model_only_folder: + raise ValueError( + "Error, try to use both load_ckpt_folder and load_model_only_folder paths, \ +if you only need to load model weights (for example starting an SFT task for the first time), \ +set load_model_only_folder path, if you need to resume training from ckpt, \ +set load_ckpt_folder or use default value \ +(if is the default value, internlm will try to load the latest ckpt from save_ckpt_folder)" + ) + + if self.load_ckpt_folder: + if gpc.is_rank_for_log(): + logger.info( + f"===========Resume training from `{self.load_ckpt_folder}` {current_time} on host:" + f"{socket.gethostname()}===========" + ) + model_load_path = self.load_ckpt_folder + elif self.load_model_only_folder: + if gpc.is_rank_for_log(): + logger.info( + f"===========Load Model from `{self.load_model_only_folder}` {current_time} on host:" + f"{socket.gethostname()}===========" + ) + model_load_path = self.load_model_only_folder + else: + if gpc.is_rank_for_log(): + logger.info( + f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()}," + f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)}," + f"dp={gpc.get_local_rank(ParallelMode.DATA)}===========" + ) + + # Loading model weights must be done before zero is initialized. + if model_load_path is not None: + load_model_checkpoint(folder=model_load_path, model=self.model) + + def try_resume_training(self, lr_scheduler, optimizer, lr, train_state, train_dl): + """Attempt to restore the training state of the last ckpt. + + Args: + lr_scheduler (_LRScheduler): lr_scheduler object. + optimizer (Optimizer): optimizer object. + lr (float): learning rate. + train_state (dict): traing states. + train_dl (DataLoader): traning dataloader object + """ + if self.load_ckpt_folder is not None: + # load optimzier states. + if self.load_optimizer: + load_optimizer_checkpoint(self.load_ckpt_folder, optimizer) + # load lr scheduler states. + load_scheduler(self.load_ckpt_folder, lr_scheduler, optimizer, lr, train_state) + # load training states. + load_context(self.load_ckpt_folder, train_dl, train_state) + # load dataloader sampler states. + if hasattr(train_state, "batch_sampler") and not isinstance( + train_state.batch_sampler, torch.utils.data.sampler.BatchSampler + ): + load_sampler(self.load_ckpt_folder, train_dl.batch_sampler) + if hasattr(train_state, "data_state_dict"): + train_dl.dataset.load_state_dict( + llm_load(os.path.join(self.load_ckpt_folder, "sampler_0.pt")), ckpt_path=self.load_ckpt_folder + ) + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + + def save_checkpoint(self, folder, model, optimizer, scheduler, train_state: TrainState, model_config: Dict = None): + """ + Save checkpoint to the given folder path. + """ + + start = time.time() + self.set_save_folder(folder, train_state.step_count) + torch.distributed.barrier() + if gpc.is_rank_for_log(): + logger.info(f"Saving checkpoint to `{folder}` at batch count:{train_state.step_count}...") + + timer("save-model").start() + save_model_checkpoint(folder=folder, model=model) + timer("save-model").stop() + + timer("save-optimizer").start() + save_optimizer_checkpoint(optim=optimizer, state_path=folder) + timer("save-optimizer").stop() + + if ( + hasattr(train_state, "data_state_dict") + and gpc.get_local_rank(ParallelMode.TENSOR) == 0 + and gpc.get_local_rank(ParallelMode.PIPELINE) == 0 + ): + llm_save( + os.path.join(folder, f"sampler_{gpc.get_local_rank(ParallelMode.DATA)}.pt"), + saved_obj=train_state.data_state_dict, + ) + + if gpc.is_rank_for_log(): + scheduler_states = scheduler.state_dict() + llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states) + if hasattr(train_state, "batch_sampler") and not isinstance( + train_state.batch_sampler, torch.utils.data.sampler.BatchSampler + ): + sampler_state = train_state.batch_sampler.state_dict() + llm_save(os.path.join(folder, "sampler.pt"), saved_obj=sampler_state) + llm_save(os.path.join(folder, "context.pt"), saved_obj=train_state.state_dict()) + + if model_config is not None: + llm_save(os.path.join(folder, "model_config.pt"), saved_obj=model_config) + + torch.distributed.barrier() + + if gpc.is_rank_for_log(): + timer.log(["save-model", "save-optimizer"], logger=logger) + logger.info(f"Step: {train_state.step_count}, rank 0 save ckpt use {time.time() - start:.3f} s") + if self.storage_manager.async_mode is False: + llm_save( + os.path.join(folder, f"{train_state.step_count}.step"), + saved_obj=dict({"step": train_state.step_count}), + ) + + def set_save_folder(self, folder, step): + self.storage_manager.latest_save_folder = folder + self.storage_manager.latest_save_step = step diff --git a/internlm/utils/storage_manager.py b/internlm/utils/storage_manager.py index c9b42ea..c7b71f4 100644 --- a/internlm/utils/storage_manager.py +++ b/internlm/utils/storage_manager.py @@ -15,8 +15,6 @@ from asyncio.tasks import ALL_COMPLETED from datetime import datetime from typing import Any, Awaitable, Callable, Dict, List, Union -import boto3 -import botocore import torch import torch.distributed as dist @@ -24,6 +22,13 @@ from internlm.core.context import global_context as gpc from internlm.utils.common import SingletonMeta from internlm.utils.logger import get_logger +try: + import boto3 + import botocore +except ImportError: + pass + + logger = get_logger(__file__) boto3_url_re = re.compile(r"([^\.]+)\.([\d\.]+)") @@ -234,13 +239,13 @@ class Boto3Client(StorageClient): """ paginator = handler.client.get_paginator("list_objects_v2") pages = paginator.paginate(Bucket=bucket_name, Prefix=fp) - folder_name_list = [] for page in pages: - for obj in page["Contents"]: - fp: str = obj["Key"] - folder_name_list.append(fp.rsplit("/", maxsplit=1)[1]) - return folder_name_list + if "Contents" in page: + for obj in page["Contents"]: + pth: str = obj["Key"] + folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0]) + return list(set(folder_name_list)) @staticmethod def async_upload_fileobj(handler, bucket_name: str, fp: str, local_nvme_path: str): @@ -391,6 +396,11 @@ class StorageManager(metaclass=SingletonMeta): self.tmp_local_folder = tmp_local_folder self.async_mode = async_mode self.has_warning = False + self._async_loop = None + self._thread_pool = None + self.latest_save_folder = None + self.latest_save_step = 0 + self.async_task_peeding = False if enable_save and self.async_mode: self._async_loop = asyncio.new_event_loop() @@ -485,6 +495,7 @@ class StorageManager(metaclass=SingletonMeta): torch.save(saved_obj, f, pickle_protocol=pickle.HIGHEST_PROTOCOL) self.async_executor(meta.async_upload_fn, *unpack_meta(meta)) os.chmod(tmp_step_file, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) + self.async_task_peeding = True else: meta.client.sync_upload_fileobj(*unpack_meta(meta), *args, saved_obj=saved_obj, **kwargs) self.upload_count += 1 @@ -523,23 +534,22 @@ class StorageManager(metaclass=SingletonMeta): pass async def _sync_tasks(self) -> Awaitable[None]: - if not self._async_stack: - return - - await asyncio.wait(self._async_stack, return_when=ALL_COMPLETED) - - for task in self._async_stack: - try: - task.exception() - except InvalidStateError: - continue - except Exception as e: - file_id = len(self._exception_list) - self._exception_list.append((e, file_id)) - - logger.error(f"File: {self._to_be_del_files[file_id]}, " f"upload failed with {e}") - - self._async_stack.clear() + if self._async_stack: + await asyncio.wait(self._async_stack, return_when=ALL_COMPLETED) + count = 0 + while self._async_stack: + t = self._async_stack[0] + try: + e = t.exception() + if e: + self._exception_list.append((e, count)) + logger.error(f"File:{self._to_be_del_files[count]}, upload failed for {e}") + # raise e + count += 1 + self._async_stack.pop(0) + except InvalidStateError: + # Not finished. https://docs.python.org/3/library/asyncio-task.html#asyncio.Task.exception + pass def async_executor(self, fn: Callable, *args, **kwargs) -> None: """ @@ -559,11 +569,14 @@ class StorageManager(metaclass=SingletonMeta): if not self.async_mode: return + if not self.async_task_peeding: + return + if self._async_loop: self._async_loop.run_until_complete(self._sync_tasks()) if self._exception_list: - for file_id, error_msg in self._exception_list: + for error_msg, file_id in self._exception_list: logger.error( f"Node:{socket.gethostname()}, Error: Checkpoint {self._to_be_del_files[file_id]} " f"failed on step {self.upload_count}: {error_msg}" @@ -577,10 +590,16 @@ class StorageManager(metaclass=SingletonMeta): self._del_tmp_folder() self._exception_list.clear() self._to_be_del_files.clear() + self.async_task_peeding = False if gpc.is_rank_for_log(): - logger.info("all async uploads succeeded!") self.upload_count += 1 + if self.async_mode: + self.save( + os.path.join(self.latest_save_folder, f"{self.latest_save_step}.step"), + saved_obj=dict({"step": self.latest_save_step}), + async_upload=False, + ) storage_manager: StorageManager = None diff --git a/internlm/utils/writer.py b/internlm/utils/writer.py index 311c6b3..5ea0680 100644 --- a/internlm/utils/writer.py +++ b/internlm/utils/writer.py @@ -11,10 +11,6 @@ from torch.utils.tensorboard import SummaryWriter from internlm.core.context import global_context as gpc -def copy_ignore_folder(source_path, target_path): - os.system(f"cp -r {source_path}/* {target_path}/") - - def tb_save_run_info(writer, config_lines, global_step=0): writer.add_text(tag="cmd", text_string=" ".join(sys.argv[:]), global_step=global_step) lines = [] @@ -44,7 +40,8 @@ def init_tb_writer( if gpc.get_global_rank() == 0: if resume_tb_folder is not None: logger.info(f"Try mv tensorboard logs: {resume_tb_folder} to {tb_folder}...") - copy_ignore_folder(resume_tb_folder, tb_folder) + os.system(f"cp -r {resume_tb_folder}/* {tb_folder}/") + os.system(f"chmod -R +w {tb_folder}/") else: logger.info(f"Login tensorboard logs to: {tb_folder}") diff --git a/train.py b/train.py index d23a30a..306fcdf 100644 --- a/train.py +++ b/train.py @@ -47,14 +47,7 @@ from internlm.utils.common import ( from internlm.utils.evaluation import evaluate_on_val_dls from internlm.utils.logger import get_logger, initialize_uniscale_logger from internlm.utils.megatron_timers import megatron_timer as timer -from internlm.utils.model_checkpoint import ( - CheckpointSaveManager, - load_context, - load_model_checkpoint, - load_optimizer_checkpoint, - load_sampler, - load_scheduler, -) +from internlm.utils.model_checkpoint import CheckpointManager from internlm.utils.parallel import ( get_parallel_log_file_name, is_no_pp_or_last_stage, @@ -462,13 +455,9 @@ def main(args): skip_batches = gpc.config.data.skip_batches total_steps = gpc.config.data.total_steps valid_every = gpc.config.data.valid_every - load_optimizer = gpc.config.ckpt.load_optimizer label_smoothing = gpc.config.loss.label_smoothing lr = gpc.config.adam.lr - load_model_only_folder = gpc.config.ckpt.get("load_model_only_folder", None) - load_resume_ckpt_folder = gpc.config.ckpt.get("load_ckpt_folder", None) - get_tflops_func = partial( get_megatron_flops, checkpoint=gpc.config.model.checkpoint, @@ -504,32 +493,19 @@ def main(args): enable_tb=gpc.config.enable_tb, ) - model_load_path = None - if load_resume_ckpt_folder is not None: - logger.info( - f"===========Resume training from `{load_resume_ckpt_folder}` {current_time} on host:" - f"{socket.gethostname()}===========" - ) - model_load_path = load_resume_ckpt_folder - elif load_model_only_folder is not None: - logger.info( - f"===========SFT training from `{load_model_only_folder}` {current_time} on host:" - f"{socket.gethostname()}===========" - ) - model_load_path = load_model_only_folder - else: - logger.info( - f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()}," - f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)}," - f"dp={gpc.get_local_rank(ParallelMode.DATA)}===========" - ) - # initialize and resume train state train_state = TrainState(gpc.config) # initialize model model = initialize_model() + ckpt_manager = CheckpointManager( + ckpt_config=gpc.config.ckpt, + model=model, + model_config=gpc.config.model, + feishu_address=gpc.config.alert_address, + ) + # initialize loss function criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing) @@ -539,30 +515,12 @@ def main(args): train_state.init_batch_sampler(train_dl) # Loading model weights must be done before zero is initialized. - if model_load_path is not None: - load_model_checkpoint(folder=model_load_path, model=model) + ckpt_manager.try_load_model(current_time) optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model) # Loading other persistent training states. - if load_resume_ckpt_folder is not None: - # load lr scheduler states. - load_scheduler(load_resume_ckpt_folder, lr_scheduler, optimizer, lr, train_state) - # load training states. - load_context(load_resume_ckpt_folder, train_dl, train_state) - # load dataloader sampler states. - load_sampler(load_resume_ckpt_folder, train_dl.batch_sampler) - # load optimzier states. - if load_optimizer: - load_optimizer_checkpoint(load_resume_ckpt_folder, optimizer) - - ckpt_save_manager = CheckpointSaveManager( - ckpt_config=gpc.config.ckpt, - model=model, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - model_config=gpc.config.model, - ) + ckpt_manager.try_resume_training(lr_scheduler, optimizer, lr, train_state, train_dl) # initialize metric for calculating accuracy and perplexity metric = AccPerplex( @@ -700,14 +658,16 @@ def main(args): # checkpoint the training states in specific steps, which is determined by the args "checkpoint_every" # # save batch sampler that tracks the true consumed samples - ckpt_save_manager.try_save_checkpoint(train_state) + now_break = ckpt_manager.try_save_checkpoint(train_state) + if now_break: + break if memory_profiler is not None: memory_profiler.step() prof.step() - ckpt_save_manager.wait_async_upload_finish() + ckpt_manager.wait_async_upload_finish() if __name__ == "__main__":