Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
Marco Di Francesco
IDA LSTM
Commits
95f130b7
Commit
95f130b7
authored
May 05, 2021
by
Marco Di Francesco
🍉
Browse files
Add to device instead of using cuda
parent
d4eea390
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
.gitignore
View file @
95f130b7
dataset/
checkpoints/
\ No newline at end of file
checkpoints/
dataset_generated/
\ No newline at end of file
cikm_inter_dst_predrnn_run.py
View file @
95f130b7
...
...
@@ -22,7 +22,7 @@ parser = argparse.ArgumentParser(
# training/test
parser
.
add_argument
(
"--is_training"
,
type
=
int
,
default
=
1
)
#
parser.add_argument('--device', type=str, default='
g
pu
:0
')
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
'
c
pu'
)
# data
parser
.
add_argument
(
"--dataset_name"
,
type
=
str
,
default
=
"radar"
)
...
...
@@ -53,7 +53,7 @@ parser.add_argument("--sampling_changing_rate", type=float, default=0.00002)
# optimization
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
0.001
)
parser
.
add_argument
(
"--reverse_input"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
)
#
4
parser
.
add_argument
(
"--max_iterations"
,
type
=
int
,
default
=
80000
)
parser
.
add_argument
(
"--display_interval"
,
type
=
int
,
default
=
200
)
parser
.
add_argument
(
"--test_interval"
,
type
=
int
,
default
=
2000
)
...
...
core/models/model_factory.py
View file @
95f130b7
...
...
@@ -26,8 +26,8 @@ class Model(object):
if
configs
.
model_name
in
networks_map
:
Network
=
networks_map
[
configs
.
model_name
]
#
self.network = Network(self.num_layers, self.num_hidden, configs).to(configs.device)
self
.
network
=
Network
(
self
.
num_layers
,
self
.
num_hidden
,
configs
).
cuda
()
self
.
network
=
Network
(
self
.
num_layers
,
self
.
num_hidden
,
configs
).
to
(
configs
.
device
)
#
self.network = Network(self.num_layers, self.num_hidden, configs).cuda()
else
:
raise
ValueError
(
"Name of network unknown %s"
%
configs
.
model_name
)
if
self
.
configs
.
is_parallel
:
...
...
@@ -53,11 +53,11 @@ class Model(object):
def
train
(
self
,
frames
,
mask
):
#
frames_tensor = torch.FloatTensor(frames).to(self.configs.device)
#
mask_tensor = torch.FloatTensor(mask).to(self.configs.device)
frames_tensor
=
torch
.
FloatTensor
(
frames
).
to
(
self
.
configs
.
device
)
mask_tensor
=
torch
.
FloatTensor
(
mask
).
to
(
self
.
configs
.
device
)
frames_tensor
=
torch
.
FloatTensor
(
frames
).
cuda
()
mask_tensor
=
torch
.
FloatTensor
(
mask
).
cuda
()
#
frames_tensor = torch.FloatTensor(frames).cuda()
#
mask_tensor = torch.FloatTensor(mask).cuda()
self
.
optimizer
.
zero_grad
()
next_frames
=
self
.
network
(
frames_tensor
,
mask_tensor
)
loss
=
self
.
MSE_criterion
(
...
...
@@ -69,8 +69,10 @@ class Model(object):
return
loss
.
detach
().
cpu
().
numpy
()
def
test
(
self
,
frames
,
mask
):
frames_tensor
=
torch
.
FloatTensor
(
frames
).
cuda
()
mask_tensor
=
torch
.
FloatTensor
(
mask
).
cuda
()
# frames_tensor = torch.FloatTensor(frames).cuda()
# mask_tensor = torch.FloatTensor(mask).cuda()
frames_tensor
=
torch
.
FloatTensor
(
frames
).
to
(
self
.
configs
.
device
)
mask_tensor
=
torch
.
FloatTensor
(
mask
).
to
(
self
.
configs
.
device
)
next_frames
=
self
.
network
(
frames_tensor
,
mask_tensor
)
loss
=
self
.
MSE_criterion
(
next_frames
,
frames_tensor
[:,
1
:]
...
...
core/models/predict.py
View file @
95f130b7
This diff is collapsed.
Click to expand it.
data_provider/CIKM/data_iterator.py
View file @
95f130b7
...
...
@@ -11,8 +11,7 @@ sys.path.append(rootPath)
from
core.utils.util
import
*
from
torch.utils
import
data
# from scipy.misc import imsave,imread
from
imageio
import
imread
from
imageio
import
imread
,
imsave
from
torch.utils.data
import
DataLoader
import
numpy
as
np
...
...
@@ -20,6 +19,8 @@ import random
import
torch
DATASET_DIR
=
'dataset/'
class
CIKM_Datasets
(
data
.
Dataset
):
def
__init__
(
self
,
root_path
):
self
.
root_path
=
root_path
...
...
@@ -32,7 +33,8 @@ class CIKM_Datasets(data.Dataset):
for
file
in
files
:
imgs
.
append
(
imread
(
self
.
folds
+
file
)[:,
:,
np
.
newaxis
])
imgs
=
np
.
stack
(
imgs
,
0
)
imgs
=
torch
.
from_numpy
(
imgs
).
cuda
()
# imgs = torch.from_numpy(imgs).cuda()
imgs
=
torch
.
from_numpy
(
imgs
)
in_imgs
=
imgs
[:
5
]
out_imgs
=
imgs
[
5
:]
return
in_imgs
,
out_imgs
...
...
@@ -42,7 +44,7 @@ class CIKM_Datasets(data.Dataset):
def
data_process
(
filename
,
data_type
,
dim
=
None
,
start_point
=
0
):
save_root
=
"/mnt/A/CIKM2017/CIKM_datasets/"
+
data_type
+
"/"
save_root
=
DATASET_DIR
+
data_type
+
"/"
if
start_point
==
0
:
clean_fold
(
save_root
)
...
...
@@ -50,7 +52,7 @@ def data_process(filename, data_type, dim=None, start_point=0):
if
data_type
==
"train"
:
sample_num
=
10000
validation
=
random
.
sample
(
range
(
1
,
10000
+
1
),
2000
)
save_validation_root
=
"/mnt/A/CIKM2017/CIKM_datasets/
validation/"
save_validation_root
=
DATASET_DIR
+
"
validation/"
clean_fold
(
save_validation_root
)
elif
data_type
==
"test"
:
sample_num
=
2000
+
start_point
...
...
@@ -106,7 +108,7 @@ def data_process(filename, data_type, dim=None, start_point=0):
def
sub_sample
(
batch_size
,
mode
=
"random"
,
data_type
=
"train"
,
index
=
None
,
type
=
7
):
if
type
not
in
[
4
,
5
,
6
,
7
]:
raise
(
"error"
)
save_root
=
"/mnt/A/CIKM2017/CIKM_datasets/"
+
data_type
+
"/"
save_root
=
DATASET_DIR
+
data_type
+
"/"
if
data_type
==
"train"
:
if
mode
==
"random"
:
imgs
=
[]
...
...
@@ -203,7 +205,7 @@ def sub_sample(batch_size, mode="random", data_type="train", index=None, type=7)
def
sample
(
batch_size
,
mode
=
"random"
,
data_type
=
"train"
,
index
=
None
):
save_root
=
"/mnt/A/CIKM2017/CIKM_datasets/"
+
data_type
+
"/"
save_root
=
DATASET_DIR
+
data_type
+
"/"
if
data_type
==
"train"
:
if
mode
==
"random"
:
imgs
=
[]
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment