Commit dee581de authored by Marco Di Francesco's avatar Marco Di Francesco 🍉
Browse files

Remove python checkpoints

parent 083ebb12
import os
import torch
import torch.nn as nn
from torch.optim import Adam
from core.models import predict
class Model(object):
def __init__(self, configs):
self.configs = configs
self.num_hidden = [int(x) for x in configs.num_hidden.split(",")]
self.num_layers = len(self.num_hidden)
networks_map = {
"convlstm": predict.ConvLSTM,
"predrnn": predict.PredRNN,
"predrnn_plus": predict.PredRNN_Plus,
"interact_convlstm": predict.InteractionConvLSTM,
"interact_predrnn": predict.InteractionPredRNN,
"interact_predrnn_plus": predict.InteractionPredRNN_Plus,
"cst_predrnn": predict.CST_PredRNN,
"sst_predrnn": predict.SST_PredRNN,
"dst_predrnn": predict.DST_PredRNN,
"interact_dst_predrnn": predict.InteractionDST_PredRNN,
}
if not configs.model_name in networks_map:
raise ValueError("Name of network unknown %s" % configs.model_name)
Network = networks_map[configs.model_name]
self.network = Network(self.num_layers, self.num_hidden, configs).to(
configs.device
)
# self.network = Network(self.num_layers, self.num_hidden, configs).cuda()
if self.configs.is_parallel:
self.network = nn.DataParallel(self.network)
self.optimizer = Adam(self.network.parameters(), lr=configs.lr)
# TODO: size_average to sum
self.MSE_criterion = nn.MSELoss(size_average=False)
self.MAE_criterion = nn.L1Loss(size_average=False)
def save(self, ite=None):
stats = {}
stats["net_param"] = self.network.state_dict()
torch.save(stats, self.configs.save_dir)
print(f"Saving model to {self.configs.save_dir}")
def load(self):
if os.path.exists(self.configs.save_dir):
stats = torch.load(self.configs.save_dir)
self.network.load_state_dict(stats["net_param"])
print("Model loaded")
else:
print("Training from scratch")
def train(self, frames, mask):
frames_tensor = torch.FloatTensor(frames).to(self.configs.device)
mask_tensor = torch.FloatTensor(mask).to(self.configs.device)
# frames_tensor = torch.FloatTensor(frames).cuda()
# mask_tensor = torch.FloatTensor(mask).cuda()
self.optimizer.zero_grad()
next_frames = self.network(frames_tensor, mask_tensor)
loss = self.MSE_criterion(
next_frames, frames_tensor[:, 1:]
) + self.MAE_criterion(next_frames, frames_tensor[:, 1:])
# 0.02*self.SSIM_criterion(next_frames, frames_tensor[:, 1:])
loss.backward()
self.optimizer.step()
return loss.detach().cpu().numpy()
def test(self, frames, mask):
# frames_tensor = torch.FloatTensor(frames).cuda()
# mask_tensor = torch.FloatTensor(mask).cuda()
frames_tensor = torch.FloatTensor(frames).to(self.configs.device)
mask_tensor = torch.FloatTensor(mask).to(self.configs.device)
next_frames = self.network(frames_tensor, mask_tensor)
loss = self.MSE_criterion(
next_frames, frames_tensor[:, 1:]
) + self.MAE_criterion(next_frames, frames_tensor[:, 1:])
# + 0.02 * self.SSIM_criterion(next_frames, frames_tensor[:, 1:])
return next_frames.detach().cpu().numpy(), loss.detach().cpu().numpy()
import numpy as np
import shutil
import copy
import os
def nor(frames):
new_frames = frames.astype(np.float32)/255.0
return new_frames
def de_nor(frames):
new_frames = copy.deepcopy(frames)
new_frames *= 255.0
new_frames = new_frames.astype(np.uint8)
return new_frames
def normalization(frames,up=80):
new_frames = frames.astype(np.float32)
new_frames /= (up/2)
new_frames -= 1
return new_frames
def denormalization(frames,up=80):
new_frames = copy.deepcopy(frames)
new_frames += 1
new_frames *= (up/2)
new_frames = new_frames.astype(np.uint8)
return new_frames
def clean_fold(path):
if os.path.exists(path):
shutil.rmtree(path)
os.makedirs(path)
else:
os.makedirs(path)
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment