2023-11-02 02:21:24 +00:00
|
|
|
import torch
|
2023-12-14 09:52:05 +00:00
|
|
|
|
|
|
|
|
2024-07-09 08:14:00 +00:00
|
|
|
def loose_close(a, b, dtype: torch.dtype = torch.float32, name=""):
|
2023-12-14 09:52:05 +00:00
|
|
|
rtol = None
|
|
|
|
atol = None
|
|
|
|
if dtype is torch.float16:
|
|
|
|
rtol = 5e-2
|
|
|
|
atol = 5e-4
|
|
|
|
elif dtype is torch.bfloat16:
|
|
|
|
rtol = 4e-3
|
|
|
|
atol = 4e-3
|
2024-07-15 06:43:27 +00:00
|
|
|
else:
|
|
|
|
assert dtype is torch.float32
|
|
|
|
rtol = 1e-5
|
|
|
|
atol = 1e-5
|
2023-12-14 09:52:05 +00:00
|
|
|
|
|
|
|
a = a.detach().to(dtype)
|
|
|
|
b = b.detach().to(dtype).to(a.device)
|
|
|
|
|
2024-07-09 08:14:00 +00:00
|
|
|
assert torch.allclose(a, b, rtol=rtol, atol=atol), f"{name} not close {a.mean()} {b.mean()}"
|