From a45a91bb843cf0b10b8b014a6ef35e695871f91b Mon Sep 17 00:00:00 2001 From: Guoteng <32697156+SolenoidWGT@users.noreply.github.com> Date: Fri, 11 Aug 2023 17:08:01 +0800 Subject: [PATCH 1/3] feat(ckpt): add auto ckpt load and singal quit (#189) Co-authored-by: wangguoteng.p --- internlm/initialize/launch.py | 6 + internlm/utils/model_checkpoint.py | 321 ++++++++++++++++++++++++----- internlm/utils/storage_manager.py | 27 ++- train.py | 74 ++----- 4 files changed, 309 insertions(+), 119 deletions(-) diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 33b5d15..2447190 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -138,6 +138,12 @@ def args_sanity_check(): if "async_upload_tmp_folder" not in gpc.config.ckpt: gpc.config.ckpt._add_item("async_upload_tmp_folder", "/dev/shm/internlm_tmp_ckpt/") + if "stop_file_path" not in gpc.config.ckpt: + gpc.config._add_item("stop_file_path", None) + + if "load_given_ckpt" not in gpc.config.ckpt: + gpc.config._add_item("load_given_ckpt", False) + if "snapshot_ckpt_folder" not in gpc.config.ckpt: gpc.config.ckpt._add_item("snapshot_ckpt_folder", os.path.join(gpc.config.ckpt.save_ckpt_folder), "snapshot") diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 3fe29cc..643fee9 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -2,7 +2,9 @@ # -*- encoding: utf-8 -*- import copy +import fcntl import os +import socket import time from enum import Enum from typing import Dict @@ -12,6 +14,7 @@ import torch from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.trainer import TrainState +from internlm.monitor import send_alert_message from internlm.solver.optimizer import HybridZeroOptimizer from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger @@ -25,8 +28,6 @@ from internlm.utils.storage_manager import ( logger = get_logger(__file__) -quit_signal_handler = None - class CheckpointType(Enum): NORMAL_CHECKPOINT = 1 @@ -167,44 +168,6 @@ def save_optimizer_checkpoint(optim, state_path): llm_save(os.path.join(state_path, fp), states) -def save_checkpoint(folder, model, optimizer, scheduler, train_state: TrainState, model_config: Dict = None): - """ - Save checkpoint to the given folder path. - """ - - start = time.time() - torch.distributed.barrier() - folder = os.path.join(folder, str(train_state.step_count)) - logger.info( - f"Saving checkpoint to `{folder}` at batch count:{train_state.step_count} from rank:{gpc.get_global_rank()}..." - ) - - timer("save-model").start() - save_model_checkpoint(folder=folder, model=model) - timer("save-model").stop() - - timer("save-optimizer").start() - save_optimizer_checkpoint(optim=optimizer, state_path=folder) - timer("save-optimizer").stop() - - if gpc.is_rank_for_log(): - scheduler_states = scheduler.state_dict() - llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states) - - sampler_state = train_state.batch_sampler.state_dict() - llm_save(os.path.join(folder, "sampler.pt"), saved_obj=sampler_state) - llm_save(os.path.join(folder, "context.pt"), saved_obj=train_state.state_dict()) - - if model_config is not None: - llm_save(os.path.join(folder, "model_config.pt"), saved_obj=model_config) - - torch.distributed.barrier() - - if gpc.is_rank_for_log(): - timer.log(["save-model", "save-optimizer"], logger=logger) - logger.info(f"Step: {train_state.step_count}, rank 0 save ckpt use {time.time() - start:.3f} s") - - def load_optimizer_checkpoint(folder, optim): """Load the optimizer state from the local file system or remote object storage Service (OSS). @@ -304,19 +267,12 @@ def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train logger.info(f"reload load_scheduler:{lr_scheduler}") -class CheckpointSaveManager: +class CheckpointManager: """StorageManagerContext""" - def __init__( - self, - ckpt_config, - model, - optimizer, - lr_scheduler, - model_config, - ) -> None: + def __init__(self, ckpt_config, model, model_config, feishu_address=None) -> None: """ - CheckpointSaveManager is used to decide when to store ckpt. If it is an asynchronous + CheckpointManager is used to decide when to store ckpt. If it is an asynchronous upload mode, you must call wait_async_upload_finish at the end of the program to wait for the asynchronous ckpt upload to complete. @@ -332,26 +288,85 @@ class CheckpointSaveManager: self.save_ckpt_folder = ckpt_config.save_ckpt_folder self.snapshot_ckpt_folder = ckpt_config.snapshot_ckpt_folder self.oss_snapshot_freq: int = ckpt_config.oss_snapshot_freq + self.stop_file_path = ckpt_config.stop_file_path + self.load_model_only_folder = ckpt_config.load_model_only_folder + self.feishu_address = feishu_address self.storage_manager = get_storage_manager() self.snapshot_counter = 0 + self.load_optimizer = gpc.config.ckpt.load_optimizer self.model = model - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler self.model_config = model_config + if self.stop_file_path and gpc.get_global_rank() == 0: + dir_path = os.path.dirname(self.stop_file_path) + if dir_path != "" and not os.path.exists(dir_path): + os.makedirs(dir_path) + with open(self.stop_file_path, "w", encoding="utf-8") as f: + f.write("0") + + if not ckpt_config.load_given_ckpt: + latest_ckpt_path = self.query_lastest_ckpt() + self.load_ckpt_folder = latest_ckpt_path if latest_ckpt_path is not None else ckpt_config.load_ckpt_folder + else: + self.load_ckpt_folder = ckpt_config.load_ckpt_folder + + def quit_signal_handler(self, train_state) -> bool: + """ + Exit signal detection function, if we write the exit step in the 'QUIT_FILE_PATH' file, + all ranks will save ckpt and exit. + Negative integer step means save ckpt. + Positive integer step means save ckpt and quit. + + Args: + train_state (TrainState): + Returns: + bool: whether to quit. + """ + if self.stop_file_path is None: + logger.warning("no set stop_file_path") + return + + now_break, now_save_ckpt, save_type = False, False, CheckpointType.NORMAL_CHECKPOINT + with open(self.stop_file_path, "a+", encoding="utf-8") as f: + fcntl.flock(f, fcntl.LOCK_EX) + f.seek(0) + msg = f.read() + fcntl.flock(f, fcntl.LOCK_UN) + action_step = int(msg) + + if action_step < 0 and abs(action_step) == train_state.step_count: + now_save_ckpt = True + + if action_step > 0 and action_step == train_state.step_count: + now_break, now_save_ckpt = True, True + + if action_step != 0 and gpc.is_rank_for_log(): + msg = "Stop" if action_step > 0 else "Save" + action_step = abs(action_step) + if train_state.step_count <= action_step: + if self.feishu_address: + send_alert_message( + address=self.feishu_address, + message=f"training will {msg} at step_count {action_step}!\ +now step_count is {train_state.step_count}", + ) + + return now_break, now_save_ckpt, save_type + def try_save_checkpoint(self, train_state): if not self.enable_save_ckpt: - return + return False save_ckpts, save_type = False, CheckpointType.NORMAL_CHECKPOINT if self.oss_snapshot_freq > 1 and train_state.step_count % self.oss_snapshot_freq == 0: save_ckpts, save_type = True, CheckpointType.SNAPSHOT_CHECKPOINT if train_state.step_count % self.checkpoint_every == 0: save_ckpts, save_type = True, CheckpointType.NORMAL_CHECKPOINT + now_break, singal_save_ckpts, singal_save_type = self.quit_signal_handler(train_state) if save_ckpts is False: - if quit_signal_handler is not None: - save_ckpts, save_type = quit_signal_handler(train_state) + save_ckpts = singal_save_ckpts + save_type = singal_save_type if save_ckpts: # Wait for the previous round of asynchronous upload storage to complete. @@ -361,9 +376,9 @@ class CheckpointSaveManager: self.snapshot_counter = (self.snapshot_counter + 1) % 2 save_ckpt_folder = os.path.join(self.snapshot_ckpt_folder, f"{self.snapshot_counter}") else: - save_ckpt_folder = self.save_ckpt_folder + save_ckpt_folder = os.path.join(self.save_ckpt_folder, str(train_state.step_count)) - save_checkpoint( + self.save_checkpoint( folder=save_ckpt_folder, model=self.model, optimizer=self.optimizer, @@ -372,7 +387,199 @@ class CheckpointSaveManager: model_config=self.model_config, ) + return now_break + def wait_async_upload_finish(self): """wait for all checkpoint uploads to be completed""" self.storage_manager.wait() torch.distributed.barrier() + + def query_latest_snapshot_step_boto3(self): + """query_latest_snapshot_step_boto3 + Returns: + Tuple(str, int): path of latest ckpt and ckpt step, if not found, None will return. + """ + ckpt_list = self.storage_manager.get_fns(self.save_ckpt_folder) + if len(ckpt_list) == 0: + return None, None + + max_normal_step = 0 + ckpt_list = list(map(lambda a: int(a.strip("/")) if a.strip("/").isdigit() else 0, ckpt_list)) + ckpt_list.sort(reverse=True) + for ckpt in ckpt_list: + fns_list = self.storage_manager.get_fns(os.path.join(self.save_ckpt_folder, str(ckpt))) + for fn in fns_list: + if fn.endswith(".step"): + max_normal_step = ckpt + break + if max_normal_step != 0: + break + + max_normal_step = ckpt_list[0] + load_normal_ckpt_path = os.path.join(self.save_ckpt_folder, str(max_normal_step)) + + snapshot_path_0 = os.path.join(self.save_ckpt_folder, "snapshot", "0") + snapshot_path_1 = os.path.join(self.save_ckpt_folder, "snapshot", "1") + ckpt_list_1 = self.storage_manager.get_fns(snapshot_path_0) + ckpt_list_2 = self.storage_manager.get_fns(snapshot_path_1) + max_step_0, max_step_1 = 0, 0 + for ckpt in ckpt_list_1: + ckpt = ckpt.strip("/") + if ckpt.endswith(".step"): + max_step_0 = max(max_step_0, int(ckpt.split(".")[0])) + for ckpt in ckpt_list_2: + ckpt = ckpt.strip("/") + if ckpt.endswith(".step"): + max_step_1 = max(max_step_1, int(ckpt.split(".")[0])) + + snap_load_path = snapshot_path_0 if max_step_0 > max_step_1 else snapshot_path_1 + snap_step = max(max_step_0, max_step_1) + load_path = snap_load_path if snap_step > max_normal_step else load_normal_ckpt_path + load_step = max(snap_step, max_normal_step) + return load_path, load_step + + def query_latest_snapshot_step_local(self): + max_step, max_step_path = 0, None + for root, _, files in os.walk(self.save_ckpt_folder, followlinks=True): + for fn in files: + fn = fn.strip("/") + if fn.endswith(".step"): + # We assume that both normal ckpt and snapshot ckpt will store the '.step' file + # as an integrity flag. + step = int(fn.rsplit(".", maxsplit=1)[0]) + if max_step < step: + max_step = step + max_step_path = root + + return max_step_path, max_step + + def query_lastest_ckpt(self): + latest_checkpoint = None + # Training was automatically restarted by the process, forcing the latest snapshot to be read. + if self.save_ckpt_folder: + if self.save_ckpt_folder.startswith("boto3"): + latest_checkpoint, step = self.query_latest_snapshot_step_boto3() + elif self.save_ckpt_folder.startswith("local"): + latest_checkpoint, step = self.query_latest_snapshot_step_local() + else: + latest_checkpoint, step = None, 0 + + if latest_checkpoint is not None: + if gpc.is_rank_for_log(): + logger.info(f"Found latest ckpt : {latest_checkpoint}, step: {step}") + send_alert_message( + address=self.feishu_address, + message=f"Auto restart resume from ckpt-path: '{latest_checkpoint}', step : {step}", + ) + else: + if gpc.is_rank_for_log(): + send_alert_message( + address=self.feishu_address, + message=f"Can't find snapshot checkpoint, use default load-ckpt path: {latest_checkpoint}", + ) + + return latest_checkpoint + + def try_load_model(self, current_time=""): + model_load_path = None + + if self.load_ckpt_folder and self.load_model_only_folder: + raise ValueError( + "Error, try to use both load_ckpt_folder and load_model_only_folder paths, \ +if you only need to load model weights (for example starting an SFT task for the first time), \ +set load_model_only_folder path, if you need to resume training from ckpt, \ +set load_ckpt_folder or use default value \ +(if is the default value, internlm will try to load the latest ckpt from save_ckpt_folder)" + ) + + if self.load_ckpt_folder: + logger.info( + f"===========Resume training from `{self.load_ckpt_folder}` {current_time} on host:" + f"{socket.gethostname()}===========" + ) + model_load_path = self.load_ckpt_folder + elif self.load_model_only_folder: + logger.info( + f"===========SFT training from `{self.load_model_only_folder}` {current_time} on host:" + f"{socket.gethostname()}===========" + ) + model_load_path = self.load_model_only_folder + else: + logger.info( + f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()}," + f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)}," + f"dp={gpc.get_local_rank(ParallelMode.DATA)}===========" + ) + + # Loading model weights must be done before zero is initialized. + if model_load_path is not None: + load_model_checkpoint(folder=model_load_path, model=self.model) + + def try_resume_traing(self, lr_scheduler, optimizer, lr, train_state, train_dl): + """Attempt to restore the training state of the last ckpt. + + Args: + lr_scheduler (_LRScheduler): lr_scheduler object. + optimizer (Optimizer): optimizer object. + lr (float): learning rate. + train_state (dict): traing states. + train_dl (DataLoader): traning dataloader object + """ + if self.load_ckpt_folder is not None: + # load lr scheduler states. + load_scheduler(self.load_ckpt_folder, lr_scheduler, optimizer, lr, train_state) + # load training states. + load_context(self.load_ckpt_folder, train_dl, train_state) + # load dataloader sampler states. + load_sampler(self.load_ckpt_folder, train_dl.batch_sampler) + # load optimzier states. + if self.load_optimizer: + load_optimizer_checkpoint(self.load_ckpt_folder, optimizer) + + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + + def save_checkpoint(self, folder, model, optimizer, scheduler, train_state: TrainState, model_config: Dict = None): + """ + Save checkpoint to the given folder path. + """ + + start = time.time() + self.set_save_folder(folder, train_state.step_count) + torch.distributed.barrier() + if gpc.is_rank_for_log(): + logger.info(f"Saving checkpoint to `{folder}` at batch count:{train_state.step_count}...") + + timer("save-model").start() + save_model_checkpoint(folder=folder, model=model) + timer("save-model").stop() + + timer("save-optimizer").start() + save_optimizer_checkpoint(optim=optimizer, state_path=folder) + timer("save-optimizer").stop() + + if gpc.is_rank_for_log(): + scheduler_states = scheduler.state_dict() + llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states) + + sampler_state = train_state.batch_sampler.state_dict() + llm_save(os.path.join(folder, "sampler.pt"), saved_obj=sampler_state) + llm_save(os.path.join(folder, "context.pt"), saved_obj=train_state.state_dict()) + + if model_config is not None: + llm_save(os.path.join(folder, "model_config.pt"), saved_obj=model_config) + + torch.distributed.barrier() + + if gpc.is_rank_for_log(): + timer.log(["save-model", "save-optimizer"], logger=logger) + logger.info(f"Step: {train_state.step_count}, rank 0 save ckpt use {time.time() - start:.3f} s") + if self.storage_manager.async_mode is False: + llm_save( + os.path.join(folder, f"{train_state.step_count}.step"), + saved_obj=dict({"step": train_state.step_count}), + ) + + def set_save_folder(self, folder, step): + self.storage_manager.latest_save_folder = folder + self.storage_manager.latest_save_step = step diff --git a/internlm/utils/storage_manager.py b/internlm/utils/storage_manager.py index 481bd28..6984ddc 100644 --- a/internlm/utils/storage_manager.py +++ b/internlm/utils/storage_manager.py @@ -234,13 +234,13 @@ class Boto3Client(StorageClient): """ paginator = handler.client.get_paginator("list_objects_v2") pages = paginator.paginate(Bucket=bucket_name, Prefix=fp) - folder_name_list = [] for page in pages: - for obj in page["Contents"]: - fp: str = obj["Key"] - folder_name_list.append(fp.rsplit("/", maxsplit=1)[1]) - return folder_name_list + if "Contents" in page: + for obj in page["Contents"]: + pth: str = obj["Key"] + folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0]) + return list(set(folder_name_list)) @staticmethod def async_upload_fileobj(handler, bucket_name: str, fp: str, local_nvme_path: str): @@ -391,6 +391,11 @@ class StorageManager(metaclass=SingletonMeta): self.tmp_local_folder = tmp_local_folde self.async_mode = async_mode self.has_warning = False + self._async_loop = None + self._thread_pool = None + self.latest_save_folder = None + self.latest_save_step = 0 + self.async_task_peeding = False if enable_save and self.async_mode: self._async_loop = asyncio.new_event_loop() @@ -485,6 +490,7 @@ class StorageManager(metaclass=SingletonMeta): torch.save(saved_obj, f, pickle_protocol=pickle.HIGHEST_PROTOCOL) self.async_executor(meta.async_upload_fn, *unpack_meta(meta)) os.chmod(tmp_step_file, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) + self.async_task_peeding = True else: meta.client.sync_upload_fileobj(*unpack_meta(meta), *args, saved_obj=saved_obj, **kwargs) self.upload_count += 1 @@ -560,6 +566,9 @@ class StorageManager(metaclass=SingletonMeta): if not self.async_mode: return + if not self.async_task_peeding: + return + if self._async_loop: self._async_loop.run_until_complete(self._sync_tasks()) @@ -578,10 +587,16 @@ class StorageManager(metaclass=SingletonMeta): self._del_tmp_folder() self._exception_list.clear() self._to_be_del_files.clear() + self.async_task_peeding = False if gpc.is_rank_for_log(): - logger.info("all async uploads succeeded!") self.upload_count += 1 + if self.async_mode: + self.save( + os.path.join(self.latest_save_folder, f"{self.latest_save_step}.step"), + saved_obj=dict({"step": self.latest_save_step}), + async_upload=False, + ) storage_manager: StorageManager = None diff --git a/train.py b/train.py index 7e87d05..f1e6fd0 100644 --- a/train.py +++ b/train.py @@ -45,14 +45,7 @@ from internlm.utils.common import ( from internlm.utils.evaluation import evaluate_on_val_dls, switch_sequence_parallel_mode from internlm.utils.logger import get_logger, initialize_uniscale_logger from internlm.utils.megatron_timers import megatron_timer as timer -from internlm.utils.model_checkpoint import ( - CheckpointSaveManager, - load_context, - load_model_checkpoint, - load_optimizer_checkpoint, - load_sampler, - load_scheduler, -) +from internlm.utils.model_checkpoint import CheckpointManager from internlm.utils.parallel import ( get_parallel_log_file_name, is_no_pp_or_last_stage, @@ -428,13 +421,9 @@ def main(args): skip_batches = gpc.config.data.skip_batches total_steps = gpc.config.data.total_steps valid_every = gpc.config.data.valid_every - load_optimizer = gpc.config.ckpt.load_optimizer label_smoothing = gpc.config.loss.label_smoothing lr = gpc.config.adam.lr - load_model_only_folder = gpc.config.ckpt.get("load_model_only_folder", None) - load_resume_ckpt_folder = gpc.config.ckpt.get("load_ckpt_folder", None) - get_tflops_func = partial( get_megatron_flops, checkpoint=gpc.config.model.checkpoint, @@ -470,32 +459,19 @@ def main(args): enable_tb=gpc.config.enable_tb, ) - model_load_path = None - if load_resume_ckpt_folder is not None: - logger.info( - f"===========Resume training from `{load_resume_ckpt_folder}` {current_time} on host:" - f"{socket.gethostname()}===========" - ) - model_load_path = load_resume_ckpt_folder - elif load_model_only_folder is not None: - logger.info( - f"===========SFT training from `{load_model_only_folder}` {current_time} on host:" - f"{socket.gethostname()}===========" - ) - model_load_path = load_model_only_folder - else: - logger.info( - f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()}," - f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)}," - f"dp={gpc.get_local_rank(ParallelMode.DATA)}===========" - ) - # initialize and resume train state train_state = TrainState(gpc.config) # initialize model model = initialize_model() + ckpt_manager = CheckpointManager( + ckpt_config=gpc.config.ckpt, + model=model, + model_config=gpc.config.model, + feishu_address=gpc.config.alert_address, + ) + # initialize loss function criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing) @@ -505,30 +481,12 @@ def main(args): train_state.init_batch_sampler(train_dl) # Loading model weights must be done before zero is initialized. - if model_load_path is not None: - load_model_checkpoint(folder=model_load_path, model=model) + ckpt_manager.try_load_model(current_time) optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model) # Loading other persistent training states. - if load_resume_ckpt_folder is not None: - # load lr scheduler states. - load_scheduler(load_resume_ckpt_folder, lr_scheduler, optimizer, lr, train_state) - # load training states. - load_context(load_resume_ckpt_folder, train_dl, train_state) - # load dataloader sampler states. - load_sampler(load_resume_ckpt_folder, train_dl.batch_sampler) - # load optimzier states. - if load_optimizer: - load_optimizer_checkpoint(load_resume_ckpt_folder, optimizer) - - ckpt_save_manager = CheckpointSaveManager( - ckpt_config=gpc.config.ckpt, - model=model, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - model_config=gpc.config.model, - ) + ckpt_manager.try_resume_traing(lr_scheduler, optimizer, lr, train_state, train_dl) # initialize metric for calculating accuracy and perplexity metric = AccPerplex( @@ -649,9 +607,11 @@ def main(args): # checkpoint the training states in specific steps, which is determined by the args "checkpoint_every" # # save batch sampler that tracks the true consumed samples - ckpt_save_manager.try_save_checkpoint(train_state) + now_break = ckpt_manager.try_save_checkpoint(train_state) + if now_break: + break - ckpt_save_manager.wait_async_upload_finish() + ckpt_manager.wait_async_upload_finish() if __name__ == "__main__": @@ -667,8 +627,10 @@ if __name__ == "__main__": try: main(args) except Exception: + format_trace = "" + for line in traceback.format_exc().split("\n")[-10:]: + format_trace += "\n" + line logger.error( - f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}", - exc_info=traceback.format_exc(), + f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}, trace:{format_trace}", ) mm.monitor_exception(alert_address=gpc.config.alert_address, excp_info=traceback.format_exc()) From 5f3133fac873720b9cd5195e4a39cab9d10343d4 Mon Sep 17 00:00:00 2001 From: Sun Peng Date: Fri, 11 Aug 2023 17:12:26 +0800 Subject: [PATCH 2/3] Revert "feat(ckpt): add auto ckpt load and singal quit (#189)" (#192) This reverts commit a45a91bb843cf0b10b8b014a6ef35e695871f91b. --- internlm/initialize/launch.py | 6 - internlm/utils/model_checkpoint.py | 321 +++++------------------------ internlm/utils/storage_manager.py | 27 +-- train.py | 74 +++++-- 4 files changed, 119 insertions(+), 309 deletions(-) diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 2447190..33b5d15 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -138,12 +138,6 @@ def args_sanity_check(): if "async_upload_tmp_folder" not in gpc.config.ckpt: gpc.config.ckpt._add_item("async_upload_tmp_folder", "/dev/shm/internlm_tmp_ckpt/") - if "stop_file_path" not in gpc.config.ckpt: - gpc.config._add_item("stop_file_path", None) - - if "load_given_ckpt" not in gpc.config.ckpt: - gpc.config._add_item("load_given_ckpt", False) - if "snapshot_ckpt_folder" not in gpc.config.ckpt: gpc.config.ckpt._add_item("snapshot_ckpt_folder", os.path.join(gpc.config.ckpt.save_ckpt_folder), "snapshot") diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 643fee9..3fe29cc 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -2,9 +2,7 @@ # -*- encoding: utf-8 -*- import copy -import fcntl import os -import socket import time from enum import Enum from typing import Dict @@ -14,7 +12,6 @@ import torch from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.trainer import TrainState -from internlm.monitor import send_alert_message from internlm.solver.optimizer import HybridZeroOptimizer from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger @@ -28,6 +25,8 @@ from internlm.utils.storage_manager import ( logger = get_logger(__file__) +quit_signal_handler = None + class CheckpointType(Enum): NORMAL_CHECKPOINT = 1 @@ -168,6 +167,44 @@ def save_optimizer_checkpoint(optim, state_path): llm_save(os.path.join(state_path, fp), states) +def save_checkpoint(folder, model, optimizer, scheduler, train_state: TrainState, model_config: Dict = None): + """ + Save checkpoint to the given folder path. + """ + + start = time.time() + torch.distributed.barrier() + folder = os.path.join(folder, str(train_state.step_count)) + logger.info( + f"Saving checkpoint to `{folder}` at batch count:{train_state.step_count} from rank:{gpc.get_global_rank()}..." + ) + + timer("save-model").start() + save_model_checkpoint(folder=folder, model=model) + timer("save-model").stop() + + timer("save-optimizer").start() + save_optimizer_checkpoint(optim=optimizer, state_path=folder) + timer("save-optimizer").stop() + + if gpc.is_rank_for_log(): + scheduler_states = scheduler.state_dict() + llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states) + + sampler_state = train_state.batch_sampler.state_dict() + llm_save(os.path.join(folder, "sampler.pt"), saved_obj=sampler_state) + llm_save(os.path.join(folder, "context.pt"), saved_obj=train_state.state_dict()) + + if model_config is not None: + llm_save(os.path.join(folder, "model_config.pt"), saved_obj=model_config) + + torch.distributed.barrier() + + if gpc.is_rank_for_log(): + timer.log(["save-model", "save-optimizer"], logger=logger) + logger.info(f"Step: {train_state.step_count}, rank 0 save ckpt use {time.time() - start:.3f} s") + + def load_optimizer_checkpoint(folder, optim): """Load the optimizer state from the local file system or remote object storage Service (OSS). @@ -267,12 +304,19 @@ def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train logger.info(f"reload load_scheduler:{lr_scheduler}") -class CheckpointManager: +class CheckpointSaveManager: """StorageManagerContext""" - def __init__(self, ckpt_config, model, model_config, feishu_address=None) -> None: + def __init__( + self, + ckpt_config, + model, + optimizer, + lr_scheduler, + model_config, + ) -> None: """ - CheckpointManager is used to decide when to store ckpt. If it is an asynchronous + CheckpointSaveManager is used to decide when to store ckpt. If it is an asynchronous upload mode, you must call wait_async_upload_finish at the end of the program to wait for the asynchronous ckpt upload to complete. @@ -288,85 +332,26 @@ class CheckpointManager: self.save_ckpt_folder = ckpt_config.save_ckpt_folder self.snapshot_ckpt_folder = ckpt_config.snapshot_ckpt_folder self.oss_snapshot_freq: int = ckpt_config.oss_snapshot_freq - self.stop_file_path = ckpt_config.stop_file_path - self.load_model_only_folder = ckpt_config.load_model_only_folder - self.feishu_address = feishu_address self.storage_manager = get_storage_manager() self.snapshot_counter = 0 - self.load_optimizer = gpc.config.ckpt.load_optimizer self.model = model + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler self.model_config = model_config - if self.stop_file_path and gpc.get_global_rank() == 0: - dir_path = os.path.dirname(self.stop_file_path) - if dir_path != "" and not os.path.exists(dir_path): - os.makedirs(dir_path) - with open(self.stop_file_path, "w", encoding="utf-8") as f: - f.write("0") - - if not ckpt_config.load_given_ckpt: - latest_ckpt_path = self.query_lastest_ckpt() - self.load_ckpt_folder = latest_ckpt_path if latest_ckpt_path is not None else ckpt_config.load_ckpt_folder - else: - self.load_ckpt_folder = ckpt_config.load_ckpt_folder - - def quit_signal_handler(self, train_state) -> bool: - """ - Exit signal detection function, if we write the exit step in the 'QUIT_FILE_PATH' file, - all ranks will save ckpt and exit. - Negative integer step means save ckpt. - Positive integer step means save ckpt and quit. - - Args: - train_state (TrainState): - Returns: - bool: whether to quit. - """ - if self.stop_file_path is None: - logger.warning("no set stop_file_path") - return - - now_break, now_save_ckpt, save_type = False, False, CheckpointType.NORMAL_CHECKPOINT - with open(self.stop_file_path, "a+", encoding="utf-8") as f: - fcntl.flock(f, fcntl.LOCK_EX) - f.seek(0) - msg = f.read() - fcntl.flock(f, fcntl.LOCK_UN) - action_step = int(msg) - - if action_step < 0 and abs(action_step) == train_state.step_count: - now_save_ckpt = True - - if action_step > 0 and action_step == train_state.step_count: - now_break, now_save_ckpt = True, True - - if action_step != 0 and gpc.is_rank_for_log(): - msg = "Stop" if action_step > 0 else "Save" - action_step = abs(action_step) - if train_state.step_count <= action_step: - if self.feishu_address: - send_alert_message( - address=self.feishu_address, - message=f"training will {msg} at step_count {action_step}!\ -now step_count is {train_state.step_count}", - ) - - return now_break, now_save_ckpt, save_type - def try_save_checkpoint(self, train_state): if not self.enable_save_ckpt: - return False + return save_ckpts, save_type = False, CheckpointType.NORMAL_CHECKPOINT if self.oss_snapshot_freq > 1 and train_state.step_count % self.oss_snapshot_freq == 0: save_ckpts, save_type = True, CheckpointType.SNAPSHOT_CHECKPOINT if train_state.step_count % self.checkpoint_every == 0: save_ckpts, save_type = True, CheckpointType.NORMAL_CHECKPOINT - now_break, singal_save_ckpts, singal_save_type = self.quit_signal_handler(train_state) if save_ckpts is False: - save_ckpts = singal_save_ckpts - save_type = singal_save_type + if quit_signal_handler is not None: + save_ckpts, save_type = quit_signal_handler(train_state) if save_ckpts: # Wait for the previous round of asynchronous upload storage to complete. @@ -376,9 +361,9 @@ now step_count is {train_state.step_count}", self.snapshot_counter = (self.snapshot_counter + 1) % 2 save_ckpt_folder = os.path.join(self.snapshot_ckpt_folder, f"{self.snapshot_counter}") else: - save_ckpt_folder = os.path.join(self.save_ckpt_folder, str(train_state.step_count)) + save_ckpt_folder = self.save_ckpt_folder - self.save_checkpoint( + save_checkpoint( folder=save_ckpt_folder, model=self.model, optimizer=self.optimizer, @@ -387,199 +372,7 @@ now step_count is {train_state.step_count}", model_config=self.model_config, ) - return now_break - def wait_async_upload_finish(self): """wait for all checkpoint uploads to be completed""" self.storage_manager.wait() torch.distributed.barrier() - - def query_latest_snapshot_step_boto3(self): - """query_latest_snapshot_step_boto3 - Returns: - Tuple(str, int): path of latest ckpt and ckpt step, if not found, None will return. - """ - ckpt_list = self.storage_manager.get_fns(self.save_ckpt_folder) - if len(ckpt_list) == 0: - return None, None - - max_normal_step = 0 - ckpt_list = list(map(lambda a: int(a.strip("/")) if a.strip("/").isdigit() else 0, ckpt_list)) - ckpt_list.sort(reverse=True) - for ckpt in ckpt_list: - fns_list = self.storage_manager.get_fns(os.path.join(self.save_ckpt_folder, str(ckpt))) - for fn in fns_list: - if fn.endswith(".step"): - max_normal_step = ckpt - break - if max_normal_step != 0: - break - - max_normal_step = ckpt_list[0] - load_normal_ckpt_path = os.path.join(self.save_ckpt_folder, str(max_normal_step)) - - snapshot_path_0 = os.path.join(self.save_ckpt_folder, "snapshot", "0") - snapshot_path_1 = os.path.join(self.save_ckpt_folder, "snapshot", "1") - ckpt_list_1 = self.storage_manager.get_fns(snapshot_path_0) - ckpt_list_2 = self.storage_manager.get_fns(snapshot_path_1) - max_step_0, max_step_1 = 0, 0 - for ckpt in ckpt_list_1: - ckpt = ckpt.strip("/") - if ckpt.endswith(".step"): - max_step_0 = max(max_step_0, int(ckpt.split(".")[0])) - for ckpt in ckpt_list_2: - ckpt = ckpt.strip("/") - if ckpt.endswith(".step"): - max_step_1 = max(max_step_1, int(ckpt.split(".")[0])) - - snap_load_path = snapshot_path_0 if max_step_0 > max_step_1 else snapshot_path_1 - snap_step = max(max_step_0, max_step_1) - load_path = snap_load_path if snap_step > max_normal_step else load_normal_ckpt_path - load_step = max(snap_step, max_normal_step) - return load_path, load_step - - def query_latest_snapshot_step_local(self): - max_step, max_step_path = 0, None - for root, _, files in os.walk(self.save_ckpt_folder, followlinks=True): - for fn in files: - fn = fn.strip("/") - if fn.endswith(".step"): - # We assume that both normal ckpt and snapshot ckpt will store the '.step' file - # as an integrity flag. - step = int(fn.rsplit(".", maxsplit=1)[0]) - if max_step < step: - max_step = step - max_step_path = root - - return max_step_path, max_step - - def query_lastest_ckpt(self): - latest_checkpoint = None - # Training was automatically restarted by the process, forcing the latest snapshot to be read. - if self.save_ckpt_folder: - if self.save_ckpt_folder.startswith("boto3"): - latest_checkpoint, step = self.query_latest_snapshot_step_boto3() - elif self.save_ckpt_folder.startswith("local"): - latest_checkpoint, step = self.query_latest_snapshot_step_local() - else: - latest_checkpoint, step = None, 0 - - if latest_checkpoint is not None: - if gpc.is_rank_for_log(): - logger.info(f"Found latest ckpt : {latest_checkpoint}, step: {step}") - send_alert_message( - address=self.feishu_address, - message=f"Auto restart resume from ckpt-path: '{latest_checkpoint}', step : {step}", - ) - else: - if gpc.is_rank_for_log(): - send_alert_message( - address=self.feishu_address, - message=f"Can't find snapshot checkpoint, use default load-ckpt path: {latest_checkpoint}", - ) - - return latest_checkpoint - - def try_load_model(self, current_time=""): - model_load_path = None - - if self.load_ckpt_folder and self.load_model_only_folder: - raise ValueError( - "Error, try to use both load_ckpt_folder and load_model_only_folder paths, \ -if you only need to load model weights (for example starting an SFT task for the first time), \ -set load_model_only_folder path, if you need to resume training from ckpt, \ -set load_ckpt_folder or use default value \ -(if is the default value, internlm will try to load the latest ckpt from save_ckpt_folder)" - ) - - if self.load_ckpt_folder: - logger.info( - f"===========Resume training from `{self.load_ckpt_folder}` {current_time} on host:" - f"{socket.gethostname()}===========" - ) - model_load_path = self.load_ckpt_folder - elif self.load_model_only_folder: - logger.info( - f"===========SFT training from `{self.load_model_only_folder}` {current_time} on host:" - f"{socket.gethostname()}===========" - ) - model_load_path = self.load_model_only_folder - else: - logger.info( - f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()}," - f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)}," - f"dp={gpc.get_local_rank(ParallelMode.DATA)}===========" - ) - - # Loading model weights must be done before zero is initialized. - if model_load_path is not None: - load_model_checkpoint(folder=model_load_path, model=self.model) - - def try_resume_traing(self, lr_scheduler, optimizer, lr, train_state, train_dl): - """Attempt to restore the training state of the last ckpt. - - Args: - lr_scheduler (_LRScheduler): lr_scheduler object. - optimizer (Optimizer): optimizer object. - lr (float): learning rate. - train_state (dict): traing states. - train_dl (DataLoader): traning dataloader object - """ - if self.load_ckpt_folder is not None: - # load lr scheduler states. - load_scheduler(self.load_ckpt_folder, lr_scheduler, optimizer, lr, train_state) - # load training states. - load_context(self.load_ckpt_folder, train_dl, train_state) - # load dataloader sampler states. - load_sampler(self.load_ckpt_folder, train_dl.batch_sampler) - # load optimzier states. - if self.load_optimizer: - load_optimizer_checkpoint(self.load_ckpt_folder, optimizer) - - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler - - def save_checkpoint(self, folder, model, optimizer, scheduler, train_state: TrainState, model_config: Dict = None): - """ - Save checkpoint to the given folder path. - """ - - start = time.time() - self.set_save_folder(folder, train_state.step_count) - torch.distributed.barrier() - if gpc.is_rank_for_log(): - logger.info(f"Saving checkpoint to `{folder}` at batch count:{train_state.step_count}...") - - timer("save-model").start() - save_model_checkpoint(folder=folder, model=model) - timer("save-model").stop() - - timer("save-optimizer").start() - save_optimizer_checkpoint(optim=optimizer, state_path=folder) - timer("save-optimizer").stop() - - if gpc.is_rank_for_log(): - scheduler_states = scheduler.state_dict() - llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states) - - sampler_state = train_state.batch_sampler.state_dict() - llm_save(os.path.join(folder, "sampler.pt"), saved_obj=sampler_state) - llm_save(os.path.join(folder, "context.pt"), saved_obj=train_state.state_dict()) - - if model_config is not None: - llm_save(os.path.join(folder, "model_config.pt"), saved_obj=model_config) - - torch.distributed.barrier() - - if gpc.is_rank_for_log(): - timer.log(["save-model", "save-optimizer"], logger=logger) - logger.info(f"Step: {train_state.step_count}, rank 0 save ckpt use {time.time() - start:.3f} s") - if self.storage_manager.async_mode is False: - llm_save( - os.path.join(folder, f"{train_state.step_count}.step"), - saved_obj=dict({"step": train_state.step_count}), - ) - - def set_save_folder(self, folder, step): - self.storage_manager.latest_save_folder = folder - self.storage_manager.latest_save_step = step diff --git a/internlm/utils/storage_manager.py b/internlm/utils/storage_manager.py index 6984ddc..481bd28 100644 --- a/internlm/utils/storage_manager.py +++ b/internlm/utils/storage_manager.py @@ -234,13 +234,13 @@ class Boto3Client(StorageClient): """ paginator = handler.client.get_paginator("list_objects_v2") pages = paginator.paginate(Bucket=bucket_name, Prefix=fp) + folder_name_list = [] for page in pages: - if "Contents" in page: - for obj in page["Contents"]: - pth: str = obj["Key"] - folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0]) - return list(set(folder_name_list)) + for obj in page["Contents"]: + fp: str = obj["Key"] + folder_name_list.append(fp.rsplit("/", maxsplit=1)[1]) + return folder_name_list @staticmethod def async_upload_fileobj(handler, bucket_name: str, fp: str, local_nvme_path: str): @@ -391,11 +391,6 @@ class StorageManager(metaclass=SingletonMeta): self.tmp_local_folder = tmp_local_folde self.async_mode = async_mode self.has_warning = False - self._async_loop = None - self._thread_pool = None - self.latest_save_folder = None - self.latest_save_step = 0 - self.async_task_peeding = False if enable_save and self.async_mode: self._async_loop = asyncio.new_event_loop() @@ -490,7 +485,6 @@ class StorageManager(metaclass=SingletonMeta): torch.save(saved_obj, f, pickle_protocol=pickle.HIGHEST_PROTOCOL) self.async_executor(meta.async_upload_fn, *unpack_meta(meta)) os.chmod(tmp_step_file, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) - self.async_task_peeding = True else: meta.client.sync_upload_fileobj(*unpack_meta(meta), *args, saved_obj=saved_obj, **kwargs) self.upload_count += 1 @@ -566,9 +560,6 @@ class StorageManager(metaclass=SingletonMeta): if not self.async_mode: return - if not self.async_task_peeding: - return - if self._async_loop: self._async_loop.run_until_complete(self._sync_tasks()) @@ -587,16 +578,10 @@ class StorageManager(metaclass=SingletonMeta): self._del_tmp_folder() self._exception_list.clear() self._to_be_del_files.clear() - self.async_task_peeding = False if gpc.is_rank_for_log(): + logger.info("all async uploads succeeded!") self.upload_count += 1 - if self.async_mode: - self.save( - os.path.join(self.latest_save_folder, f"{self.latest_save_step}.step"), - saved_obj=dict({"step": self.latest_save_step}), - async_upload=False, - ) storage_manager: StorageManager = None diff --git a/train.py b/train.py index f1e6fd0..7e87d05 100644 --- a/train.py +++ b/train.py @@ -45,7 +45,14 @@ from internlm.utils.common import ( from internlm.utils.evaluation import evaluate_on_val_dls, switch_sequence_parallel_mode from internlm.utils.logger import get_logger, initialize_uniscale_logger from internlm.utils.megatron_timers import megatron_timer as timer -from internlm.utils.model_checkpoint import CheckpointManager +from internlm.utils.model_checkpoint import ( + CheckpointSaveManager, + load_context, + load_model_checkpoint, + load_optimizer_checkpoint, + load_sampler, + load_scheduler, +) from internlm.utils.parallel import ( get_parallel_log_file_name, is_no_pp_or_last_stage, @@ -421,9 +428,13 @@ def main(args): skip_batches = gpc.config.data.skip_batches total_steps = gpc.config.data.total_steps valid_every = gpc.config.data.valid_every + load_optimizer = gpc.config.ckpt.load_optimizer label_smoothing = gpc.config.loss.label_smoothing lr = gpc.config.adam.lr + load_model_only_folder = gpc.config.ckpt.get("load_model_only_folder", None) + load_resume_ckpt_folder = gpc.config.ckpt.get("load_ckpt_folder", None) + get_tflops_func = partial( get_megatron_flops, checkpoint=gpc.config.model.checkpoint, @@ -459,19 +470,32 @@ def main(args): enable_tb=gpc.config.enable_tb, ) + model_load_path = None + if load_resume_ckpt_folder is not None: + logger.info( + f"===========Resume training from `{load_resume_ckpt_folder}` {current_time} on host:" + f"{socket.gethostname()}===========" + ) + model_load_path = load_resume_ckpt_folder + elif load_model_only_folder is not None: + logger.info( + f"===========SFT training from `{load_model_only_folder}` {current_time} on host:" + f"{socket.gethostname()}===========" + ) + model_load_path = load_model_only_folder + else: + logger.info( + f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()}," + f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)}," + f"dp={gpc.get_local_rank(ParallelMode.DATA)}===========" + ) + # initialize and resume train state train_state = TrainState(gpc.config) # initialize model model = initialize_model() - ckpt_manager = CheckpointManager( - ckpt_config=gpc.config.ckpt, - model=model, - model_config=gpc.config.model, - feishu_address=gpc.config.alert_address, - ) - # initialize loss function criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing) @@ -481,12 +505,30 @@ def main(args): train_state.init_batch_sampler(train_dl) # Loading model weights must be done before zero is initialized. - ckpt_manager.try_load_model(current_time) + if model_load_path is not None: + load_model_checkpoint(folder=model_load_path, model=model) optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model) # Loading other persistent training states. - ckpt_manager.try_resume_traing(lr_scheduler, optimizer, lr, train_state, train_dl) + if load_resume_ckpt_folder is not None: + # load lr scheduler states. + load_scheduler(load_resume_ckpt_folder, lr_scheduler, optimizer, lr, train_state) + # load training states. + load_context(load_resume_ckpt_folder, train_dl, train_state) + # load dataloader sampler states. + load_sampler(load_resume_ckpt_folder, train_dl.batch_sampler) + # load optimzier states. + if load_optimizer: + load_optimizer_checkpoint(load_resume_ckpt_folder, optimizer) + + ckpt_save_manager = CheckpointSaveManager( + ckpt_config=gpc.config.ckpt, + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + model_config=gpc.config.model, + ) # initialize metric for calculating accuracy and perplexity metric = AccPerplex( @@ -607,11 +649,9 @@ def main(args): # checkpoint the training states in specific steps, which is determined by the args "checkpoint_every" # # save batch sampler that tracks the true consumed samples - now_break = ckpt_manager.try_save_checkpoint(train_state) - if now_break: - break + ckpt_save_manager.try_save_checkpoint(train_state) - ckpt_manager.wait_async_upload_finish() + ckpt_save_manager.wait_async_upload_finish() if __name__ == "__main__": @@ -627,10 +667,8 @@ if __name__ == "__main__": try: main(args) except Exception: - format_trace = "" - for line in traceback.format_exc().split("\n")[-10:]: - format_trace += "\n" + line logger.error( - f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}, trace:{format_trace}", + f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}", + exc_info=traceback.format_exc(), ) mm.monitor_exception(alert_address=gpc.config.alert_address, excp_info=traceback.format_exc()) From 4e8bd39d8facc3b2c61ab0e53cc5320da79ec41e Mon Sep 17 00:00:00 2001 From: cx <759046501@qq.com> Date: Fri, 11 Aug 2023 17:46:07 +0800 Subject: [PATCH 3/3] refactor(solver/optimizer): improve optimizer memory (#193) * refactor(solver/optimizer): improve optimizer memory * feat(data): remove useless dataset type ids map --- internlm/data/utils.py | 2 +- internlm/initialize/launch.py | 2 +- internlm/solver/optimizer/hybrid_zero_optim.py | 12 ++++++------ 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/internlm/data/utils.py b/internlm/data/utils.py index a86984a..3eee9d9 100644 --- a/internlm/data/utils.py +++ b/internlm/data/utils.py @@ -5,7 +5,7 @@ import torch from internlm.core.context import global_context as gpc -DATASET_TYPE_IDS_MAP = {"en": 0, "cn": 1, "code": 2, "ja": 3, "ar": 4, "kaoshi": 5} +DATASET_TYPE_IDS_MAP = {"en": 0, "cn": 1} def get_dataset_type_id(path): diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 33b5d15..d3ea708 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -39,7 +39,7 @@ def get_default_parser(): parser.add_argument("--local_rank", type=int, help="local rank on the node") parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication") parser.add_argument("--seed", type=int, default=1024) - parser.add_argument("--profiling", default=True, action="store_true", help="enable/diable profiling.") + parser.add_argument("--profiling", default=False, action="store_true", help="enable/disable profiling.") return parser diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 618b772..9d42a98 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -9,6 +9,7 @@ from torch.optim import Optimizer from internlm.core.context import Config, ParallelMode from internlm.core.context import global_context as gpc +from internlm.monitor import send_alert_message from internlm.solver.optimizer.store import ( BucketStore, GradientStore, @@ -28,7 +29,6 @@ from internlm.solver.optimizer.utils import ( from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer -from internlm.monitor import send_alert_message from .utils import compute_norm @@ -556,14 +556,16 @@ class HybridZeroOptimizer(BaseOptimizer): # The following operations are performed only on the rank to which parameters are assigned. if not self.param_group_has_params[group_id]: continue - gradients = self._grad_store.get_averaged_gradients_by_group(group_id) # create flat gradient for the flat fp32 params - fp16_avg_grads = gradients - flat_fp16_avg_grads = flatten(fp16_avg_grads) + gradients = self._grad_store.get_averaged_gradients_by_group(group_id) + flat_fp16_avg_grads = flatten(gradients) + self._grad_store.reset_average_gradients_by_group(group_id) + del gradients # release cuda memory dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype) + del flat_fp16_avg_grads # release cuda memory param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape assert ( @@ -573,8 +575,6 @@ class HybridZeroOptimizer(BaseOptimizer): single_grad_partition_groups.append(flat_fp32_avg_grads) device = self._fp32_flat_param_groups_of_current_rank[group_id].device self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device) - self._grad_store._averaged_gradients[group_id] = [] - self._grad_store._averaged_gradients[group_id] = [] # unscale and clip grads # get the global norm