Commit ad017f30 authored by Marco Di Francesco's avatar Marco Di Francesco 🍉
Browse files

Update inter dst predrnn taasss

parent 102819f5
FROM nvcr.io/nvidia/pytorch:21.05-py3
WORKDIR /app
# Install opencv dependencies
ENV DEBIAN_FRONTEND noninteractive
RUN apt-get update && apt-get install -y ffmpeg libsm6 libxext6
COPY requirements.txt requirements.txt
RUN pip install --no-cache-dir -r requirements.txt
......
......@@ -156,14 +156,12 @@ def wrapper_train(model):
best_iter = None
for itr in range(1, args.max_iterations + 1):
ims = sample(batch_size=batch_size)
print("G1", ims.shape)
ims = padding_CIKM_data(ims)
ims = preprocess.reshape_patch(ims, args.patch_size)
print("G2", ims.shape)
ims = nor(ims)
eta, real_input_flag = schedule_sampling(eta, itr)
print("IMS", ims.shape)
cost = trainer.train(model, ims, real_input_flag, args, itr)
if itr % args.display_interval == 0:
......@@ -213,7 +211,6 @@ def wrapper_valid(model):
)
output_length = args.total_length - args.input_length
while flag:
dat, (index, b_cup) = sample(batch_size, data_type="validation", index=index)
dat = nor(dat)
tars = dat[:, -output_length:]
......
......@@ -22,7 +22,6 @@ parser = argparse.ArgumentParser(
description="PyTorch video prediction model - DST PredRNN"
)
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
parser.add_argument("--is_training", type=int, default=1)
parser.add_argument("--device", type=str, default="cuda") # cuda
......@@ -34,7 +33,7 @@ parser.add_argument("--save_dir", type=str, default="checkpoints/model.ckpt")
parser.add_argument("--gen_frm_dir", type=str, default="dataset_generated/")
parser.add_argument("--input_length", type=int, default=5)
parser.add_argument("--total_length", type=int, default=25) # 15
parser.add_argument("--img_width", type=int, default=256) # 512
parser.add_argument("--img_width", type=int, default=512)
parser.add_argument("--img_channel", type=int, default=1)
# model
......@@ -81,7 +80,7 @@ def get_batcher():
metadata["end_datetime"] = pd.to_datetime(metadata["end_datetime"])
sort_meta = metadata.sample(frac=1)
return infinite_batcher(
batcher = infinite_batcher(
all_data,
sort_meta,
outlier_mask,
......@@ -89,6 +88,11 @@ def get_batcher():
batch_size=args.batch_size, # TODO: UPDATE FROM 1 TO 4
filter_threshold=0,
)
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# Filter images in the ROI
for imgs, _, masks in batcher:
yield imgs * ~masks
def schedule_sampling(eta: float, itr: int):
......@@ -97,7 +101,6 @@ def schedule_sampling(eta: float, itr: int):
- Eta: 1 down to 0
- Itr: (1, 9, 32, 32, 16)
"""
print("ITR", itr)
zeros = np.zeros(
(
args.batch_size,
......@@ -167,11 +170,11 @@ def padding_taasss(array: np.ndarray) -> np.ndarray:
to
(1, 25, 480, 480, 1)
"""
# zeros = np.zeros((1, 25, 512, 512, 1))
# zeros[:, :, 16:496, 16:496, :] = array
zeros = np.zeros((1, 25, 512, 512, 1))
zeros[:, :, 16:496, 16:496, :] = array
zeros = np.zeros((1, 25, 256, 256, 1))
zeros[:, :, 8:248, 8:248, :] = array[:, :, :240, :240, :]
# zeros = np.zeros((1, 25, 256, 256, 1))
# zeros[:, :, 8:248, 8:248, :] = array[:, :, :240, :240, :]
# zeros = np.zeros((1, 25, 128, 128, 1))
# zeros[:, :, 4:124, 4:124, :] = array[:, :, :120, :120, :]
......@@ -185,38 +188,33 @@ def unpadding_taasss(array: np.ndarray) -> np.ndarray:
to
(1, 25, 512, 512, 1)
"""
# return array[:, :, 16:496, 16:496, :]
return array[:, :, 8:248, 8:248, :]
return array[:, :, 16:496, 16:496, :]
# return array[:, :, 8:248, 8:248, :]
# return array[:, :, 4:124, 4:124, :]
def wrapper_train(model: Model):
# if args.pretrained_model:
# model.load(args.pretrained_model)
# eta = args.sampling_start_value
eta = args.sampling_start_value
best_mse = math.inf
tolerate = 0
limit = 3
train_model_iter = get_batcher()
iterator = get_batcher()
for itr in range(1, args.max_iterations + 1):
imgs, _, masks = next(train_model_iter)
imgs = next(iterator)
imgs = change_taasss_dims(imgs)
imgs = padding_taasss(imgs)
imgs = preprocess.reshape_patch(imgs, args.patch_size)
imgs = nor(imgs)
# Should already by 0 to 1
# imgs = nor(imgs)
masks = change_taasss_dims(masks)
masks = padding_taasss(masks)
masks = preprocess.reshape_patch(masks, args.patch_size)
# eta, masks = schedule_sampling(eta, itr)
cost = trainer.train(model, imgs, masks, args)
eta, real_input_flag = schedule_sampling(eta, itr)
cost = trainer.train(model, imgs, real_input_flag, args)
if itr % args.display_interval == 0:
print("itr: " + str(itr))
print("training loss: " + str(cost))
if itr % args.test_interval == 0:
print("validation one ")
valid_mse = wrapper_valid(model)
valid_mse = wrapper_valid(model, iterator)
print("validation mse is:", str(valid_mse))
if valid_mse < best_mse:
......@@ -234,12 +232,8 @@ def wrapper_train(model: Model):
break
def wrapper_valid(model: Model):
def wrapper_valid(model: Model, iterator):
loss = 0
count = 0
index = 1
flag = True
real_input_flag = np.zeros(
(
args.batch_size,
......@@ -250,37 +244,34 @@ def wrapper_valid(model: Model):
)
)
output_length = args.total_length - args.input_length
while flag:
dat, (index, b_cup) = sample(batch_size, data_type="validation", index=index)
dat = nor(dat)
tars = dat[:, -output_length:]
ims = padding_CIKM_data(dat)
ims = preprocess.reshape_patch(ims, args.patch_size)
img_gen, _ = model.test(ims, real_input_flag)
# TODO: understand if 50 steps is right
steps = 50
for _ in range(steps):
imgs = next(iterator)
imgs = change_taasss_dims(imgs)
tars = imgs[:, -output_length:]
# TODO: REMOVE IT ONE THE IMAGE IS FULL
# TODO: TEST ITTTT
print("TARS", tars.shape)
tars = tars[:, :, :120, :120, :]
print("TARS", tars.shape)
imgs = padding_taasss(imgs)
# Should alreadyn be 0 to 1
# imgs = nor(imgs)
imgs = preprocess.reshape_patch(imgs, args.patch_size)
img_gen, _ = model.test(imgs, real_input_flag)
img_gen = preprocess.reshape_patch_back(img_gen, args.patch_size)
img_out = unpadding_CIKM_data(img_gen[:, -output_length:])
img_out = unpadding_taasss(img_gen[:, -output_length:])
print("SHAPE", tars.shape, img_out.shape)
mse = np.mean(np.square(tars - img_out))
loss = loss + mse
count = count + 1
if b_cup == args.batch_size - 1:
pass
else:
flag = False
return loss / count
print("LOSS", loss, "MSE", mse)
return loss / steps
def wrapper_test(model: Model):
test_save_root = args.gen_frm_dir
clean_fold(test_save_root)
loss = 0
count = 0
# index = 1
flag = True
# Shape: (1, 9, 32, 32, 16)
real_input_flag = np.zeros(
(
args.batch_size,
......@@ -291,22 +282,17 @@ def wrapper_test(model: Model):
)
)
output_length = args.total_length - args.input_length
train_model_iter = get_batcher()
index = 1
b_cup = 0 # ?????????
while flag:
index += 1
iterator = get_batcher()
steps = 10
for index in range(steps):
print("Sample is:", index)
train_batch, sample_datetimes, train_mask = next(train_model_iter)
# Just because they called it this way
dat = train_batch
dat = next(iterator)
# (25, 1, 1, 480, 480) to (1, 25, 480, 480, 1)
dat = np.squeeze(dat)
dat = np.expand_dims(dat, axis=0)
dat = np.expand_dims(dat, axis=4)
dat = nor(dat)
# Should already by 0 to 1
# dat = nor(dat)
tars = dat[:, -output_length:]
ims = padding_taasss(dat)
ims = preprocess.reshape_patch(ims, args.patch_size)
......@@ -316,28 +302,24 @@ def wrapper_test(model: Model):
img_out = unpadding_taasss(img_gen[:, -output_length:])
mse = np.mean(np.square(tars - img_out))
print(index, "MSE", mse)
img_out = de_nor(img_out)
# Should already by 0 to 1
# img_out = de_nor(img_out)
loss += mse
count += 1
bat_ind = 0
for ind in range(index - batch_size, index, 1):
save_fold = test_save_root + "sample_" + str(ind) + "/"
clean_fold(save_fold)
save_fold = test_save_root / f"sample_{ind}"
for t in range(6, 16, 1):
imsave(
save_fold + "img_" + str(t) + ".png",
save_fold / f"img_{t}.png",
img_out[bat_ind, t - 6, :, :, 0],
)
bat_ind = bat_ind + 1
if b_cup == args.batch_size - 1:
pass
else:
flag = False
return loss / count
bat_ind += 1
return loss / steps
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
save_dir = Path(args.save_dir).parent
if save_dir.exists():
shutil.rmtree(save_dir)
......
......@@ -60,9 +60,7 @@ class Model(object):
# frames_tensor = torch.FloatTensor(frames).cuda()
# mask_tensor = torch.FloatTensor(mask).cuda()
self.optimizer.zero_grad()
print("BEFORE NET")
next_frames = self.network(frames_tensor, mask_tensor)
print("AFTER NET")
loss = self.MSE_criterion(
next_frames, frames_tensor[:, 1:]
) + self.MAE_criterion(next_frames, frames_tensor[:, 1:])
......
......@@ -975,7 +975,6 @@ class InteractionDST_PredRNN(nn.Module):
batch = frames.shape[0]
height = frames.shape[3]
width = frames.shape[4]
print("FORWARD")
next_frames = []
h_t = []
c_t = []
......@@ -1034,5 +1033,4 @@ class InteractionDST_PredRNN(nn.Module):
next_frames = (
torch.stack(next_frames, dim=0).permute(1, 0, 3, 4, 2).contiguous()
)
print("END F")
return next_frames
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment