Commit 984d3261 authored by Marco Di Francesco's avatar Marco Di Francesco 🍉
Browse files

Add training with real train masks

parent 9face97f
......@@ -155,11 +155,12 @@ def wrapper_train(model):
limit = 3
best_iter = None
for itr in range(1, args.max_iterations + 1):
ims = sample(batch_size=batch_size)
print("G1", ims.shape)
ims = padding_CIKM_data(ims)
ims = preprocess.reshape_patch(ims, args.patch_size)
print("G2", ims.shape)
ims = nor(ims)
eta, real_input_flag = schedule_sampling(eta, itr)
......@@ -309,5 +310,5 @@ os.makedirs(args.gen_frm_dir)
print("Initializing models")
model = Model(args)
model.load()
test_mse = wrapper_test(model)
# wrapper_train(model)
# test_mse = wrapper_test(model)
wrapper_train(model)
......@@ -59,12 +59,11 @@ parser.add_argument("--reverse_input", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=1) # 4
parser.add_argument("--max_iterations", type=int, default=80000)
parser.add_argument("--display_interval", type=int, default=200)
parser.add_argument("--test_interval", type=int, default=2000)
parser.add_argument("--test_interval", type=int, default=1) # TO 2000
parser.add_argument("--snapshot_interval", type=int, default=5000)
parser.add_argument("--num_save_samples", type=int, default=10)
parser.add_argument("--n_gpu", type=int, default=0)
args = parser.parse_args()
batch_size = args.batch_size
......@@ -88,30 +87,18 @@ def get_batcher():
sort_meta,
outlier_mask,
shuffle=False,
batch_size=args.batch_size, # TODO: UPDATE ITTTTTTTT
batch_size=args.batch_size, # TODO: UPDATE FROM 1 TO 4
filter_threshold=0,
)
# def padding_CIKM_data(frame_data):
# # (1, 15, 101, 101, 1)
# shape = frame_data.shape
# batch_size = shape[0]
# seq_length = shape[1]
# # (1, 15, 128, 128, 1)
# padding_frame_dat = np.zeros(
# (batch_size, seq_length, args.img_width, args.img_width, args.img_channel)
# )
# padding_frame_dat[:, :, 13:-14, 13:-14, :] = frame_data
# return padding_frame_dat
# def unpadding_CIKM_data(padding_frame_dat):
# return padding_frame_dat[:, :, 13:-14, 13:-14, :]
def schedule_sampling(eta, itr):
def schedule_sampling(eta: float, itr: int):
"""
Return
- Eta: 1 down to 0
- Itr: (1, 9, 32, 32, 16)
"""
print("ITR", itr)
zeros = np.zeros(
(
args.batch_size,
......@@ -167,111 +154,112 @@ def schedule_sampling(eta, itr):
return eta, real_input_flag
# def wrapper_train(model):
# if args.pretrained_model:
# model.load(args.pretrained_model)
# # load data
# # train_input_handle, test_input_handle = datasets_factory.data_provider(
# # args.dataset_name, args.train_data_paths, args.valid_data_paths, args.batch_size, args.img_width,
# # seq_length=args.total_length, is_training=True)
# eta = args.sampling_start_value
# best_mse = math.inf
# tolerate = 0
# limit = 3
# best_iter = None
# for itr in range(1, args.max_iterations + 1):
# ims = sample(batch_size=batch_size)
# ims = padding_CIKM_data(ims)
# ims = preprocess.reshape_patch(ims, args.patch_size)
# ims = nor(ims)
# eta, real_input_flag = schedule_sampling(eta, itr)
# cost = trainer.train(model, ims, real_input_flag, args, itr)
# if itr % args.display_interval == 0:
# print("itr: " + str(itr))
# print("training loss: " + str(cost))
# if itr % args.test_interval == 0:
# print("validation one ")
# valid_mse = wrapper_valid(model)
# print("validation mse is:", str(valid_mse))
# if valid_mse < best_mse:
# best_mse = valid_mse
# best_iter = itr
# tolerate = 0
# model.save()
# else:
# tolerate = tolerate + 1
# if tolerate == limit:
# model.load()
# test_mse = wrapper_test(model)
# print("the best valid mse is:", str(best_mse))
# print("the test mse is ", str(test_mse))
# break
# def wrapper_valid(model):
# loss = 0
# count = 0
# index = 1
# flag = True
# # img_mse, ssim = [], []
# # for i in range(args.total_length - args.input_length):
# # img_mse.append(0)
# # ssim.append(0)
# real_input_flag = np.zeros(
# (
# args.batch_size,
# args.total_length - args.input_length - 1,
# args.img_width // args.patch_size,
# args.img_width // args.patch_size,
# args.patch_size ** 2 * args.img_channel,
# )
# )
# output_length = args.total_length - args.input_length
# while flag:
# dat, (index, b_cup) = sample(batch_size, data_type="validation", index=index)
# dat = nor(dat)
# tars = dat[:, -output_length:]
# ims = padding_CIKM_data(dat)
# ims = preprocess.reshape_patch(ims, args.patch_size)
# img_gen, _ = model.test(ims, real_input_flag)
# img_gen = preprocess.reshape_patch_back(img_gen, args.patch_size)
# img_out = unpadding_CIKM_data(img_gen[:, -output_length:])
# mse = np.mean(np.square(tars - img_out))
# loss = loss + mse
# count = count + 1
# if b_cup == args.batch_size - 1:
# pass
# else:
# flag = False
# return loss / count
def change_taasss_shape(a: np.ndarray) -> np.ndarray:
"""(25, 1, 1, 480, 480) to (1, 25, 480, 480, 1)"""
a = np.squeeze(a)
a = np.expand_dims(a, axis=0)
return np.expand_dims(a, axis=4)
def wrapper_train(model: Model):
# if args.pretrained_model:
# model.load(args.pretrained_model)
eta = args.sampling_start_value
best_mse = math.inf
tolerate = 0
limit = 3
train_model_iter = get_batcher()
for itr in range(1, args.max_iterations + 1):
train_batch, _, train_mask = next(train_model_iter)
train_batch = change_taasss_shape(train_batch)
print("IMS", train_batch.shape)
ims = crop_taasss(train_batch)
ims = preprocess.reshape_patch(ims, args.patch_size)
ims = nor(ims)
train_mask = change_taasss_shape(train_mask)
train_mask = crop_taasss(train_mask)
train_mask = preprocess.reshape_patch(train_mask, args.patch_size)
# eta, train_mask = schedule_sampling(eta, itr)
cost = trainer.train(model, ims, train_mask, args, itr)
if itr % args.display_interval == 0:
print("itr: " + str(itr))
print("training loss: " + str(cost))
if itr % args.test_interval == 0:
print("validation one ")
valid_mse = wrapper_valid(model)
print("validation mse is:", str(valid_mse))
if valid_mse < best_mse:
best_mse = valid_mse
tolerate = 0
model.save()
else:
tolerate = tolerate + 1
if tolerate == limit:
model.load()
test_mse = wrapper_test(model)
print("the best valid mse is:", str(best_mse))
print("the test mse is ", str(test_mse))
break
def wrapper_valid(model: Model):
loss = 0
count = 0
index = 1
flag = True
real_input_flag = np.zeros(
(
args.batch_size,
args.total_length - args.input_length - 1,
args.img_width // args.patch_size,
args.img_width // args.patch_size,
args.patch_size ** 2 * args.img_channel,
)
)
output_length = args.total_length - args.input_length
while flag:
dat, (index, b_cup) = sample(batch_size, data_type="validation", index=index)
dat = nor(dat)
tars = dat[:, -output_length:]
ims = padding_CIKM_data(dat)
ims = preprocess.reshape_patch(ims, args.patch_size)
img_gen, _ = model.test(ims, real_input_flag)
img_gen = preprocess.reshape_patch_back(img_gen, args.patch_size)
img_out = unpadding_CIKM_data(img_gen[:, -output_length:])
mse = np.mean(np.square(tars - img_out))
loss = loss + mse
count = count + 1
if b_cup == args.batch_size - 1:
pass
else:
flag = False
return loss / count
def crop_taasss(array):
"""Add padding"""
return array[:, :15, :128, :128, :]
def uncrop_taasss(array):
"""Remove padding"""
# They had to go from 101 to 128, we did 480 to 480
zeros = np.zeros((1, 10, 480, 480, 1))
zeros[:, :, :128, :128, :] = array
return zeros
def wrapper_test(model):
def wrapper_test(model: Model):
test_save_root = args.gen_frm_dir
clean_fold(test_save_root)
loss = 0
......@@ -279,11 +267,6 @@ def wrapper_test(model):
# index = 1
flag = True
# img_mse, ssim = [], []
# for _ in range(args.total_length - args.input_length):
# img_mse.append(0)
# ssim.append(0)
# Shape: (1, 9, 32, 32, 16)
real_input_flag = np.zeros(
(
......@@ -300,13 +283,11 @@ def wrapper_test(model):
b_cup = 0 # ?????????
while flag:
index += 1
# print("Sample is:", index)
# dat, (index, b_cup) = sample(batch_size, data_type="test", index=index)
print("Sample is:", index)
train_batch, sample_datetimes, train_mask = next(train_model_iter)
# Just because they called it this way
dat = train_batch
# (1, 15, 101, 101, 1)
# (25, 1, 1, 480, 480) to (1, 25, 480, 480, 1)
dat = np.squeeze(dat)
dat = np.expand_dims(dat, axis=0)
......@@ -337,7 +318,6 @@ def wrapper_test(model):
img_out[bat_ind, t - 6, :, :, 0],
)
bat_ind = bat_ind + 1
if b_cup == args.batch_size - 1:
pass
else:
......@@ -357,6 +337,7 @@ os.makedirs(args.gen_frm_dir)
print("Initializing models")
model = Model(args)
print("MODELTYPE", type(model))
model.load()
test_mse = wrapper_test(model)
# wrapper_train(model)
# test_mse = wrapper_test(model)
wrapper_train(model)
from core.models.model_factory import Model
import os.path
import datetime
import numpy as np
from core.utils import preprocess
def train(model, ims, real_input_flag, configs, itr):
def train(model: Model, ims, real_input_flag, configs, itr):
cost = model.train(ims, real_input_flag)
if configs.reverse_input:
ims_rev = np.flip(ims, axis=1).copy()
......@@ -12,5 +13,3 @@ def train(model, ims, real_input_flag, configs, itr):
cost = cost / 2
return cost
......@@ -5,11 +5,13 @@ import os
def nor(frames):
"""Pixels / 255"""
new_frames = frames.astype(np.float32) / 255.0
return new_frames
def de_nor(frames):
"""Pixels * 255"""
new_frames = copy.deepcopy(frames)
new_frames *= 255.0
new_frames = new_frames.astype(np.uint8)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment