mirror of https://github.com/InternLM/InternLM
				
				
				
			compatible with old ckpt (#418)
							parent
							
								
									eeef07934a
								
							
						
					
					
						commit
						aa5e34d815
					
				| 
						 | 
				
			
			@ -556,6 +556,18 @@ def load_optimizer_checkpoint(folder, optim):
 | 
			
		|||
                    f"Please check whether loading ckpts are saved with the HybridZeroOptimizer."
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
    # compatible with old code that only have one param group, need to align with both parameter groups
 | 
			
		||||
    if len(states["base_optim_states"]["param_groups"]) == 1:
 | 
			
		||||
        for group in optim.param_groups:
 | 
			
		||||
            # for new added empty group, since it has no params, just create it fakely
 | 
			
		||||
            if len(group["params"]) == 0:
 | 
			
		||||
                states["base_optim_states"]["param_groups"].append(group)
 | 
			
		||||
            # for origin group, create new added attributes in recent updates
 | 
			
		||||
            else:
 | 
			
		||||
                saved_group = states["base_optim_states"]["param_groups"][0]
 | 
			
		||||
                saved_group["dp_mode"] = group["dp_mode"]
 | 
			
		||||
                saved_group["dtype"] = group["dtype"]
 | 
			
		||||
 | 
			
		||||
    optim.load_state_dict(states)
 | 
			
		||||
    del states
 | 
			
		||||
    torch.cuda.empty_cache()
 | 
			
		||||
| 
						 | 
				
			
			@ -598,6 +610,10 @@ def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, train_state: TrainSt
 | 
			
		|||
    lr_scheduler.load_state_dict(scheduler_states)
 | 
			
		||||
    lr_scheduler.last_epoch = train_state.step_count + 1
 | 
			
		||||
 | 
			
		||||
    # compatible with old code that only have one param group
 | 
			
		||||
    if len(base_lrs) == 1:
 | 
			
		||||
        base_lrs = base_lrs * len(optimizer.param_groups)
 | 
			
		||||
 | 
			
		||||
    ratios = [learning_rate / lr for lr in base_lrs]
 | 
			
		||||
    for idx, param_group in enumerate(optimizer.param_groups):
 | 
			
		||||
        param_group["lr"] = param_group["lr"] * ratios[idx]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue