Commit 95f130b7 authored by Marco Di Francesco's avatar Marco Di Francesco 🍉
Browse files

Add to device instead of using cuda

parent d4eea390
dataset/
checkpoints/
\ No newline at end of file
checkpoints/
dataset_generated/
\ No newline at end of file
......@@ -22,7 +22,7 @@ parser = argparse.ArgumentParser(
# training/test
parser.add_argument("--is_training", type=int, default=1)
# parser.add_argument('--device', type=str, default='gpu:0')
parser.add_argument('--device', type=str, default='cpu')
# data
parser.add_argument("--dataset_name", type=str, default="radar")
......@@ -53,7 +53,7 @@ parser.add_argument("--sampling_changing_rate", type=float, default=0.00002)
# optimization
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--reverse_input", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--batch_size", type=int, default=1) # 4
parser.add_argument("--max_iterations", type=int, default=80000)
parser.add_argument("--display_interval", type=int, default=200)
parser.add_argument("--test_interval", type=int, default=2000)
......
......@@ -26,8 +26,8 @@ class Model(object):
if configs.model_name in networks_map:
Network = networks_map[configs.model_name]
# self.network = Network(self.num_layers, self.num_hidden, configs).to(configs.device)
self.network = Network(self.num_layers, self.num_hidden, configs).cuda()
self.network = Network(self.num_layers, self.num_hidden, configs).to(configs.device)
# self.network = Network(self.num_layers, self.num_hidden, configs).cuda()
else:
raise ValueError("Name of network unknown %s" % configs.model_name)
if self.configs.is_parallel:
......@@ -53,11 +53,11 @@ class Model(object):
def train(self, frames, mask):
# frames_tensor = torch.FloatTensor(frames).to(self.configs.device)
# mask_tensor = torch.FloatTensor(mask).to(self.configs.device)
frames_tensor = torch.FloatTensor(frames).to(self.configs.device)
mask_tensor = torch.FloatTensor(mask).to(self.configs.device)
frames_tensor = torch.FloatTensor(frames).cuda()
mask_tensor = torch.FloatTensor(mask).cuda()
# frames_tensor = torch.FloatTensor(frames).cuda()
# mask_tensor = torch.FloatTensor(mask).cuda()
self.optimizer.zero_grad()
next_frames = self.network(frames_tensor, mask_tensor)
loss = self.MSE_criterion(
......@@ -69,8 +69,10 @@ class Model(object):
return loss.detach().cpu().numpy()
def test(self, frames, mask):
frames_tensor = torch.FloatTensor(frames).cuda()
mask_tensor = torch.FloatTensor(mask).cuda()
# frames_tensor = torch.FloatTensor(frames).cuda()
# mask_tensor = torch.FloatTensor(mask).cuda()
frames_tensor = torch.FloatTensor(frames).to(self.configs.device)
mask_tensor = torch.FloatTensor(mask).to(self.configs.device)
next_frames = self.network(frames_tensor, mask_tensor)
loss = self.MSE_criterion(
next_frames, frames_tensor[:, 1:]
......
This diff is collapsed.
......@@ -11,8 +11,7 @@ sys.path.append(rootPath)
from core.utils.util import *
from torch.utils import data
# from scipy.misc import imsave,imread
from imageio import imread
from imageio import imread, imsave
from torch.utils.data import DataLoader
import numpy as np
......@@ -20,6 +19,8 @@ import random
import torch
DATASET_DIR = 'dataset/'
class CIKM_Datasets(data.Dataset):
def __init__(self, root_path):
self.root_path = root_path
......@@ -32,7 +33,8 @@ class CIKM_Datasets(data.Dataset):
for file in files:
imgs.append(imread(self.folds + file)[:, :, np.newaxis])
imgs = np.stack(imgs, 0)
imgs = torch.from_numpy(imgs).cuda()
# imgs = torch.from_numpy(imgs).cuda()
imgs = torch.from_numpy(imgs)
in_imgs = imgs[:5]
out_imgs = imgs[5:]
return in_imgs, out_imgs
......@@ -42,7 +44,7 @@ class CIKM_Datasets(data.Dataset):
def data_process(filename, data_type, dim=None, start_point=0):
save_root = "/mnt/A/CIKM2017/CIKM_datasets/" + data_type + "/"
save_root = DATASET_DIR + data_type + "/"
if start_point == 0:
clean_fold(save_root)
......@@ -50,7 +52,7 @@ def data_process(filename, data_type, dim=None, start_point=0):
if data_type == "train":
sample_num = 10000
validation = random.sample(range(1, 10000 + 1), 2000)
save_validation_root = "/mnt/A/CIKM2017/CIKM_datasets/validation/"
save_validation_root = DATASET_DIR + "validation/"
clean_fold(save_validation_root)
elif data_type == "test":
sample_num = 2000 + start_point
......@@ -106,7 +108,7 @@ def data_process(filename, data_type, dim=None, start_point=0):
def sub_sample(batch_size, mode="random", data_type="train", index=None, type=7):
if type not in [4, 5, 6, 7]:
raise ("error")
save_root = "/mnt/A/CIKM2017/CIKM_datasets/" + data_type + "/"
save_root = DATASET_DIR + data_type + "/"
if data_type == "train":
if mode == "random":
imgs = []
......@@ -203,7 +205,7 @@ def sub_sample(batch_size, mode="random", data_type="train", index=None, type=7)
def sample(batch_size, mode="random", data_type="train", index=None):
save_root = "/mnt/A/CIKM2017/CIKM_datasets/" + data_type + "/"
save_root = DATASET_DIR + data_type + "/"
if data_type == "train":
if mode == "random":
imgs = []
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment