mirror of https://github.com/InternLM/InternLM
				
				
				
			fix(metric): add metric dtype control
							parent
							
								
									81ffb3d824
								
							
						
					
					
						commit
						649af64c59
					
				| 
						 | 
				
			
			@ -26,7 +26,8 @@ class AccPerplex:
 | 
			
		|||
        self.device = device
 | 
			
		||||
        self.right = torch.Tensor([0]).to(device=device)
 | 
			
		||||
        self.total = torch.Tensor([0]).to(device=device)
 | 
			
		||||
        self.total_log_probs = torch.Tensor([0]).to(device=device, dtype=torch.float)
 | 
			
		||||
        self.metric_dtype = torch.float if gpc.config.get("metric_dtype", "fp32") == "fp32" else None
 | 
			
		||||
        self.total_log_probs = torch.Tensor([0]).to(device=device, dtype=self.metric_dtype)
 | 
			
		||||
        self.tp_pg = tp_pg
 | 
			
		||||
        self.dp_pg = dp_pg
 | 
			
		||||
        self.tp_local_rank = torch.distributed.get_rank(self.tp_pg)
 | 
			
		||||
| 
						 | 
				
			
			@ -128,8 +129,9 @@ class AccPerplex:
 | 
			
		|||
            # All reduce is needed to get the chunks from other GPUs.
 | 
			
		||||
            torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
 | 
			
		||||
 | 
			
		||||
            predicted_logits = predicted_logits.to(dtype=torch.float)
 | 
			
		||||
            shift_logits = shift_logits.to(dtype=torch.float)
 | 
			
		||||
            if self.metric_dtype is not None:
 | 
			
		||||
                predicted_logits = predicted_logits.to(dtype=self.metric_dtype)
 | 
			
		||||
                shift_logits = shift_logits.to(dtype=self.metric_dtype)
 | 
			
		||||
 | 
			
		||||
            pred_exp_logits = torch.exp(predicted_logits)
 | 
			
		||||
            # Sum of exponential of logits along vocab dimension across all GPUs.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue