From 5539f9db5055307b6b7d7ce88b07d90d8136692b Mon Sep 17 00:00:00 2001 From: Yang Gao Date: Fri, 29 Dec 2023 20:22:39 +0800 Subject: [PATCH] fix when resuming lr_scheduler without loading optimizer (#565) --- internlm/solver/optimizer/fsdp_optimizer.py | 8 ++++++++ internlm/solver/optimizer/hybrid_zero_optim.py | 8 ++++++++ internlm/utils/model_checkpoint.py | 8 +++++++- 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/internlm/solver/optimizer/fsdp_optimizer.py b/internlm/solver/optimizer/fsdp_optimizer.py index ab15917..b5bf457 100644 --- a/internlm/solver/optimizer/fsdp_optimizer.py +++ b/internlm/solver/optimizer/fsdp_optimizer.py @@ -206,6 +206,14 @@ class FSDPadaptOptimizer(BaseOptimizer): grad_scaler = states["grad_scaler"] self.grad_scaler.load_state_dict(grad_scaler) optim_states = states["base_optim_states"] + + if gpc.config.get("only_load_lr", False): + if gpc.is_rank_for_log(): + logger.info("Only load lr in param_groups, skip loading weights in optimizer...") + for pg1, pg2 in zip(self.optim.param_groups, optim_states["param_groups"]): + pg1["lr"] = pg2["lr"] + return + self.optim.load_state_dict(optim_states) # load fp32 optimizer weight diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index c4b87d7..a1e848e 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -943,6 +943,14 @@ class HybridZeroOptimizer(BaseOptimizer): grad_scaler = states["grad_scaler"] self.grad_scaler.load_state_dict(grad_scaler) optim_states = states["base_optim_states"] + + if gpc.config.get("only_load_lr", False): + if gpc.is_rank_for_log(): + logger.info("Only load lr in param_groups, skip loading weights in optimizer...") + for pg1, pg2 in zip(self.optim.param_groups, optim_states["param_groups"]): + pg1["lr"] = pg2["lr"] + return + self.optim.load_state_dict(optim_states) # load fp32 model weight. diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 9b64e3b..b9326de 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -231,7 +231,7 @@ def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState): # load training states. load_context(load_ckpt_folder, train_state) - # load optimzier states. + # load optimizer states. if load_content.need_load(CheckpointLoadContent.OPIMIZER): load_optimizer_checkpoint(load_ckpt_folder, ckpt_mm.optimizer) load_content_str += f"{CheckpointLoadContent.OPIMIZER}, " @@ -248,6 +248,12 @@ def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState): if gpc.is_rank_for_log(): logger.warning("CheckpointManager has no 'lr_scheduler', skip reload lr_scheduler checkpoint!") + if not load_content.need_load(CheckpointLoadContent.OPIMIZER): + if ckpt_mm.lr_scheduler and train_state: + gpc.config.only_load_lr = True + load_optimizer_checkpoint(load_ckpt_folder, ckpt_mm.optimizer) + gpc.config.only_load_lr = False + # load dataloader sampler states. if load_content.need_load(CheckpointLoadContent.SAMPLER): if hasattr(train_state, "batch_sampler") and not isinstance(