Commit 41981e06 authored by Marco Di Francesco's avatar Marco Di Francesco 🍉
Browse files

Fix image library

parent 95f130b7
dataset/
checkpoints/
dataset_generated/
\ No newline at end of file
dataset_generated/
.venv
\ No newline at end of file
{
"python.pythonPath": ".venv/bin/python3"
}
\ No newline at end of file
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import argparse
import math
import shutil
......@@ -9,9 +8,10 @@ import numpy as np
import torch
import core.trainer as trainer
from core.models.model_factory import *
from core.models.model_factory import Model
from core.utils import preprocess
from data_provider.CIKM.data_iterator import *
from data_provider.CIKM.data_iterator import clean_fold, sample, imsave
from core.utils.util import nor, de_nor
from pathlib import Path
......@@ -21,8 +21,9 @@ parser = argparse.ArgumentParser(
)
# training/test
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
parser.add_argument("--is_training", type=int, default=1)
parser.add_argument('--device', type=str, default='cpu')
parser.add_argument("--device", type=str, default="cuda")
# data
parser.add_argument("--dataset_name", type=str, default="radar")
......@@ -67,9 +68,12 @@ batch_size = args.batch_size
def padding_CIKM_data(frame_data):
# (1, 15, 101, 101, 1)
shape = frame_data.shape
batch_size = shape[0]
seq_length = shape[1]
# (1, 15, 128, 128, 1)
padding_frame_dat = np.zeros(
(batch_size, seq_length, args.img_width, args.img_width, args.img_channel)
)
......@@ -137,18 +141,65 @@ def schedule_sampling(eta, itr):
return eta, real_input_flag
def wrapper_test(model):
test_save_root = args.gen_frm_dir
clean_fold(test_save_root)
def wrapper_train(model):
if args.pretrained_model:
model.load(args.pretrained_model)
# load data
# train_input_handle, test_input_handle = datasets_factory.data_provider(
# args.dataset_name, args.train_data_paths, args.valid_data_paths, args.batch_size, args.img_width,
# seq_length=args.total_length, is_training=True)
eta = args.sampling_start_value
best_mse = math.inf
tolerate = 0
limit = 3
best_iter = None
for itr in range(1, args.max_iterations + 1):
ims = sample(batch_size=batch_size)
ims = padding_CIKM_data(ims)
ims = preprocess.reshape_patch(ims, args.patch_size)
ims = nor(ims)
eta, real_input_flag = schedule_sampling(eta, itr)
cost = trainer.train(model, ims, real_input_flag, args, itr)
if itr % args.display_interval == 0:
print("itr: " + str(itr))
print("training loss: " + str(cost))
if itr % args.test_interval == 0:
print("validation one ")
valid_mse = wrapper_valid(model)
print("validation mse is:", str(valid_mse))
if valid_mse < best_mse:
best_mse = valid_mse
best_iter = itr
tolerate = 0
model.save()
else:
tolerate = tolerate + 1
if tolerate == limit:
model.load()
test_mse = wrapper_test(model)
print("the best valid mse is:", str(best_mse))
print("the test mse is ", str(test_mse))
break
def wrapper_valid(model):
loss = 0
count = 0
index = 1
flag = True
img_mse, ssim = [], []
# img_mse, ssim = [], []
for i in range(args.total_length - args.input_length):
img_mse.append(0)
ssim.append(0)
# for i in range(args.total_length - args.input_length):
# img_mse.append(0)
# ssim.append(0)
real_input_flag = np.zeros(
(
......@@ -161,33 +212,20 @@ def wrapper_test(model):
)
output_length = args.total_length - args.input_length
while flag:
dat, (index, b_cup) = sample(batch_size, data_type="test", index=index)
dat, (index, b_cup) = sample(batch_size, data_type="validation", index=index)
dat = nor(dat)
tars = dat[:, -output_length:]
ims = padding_CIKM_data(dat)
print("index is:", str(index))
ims = preprocess.reshape_patch(ims, args.patch_size)
img_gen, _ = model.test(ims, real_input_flag)
img_gen = preprocess.reshape_patch_back(img_gen, args.patch_size)
img_out = unpadding_CIKM_data(img_gen[:, -output_length:])
mse = np.mean(np.square(tars - img_out))
img_out = de_nor(img_out)
loss = loss + mse
count = count + 1
bat_ind = 0
for ind in range(index - batch_size, index, 1):
save_fold = test_save_root + "sample_" + str(ind) + "/"
clean_fold(save_fold)
for t in range(6, 16, 1):
imsave(
save_fold + "img_" + str(t) + ".png",
img_out[bat_ind, t - 6, :, :, 0],
)
bat_ind = bat_ind + 1
if b_cup == args.batch_size - 1:
pass
else:
......@@ -196,17 +234,20 @@ def wrapper_test(model):
return loss / count
def wrapper_valid(model):
def wrapper_test(model):
test_save_root = args.gen_frm_dir
clean_fold(test_save_root)
loss = 0
count = 0
index = 1
flag = True
img_mse, ssim = [], []
for i in range(args.total_length - args.input_length):
img_mse.append(0)
ssim.append(0)
# img_mse, ssim = [], []
# for _ in range(args.total_length - args.input_length):
# img_mse.append(0)
# ssim.append(0)
# Shape: (1, 9, 32, 32, 16)
real_input_flag = np.zeros(
(
args.batch_size,
......@@ -218,20 +259,33 @@ def wrapper_valid(model):
)
output_length = args.total_length - args.input_length
while flag:
dat, (index, b_cup) = sample(batch_size, data_type="validation", index=index)
dat, (index, b_cup) = sample(batch_size, data_type="test", index=index)
dat = nor(dat)
tars = dat[:, -output_length:]
ims = padding_CIKM_data(dat)
ims = preprocess.reshape_patch(ims, args.patch_size)
img_gen, _ = model.test(ims, real_input_flag)
img_gen = preprocess.reshape_patch_back(img_gen, args.patch_size)
img_out = unpadding_CIKM_data(img_gen[:, -output_length:])
mse = np.mean(np.square(tars - img_out))
loss = loss + mse
count = count + 1
img_out = de_nor(img_out)
loss += mse
count += 1
bat_ind = 0
print("index is:", index)
for ind in range(index - batch_size, index, 1):
save_fold = test_save_root + "sample_" + str(ind) + "/"
clean_fold(save_fold)
for t in range(6, 16, 1):
imsave(
save_fold + "img_" + str(t) + ".png",
img_out[bat_ind, t - 6, :, :, 0],
)
bat_ind = bat_ind + 1
if b_cup == args.batch_size - 1:
pass
else:
......@@ -240,76 +294,18 @@ def wrapper_valid(model):
return loss / count
def wrapper_train(model):
if args.pretrained_model:
model.load(args.pretrained_model)
# load data
# train_input_handle, test_input_handle = datasets_factory.data_provider(
# args.dataset_name, args.train_data_paths, args.valid_data_paths, args.batch_size, args.img_width,
# seq_length=args.total_length, is_training=True)
eta = args.sampling_start_value
best_mse = math.inf
tolerate = 0
limit = 3
best_iter = None
for itr in range(1, args.max_iterations + 1):
ims = sample(batch_size=batch_size)
ims = padding_CIKM_data(ims)
ims = preprocess.reshape_patch(ims, args.patch_size)
ims = nor(ims)
eta, real_input_flag = schedule_sampling(eta, itr)
cost = trainer.train(model, ims, real_input_flag, args, itr)
if itr % args.display_interval == 0:
print("itr: " + str(itr))
print("training loss: " + str(cost))
if itr % args.test_interval == 0:
print("validation one ")
valid_mse = wrapper_valid(model)
print("validation mse is:", str(valid_mse))
if valid_mse < best_mse:
best_mse = valid_mse
best_iter = itr
tolerate = 0
model.save()
else:
tolerate = tolerate + 1
if tolerate == limit:
model.load()
test_mse = wrapper_test(model)
print("the best valid mse is:", str(best_mse))
print("the test mse is ", str(test_mse))
break
chckp_dir = Path(args.save_dir).parent
if chckp_dir.exists():
shutil.rmtree(chckp_dir)
chckp_dir.mkdir
save_dir = Path(args.save_dir).parent
if save_dir.exists():
shutil.rmtree(save_dir)
save_dir.mkdir
# if os.path.exists(args.gen_frm_dir):
# shutil.rmtree(args.gen_frm_dir)
# os.makedirs(args.gen_frm_dir)
# print(torch.cuda.device_count())
# gpu_list = np.asarray(os.environ.get('CUDA_VISIBLE_DEVICES', '-1').split(','), dtype=np.int32)
# args.n_gpu = len(gpu_list)
# print('the number gpu is:',str(len(gpu_list)))
print("Initializing models")
model = Model(args)
model.load()
test_mse = wrapper_test(model)
print("test mse is:", str(test_mse))
# if args.is_training:
# wrapper_train(model)
# else:
# wrapper_test(model)
# wrapper_train(model)
wrapper_test(model)
import os
import torch
import torch.nn as nn
from torch.optim import Adam
from core.models import predict
class Model(object):
def __init__(self, configs):
self.configs = configs
self.num_hidden = [int(x) for x in configs.num_hidden.split(",")]
self.num_layers = len(self.num_hidden)
networks_map = {
"convlstm": predict.ConvLSTM,
"predrnn": predict.PredRNN,
"predrnn_plus": predict.PredRNN_Plus,
"interact_convlstm": predict.InteractionConvLSTM,
"interact_predrnn": predict.InteractionPredRNN,
"interact_predrnn_plus": predict.InteractionPredRNN_Plus,
"cst_predrnn": predict.CST_PredRNN,
"sst_predrnn": predict.SST_PredRNN,
"dst_predrnn": predict.DST_PredRNN,
"interact_dst_predrnn": predict.InteractionDST_PredRNN,
}
if not configs.model_name in networks_map:
raise ValueError("Name of network unknown %s" % configs.model_name)
Network = networks_map[configs.model_name]
self.network = Network(self.num_layers, self.num_hidden, configs).to(
configs.device
)
# self.network = Network(self.num_layers, self.num_hidden, configs).cuda()
if self.configs.is_parallel:
self.network = nn.DataParallel(self.network)
self.optimizer = Adam(self.network.parameters(), lr=configs.lr)
# TODO: size_average to sum
self.MSE_criterion = nn.MSELoss(size_average=False)
self.MAE_criterion = nn.L1Loss(size_average=False)
def save(self, ite=None):
stats = {}
stats["net_param"] = self.network.state_dict()
torch.save(stats, self.configs.save_dir)
print(f"Saving model to {self.configs.save_dir}")
def load(self):
if os.path.exists(self.configs.save_dir):
stats = torch.load(self.configs.save_dir)
self.network.load_state_dict(stats["net_param"])
print("Model loaded")
else:
print("Training from scratch")
def train(self, frames, mask):
frames_tensor = torch.FloatTensor(frames).to(self.configs.device)
mask_tensor = torch.FloatTensor(mask).to(self.configs.device)
# frames_tensor = torch.FloatTensor(frames).cuda()
# mask_tensor = torch.FloatTensor(mask).cuda()
self.optimizer.zero_grad()
next_frames = self.network(frames_tensor, mask_tensor)
loss = self.MSE_criterion(
next_frames, frames_tensor[:, 1:]
) + self.MAE_criterion(next_frames, frames_tensor[:, 1:])
# 0.02*self.SSIM_criterion(next_frames, frames_tensor[:, 1:])
loss.backward()
self.optimizer.step()
return loss.detach().cpu().numpy()
def test(self, frames, mask):
# frames_tensor = torch.FloatTensor(frames).cuda()
# mask_tensor = torch.FloatTensor(mask).cuda()
frames_tensor = torch.FloatTensor(frames).to(self.configs.device)
mask_tensor = torch.FloatTensor(mask).to(self.configs.device)
next_frames = self.network(frames_tensor, mask_tensor)
loss = self.MSE_criterion(
next_frames, frames_tensor[:, 1:]
) + self.MAE_criterion(next_frames, frames_tensor[:, 1:])
# + 0.02 * self.SSIM_criterion(next_frames, frames_tensor[:, 1:])
return next_frames.detach().cpu().numpy(), loss.detach().cpu().numpy()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment