mirror of https://github.com/InternLM/InternLM
				
				
				
			add dtype condition for post hook
							parent
							
								
									5a0b3d5d9a
								
							
						
					
					
						commit
						883160a558
					
				| 
						 | 
				
			
			@ -165,13 +165,13 @@ class NaiveAMPModel(nn.Module):
 | 
			
		|||
            assert isinstance(outputs, (Tensor, tuple))
 | 
			
		||||
            if isinstance(outputs, tuple):
 | 
			
		||||
                for output_data_ in outputs:
 | 
			
		||||
                    if isinstance(output_data_, Tensor):
 | 
			
		||||
                    if isinstance(output_data_, Tensor) and output_data_.dtype is not self.dtype:
 | 
			
		||||
                        outputs_.append(output_data_.to(self.dtype))
 | 
			
		||||
                    else:
 | 
			
		||||
                        outputs_.append(output_data_)
 | 
			
		||||
                return tuple(outputs_)
 | 
			
		||||
            else:
 | 
			
		||||
                return outputs.to(self.dtype)
 | 
			
		||||
                return outputs.to(self.dtype) if outputs.dtype is not self.dtype else outputs
 | 
			
		||||
 | 
			
		||||
        # just want to share same for loop for ModuleList and Module
 | 
			
		||||
        if isinstance(self.model, nn.ModuleList):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue