Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
Marco Di Francesco
IDA LSTM
Commits
f01e1734
Commit
f01e1734
authored
Sep 08, 2021
by
Gabriele Franch
Browse files
updated model save/load to include schedule sampling settings
parent
4c92f32b
Changes
2
Hide whitespace changes
Inline
Side-by-side
cikm_inter_dst_predrnn_run_taasss.py
View file @
f01e1734
...
...
@@ -130,9 +130,12 @@ def schedule_sampling(eta: float, itr: int):
def
wrapper_train
(
model
:
Model
):
itr
=
1
eta
=
args
.
sampling_start_value
if
args
.
pretrained_model
:
itr
,
eta
=
model
.
load
(
args
.
pretrained_model
)
iterator
=
get_batcher
(
args
)
progress_bar
=
tqdm
(
range
(
1
,
args
.
max_iterations
+
1
))
progress_bar
=
tqdm
(
range
(
itr
,
args
.
max_iterations
+
1
))
for
itr
in
progress_bar
:
imgs
=
next
(
iterator
)
imgs
=
change_taasss_dims
(
imgs
)
...
...
@@ -143,7 +146,7 @@ def wrapper_train(model: Model):
progress_bar
.
set_description
(
f
"Loss:
{
cost
}
"
)
if
itr
%
args
.
test_interval
==
0
:
model
.
save
(
itr
)
model
.
save
(
itr
,
eta
)
print
(
"Initializing models"
)
...
...
core/models/model_factory.py
View file @
f01e1734
...
...
@@ -39,24 +39,33 @@ class Model(object):
self
.
MSE_criterion
=
nn
.
MSELoss
(
size_average
=
False
)
self
.
MAE_criterion
=
nn
.
L1Loss
(
size_average
=
False
)
def
save
(
self
,
itr
):
state_dict
=
{
def
save
(
self
,
itr
,
eta
):
if
self
.
configs
.
is_parallel
:
model_states
=
self
.
network
.
module
.
state_dict
()
else
:
model_states
=
self
.
network
.
state_dict
()
save_dict
=
{
'iter'
:
itr
,
'model_state_dict'
:
self
.
network
.
state_dict
(),
'eta'
:
eta
,
'model_state_dict'
:
model_states
,
'optimizer_state_dict'
:
self
.
optimizer
.
state_dict
(),
"net_param"
:
self
.
network
.
state_dict
()
}
save_path
=
Path
(
"/"
)
/
"data1"
/
"IDA_LSTM_checkpoints"
save_path
.
mkdir
(
exist_ok
=
True
)
save_path
=
save_path
/
f
"
{
itr
}
.pth"
torch
.
save
(
s
tat
e_dict
,
save_path
)
torch
.
save
(
s
av
e_dict
,
save_path
)
def
load
(
self
,
path
):
assert
os
.
path
.
exists
(
path
),
"Weights dir does not exist"
stats
=
torch
.
load
(
path
,
map_location
=
torch
.
device
(
self
.
configs
.
device
))
self
.
network
.
load_state_dict
(
stats
[
"model_state_dict"
])
if
self
.
configs
.
is_parallel
:
self
.
network
.
module
.
load_state_dict
(
stats
[
"model_state_dict"
])
else
:
self
.
network
.
load_state_dict
(
stats
[
"model_state_dict"
])
self
.
optimizer
.
load_state_dict
(
stats
[
"optimizer_state_dict"
])
print
(
"Model loaded"
)
return
stats
[
'iter'
],
stats
[
'eta'
]
def
train
(
self
,
frames
,
mask
):
frames_tensor
=
torch
.
FloatTensor
(
frames
).
to
(
self
.
configs
.
device
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment