mirror of https://github.com/InternLM/InternLM
parent
220953d7e5
commit
5539f9db50
|
@ -206,6 +206,14 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
|||
grad_scaler = states["grad_scaler"]
|
||||
self.grad_scaler.load_state_dict(grad_scaler)
|
||||
optim_states = states["base_optim_states"]
|
||||
|
||||
if gpc.config.get("only_load_lr", False):
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info("Only load lr in param_groups, skip loading weights in optimizer...")
|
||||
for pg1, pg2 in zip(self.optim.param_groups, optim_states["param_groups"]):
|
||||
pg1["lr"] = pg2["lr"]
|
||||
return
|
||||
|
||||
self.optim.load_state_dict(optim_states)
|
||||
|
||||
# load fp32 optimizer weight
|
||||
|
|
|
@ -943,6 +943,14 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
grad_scaler = states["grad_scaler"]
|
||||
self.grad_scaler.load_state_dict(grad_scaler)
|
||||
optim_states = states["base_optim_states"]
|
||||
|
||||
if gpc.config.get("only_load_lr", False):
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info("Only load lr in param_groups, skip loading weights in optimizer...")
|
||||
for pg1, pg2 in zip(self.optim.param_groups, optim_states["param_groups"]):
|
||||
pg1["lr"] = pg2["lr"]
|
||||
return
|
||||
|
||||
self.optim.load_state_dict(optim_states)
|
||||
|
||||
# load fp32 model weight.
|
||||
|
|
|
@ -231,7 +231,7 @@ def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState):
|
|||
# load training states.
|
||||
load_context(load_ckpt_folder, train_state)
|
||||
|
||||
# load optimzier states.
|
||||
# load optimizer states.
|
||||
if load_content.need_load(CheckpointLoadContent.OPIMIZER):
|
||||
load_optimizer_checkpoint(load_ckpt_folder, ckpt_mm.optimizer)
|
||||
load_content_str += f"{CheckpointLoadContent.OPIMIZER}, "
|
||||
|
@ -248,6 +248,12 @@ def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState):
|
|||
if gpc.is_rank_for_log():
|
||||
logger.warning("CheckpointManager has no 'lr_scheduler', skip reload lr_scheduler checkpoint!")
|
||||
|
||||
if not load_content.need_load(CheckpointLoadContent.OPIMIZER):
|
||||
if ckpt_mm.lr_scheduler and train_state:
|
||||
gpc.config.only_load_lr = True
|
||||
load_optimizer_checkpoint(load_ckpt_folder, ckpt_mm.optimizer)
|
||||
gpc.config.only_load_lr = False
|
||||
|
||||
# load dataloader sampler states.
|
||||
if load_content.need_load(CheckpointLoadContent.SAMPLER):
|
||||
if hasattr(train_state, "batch_sampler") and not isinstance(
|
||||
|
|
Loading…
Reference in New Issue