mirror of https://github.com/InternLM/InternLM
compatible with old ckpt (#418)
parent
eeef07934a
commit
aa5e34d815
|
@ -556,6 +556,18 @@ def load_optimizer_checkpoint(folder, optim):
|
|||
f"Please check whether loading ckpts are saved with the HybridZeroOptimizer."
|
||||
)
|
||||
|
||||
# compatible with old code that only have one param group, need to align with both parameter groups
|
||||
if len(states["base_optim_states"]["param_groups"]) == 1:
|
||||
for group in optim.param_groups:
|
||||
# for new added empty group, since it has no params, just create it fakely
|
||||
if len(group["params"]) == 0:
|
||||
states["base_optim_states"]["param_groups"].append(group)
|
||||
# for origin group, create new added attributes in recent updates
|
||||
else:
|
||||
saved_group = states["base_optim_states"]["param_groups"][0]
|
||||
saved_group["dp_mode"] = group["dp_mode"]
|
||||
saved_group["dtype"] = group["dtype"]
|
||||
|
||||
optim.load_state_dict(states)
|
||||
del states
|
||||
torch.cuda.empty_cache()
|
||||
|
@ -598,6 +610,10 @@ def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, train_state: TrainSt
|
|||
lr_scheduler.load_state_dict(scheduler_states)
|
||||
lr_scheduler.last_epoch = train_state.step_count + 1
|
||||
|
||||
# compatible with old code that only have one param group
|
||||
if len(base_lrs) == 1:
|
||||
base_lrs = base_lrs * len(optimizer.param_groups)
|
||||
|
||||
ratios = [learning_rate / lr for lr in base_lrs]
|
||||
for idx, param_group in enumerate(optimizer.param_groups):
|
||||
param_group["lr"] = param_group["lr"] * ratios[idx]
|
||||
|
|
Loading…
Reference in New Issue