mirror of https://github.com/InternLM/InternLM
				
				
				
			fix bug on logger
							parent
							
								
									05a3b2a3be
								
							
						
					
					
						commit
						a8dd77ce76
					
				| 
						 | 
				
			
			@ -329,7 +329,6 @@ class PipelineScheduler(BaseScheduler):
 | 
			
		|||
                else:
 | 
			
		||||
                    engine.backward_by_grad([output_obj, moe_loss], [output_obj_grad, None])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        # Collect the grad of the input_obj.
 | 
			
		||||
        input_obj_grad = None
 | 
			
		||||
        if input_obj is not None:
 | 
			
		||||
| 
						 | 
				
			
			@ -1312,11 +1311,11 @@ class InterleavedPipelineScheduler(PipelineScheduler):
 | 
			
		|||
 | 
			
		||||
        if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
 | 
			
		||||
            self._accum_loss = torch.zeros(1, device=get_current_device())
 | 
			
		||||
        if return_loss:
 | 
			
		||||
            self._accum_moe_loss = torch.zeros(1, device=get_current_device())
 | 
			
		||||
        self._accum_moe_loss = torch.zeros(1, device=get_current_device())
 | 
			
		||||
 | 
			
		||||
        if return_output_label:
 | 
			
		||||
            self._return_tensors = []
 | 
			
		||||
 | 
			
		||||
        
 | 
			
		||||
        if forward_only:
 | 
			
		||||
            self._forward_only_step(engine, moe_loss_coeff)
 | 
			
		||||
        else:
 | 
			
		||||
| 
						 | 
				
			
			@ -1327,14 +1326,15 @@ class InterleavedPipelineScheduler(PipelineScheduler):
 | 
			
		|||
        else:
 | 
			
		||||
            output, label = (None, None)
 | 
			
		||||
 | 
			
		||||
        accum_loss = self._accum_loss
 | 
			
		||||
        accum_loss += self._accum_moe_loss
 | 
			
		||||
 | 
			
		||||
        logger.info(f"{gpc.get_local_rank(ParallelMode.PIPELINE)}, moe_loss: {self._accum_moe_loss.item()}")
 | 
			
		||||
 | 
			
		||||
        dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
 | 
			
		||||
        accum_moe_loss = self._accum_moe_loss
 | 
			
		||||
 | 
			
		||||
        accum_loss = self._accum_loss
 | 
			
		||||
        if return_loss:
 | 
			
		||||
            accum_loss += self._accum_moe_loss
 | 
			
		||||
 | 
			
		||||
        self._clear_state()
 | 
			
		||||
 | 
			
		||||
        return output, label, accum_loss, accum_moe_loss
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue