Commit 0812e9fd authored by Gabriele Franch's avatar Gabriele Franch
Browse files

Separate train and test in taasss dataset

parent a5b77efb
......@@ -29,7 +29,7 @@ parser.add_argument("--device", type=str, default="cpu") # cuda
parser.add_argument("--dataset_name", type=str, default="radar")
parser.add_argument("--r", type=int, default=4)
parser.add_argument("--is_parallel", type=bool, default=False)
parser.add_argument("--save_dir", type=str, default="checkpoints/model.ckpt")
parser.add_argument("--save_dir", type=str, default="checkpoints/inter_dst_predrnn")
parser.add_argument("--gen_frm_dir", type=str, default="dataset_generated/")
parser.add_argument("--input_length", type=int, default=5)
parser.add_argument("--total_length", type=int, default=15)
......
import os
# import os
import argparse
import math
import shutil
# import math
# import shutil
import numpy as np
import core.trainer as trainer
from core.models.model_factory import Model
from core.utils import preprocess
from data_provider.CIKM.data_iterator import clean_fold, sample, imsave
from core.utils.util import nor, de_nor
from data_provider.CIKM.taasss import infinite_batcher
from pathlib import Path
import h5py
import cv2
import pandas as pd
# from data_provider.CIKM.data_iterator import clean_fold, sample, imsave
# from core.utils.util import nor, de_nor
from tqdm import tqdm
from cikm_inter_dst_predrnn_run_taasss_utils import (
padding_taasss,
get_batcher,
)
# -----------------------------------------------------------------------------
parser = argparse.ArgumentParser(
description="PyTorch video prediction model - DST PredRNN"
)
......@@ -29,11 +28,11 @@ parser.add_argument("--device", type=str, default="cuda") # cuda
parser.add_argument("--dataset_name", type=str, default="radar")
parser.add_argument("--r", type=int, default=4)
parser.add_argument("--is_parallel", type=bool, default=False)
parser.add_argument("--save_dir", type=str, default="checkpoints/model.ckpt")
# parser.add_argument("--save_dir", type=str, default="checkpoints/inter_dst_predrnn")
parser.add_argument("--gen_frm_dir", type=str, default="dataset_generated/")
parser.add_argument("--input_length", type=int, default=5)
parser.add_argument("--total_length", type=int, default=25) # 15
parser.add_argument("--img_width", type=int, default=512)
parser.add_argument("--img_width", type=int, default=512) # 512
parser.add_argument("--img_channel", type=int, default=1)
# model
......@@ -42,7 +41,7 @@ parser.add_argument("--pretrained_model", type=str, default="")
parser.add_argument("--num_hidden", type=str, default="64,64,64,64")
parser.add_argument("--filter_size", type=int, default=5)
parser.add_argument("--stride", type=int, default=1)
parser.add_argument("--patch_size", type=int, default=4)
# parser.add_argument("--patch_size", type=int, default=4)
parser.add_argument("--layer_norm", type=int, default=1)
# scheduled sampling
......@@ -56,43 +55,21 @@ parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--reverse_input", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=1) # 4
parser.add_argument("--max_iterations", type=int, default=80000)
parser.add_argument("--display_interval", type=int, default=1) # 200
parser.add_argument("--test_interval", type=int, default=1) # 2000
parser.add_argument("--snapshot_interval", type=int, default=5000)
# parser.add_argument("--display_interval", type=int, default=1) # 200
parser.add_argument("--test_interval", type=int, default=5000) # 5000
# parser.add_argument("--snapshot_interval", type=int, default=5000)
parser.add_argument("--num_save_samples", type=int, default=10)
parser.add_argument("--n_gpu", type=int, default=0)
args = parser.parse_args()
batch_size = args.batch_size
def get_batcher():
data_dir = Path("/") / "data2" / "franch" / "meteotn_traindata"
metadata_file = data_dir / "run_metadata.csv"
all_data = h5py.File(
data_dir / "hdf_archives" / "all_data.hdf5",
"r",
libver="latest",
)
outlier_mask = cv2.imread(str(data_dir / "mask.png"), 0)
metadata = pd.read_csv(metadata_file, index_col="id")
metadata["start_datetime"] = pd.to_datetime(metadata["start_datetime"])
metadata["end_datetime"] = pd.to_datetime(metadata["end_datetime"])
sort_meta = metadata.sample(frac=1)
batcher = infinite_batcher(
all_data,
sort_meta,
outlier_mask,
shuffle=False,
batch_size=args.batch_size, # TODO: UPDATE FROM 1 TO 4
filter_threshold=0,
)
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# Filter images in the ROI
for imgs, _, masks in batcher:
yield imgs * ~masks
if args.img_width == 128:
args.patch_size = 4
elif args.img_width == 256:
args.patch_size = 8
elif args.img_width == 512:
args.patch_size = 16
def schedule_sampling(eta: float, itr: int):
......@@ -163,176 +140,25 @@ def change_taasss_dims(a: np.ndarray) -> np.ndarray:
return np.expand_dims(a, axis=4)
def padding_taasss(array: np.ndarray) -> np.ndarray:
"""
Add padding
(1, 25, 512, 512, 1)
to
(1, 25, 480, 480, 1)
"""
zeros = np.zeros((1, 25, 512, 512, 1))
zeros[:, :, 16:496, 16:496, :] = array
# zeros = np.zeros((1, 25, 256, 256, 1))
# zeros[:, :, 8:248, 8:248, :] = array[:, :, :240, :240, :]
# zeros = np.zeros((1, 25, 128, 128, 1))
# zeros[:, :, 4:124, 4:124, :] = array[:, :, :120, :120, :]
return zeros
def unpadding_taasss(array: np.ndarray) -> np.ndarray:
"""
Remove padding
(1, 25, 480, 480, 1)
to
(1, 25, 512, 512, 1)
"""
return array[:, :, 16:496, 16:496, :]
# return array[:, :, 8:248, 8:248, :]
# return array[:, :, 4:124, 4:124, :]
def wrapper_train(model: Model):
eta = args.sampling_start_value
best_mse = math.inf
tolerate = 0
limit = 3
iterator = get_batcher()
for itr in range(1, args.max_iterations + 1):
iterator = get_batcher(args)
progress_bar = tqdm(range(1, args.max_iterations + 1))
for itr in progress_bar:
imgs = next(iterator)
imgs = change_taasss_dims(imgs)
imgs = padding_taasss(imgs)
imgs = padding_taasss(imgs, args)
imgs = preprocess.reshape_patch(imgs, args.patch_size)
# Should already by 0 to 1
# imgs = nor(imgs)
eta, real_input_flag = schedule_sampling(eta, itr)
cost = trainer.train(model, imgs, real_input_flag, args)
if itr % args.display_interval == 0:
print("itr: " + str(itr))
print("training loss: " + str(cost))
progress_bar.set_description(f"Loss: {cost}")
if itr % args.test_interval == 0:
valid_mse = wrapper_valid(model, iterator)
print("validation mse is:", str(valid_mse))
if valid_mse < best_mse:
best_mse = valid_mse
tolerate = 0
model.save()
else:
tolerate = tolerate + 1
if tolerate == limit:
model.load()
test_mse = wrapper_test(model)
print("the best valid mse is:", str(best_mse))
print("the test mse is ", str(test_mse))
break
def wrapper_valid(model: Model, iterator):
loss = 0
real_input_flag = np.zeros(
(
args.batch_size,
args.total_length - args.input_length - 1,
args.img_width // args.patch_size,
args.img_width // args.patch_size,
args.patch_size ** 2 * args.img_channel,
)
)
output_length = args.total_length - args.input_length
# TODO: understand if 50 steps is right
steps = 50
for _ in range(steps):
imgs = next(iterator)
imgs = change_taasss_dims(imgs)
tars = imgs[:, -output_length:]
# TODO: REMOVE IT ONE THE IMAGE IS FULL
# TODO: TEST ITTTT
print("TARS", tars.shape)
tars = tars[:, :, :120, :120, :]
print("TARS", tars.shape)
imgs = padding_taasss(imgs)
# Should alreadyn be 0 to 1
# imgs = nor(imgs)
imgs = preprocess.reshape_patch(imgs, args.patch_size)
img_gen, _ = model.test(imgs, real_input_flag)
img_gen = preprocess.reshape_patch_back(img_gen, args.patch_size)
img_out = unpadding_taasss(img_gen[:, -output_length:])
print("SHAPE", tars.shape, img_out.shape)
mse = np.mean(np.square(tars - img_out))
loss = loss + mse
print("LOSS", loss, "MSE", mse)
return loss / steps
def wrapper_test(model: Model):
test_save_root = args.gen_frm_dir
loss = 0
real_input_flag = np.zeros(
(
args.batch_size,
args.total_length - args.input_length - 1,
args.img_width // args.patch_size,
args.img_width // args.patch_size,
args.patch_size ** 2 * args.img_channel,
)
)
output_length = args.total_length - args.input_length
iterator = get_batcher()
steps = 10
for index in range(steps):
print("Sample is:", index)
dat = next(iterator)
# (25, 1, 1, 480, 480) to (1, 25, 480, 480, 1)
dat = np.squeeze(dat)
dat = np.expand_dims(dat, axis=0)
dat = np.expand_dims(dat, axis=4)
# Should already by 0 to 1
# dat = nor(dat)
tars = dat[:, -output_length:]
ims = padding_taasss(dat)
ims = preprocess.reshape_patch(ims, args.patch_size)
ims = ims.astype(np.float64)
img_gen, _ = model.test(ims, real_input_flag)
img_gen = preprocess.reshape_patch_back(img_gen, args.patch_size)
img_out = unpadding_taasss(img_gen[:, -output_length:])
mse = np.mean(np.square(tars - img_out))
# Should already by 0 to 1
# img_out = de_nor(img_out)
loss += mse
bat_ind = 0
for ind in range(index - batch_size, index, 1):
save_fold = test_save_root / f"sample_{ind}"
for t in range(6, 16, 1):
imsave(
save_fold / f"img_{t}.png",
img_out[bat_ind, t - 6, :, :, 0],
)
bat_ind += 1
return loss / steps
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
save_dir = Path(args.save_dir).parent
if save_dir.exists():
shutil.rmtree(save_dir)
save_dir.mkdir
model.save(itr)
# Remove dataset_generated directory
if os.path.exists(args.gen_frm_dir):
shutil.rmtree(args.gen_frm_dir)
os.makedirs(args.gen_frm_dir)
print("Initializing models")
model = Model(args)
print("MODELTYPE", type(model))
model.load()
# test_mse = wrapper_test(model)
wrapper_train(model)
This diff is collapsed.
This diff is collapsed.
import numpy as np
from pathlib import Path
import pandas as pd
import h5py
import cv2
from data_provider.CIKM.taasss import infinite_batcher
def get_batcher(args):
data_dir = Path("/") / "data1" / "meteotn_data_2010_2016"
metadata_file = data_dir / "run_metadata.csv"
all_data = h5py.File(
data_dir / "hdf_archives" / "all_data.hdf5",
"r",
libver="latest",
)
outlier_mask = cv2.imread(str(data_dir / "mask.png"), 0)
metadata = pd.read_csv(metadata_file, index_col="id")
metadata["start_datetime"] = pd.to_datetime(metadata["start_datetime"])
metadata["end_datetime"] = pd.to_datetime(metadata["end_datetime"])
sort_meta = metadata.sample(frac=1)
batcher = infinite_batcher(
all_data,
sort_meta,
outlier_mask,
shuffle=False,
batch_size=args.batch_size, # TODO: UPDATE FROM 1 TO 4
filter_threshold=0,
)
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# Filter images in the ROI
for imgs, _, masks in batcher:
yield imgs * ~masks
def padding_taasss(array: np.ndarray, args) -> np.ndarray:
"""
Add padding
(1, 25, 512, 512, 1)
to
(1, 25, 480, 480, 1)
"""
zeros = np.zeros((1, 25, args.img_width, args.img_width, 1))
if args.img_width == 512:
zeros[:, :, 16:496, 16:496, :] = array
elif args.img_width == 256:
zeros[:, :, 8:248, 8:248, :] = array[:, :, :240, :240, :]
elif args.img_width == 128:
zeros[:, :, 4:124, 4:124, :] = array[:, :, :120, :120, :]
return zeros
def unpadding_taasss(array: np.ndarray, args) -> np.ndarray:
"""
Remove padding
(1, 25, 480, 480, 1)
to
(1, 25, 512, 512, 1)
"""
if args.img_width == 512:
return array[:, :, 16:496, 16:496, :]
elif args.img_width == 256:
return array[:, :, 8:248, 8:248, :]
elif args.img_width == 128:
return array[:, :, 4:124, 4:124, :]
......@@ -3,7 +3,7 @@ import torch
import torch.nn as nn
from torch.optim import Adam
from core.models import predict
from pathlib import Path
class Model(object):
def __init__(self, configs):
......@@ -26,7 +26,6 @@ class Model(object):
if configs.model_name not in networks_map:
raise ValueError("Name of network unknown %s" % configs.model_name)
Network = networks_map[configs.model_name]
print("BEF NET")
self.network = Network(self.num_layers, self.num_hidden, configs).to(
configs.device
)
......@@ -39,19 +38,20 @@ class Model(object):
self.MSE_criterion = nn.MSELoss(size_average=False)
self.MAE_criterion = nn.L1Loss(size_average=False)
def save(self, ite=None):
def save(self, itr):
stats = {}
stats["net_param"] = self.network.state_dict()
torch.save(stats, self.configs.save_dir)
print(f"Saving model to {self.configs.save_dir}")
save_path = Path("/") / "data1" / "IDA_LSTM_checkpoints"
save_path.mkdir(exist_ok=True)
save_path = save_path / f"{itr}.pth"
torch.save(stats, save_path)
# print(f"Saving model to {save_path}")
def load(self):
if os.path.exists(self.configs.save_dir):
stats = torch.load(self.configs.save_dir)
self.network.load_state_dict(stats["net_param"])
print("Model loaded")
else:
print("Training from scratch")
def load(self, path):
assert os.path.exists(path), "Weights dir does not exist"
stats = torch.load(path)
self.network.load_state_dict(stats["net_param"])
print("Model loaded")
def train(self, frames, mask):
frames_tensor = torch.FloatTensor(frames).to(self.configs.device)
......
......@@ -929,7 +929,6 @@ class CST_PredRNN(nn.Module):
class InteractionDST_PredRNN(nn.Module):
def __init__(self, num_layers, num_hidden, configs):
print("IN NET")
super(InteractionDST_PredRNN, self).__init__()
self.configs = configs
self.frame_channel = (
......
......@@ -4,4 +4,6 @@ jpype1
scipy
imageio
pandas
h5py
\ No newline at end of file
h5py
tqdm
matplotlib
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"source": [
"def wrapper_train(model):\n",
" if args.pretrained_model:\n",
" model.load(args.pretrained_model)\n",
" # load data\n",
" # train_input_handle, test_input_handle = datasets_factory.data_provider(\n",
" # args.dataset_name, args.train_data_paths, args.valid_data_paths, args.batch_size, args.img_width,\n",
" # seq_length=args.total_length, is_training=True)\n",
"\n",
" eta = args.sampling_start_value\n",
" best_mse = math.inf\n",
" tolerate = 0\n",
" limit = 3\n",
" best_iter = None\n",
" for itr in range(1, args.max_iterations + 1):\n",
"\n",
" ims = sample(batch_size=batch_size)\n",
" ims = padding_CIKM_data(ims)\n",
"\n",
" ims = preprocess.reshape_patch(ims, args.patch_size)\n",
" ims = nor(ims)\n",
" eta, real_input_flag = schedule_sampling(eta, itr)\n",
"\n",
" cost = trainer.train(model, ims, real_input_flag, args, itr)\n",
"\n",
" if itr % args.display_interval == 0:\n",
" print(\"itr: \" + str(itr))\n",
" print(\"training loss: \" + str(cost))\n",
"\n",
" if itr % args.test_interval == 0:\n",
" print(\"validation one \")\n",
" valid_mse = wrapper_valid(model)\n",
" print(\"validation mse is:\", str(valid_mse))\n",
"\n",
" if valid_mse < best_mse:\n",
" best_mse = valid_mse\n",
" best_iter = itr\n",
" tolerate = 0\n",
" model.save()\n",
" else:\n",
" tolerate = tolerate + 1\n",
"\n",
" if tolerate == limit:\n",
" model.load()\n",
" test_mse = wrapper_test(model)\n",
" print(\"the best valid mse is:\", str(best_mse))\n",
" print(\"the test mse is \", str(test_mse))\n",
" break\n",
"\n",
"\n",
"def wrapper_valid(model):\n",
" loss = 0\n",
" count = 0\n",
" index = 1\n",
" flag = True\n",
" # img_mse, ssim = [], []\n",
"\n",
" # for i in range(args.total_length - args.input_length):\n",
" # img_mse.append(0)\n",
" # ssim.append(0)\n",
"\n",
" real_input_flag = np.zeros(\n",
" (\n",
" args.batch_size,\n",
" args.total_length - args.input_length - 1,\n",
" args.img_width // args.patch_size,\n",
" args.img_width // args.patch_size,\n",
" args.patch_size ** 2 * args.img_channel,\n",
" )\n",
" )\n",
" output_length = args.total_length - args.input_length\n",
" while flag:\n",
"\n",
" dat, (index, b_cup) = sample(batch_size, data_type=\"validation\", index=index)\n",
" dat = nor(dat)\n",
" tars = dat[:, -output_length:]\n",
" ims = padding_CIKM_data(dat)\n",
"\n",
" ims = preprocess.reshape_patch(ims, args.patch_size)\n",
" img_gen, _ = model.test(ims, real_input_flag)\n",
" img_gen = preprocess.reshape_patch_back(img_gen, args.patch_size)\n",
" img_out = unpadding_CIKM_data(img_gen[:, -output_length:])\n",
"\n",
" mse = np.mean(np.square(tars - img_out))\n",
" loss = loss + mse\n",
" count = count + 1\n",
" if b_cup == args.batch_size - 1:\n",
" pass\n",
" else:\n",
" flag = False\n",
"\n",
" return loss / count\n",
"\n"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# if args.is_training:\n",
"# wrapper_train(model)\n",
"# else:\n",
"# wrapper_test(model)"
],
"outputs": [],
"metadata": {}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
\ No newline at end of file
%% Cell type:code id: tags:
``` python
def wrapper_train(model):
if args.pretrained_model:
model.load(args.pretrained_model)
# load data
# train_input_handle, test_input_handle = datasets_factory.data_provider(
# args.dataset_name, args.train_data_paths, args.valid_data_paths, args.batch_size, args.img_width,
# seq_length=args.total_length, is_training=True)
eta = args.sampling_start_value
best_mse = math.inf
tolerate = 0
limit = 3
best_iter = None
for itr in range(1, args.max_iterations + 1):
ims = sample(batch_size=batch_size)
ims = padding_CIKM_data(ims)
ims = preprocess.reshape_patch(ims, args.patch_size)
ims = nor(ims)
eta, real_input_flag = schedule_sampling(eta, itr)
cost = trainer.train(model, ims, real_input_flag, args, itr)
if itr % args.display_interval == 0:
print("itr: " + str(itr))
print("training loss: " + str(cost))
if itr % args.test_interval == 0:
print("validation one ")
valid_mse = wrapper_valid(model)
print("validation mse is:", str(valid_mse))
if valid_mse < best_mse: