Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
Marco Di Francesco
IDA LSTM
Commits
dee581de
Commit
dee581de
authored
Jun 16, 2021
by
Marco Di Francesco
🍉
Browse files
Remove python checkpoints
parent
083ebb12
Changes
8
Hide whitespace changes
Inline
Side-by-side
core/models/.ipynb_checkpoints/model_factory-checkpoint.py
deleted
100644 → 0
View file @
083ebb12
import
os
import
torch
import
torch.nn
as
nn
from
torch.optim
import
Adam
from
core.models
import
predict
class
Model
(
object
):
def
__init__
(
self
,
configs
):
self
.
configs
=
configs
self
.
num_hidden
=
[
int
(
x
)
for
x
in
configs
.
num_hidden
.
split
(
","
)]
self
.
num_layers
=
len
(
self
.
num_hidden
)
networks_map
=
{
"convlstm"
:
predict
.
ConvLSTM
,
"predrnn"
:
predict
.
PredRNN
,
"predrnn_plus"
:
predict
.
PredRNN_Plus
,
"interact_convlstm"
:
predict
.
InteractionConvLSTM
,
"interact_predrnn"
:
predict
.
InteractionPredRNN
,
"interact_predrnn_plus"
:
predict
.
InteractionPredRNN_Plus
,
"cst_predrnn"
:
predict
.
CST_PredRNN
,
"sst_predrnn"
:
predict
.
SST_PredRNN
,
"dst_predrnn"
:
predict
.
DST_PredRNN
,
"interact_dst_predrnn"
:
predict
.
InteractionDST_PredRNN
,
}
if
not
configs
.
model_name
in
networks_map
:
raise
ValueError
(
"Name of network unknown %s"
%
configs
.
model_name
)
Network
=
networks_map
[
configs
.
model_name
]
self
.
network
=
Network
(
self
.
num_layers
,
self
.
num_hidden
,
configs
).
to
(
configs
.
device
)
# self.network = Network(self.num_layers, self.num_hidden, configs).cuda()
if
self
.
configs
.
is_parallel
:
self
.
network
=
nn
.
DataParallel
(
self
.
network
)
self
.
optimizer
=
Adam
(
self
.
network
.
parameters
(),
lr
=
configs
.
lr
)
# TODO: size_average to sum
self
.
MSE_criterion
=
nn
.
MSELoss
(
size_average
=
False
)
self
.
MAE_criterion
=
nn
.
L1Loss
(
size_average
=
False
)
def
save
(
self
,
ite
=
None
):
stats
=
{}
stats
[
"net_param"
]
=
self
.
network
.
state_dict
()
torch
.
save
(
stats
,
self
.
configs
.
save_dir
)
print
(
f
"Saving model to
{
self
.
configs
.
save_dir
}
"
)
def
load
(
self
):
if
os
.
path
.
exists
(
self
.
configs
.
save_dir
):
stats
=
torch
.
load
(
self
.
configs
.
save_dir
)
self
.
network
.
load_state_dict
(
stats
[
"net_param"
])
print
(
"Model loaded"
)
else
:
print
(
"Training from scratch"
)
def
train
(
self
,
frames
,
mask
):
frames_tensor
=
torch
.
FloatTensor
(
frames
).
to
(
self
.
configs
.
device
)
mask_tensor
=
torch
.
FloatTensor
(
mask
).
to
(
self
.
configs
.
device
)
# frames_tensor = torch.FloatTensor(frames).cuda()
# mask_tensor = torch.FloatTensor(mask).cuda()
self
.
optimizer
.
zero_grad
()
next_frames
=
self
.
network
(
frames_tensor
,
mask_tensor
)
loss
=
self
.
MSE_criterion
(
next_frames
,
frames_tensor
[:,
1
:]
)
+
self
.
MAE_criterion
(
next_frames
,
frames_tensor
[:,
1
:])
# 0.02*self.SSIM_criterion(next_frames, frames_tensor[:, 1:])
loss
.
backward
()
self
.
optimizer
.
step
()
return
loss
.
detach
().
cpu
().
numpy
()
def
test
(
self
,
frames
,
mask
):
# frames_tensor = torch.FloatTensor(frames).cuda()
# mask_tensor = torch.FloatTensor(mask).cuda()
frames_tensor
=
torch
.
FloatTensor
(
frames
).
to
(
self
.
configs
.
device
)
mask_tensor
=
torch
.
FloatTensor
(
mask
).
to
(
self
.
configs
.
device
)
next_frames
=
self
.
network
(
frames_tensor
,
mask_tensor
)
loss
=
self
.
MSE_criterion
(
next_frames
,
frames_tensor
[:,
1
:]
)
+
self
.
MAE_criterion
(
next_frames
,
frames_tensor
[:,
1
:])
# + 0.02 * self.SSIM_criterion(next_frames, frames_tensor[:, 1:])
return
next_frames
.
detach
().
cpu
().
numpy
(),
loss
.
detach
().
cpu
().
numpy
()
core/models/__pycache__/__init__.cpython-36.pyc
deleted
100644 → 0
View file @
083ebb12
File deleted
core/models/__pycache__/__init__.cpython-38.pyc
deleted
100644 → 0
View file @
083ebb12
File deleted
core/models/__pycache__/model_factory.cpython-36.pyc
deleted
100644 → 0
View file @
083ebb12
File deleted
core/models/__pycache__/model_factory.cpython-38.pyc
deleted
100644 → 0
View file @
083ebb12
File deleted
core/models/__pycache__/predict.cpython-36.pyc
deleted
100644 → 0
View file @
083ebb12
File deleted
core/models/__pycache__/predict.cpython-38.pyc
deleted
100644 → 0
View file @
083ebb12
File deleted
core/utils/.ipynb_checkpoints/util-checkpoint.py
deleted
100644 → 0
View file @
083ebb12
import
numpy
as
np
import
shutil
import
copy
import
os
def
nor
(
frames
):
new_frames
=
frames
.
astype
(
np
.
float32
)
/
255.0
return
new_frames
def
de_nor
(
frames
):
new_frames
=
copy
.
deepcopy
(
frames
)
new_frames
*=
255.0
new_frames
=
new_frames
.
astype
(
np
.
uint8
)
return
new_frames
def
normalization
(
frames
,
up
=
80
):
new_frames
=
frames
.
astype
(
np
.
float32
)
new_frames
/=
(
up
/
2
)
new_frames
-=
1
return
new_frames
def
denormalization
(
frames
,
up
=
80
):
new_frames
=
copy
.
deepcopy
(
frames
)
new_frames
+=
1
new_frames
*=
(
up
/
2
)
new_frames
=
new_frames
.
astype
(
np
.
uint8
)
return
new_frames
def
clean_fold
(
path
):
if
os
.
path
.
exists
(
path
):
shutil
.
rmtree
(
path
)
os
.
makedirs
(
path
)
else
:
os
.
makedirs
(
path
)
\ No newline at end of file
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment