Commit 37544916 authored by Marco Di Francesco's avatar Marco Di Francesco 🍉
Browse files

Add docker file

parent 984d3261
FROM python:3.9-slim
WORKDIR /app
COPY requirements.txt requirements.txt
RUN pip install -r requirements.txt
CMD python cikm_inter_dst_predrnn_run_taasss.py
\ No newline at end of file
......@@ -23,7 +23,7 @@ parser = argparse.ArgumentParser(
# training/test
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
parser.add_argument("--is_training", type=int, default=1)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--device", type=str, default="cpu") # cuda
# data
parser.add_argument("--dataset_name", type=str, default="radar")
......
......@@ -22,10 +22,9 @@ parser = argparse.ArgumentParser(
description="PyTorch video prediction model - DST PredRNN"
)
# training/test
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
parser.add_argument("--is_training", type=int, default=1)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--device", type=str, default="cuda") # cuda
# data
parser.add_argument("--dataset_name", type=str, default="radar")
......@@ -34,8 +33,8 @@ parser.add_argument("--is_parallel", type=bool, default=False)
parser.add_argument("--save_dir", type=str, default="checkpoints/model.ckpt")
parser.add_argument("--gen_frm_dir", type=str, default="dataset_generated/")
parser.add_argument("--input_length", type=int, default=5)
parser.add_argument("--total_length", type=int, default=15) # 25
parser.add_argument("--img_width", type=int, default=128)
parser.add_argument("--total_length", type=int, default=25) # 15
parser.add_argument("--img_width", type=int, default=256) # 512
parser.add_argument("--img_channel", type=int, default=1)
# model
......@@ -58,8 +57,8 @@ parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--reverse_input", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=1) # 4
parser.add_argument("--max_iterations", type=int, default=80000)
parser.add_argument("--display_interval", type=int, default=200)
parser.add_argument("--test_interval", type=int, default=1) # TO 2000
parser.add_argument("--display_interval", type=int, default=1) # 200
parser.add_argument("--test_interval", type=int, default=1) # 2000
parser.add_argument("--snapshot_interval", type=int, default=5000)
parser.add_argument("--num_save_samples", type=int, default=10)
parser.add_argument("--n_gpu", type=int, default=0)
......@@ -154,36 +153,63 @@ def schedule_sampling(eta: float, itr: int):
return eta, real_input_flag
def change_taasss_shape(a: np.ndarray) -> np.ndarray:
def change_taasss_dims(a: np.ndarray) -> np.ndarray:
"""(25, 1, 1, 480, 480) to (1, 25, 480, 480, 1)"""
a = np.squeeze(a)
a = np.expand_dims(a, axis=0)
return np.expand_dims(a, axis=4)
def padding_taasss(array: np.ndarray) -> np.ndarray:
"""
Add padding
(1, 25, 512, 512, 1)
to
(1, 25, 480, 480, 1)
"""
# zeros = np.zeros((1, 25, 512, 512, 1))
# zeros[:, :, 16:496, 16:496, :] = array
zeros = np.zeros((1, 25, 256, 256, 1))
zeros[:, :, 8:248, 8:248, :] = array[:, :, :240, :240, :]
# zeros = np.zeros((1, 25, 128, 128, 1))
# zeros[:, :, 4:124, 4:124, :] = array[:, :, :120, :120, :]
return zeros
def unpadding_taasss(array: np.ndarray) -> np.ndarray:
"""
Remove padding
(1, 25, 480, 480, 1)
to
(1, 25, 512, 512, 1)
"""
# return array[:, :, 16:496, 16:496, :]
return array[:, :, 8:248, 8:248, :]
# return array[:, :, 4:124, 4:124, :]
def wrapper_train(model: Model):
# if args.pretrained_model:
# model.load(args.pretrained_model)
eta = args.sampling_start_value
# eta = args.sampling_start_value
best_mse = math.inf
tolerate = 0
limit = 3
train_model_iter = get_batcher()
for itr in range(1, args.max_iterations + 1):
train_batch, _, train_mask = next(train_model_iter)
train_batch = change_taasss_shape(train_batch)
print("IMS", train_batch.shape)
ims = crop_taasss(train_batch)
ims = preprocess.reshape_patch(ims, args.patch_size)
ims = nor(ims)
train_mask = change_taasss_shape(train_mask)
train_mask = crop_taasss(train_mask)
train_mask = preprocess.reshape_patch(train_mask, args.patch_size)
# eta, train_mask = schedule_sampling(eta, itr)
cost = trainer.train(model, ims, train_mask, args, itr)
imgs, _, masks = next(train_model_iter)
imgs = change_taasss_dims(imgs)
imgs = padding_taasss(imgs)
imgs = preprocess.reshape_patch(imgs, args.patch_size)
imgs = nor(imgs)
masks = change_taasss_dims(masks)
masks = padding_taasss(masks)
masks = preprocess.reshape_patch(masks, args.patch_size)
# eta, masks = schedule_sampling(eta, itr)
cost = trainer.train(model, imgs, masks, args)
if itr % args.display_interval == 0:
print("itr: " + str(itr))
print("training loss: " + str(cost))
......@@ -246,19 +272,6 @@ def wrapper_valid(model: Model):
return loss / count
def crop_taasss(array):
"""Add padding"""
return array[:, :15, :128, :128, :]
def uncrop_taasss(array):
"""Remove padding"""
# They had to go from 101 to 128, we did 480 to 480
zeros = np.zeros((1, 10, 480, 480, 1))
zeros[:, :, :128, :128, :] = array
return zeros
def wrapper_test(model: Model):
test_save_root = args.gen_frm_dir
clean_fold(test_save_root)
......@@ -295,12 +308,12 @@ def wrapper_test(model: Model):
dat = nor(dat)
tars = dat[:, -output_length:]
ims = crop_taasss(dat)
ims = padding_taasss(dat)
ims = preprocess.reshape_patch(ims, args.patch_size)
ims = ims.astype(np.float64)
img_gen, _ = model.test(ims, real_input_flag)
img_gen = preprocess.reshape_patch_back(img_gen, args.patch_size)
img_out = uncrop_taasss(img_gen[:, -output_length:])
img_out = unpadding_taasss(img_gen[:, -output_length:])
mse = np.mean(np.square(tars - img_out))
print(index, "MSE", mse)
......
......@@ -26,6 +26,7 @@ class Model(object):
if configs.model_name not in networks_map:
raise ValueError("Name of network unknown %s" % configs.model_name)
Network = networks_map[configs.model_name]
print("BEF NET")
self.network = Network(self.num_layers, self.num_hidden, configs).to(
configs.device
)
......@@ -53,14 +54,15 @@ class Model(object):
print("Training from scratch")
def train(self, frames, mask):
frames_tensor = torch.FloatTensor(frames).to(self.configs.device)
mask_tensor = torch.FloatTensor(mask).to(self.configs.device)
# frames_tensor = torch.FloatTensor(frames).cuda()
# mask_tensor = torch.FloatTensor(mask).cuda()
self.optimizer.zero_grad()
print("BEFORE NET")
next_frames = self.network(frames_tensor, mask_tensor)
print("AFTER NET")
loss = self.MSE_criterion(
next_frames, frames_tensor[:, 1:]
) + self.MAE_criterion(next_frames, frames_tensor[:, 1:])
......
......@@ -65,7 +65,9 @@ class ConvLSTM(nn.Module):
c_t = []
for i in range(self.num_layers):
zeros = torch.zeros([batch, self.num_hidden[i], height, width]).to(self.configs.device)
zeros = torch.zeros([batch, self.num_hidden[i], height, width]).to(
self.configs.device
)
# zeros = torch.zeros([batch, self.num_hidden[i], height, width]).cuda()
h_t.append(zeros)
c_t.append(zeros)
......@@ -145,12 +147,16 @@ class PredRNN(nn.Module):
c_t = []
for i in range(self.num_layers):
zeros = torch.zeros([batch, self.num_hidden[i], height, width]).to(self.configs.device)
zeros = torch.zeros([batch, self.num_hidden[i], height, width]).to(
self.configs.device
)
# zeros = torch.zeros([batch, self.num_hidden[i], height, width]).cuda()
h_t.append(zeros)
c_t.append(zeros)
memory = torch.zeros([batch, self.num_hidden[0], height, width]).to(self.configs.device)
memory = torch.zeros([batch, self.num_hidden[0], height, width]).to(
self.configs.device
)
# memory = torch.zeros([batch, self.num_hidden[0], height, width]).cuda()
for t in range(self.configs.total_length - 1):
......@@ -246,15 +252,21 @@ class PredRNN_Plus(nn.Module):
c_t_wide = []
for i in range(self.num_layers):
zeros = torch.zeros([batch, self.num_hidden[i], height, width]).to(self.configs.device)
zeros = torch.zeros([batch, self.num_hidden[i], height, width]).to(
self.configs.device
)
# zeros = torch.zeros([batch, self.num_hidden[i], height, width]).cuda()
# num_hidden_in = self.deep_num_hidden[i-1]
h_t.append(zeros)
c_t.append(zeros)
memory = torch.zeros([batch, self.num_hidden[-1], height, width]).to(self.configs.device)
memory = torch.zeros([batch, self.num_hidden[-1], height, width]).to(
self.configs.device
)
# memory = torch.zeros([batch, self.num_hidden[-1], height, width]).cuda()
z_t = torch.zeros([batch, self.num_hidden[0], height, width]).to(self.configs.device)
z_t = torch.zeros([batch, self.num_hidden[0], height, width]).to(
self.configs.device
)
if is_training:
seq_length = self.configs.total_length
......@@ -342,7 +354,9 @@ class InteractionConvLSTM(nn.Module):
c_t = []
for i in range(self.num_layers):
zeros = torch.zeros([batch, self.num_hidden[i], height, width]).to(self.configs.device)
zeros = torch.zeros([batch, self.num_hidden[i], height, width]).to(
self.configs.device
)
# zeros = torch.zeros([batch, self.num_hidden[i], height, width]).cuda()
h_t.append(zeros)
c_t.append(zeros)
......@@ -423,12 +437,16 @@ class InteractionPredRNN(nn.Module):
c_t = []
for i in range(self.num_layers):
zeros = torch.zeros([batch, self.num_hidden[i], height, width]).to(self.configs.device)
zeros = torch.zeros([batch, self.num_hidden[i], height, width]).to(
self.configs.device
)
# zeros = torch.zeros([batch, self.num_hidden[i], height, width]).cuda()
h_t.append(zeros)
c_t.append(zeros)
memory = torch.zeros([batch, self.num_hidden[0], height, width]).to(self.configs.device)
memory = torch.zeros([batch, self.num_hidden[0], height, width]).to(
self.configs.device
)
# memory = torch.zeros([batch, self.num_hidden[0], height, width]).cuda()
for t in range(self.configs.total_length - 1):
......@@ -525,15 +543,21 @@ class InteractionPredRNN_Plus(nn.Module):
c_t_wide = []
for i in range(self.num_layers):
zeros = torch.zeros([batch, self.num_hidden[i], height, width]).to(self.configs.device)
zeros = torch.zeros([batch, self.num_hidden[i], height, width]).to(
self.configs.device
)
# zeros = torch.zeros([batch, self.num_hidden[i], height, width]).cuda()
# num_hidden_in = self.deep_num_hidden[i-1]
h_t.append(zeros)
c_t.append(zeros)
memory = torch.zeros([batch, self.num_hidden[-1], height, width]).to(self.configs.device)
memory = torch.zeros([batch, self.num_hidden[-1], height, width]).to(
self.configs.device
)
memory = torch.zeros([batch, self.num_hidden[-1], height, width]).cuda()
z_t = torch.zeros([batch, self.num_hidden[0], height, width]).to(self.configs.device)
z_t = torch.zeros([batch, self.num_hidden[0], height, width]).to(
self.configs.device
)
# z_t = torch.zeros([batch, self.num_hidden[0], height, width]).cuda()
if is_training:
......@@ -628,16 +652,22 @@ class DST_PredRNN(nn.Module):
c_t_wide = []
c_t_history = []
for i in range(self.num_layers):
zeros = torch.zeros([batch, self.num_hidden[i], height, width]).to(self.configs.device)
zeros = torch.zeros([batch, self.num_hidden[i], height, width]).to(
self.configs.device
)
# zeros = torch.zeros([batch, self.num_hidden[i], height, width]).cuda()
# num_hidden_in = self.deep_num_hidden[i-1]
h_t.append(zeros)
c_t.append(zeros)
c_t_history.append(zeros.unsqueeze(1))
memory = torch.zeros([batch, self.num_hidden[-1], height, width]).to(self.configs.device)
memory = torch.zeros([batch, self.num_hidden[-1], height, width]).to(
self.configs.device
)
# memory = torch.zeros([batch, self.num_hidden[-1], height, width]).cuda()
z_t = torch.zeros([batch, self.num_hidden[0], height, width]).to(self.configs.device)
z_t = torch.zeros([batch, self.num_hidden[0], height, width]).to(
self.configs.device
)
# z_t = torch.zeros([batch, self.num_hidden[0], height, width]).cuda()
if is_training:
......@@ -732,16 +762,22 @@ class SST_PredRNN(nn.Module):
c_t_wide = []
c_t_history = []
for i in range(self.num_layers):
zeros = torch.zeros([batch, self.num_hidden[i], height, width]).to(self.configs.device)
zeros = torch.zeros([batch, self.num_hidden[i], height, width]).to(
self.configs.device
)
# zeros = torch.zeros([batch, self.num_hidden[i], height, width]).cuda()
# num_hidden_in = self.deep_num_hidden[i-1]
h_t.append(zeros)
c_t.append(zeros)
c_t_history.append(zeros.unsqueeze(1))
memory = torch.zeros([batch, self.num_hidden[-1], height, width]).to(self.configs.device)
memory = torch.zeros([batch, self.num_hidden[-1], height, width]).to(
self.configs.device
)
memory = torch.zeros([batch, self.num_hidden[-1], height, width]).cuda()
z_t = torch.zeros([batch, self.num_hidden[0], height, width]).to(self.configs.device)
z_t = torch.zeros([batch, self.num_hidden[0], height, width]).to(
self.configs.device
)
z_t = torch.zeros([batch, self.num_hidden[0], height, width]).cuda()
if is_training:
......@@ -836,16 +872,22 @@ class CST_PredRNN(nn.Module):
c_t_wide = []
c_t_history = []
for i in range(self.num_layers):
zeros = torch.zeros([batch, self.num_hidden[i], height, width]).to(self.configs.device)
zeros = torch.zeros([batch, self.num_hidden[i], height, width]).to(
self.configs.device
)
# zeros = torch.zeros([batch, self.num_hidden[i], height, width]).cuda()
# num_hidden_in = self.deep_num_hidden[i-1]
h_t.append(zeros)
c_t.append(zeros)
c_t_history.append(zeros.unsqueeze(1))
memory = torch.zeros([batch, self.num_hidden[-1], height, width]).to(self.configs.device)
memory = torch.zeros([batch, self.num_hidden[-1], height, width]).to(
self.configs.device
)
# memory = torch.zeros([batch, self.num_hidden[-1], height, width]).cuda()
z_t = torch.zeros([batch, self.num_hidden[0], height, width]).to(self.configs.device)
z_t = torch.zeros([batch, self.num_hidden[0], height, width]).to(
self.configs.device
)
# z_t = torch.zeros([batch, self.num_hidden[0], height, width]).cuda()
if is_training:
......@@ -887,8 +929,8 @@ class CST_PredRNN(nn.Module):
class InteractionDST_PredRNN(nn.Module):
def __init__(self, num_layers, num_hidden, configs):
print("IN NET")
super(InteractionDST_PredRNN, self).__init__()
self.configs = configs
self.frame_channel = (
configs.img_channel * configs.patch_size * configs.patch_size
......@@ -933,7 +975,7 @@ class InteractionDST_PredRNN(nn.Module):
batch = frames.shape[0]
height = frames.shape[3]
width = frames.shape[4]
print("FORWARD")
next_frames = []
h_t = []
c_t = []
......@@ -941,16 +983,22 @@ class InteractionDST_PredRNN(nn.Module):
c_t_wide = []
c_t_history = []
for i in range(self.num_layers):
zeros = torch.zeros([batch, self.num_hidden[i], height, width]).to(self.configs.device)
zeros = torch.zeros([batch, self.num_hidden[i], height, width]).to(
self.configs.device
)
# zeros = torch.zeros([batch, self.num_hidden[i], height, width]).cuda()
# num_hidden_in = self.deep_num_hidden[i-1]
h_t.append(zeros)
c_t.append(zeros)
c_t_history.append(zeros.unsqueeze(1))
memory = torch.zeros([batch, self.num_hidden[-1], height, width]).to(self.configs.device)
memory = torch.zeros([batch, self.num_hidden[-1], height, width]).to(
self.configs.device
)
# memory = torch.zeros([batch, self.num_hidden[-1], height, width]).cuda()
z_t = torch.zeros([batch, self.num_hidden[0], height, width]).to(self.configs.device)
z_t = torch.zeros([batch, self.num_hidden[0], height, width]).to(
self.configs.device
)
# z_t = torch.zeros([batch, self.num_hidden[0], height, width]).cuda()
if is_training:
......@@ -986,5 +1034,5 @@ class InteractionDST_PredRNN(nn.Module):
next_frames = (
torch.stack(next_frames, dim=0).permute(1, 0, 3, 4, 2).contiguous()
)
print("END F")
return next_frames
from core.models.model_factory import Model
import os.path
import datetime
import numpy as np
from core.utils import preprocess
def train(model: Model, ims, real_input_flag, configs, itr):
def train(model: Model, ims, real_input_flag, configs, itr=None):
cost = model.train(ims, real_input_flag)
if configs.reverse_input:
ims_rev = np.flip(ims, axis=1).copy()
......
......@@ -4,13 +4,13 @@ import copy
import os
def nor(frames):
def nor(frames: np.ndarray) -> np.ndarray:
"""Pixels / 255"""
new_frames = frames.astype(np.float32) / 255.0
return new_frames
def de_nor(frames):
def de_nor(frames: np.ndarray) -> np.ndarray:
"""Pixels * 255"""
new_frames = copy.deepcopy(frames)
new_frames *= 255.0
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment