Commit b41421af authored by Gabriele Franch's avatar Gabriele Franch
Browse files

Fixx ida lstm in taasss dataset

parent e0d98520
/data1/IDA_LSTM_checkpoints
\ No newline at end of file
......@@ -6,10 +6,12 @@ import numpy as np
from tqdm import tqdm
import core.trainer as trainer
from cikm_inter_dst_predrnn_run_taasss_utils import get_batcher, padding_taasss
from cikm_inter_dst_predrnn_run_taasss_utils import get_batcher, padding_taasss, change_taasss_dims
from core.models.model_factory import Model
from core.utils import preprocess
from matplotlib import pyplot as plt
import scipy.misc
from PIL import Image
parser = argparse.ArgumentParser(
description="PyTorch video prediction model - DST PredRNN"
......@@ -126,14 +128,6 @@ def schedule_sampling(eta: float, itr: int):
)
return eta, real_input_flag
def change_taasss_dims(a: np.ndarray) -> np.ndarray:
"""(25, 1, 1, 480, 480) to (1, 25, 480, 480, 1)"""
a = np.squeeze(a)
a = np.expand_dims(a, axis=0)
return np.expand_dims(a, axis=4)
def wrapper_train(model: Model):
eta = args.sampling_start_value
iterator = get_batcher(args)
......@@ -143,8 +137,6 @@ def wrapper_train(model: Model):
imgs = change_taasss_dims(imgs)
imgs = padding_taasss(imgs, args)
imgs = preprocess.reshape_patch(imgs, args.patch_size)
# Should already by 0 to 1
# imgs = nor(imgs)
eta, real_input_flag = schedule_sampling(eta, itr)
cost = trainer.train(model, imgs, real_input_flag, args)
progress_bar.set_description(f"Loss: {cost}")
......
This diff is collapsed.
This diff is collapsed.
......@@ -14,7 +14,9 @@ def get_batcher(args):
"r",
libver="latest",
)
outlier_mask = cv2.imread(str(data_dir / "mask.png"), 0)
mask_path = data_dir / "mask.png"
assert mask_path.exists(), "Mask does not exist"
outlier_mask = cv2.imread(str(mask_path), 0)
metadata = pd.read_csv(metadata_file, index_col="id")
metadata["start_datetime"] = pd.to_datetime(metadata["start_datetime"])
metadata["end_datetime"] = pd.to_datetime(metadata["end_datetime"])
......@@ -30,7 +32,7 @@ def get_batcher(args):
)
# Filter images in the ROI
for imgs, _, masks in batcher:
yield imgs * ~masks
yield imgs * masks
def padding_taasss(array: np.ndarray, args) -> np.ndarray:
......@@ -64,3 +66,8 @@ def unpadding_taasss(array: np.ndarray, args) -> np.ndarray:
elif args.img_width == 128:
return array[:, :, 4:124, 4:124, :]
raise ValueError
def change_taasss_dims(a: np.ndarray) -> np.ndarray:
"""(25, 1, 1, 480, 480) to (1, 25, 480, 480, 1)"""
a = np.squeeze(a)
a = np.expand_dims(a, axis=0)
return np.expand_dims(a, axis=4)
......@@ -6,4 +6,5 @@ imageio
pandas
h5py
tqdm
matplotlib
\ No newline at end of file
matplotlib
pysteps
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment