compatible with old ckpt (#418)

pull/411/head
Wenwen Qu 2023-10-17 17:25:36 +08:00 committed by GitHub
parent eeef07934a
commit aa5e34d815
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 16 additions and 0 deletions

View File

@ -556,6 +556,18 @@ def load_optimizer_checkpoint(folder, optim):
f"Please check whether loading ckpts are saved with the HybridZeroOptimizer."
)
# compatible with old code that only have one param group, need to align with both parameter groups
if len(states["base_optim_states"]["param_groups"]) == 1:
for group in optim.param_groups:
# for new added empty group, since it has no params, just create it fakely
if len(group["params"]) == 0:
states["base_optim_states"]["param_groups"].append(group)
# for origin group, create new added attributes in recent updates
else:
saved_group = states["base_optim_states"]["param_groups"][0]
saved_group["dp_mode"] = group["dp_mode"]
saved_group["dtype"] = group["dtype"]
optim.load_state_dict(states)
del states
torch.cuda.empty_cache()
@ -598,6 +610,10 @@ def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, train_state: TrainSt
lr_scheduler.load_state_dict(scheduler_states)
lr_scheduler.last_epoch = train_state.step_count + 1
# compatible with old code that only have one param group
if len(base_lrs) == 1:
base_lrs = base_lrs * len(optimizer.param_groups)
ratios = [learning_rate / lr for lr in base_lrs]
for idx, param_group in enumerate(optimizer.param_groups):
param_group["lr"] = param_group["lr"] * ratios[idx]