mirror of https://github.com/InternLM/InternLM
				
				
				
			fix when resuming lr_scheduler without loading optimizer
							parent
							
								
									d418eba094
								
							
						
					
					
						commit
						483bd706dd
					
				| 
						 | 
				
			
			@ -206,6 +206,14 @@ class FSDPadaptOptimizer(BaseOptimizer):
 | 
			
		|||
        grad_scaler = states["grad_scaler"]
 | 
			
		||||
        self.grad_scaler.load_state_dict(grad_scaler)
 | 
			
		||||
        optim_states = states["base_optim_states"]
 | 
			
		||||
 | 
			
		||||
        if gpc.config.get("only_load_lr", False):
 | 
			
		||||
            if gpc.is_rank_for_log():
 | 
			
		||||
                logger.info("Only load lr in param_groups, skip loading weights in optimizer...")
 | 
			
		||||
            for pg1, pg2 in zip(self.optim.param_groups, optim_states["param_groups"]):
 | 
			
		||||
                pg1["lr"] = pg2["lr"]
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        self.optim.load_state_dict(optim_states)
 | 
			
		||||
 | 
			
		||||
        # load fp32 optimizer weight
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -943,6 +943,14 @@ class HybridZeroOptimizer(BaseOptimizer):
 | 
			
		|||
        grad_scaler = states["grad_scaler"]
 | 
			
		||||
        self.grad_scaler.load_state_dict(grad_scaler)
 | 
			
		||||
        optim_states = states["base_optim_states"]
 | 
			
		||||
 | 
			
		||||
        if gpc.config.get("only_load_lr", False):
 | 
			
		||||
            if gpc.is_rank_for_log():
 | 
			
		||||
                logger.info("Only load lr in param_groups, skip loading weights in optimizer...")
 | 
			
		||||
            for pg1, pg2 in zip(self.optim.param_groups, optim_states["param_groups"]):
 | 
			
		||||
                pg1["lr"] = pg2["lr"]
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        self.optim.load_state_dict(optim_states)
 | 
			
		||||
 | 
			
		||||
        # load fp32 model weight.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -231,7 +231,7 @@ def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState):
 | 
			
		|||
        # load training states.
 | 
			
		||||
        load_context(load_ckpt_folder, train_state)
 | 
			
		||||
 | 
			
		||||
        # load optimzier states.
 | 
			
		||||
        # load optimizer states.
 | 
			
		||||
        if load_content.need_load(CheckpointLoadContent.OPIMIZER):
 | 
			
		||||
            load_optimizer_checkpoint(load_ckpt_folder, ckpt_mm.optimizer)
 | 
			
		||||
            load_content_str += f"{CheckpointLoadContent.OPIMIZER}, "
 | 
			
		||||
| 
						 | 
				
			
			@ -248,6 +248,12 @@ def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState):
 | 
			
		|||
                if gpc.is_rank_for_log():
 | 
			
		||||
                    logger.warning("CheckpointManager has no 'lr_scheduler', skip reload lr_scheduler checkpoint!")
 | 
			
		||||
 | 
			
		||||
            if not load_content.need_load(CheckpointLoadContent.OPIMIZER):
 | 
			
		||||
                if ckpt_mm.lr_scheduler and train_state:
 | 
			
		||||
                    gpc.config.only_load_lr = True
 | 
			
		||||
                    load_optimizer_checkpoint(load_ckpt_folder, ckpt_mm.optimizer)
 | 
			
		||||
                    gpc.config.only_load_lr = False
 | 
			
		||||
 | 
			
		||||
        # load dataloader sampler states.
 | 
			
		||||
        if load_content.need_load(CheckpointLoadContent.SAMPLER):
 | 
			
		||||
            if hasattr(train_state, "batch_sampler") and not isinstance(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue