2022-11-29 05:00:30 +00:00
|
|
|
import random
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
def seed_all(seed, cuda_deterministic=False):
|
|
|
|
random.seed(seed)
|
|
|
|
np.random.seed(seed)
|
|
|
|
torch.manual_seed(seed)
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
torch.cuda.manual_seed(seed)
|
|
|
|
torch.cuda.manual_seed_all(seed)
|
2023-09-19 06:20:26 +00:00
|
|
|
if cuda_deterministic: # slower, more reproducible
|
2022-11-29 05:00:30 +00:00
|
|
|
torch.backends.cudnn.deterministic = True
|
|
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
else:
|
|
|
|
torch.backends.cudnn.deterministic = False
|
|
|
|
torch.backends.cudnn.benchmark = True
|