fix when resuming lr_scheduler without loading optimizer (#565)

pull/570/head v0.2.1dev20240102
Yang Gao 2023-12-29 20:22:39 +08:00 committed by GitHub
parent 220953d7e5
commit 5539f9db50
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 23 additions and 1 deletions

View File

@ -206,6 +206,14 @@ class FSDPadaptOptimizer(BaseOptimizer):
grad_scaler = states["grad_scaler"]
self.grad_scaler.load_state_dict(grad_scaler)
optim_states = states["base_optim_states"]
if gpc.config.get("only_load_lr", False):
if gpc.is_rank_for_log():
logger.info("Only load lr in param_groups, skip loading weights in optimizer...")
for pg1, pg2 in zip(self.optim.param_groups, optim_states["param_groups"]):
pg1["lr"] = pg2["lr"]
return
self.optim.load_state_dict(optim_states)
# load fp32 optimizer weight

View File

@ -943,6 +943,14 @@ class HybridZeroOptimizer(BaseOptimizer):
grad_scaler = states["grad_scaler"]
self.grad_scaler.load_state_dict(grad_scaler)
optim_states = states["base_optim_states"]
if gpc.config.get("only_load_lr", False):
if gpc.is_rank_for_log():
logger.info("Only load lr in param_groups, skip loading weights in optimizer...")
for pg1, pg2 in zip(self.optim.param_groups, optim_states["param_groups"]):
pg1["lr"] = pg2["lr"]
return
self.optim.load_state_dict(optim_states)
# load fp32 model weight.

View File

@ -231,7 +231,7 @@ def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState):
# load training states.
load_context(load_ckpt_folder, train_state)
# load optimzier states.
# load optimizer states.
if load_content.need_load(CheckpointLoadContent.OPIMIZER):
load_optimizer_checkpoint(load_ckpt_folder, ckpt_mm.optimizer)
load_content_str += f"{CheckpointLoadContent.OPIMIZER}, "
@ -248,6 +248,12 @@ def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState):
if gpc.is_rank_for_log():
logger.warning("CheckpointManager has no 'lr_scheduler', skip reload lr_scheduler checkpoint!")
if not load_content.need_load(CheckpointLoadContent.OPIMIZER):
if ckpt_mm.lr_scheduler and train_state:
gpc.config.only_load_lr = True
load_optimizer_checkpoint(load_ckpt_folder, ckpt_mm.optimizer)
gpc.config.only_load_lr = False
# load dataloader sampler states.
if load_content.need_load(CheckpointLoadContent.SAMPLER):
if hasattr(train_state, "batch_sampler") and not isinstance(