Commit 190bf292 authored by Gabriele Franch's avatar Gabriele Franch
Browse files

added seeding

parent acca89d5
...@@ -6,6 +6,8 @@ import numpy as np ...@@ -6,6 +6,8 @@ import numpy as np
from tqdm import tqdm from tqdm import tqdm
import os import os
import sys import sys
import torch
import random
import core.trainer as trainer import core.trainer as trainer
from cikm_inter_dst_predrnn_run_taasss_utils import get_batcher, padding_taasss, change_taasss_dims from cikm_inter_dst_predrnn_run_taasss_utils import get_batcher, padding_taasss, change_taasss_dims
...@@ -67,6 +69,15 @@ batch_size = args.batch_size ...@@ -67,6 +69,15 @@ batch_size = args.batch_size
# args.patch_size = 16 # args.patch_size = 16
def seed_everything(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
def schedule_sampling(eta: float, itr: int): def schedule_sampling(eta: float, itr: int):
""" """
Return Return
...@@ -158,5 +169,6 @@ def wrapper_train(model: Model): ...@@ -158,5 +169,6 @@ def wrapper_train(model: Model):
print("Initializing models", file=sys.stderr) print("Initializing models", file=sys.stderr)
seed_everything(42)
model = Model(args) model = Model(args)
wrapper_train(model) wrapper_train(model)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment