2022-03-15 02:05:38 +00:00
|
|
|
from typing import List
|
2022-11-08 06:57:18 +00:00
|
|
|
|
2022-03-15 02:05:38 +00:00
|
|
|
from torch import Tensor
|
|
|
|
|
|
|
|
|
|
|
|
def has_inf_or_nan(tensor):
|
2022-04-25 05:42:17 +00:00
|
|
|
"""Check if tensor has inf or nan values.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
tensor (:class:`torch.Tensor`): a torch tensor object
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
bool: Whether the tensor has inf or nan. True for yes and False for no.
|
|
|
|
"""
|
2022-03-15 02:05:38 +00:00
|
|
|
try:
|
|
|
|
# if tensor is half, the .float() incurs an additional deep copy, but it's necessary if
|
|
|
|
# Pytorch's .sum() creates a one-element tensor of the same type as tensor
|
|
|
|
# (which is true for some recent version of pytorch).
|
|
|
|
tensor_sum = float(tensor.float().sum())
|
|
|
|
# More efficient version that can be used if .sum() returns a Python scalar
|
|
|
|
# tensor_sum = float(tensor.sum())
|
|
|
|
except RuntimeError as instance:
|
|
|
|
# We want to check if inst is actually an overflow exception.
|
|
|
|
# RuntimeError could come from a different error.
|
|
|
|
# If so, we still want the exception to propagate.
|
|
|
|
if "value cannot be converted" not in instance.args[0]:
|
|
|
|
raise
|
|
|
|
return True
|
|
|
|
else:
|
2023-09-19 06:20:26 +00:00
|
|
|
if tensor_sum == float("inf") or tensor_sum == -float("inf") or tensor_sum != tensor_sum:
|
2022-03-15 02:05:38 +00:00
|
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
def zero_gard_by_list(tensor_list: List[Tensor], set_to_none: bool = True) -> None:
|
2022-04-25 05:42:17 +00:00
|
|
|
"""Clear the gradient of a list of tensors,
|
|
|
|
|
2022-03-15 02:05:38 +00:00
|
|
|
Note: copied from torch.optim.optimizer.
|
|
|
|
"""
|
|
|
|
for param in tensor_list:
|
|
|
|
if param.grad is not None:
|
|
|
|
if set_to_none:
|
|
|
|
param.grad = None
|
|
|
|
else:
|
|
|
|
if param.grad.grad_fn is not None:
|
|
|
|
param.grad.detach_()
|
|
|
|
else:
|
|
|
|
param.grad.requires_grad_(False)
|
|
|
|
param.grad.zero_()
|