Commit 25c3d924 authored by Alessia Marcolini's avatar Alessia Marcolini
Browse files

Add reproducibility flags

parent 35a8317c
......@@ -27,7 +27,10 @@ class BaseModel(ABC):
self.is_train = configuration['is_train']
self.use_cuda = torch.cuda.is_available()
self.device = torch.device('cuda:0') if self.use_cuda else torch.device('cpu')
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.enabled = False
self.save_dir = configuration['checkpoint_path']
self.network_names = []
self.loss_names = []
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment