diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index d63ed7a..4b3f7d5 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -556,6 +556,18 @@ def load_optimizer_checkpoint(folder, optim): f"Please check whether loading ckpts are saved with the HybridZeroOptimizer." ) + # compatible with old code that only have one param group, need to align with both parameter groups + if len(states["base_optim_states"]["param_groups"]) == 1: + for group in optim.param_groups: + # for new added empty group, since it has no params, just create it fakely + if len(group["params"]) == 0: + states["base_optim_states"]["param_groups"].append(group) + # for origin group, create new added attributes in recent updates + else: + saved_group = states["base_optim_states"]["param_groups"][0] + saved_group["dp_mode"] = group["dp_mode"] + saved_group["dtype"] = group["dtype"] + optim.load_state_dict(states) del states torch.cuda.empty_cache() @@ -598,6 +610,10 @@ def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, train_state: TrainSt lr_scheduler.load_state_dict(scheduler_states) lr_scheduler.last_epoch = train_state.step_count + 1 + # compatible with old code that only have one param group + if len(base_lrs) == 1: + base_lrs = base_lrs * len(optimizer.param_groups) + ratios = [learning_rate / lr for lr in base_lrs] for idx, param_group in enumerate(optimizer.param_groups): param_group["lr"] = param_group["lr"] * ratios[idx]