Commit 9988109c authored by Marco Di Francesco's avatar Marco Di Francesco 🍉
Browse files

Fix test model for taasss

parent dee581de
dataset/
checkpoints/
dataset_generated/
.venv
\ No newline at end of file
.venv
# Python files
__pycache__/
.pyc
.ipynb_checkpoints
......@@ -5,6 +5,7 @@ import math
import shutil
import numpy as np
import torch
import core.trainer as trainer
from core.models.model_factory import Model
......@@ -263,11 +264,11 @@ def wrapper_test(model):
tars = dat[:, -output_length:]
ims = padding_CIKM_data(dat)
ims = preprocess.reshape_patch(ims, args.patch_size)
print(ims.shape, real_input_flag.shape)
img_gen, _ = model.test(ims, real_input_flag)
img_gen = preprocess.reshape_patch_back(img_gen, args.patch_size)
img_out = unpadding_CIKM_data(img_gen[:, -output_length:])
mse = np.mean(np.square(tars - img_out))
print("MSE", mse)
img_out = de_nor(img_out)
loss += mse
......@@ -275,6 +276,7 @@ def wrapper_test(model):
bat_ind = 0
print("index is:", index)
print("INDEX", index, b_cup)
for ind in range(index - batch_size, index, 1):
save_fold = test_save_root + "sample_" + str(ind) + "/"
......
......@@ -34,7 +34,7 @@ parser.add_argument("--is_parallel", type=bool, default=False)
parser.add_argument("--save_dir", type=str, default="checkpoints/model.ckpt")
parser.add_argument("--gen_frm_dir", type=str, default="dataset_generated/")
parser.add_argument("--input_length", type=int, default=5)
parser.add_argument("--total_length", type=int, default=15) # 25
parser.add_argument("--total_length", type=int, default=15) # 25
parser.add_argument("--img_width", type=int, default=128)
parser.add_argument("--img_channel", type=int, default=1)
......@@ -264,6 +264,13 @@ def crop_taasss(array):
return array[:, :15, :128, :128, :]
def uncrop_taasss(array):
# They had to go from 101 to 128, we did 480 to 480
zeros = np.zeros((1, 10, 480, 480, 1))
zeros[:, :, :128, :128, :] = array
return zeros
def wrapper_test(model):
test_save_root = args.gen_frm_dir
clean_fold(test_save_root)
......@@ -288,10 +295,14 @@ def wrapper_test(model):
)
)
output_length = args.total_length - args.input_length
train_model_iter = get_batcher()
index = 1
b_cup = 0 # ?????????
while flag:
index += 1
# print("Sample is:", index)
# dat, (index, b_cup) = sample(batch_size, data_type="test", index=index)
train_model_iter = get_batcher()
train_batch, sample_datetimes, train_mask = next(train_model_iter)
# Just because they called it this way
dat = train_batch
......@@ -305,13 +316,13 @@ def wrapper_test(model):
tars = dat[:, -output_length:]
ims = crop_taasss(dat)
ims = preprocess.reshape_patch(ims, args.patch_size)
print(ims.shape, real_input_flag.shape)
img_gen, _ = model.test(train_batch, real_input_flag)
# img_gen = preprocess.reshape_patch_back(img_gen, args.patch_size)
# img_out = unpadding_CIKM_data(img_gen[:, -output_length:])
ims = ims.astype(np.float64)
img_gen, _ = model.test(ims, real_input_flag)
img_gen = preprocess.reshape_patch_back(img_gen, args.patch_size)
img_out = uncrop_taasss(img_gen[:, -output_length:])
mse = np.mean(np.square(tars - img_out))
print(index, "MSE", mse)
img_out = de_nor(img_out)
loss += mse
count += 1
......@@ -331,7 +342,6 @@ def wrapper_test(model):
pass
else:
flag = False
return loss / count
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment