mirror of https://github.com/InternLM/InternLM
compatible with old ckpt (#418)
parent
eeef07934a
commit
aa5e34d815
|
@ -556,6 +556,18 @@ def load_optimizer_checkpoint(folder, optim):
|
||||||
f"Please check whether loading ckpts are saved with the HybridZeroOptimizer."
|
f"Please check whether loading ckpts are saved with the HybridZeroOptimizer."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# compatible with old code that only have one param group, need to align with both parameter groups
|
||||||
|
if len(states["base_optim_states"]["param_groups"]) == 1:
|
||||||
|
for group in optim.param_groups:
|
||||||
|
# for new added empty group, since it has no params, just create it fakely
|
||||||
|
if len(group["params"]) == 0:
|
||||||
|
states["base_optim_states"]["param_groups"].append(group)
|
||||||
|
# for origin group, create new added attributes in recent updates
|
||||||
|
else:
|
||||||
|
saved_group = states["base_optim_states"]["param_groups"][0]
|
||||||
|
saved_group["dp_mode"] = group["dp_mode"]
|
||||||
|
saved_group["dtype"] = group["dtype"]
|
||||||
|
|
||||||
optim.load_state_dict(states)
|
optim.load_state_dict(states)
|
||||||
del states
|
del states
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
@ -598,6 +610,10 @@ def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, train_state: TrainSt
|
||||||
lr_scheduler.load_state_dict(scheduler_states)
|
lr_scheduler.load_state_dict(scheduler_states)
|
||||||
lr_scheduler.last_epoch = train_state.step_count + 1
|
lr_scheduler.last_epoch = train_state.step_count + 1
|
||||||
|
|
||||||
|
# compatible with old code that only have one param group
|
||||||
|
if len(base_lrs) == 1:
|
||||||
|
base_lrs = base_lrs * len(optimizer.param_groups)
|
||||||
|
|
||||||
ratios = [learning_rate / lr for lr in base_lrs]
|
ratios = [learning_rate / lr for lr in base_lrs]
|
||||||
for idx, param_group in enumerate(optimizer.param_groups):
|
for idx, param_group in enumerate(optimizer.param_groups):
|
||||||
param_group["lr"] = param_group["lr"] * ratios[idx]
|
param_group["lr"] = param_group["lr"] * ratios[idx]
|
||||||
|
|
Loading…
Reference in New Issue