Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
Marco Di Francesco
IDA LSTM
Commits
0812e9fd
Commit
0812e9fd
authored
Jul 23, 2021
by
Gabriele Franch
Browse files
Separate train and test in taasss dataset
parent
a5b77efb
Changes
9
Expand all
Hide whitespace changes
Inline
Side-by-side
cikm_inter_dst_predrnn_run.py
View file @
0812e9fd
...
...
@@ -29,7 +29,7 @@ parser.add_argument("--device", type=str, default="cpu") # cuda
parser
.
add_argument
(
"--dataset_name"
,
type
=
str
,
default
=
"radar"
)
parser
.
add_argument
(
"--r"
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
"--is_parallel"
,
type
=
bool
,
default
=
False
)
parser
.
add_argument
(
"--save_dir"
,
type
=
str
,
default
=
"checkpoints/
model.ckpt
"
)
parser
.
add_argument
(
"--save_dir"
,
type
=
str
,
default
=
"checkpoints/
inter_dst_predrnn
"
)
parser
.
add_argument
(
"--gen_frm_dir"
,
type
=
str
,
default
=
"dataset_generated/"
)
parser
.
add_argument
(
"--input_length"
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
"--total_length"
,
type
=
int
,
default
=
15
)
...
...
cikm_inter_dst_predrnn_run_taasss.py
View file @
0812e9fd
import
os
#
import os
import
argparse
import
math
import
shutil
#
import math
#
import shutil
import
numpy
as
np
import
core.trainer
as
trainer
from
core.models.model_factory
import
Model
from
core.utils
import
preprocess
from
data_provider.CIKM.data_iterator
import
clean_fold
,
sample
,
imsave
from
core.utils.util
import
nor
,
de_nor
from
data_provider.CIKM.taasss
import
infinite_batcher
from
pathlib
import
Path
import
h5py
import
cv2
import
pandas
as
pd
#
from data_provider.CIKM.data_iterator import clean_fold, sample, imsave
#
from core.utils.util import nor, de_nor
from
tqdm
import
tqdm
from
cikm_inter_dst_predrnn_run_taasss_utils
import
(
padding_taasss
,
get_batcher
,
)
# -----------------------------------------------------------------------------
parser
=
argparse
.
ArgumentParser
(
description
=
"PyTorch video prediction model - DST PredRNN"
)
...
...
@@ -29,11 +28,11 @@ parser.add_argument("--device", type=str, default="cuda") # cuda
parser
.
add_argument
(
"--dataset_name"
,
type
=
str
,
default
=
"radar"
)
parser
.
add_argument
(
"--r"
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
"--is_parallel"
,
type
=
bool
,
default
=
False
)
parser
.
add_argument
(
"--save_dir"
,
type
=
str
,
default
=
"checkpoints/
model.ckpt
"
)
#
parser.add_argument("--save_dir", type=str, default="checkpoints/
inter_dst_predrnn
")
parser
.
add_argument
(
"--gen_frm_dir"
,
type
=
str
,
default
=
"dataset_generated/"
)
parser
.
add_argument
(
"--input_length"
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
"--total_length"
,
type
=
int
,
default
=
25
)
# 15
parser
.
add_argument
(
"--img_width"
,
type
=
int
,
default
=
512
)
parser
.
add_argument
(
"--img_width"
,
type
=
int
,
default
=
512
)
# 512
parser
.
add_argument
(
"--img_channel"
,
type
=
int
,
default
=
1
)
# model
...
...
@@ -42,7 +41,7 @@ parser.add_argument("--pretrained_model", type=str, default="")
parser
.
add_argument
(
"--num_hidden"
,
type
=
str
,
default
=
"64,64,64,64"
)
parser
.
add_argument
(
"--filter_size"
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
"--stride"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--patch_size"
,
type
=
int
,
default
=
4
)
#
parser.add_argument("--patch_size", type=int, default=4)
parser
.
add_argument
(
"--layer_norm"
,
type
=
int
,
default
=
1
)
# scheduled sampling
...
...
@@ -56,43 +55,21 @@ parser.add_argument("--lr", type=float, default=0.001)
parser
.
add_argument
(
"--reverse_input"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
)
# 4
parser
.
add_argument
(
"--max_iterations"
,
type
=
int
,
default
=
80000
)
parser
.
add_argument
(
"--display_interval"
,
type
=
int
,
default
=
1
)
# 200
parser
.
add_argument
(
"--test_interval"
,
type
=
int
,
default
=
1
)
#
2
000
parser
.
add_argument
(
"--snapshot_interval"
,
type
=
int
,
default
=
5000
)
#
parser.add_argument("--display_interval", type=int, default=1) # 200
parser
.
add_argument
(
"--test_interval"
,
type
=
int
,
default
=
5000
)
#
5
000
#
parser.add_argument("--snapshot_interval", type=int, default=5000)
parser
.
add_argument
(
"--num_save_samples"
,
type
=
int
,
default
=
10
)
parser
.
add_argument
(
"--n_gpu"
,
type
=
int
,
default
=
0
)
args
=
parser
.
parse_args
()
batch_size
=
args
.
batch_size
def
get_batcher
():
data_dir
=
Path
(
"/"
)
/
"data2"
/
"franch"
/
"meteotn_traindata"
metadata_file
=
data_dir
/
"run_metadata.csv"
all_data
=
h5py
.
File
(
data_dir
/
"hdf_archives"
/
"all_data.hdf5"
,
"r"
,
libver
=
"latest"
,
)
outlier_mask
=
cv2
.
imread
(
str
(
data_dir
/
"mask.png"
),
0
)
metadata
=
pd
.
read_csv
(
metadata_file
,
index_col
=
"id"
)
metadata
[
"start_datetime"
]
=
pd
.
to_datetime
(
metadata
[
"start_datetime"
])
metadata
[
"end_datetime"
]
=
pd
.
to_datetime
(
metadata
[
"end_datetime"
])
sort_meta
=
metadata
.
sample
(
frac
=
1
)
batcher
=
infinite_batcher
(
all_data
,
sort_meta
,
outlier_mask
,
shuffle
=
False
,
batch_size
=
args
.
batch_size
,
# TODO: UPDATE FROM 1 TO 4
filter_threshold
=
0
,
)
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# Filter images in the ROI
for
imgs
,
_
,
masks
in
batcher
:
yield
imgs
*
~
masks
if
args
.
img_width
==
128
:
args
.
patch_size
=
4
elif
args
.
img_width
==
256
:
args
.
patch_size
=
8
elif
args
.
img_width
==
512
:
args
.
patch_size
=
16
def
schedule_sampling
(
eta
:
float
,
itr
:
int
):
...
...
@@ -163,176 +140,25 @@ def change_taasss_dims(a: np.ndarray) -> np.ndarray:
return
np
.
expand_dims
(
a
,
axis
=
4
)
def
padding_taasss
(
array
:
np
.
ndarray
)
->
np
.
ndarray
:
"""
Add padding
(1, 25, 512, 512, 1)
to
(1, 25, 480, 480, 1)
"""
zeros
=
np
.
zeros
((
1
,
25
,
512
,
512
,
1
))
zeros
[:,
:,
16
:
496
,
16
:
496
,
:]
=
array
# zeros = np.zeros((1, 25, 256, 256, 1))
# zeros[:, :, 8:248, 8:248, :] = array[:, :, :240, :240, :]
# zeros = np.zeros((1, 25, 128, 128, 1))
# zeros[:, :, 4:124, 4:124, :] = array[:, :, :120, :120, :]
return
zeros
def
unpadding_taasss
(
array
:
np
.
ndarray
)
->
np
.
ndarray
:
"""
Remove padding
(1, 25, 480, 480, 1)
to
(1, 25, 512, 512, 1)
"""
return
array
[:,
:,
16
:
496
,
16
:
496
,
:]
# return array[:, :, 8:248, 8:248, :]
# return array[:, :, 4:124, 4:124, :]
def
wrapper_train
(
model
:
Model
):
eta
=
args
.
sampling_start_value
best_mse
=
math
.
inf
tolerate
=
0
limit
=
3
iterator
=
get_batcher
()
for
itr
in
range
(
1
,
args
.
max_iterations
+
1
):
iterator
=
get_batcher
(
args
)
progress_bar
=
tqdm
(
range
(
1
,
args
.
max_iterations
+
1
))
for
itr
in
progress_bar
:
imgs
=
next
(
iterator
)
imgs
=
change_taasss_dims
(
imgs
)
imgs
=
padding_taasss
(
imgs
)
imgs
=
padding_taasss
(
imgs
,
args
)
imgs
=
preprocess
.
reshape_patch
(
imgs
,
args
.
patch_size
)
# Should already by 0 to 1
# imgs = nor(imgs)
eta
,
real_input_flag
=
schedule_sampling
(
eta
,
itr
)
cost
=
trainer
.
train
(
model
,
imgs
,
real_input_flag
,
args
)
if
itr
%
args
.
display_interval
==
0
:
print
(
"itr: "
+
str
(
itr
))
print
(
"training loss: "
+
str
(
cost
))
progress_bar
.
set_description
(
f
"Loss:
{
cost
}
"
)
if
itr
%
args
.
test_interval
==
0
:
valid_mse
=
wrapper_valid
(
model
,
iterator
)
print
(
"validation mse is:"
,
str
(
valid_mse
))
if
valid_mse
<
best_mse
:
best_mse
=
valid_mse
tolerate
=
0
model
.
save
()
else
:
tolerate
=
tolerate
+
1
if
tolerate
==
limit
:
model
.
load
()
test_mse
=
wrapper_test
(
model
)
print
(
"the best valid mse is:"
,
str
(
best_mse
))
print
(
"the test mse is "
,
str
(
test_mse
))
break
def
wrapper_valid
(
model
:
Model
,
iterator
):
loss
=
0
real_input_flag
=
np
.
zeros
(
(
args
.
batch_size
,
args
.
total_length
-
args
.
input_length
-
1
,
args
.
img_width
//
args
.
patch_size
,
args
.
img_width
//
args
.
patch_size
,
args
.
patch_size
**
2
*
args
.
img_channel
,
)
)
output_length
=
args
.
total_length
-
args
.
input_length
# TODO: understand if 50 steps is right
steps
=
50
for
_
in
range
(
steps
):
imgs
=
next
(
iterator
)
imgs
=
change_taasss_dims
(
imgs
)
tars
=
imgs
[:,
-
output_length
:]
# TODO: REMOVE IT ONE THE IMAGE IS FULL
# TODO: TEST ITTTT
print
(
"TARS"
,
tars
.
shape
)
tars
=
tars
[:,
:,
:
120
,
:
120
,
:]
print
(
"TARS"
,
tars
.
shape
)
imgs
=
padding_taasss
(
imgs
)
# Should alreadyn be 0 to 1
# imgs = nor(imgs)
imgs
=
preprocess
.
reshape_patch
(
imgs
,
args
.
patch_size
)
img_gen
,
_
=
model
.
test
(
imgs
,
real_input_flag
)
img_gen
=
preprocess
.
reshape_patch_back
(
img_gen
,
args
.
patch_size
)
img_out
=
unpadding_taasss
(
img_gen
[:,
-
output_length
:])
print
(
"SHAPE"
,
tars
.
shape
,
img_out
.
shape
)
mse
=
np
.
mean
(
np
.
square
(
tars
-
img_out
))
loss
=
loss
+
mse
print
(
"LOSS"
,
loss
,
"MSE"
,
mse
)
return
loss
/
steps
def
wrapper_test
(
model
:
Model
):
test_save_root
=
args
.
gen_frm_dir
loss
=
0
real_input_flag
=
np
.
zeros
(
(
args
.
batch_size
,
args
.
total_length
-
args
.
input_length
-
1
,
args
.
img_width
//
args
.
patch_size
,
args
.
img_width
//
args
.
patch_size
,
args
.
patch_size
**
2
*
args
.
img_channel
,
)
)
output_length
=
args
.
total_length
-
args
.
input_length
iterator
=
get_batcher
()
steps
=
10
for
index
in
range
(
steps
):
print
(
"Sample is:"
,
index
)
dat
=
next
(
iterator
)
# (25, 1, 1, 480, 480) to (1, 25, 480, 480, 1)
dat
=
np
.
squeeze
(
dat
)
dat
=
np
.
expand_dims
(
dat
,
axis
=
0
)
dat
=
np
.
expand_dims
(
dat
,
axis
=
4
)
# Should already by 0 to 1
# dat = nor(dat)
tars
=
dat
[:,
-
output_length
:]
ims
=
padding_taasss
(
dat
)
ims
=
preprocess
.
reshape_patch
(
ims
,
args
.
patch_size
)
ims
=
ims
.
astype
(
np
.
float64
)
img_gen
,
_
=
model
.
test
(
ims
,
real_input_flag
)
img_gen
=
preprocess
.
reshape_patch_back
(
img_gen
,
args
.
patch_size
)
img_out
=
unpadding_taasss
(
img_gen
[:,
-
output_length
:])
mse
=
np
.
mean
(
np
.
square
(
tars
-
img_out
))
# Should already by 0 to 1
# img_out = de_nor(img_out)
loss
+=
mse
bat_ind
=
0
for
ind
in
range
(
index
-
batch_size
,
index
,
1
):
save_fold
=
test_save_root
/
f
"sample_
{
ind
}
"
for
t
in
range
(
6
,
16
,
1
):
imsave
(
save_fold
/
f
"img_
{
t
}
.png"
,
img_out
[
bat_ind
,
t
-
6
,
:,
:,
0
],
)
bat_ind
+=
1
return
loss
/
steps
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
save_dir
=
Path
(
args
.
save_dir
).
parent
if
save_dir
.
exists
():
shutil
.
rmtree
(
save_dir
)
save_dir
.
mkdir
model
.
save
(
itr
)
# Remove dataset_generated directory
if
os
.
path
.
exists
(
args
.
gen_frm_dir
):
shutil
.
rmtree
(
args
.
gen_frm_dir
)
os
.
makedirs
(
args
.
gen_frm_dir
)
print
(
"Initializing models"
)
model
=
Model
(
args
)
print
(
"MODELTYPE"
,
type
(
model
))
model
.
load
()
# test_mse = wrapper_test(model)
wrapper_train
(
model
)
cikm_inter_dst_predrnn_run_taasss_test.ipynb
0 → 100644
View file @
0812e9fd
This diff is collapsed.
Click to expand it.
This diff is collapsed.
Click to expand it.
cikm_inter_dst_predrnn_run_taasss_utils.py
0 → 100644
View file @
0812e9fd
import
numpy
as
np
from
pathlib
import
Path
import
pandas
as
pd
import
h5py
import
cv2
from
data_provider.CIKM.taasss
import
infinite_batcher
def
get_batcher
(
args
):
data_dir
=
Path
(
"/"
)
/
"data1"
/
"meteotn_data_2010_2016"
metadata_file
=
data_dir
/
"run_metadata.csv"
all_data
=
h5py
.
File
(
data_dir
/
"hdf_archives"
/
"all_data.hdf5"
,
"r"
,
libver
=
"latest"
,
)
outlier_mask
=
cv2
.
imread
(
str
(
data_dir
/
"mask.png"
),
0
)
metadata
=
pd
.
read_csv
(
metadata_file
,
index_col
=
"id"
)
metadata
[
"start_datetime"
]
=
pd
.
to_datetime
(
metadata
[
"start_datetime"
])
metadata
[
"end_datetime"
]
=
pd
.
to_datetime
(
metadata
[
"end_datetime"
])
sort_meta
=
metadata
.
sample
(
frac
=
1
)
batcher
=
infinite_batcher
(
all_data
,
sort_meta
,
outlier_mask
,
shuffle
=
False
,
batch_size
=
args
.
batch_size
,
# TODO: UPDATE FROM 1 TO 4
filter_threshold
=
0
,
)
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# Filter images in the ROI
for
imgs
,
_
,
masks
in
batcher
:
yield
imgs
*
~
masks
def
padding_taasss
(
array
:
np
.
ndarray
,
args
)
->
np
.
ndarray
:
"""
Add padding
(1, 25, 512, 512, 1)
to
(1, 25, 480, 480, 1)
"""
zeros
=
np
.
zeros
((
1
,
25
,
args
.
img_width
,
args
.
img_width
,
1
))
if
args
.
img_width
==
512
:
zeros
[:,
:,
16
:
496
,
16
:
496
,
:]
=
array
elif
args
.
img_width
==
256
:
zeros
[:,
:,
8
:
248
,
8
:
248
,
:]
=
array
[:,
:,
:
240
,
:
240
,
:]
elif
args
.
img_width
==
128
:
zeros
[:,
:,
4
:
124
,
4
:
124
,
:]
=
array
[:,
:,
:
120
,
:
120
,
:]
return
zeros
def
unpadding_taasss
(
array
:
np
.
ndarray
,
args
)
->
np
.
ndarray
:
"""
Remove padding
(1, 25, 480, 480, 1)
to
(1, 25, 512, 512, 1)
"""
if
args
.
img_width
==
512
:
return
array
[:,
:,
16
:
496
,
16
:
496
,
:]
elif
args
.
img_width
==
256
:
return
array
[:,
:,
8
:
248
,
8
:
248
,
:]
elif
args
.
img_width
==
128
:
return
array
[:,
:,
4
:
124
,
4
:
124
,
:]
core/models/model_factory.py
View file @
0812e9fd
...
...
@@ -3,7 +3,7 @@ import torch
import
torch.nn
as
nn
from
torch.optim
import
Adam
from
core.models
import
predict
from
pathlib
import
Path
class
Model
(
object
):
def
__init__
(
self
,
configs
):
...
...
@@ -26,7 +26,6 @@ class Model(object):
if
configs
.
model_name
not
in
networks_map
:
raise
ValueError
(
"Name of network unknown %s"
%
configs
.
model_name
)
Network
=
networks_map
[
configs
.
model_name
]
print
(
"BEF NET"
)
self
.
network
=
Network
(
self
.
num_layers
,
self
.
num_hidden
,
configs
).
to
(
configs
.
device
)
...
...
@@ -39,19 +38,20 @@ class Model(object):
self
.
MSE_criterion
=
nn
.
MSELoss
(
size_average
=
False
)
self
.
MAE_criterion
=
nn
.
L1Loss
(
size_average
=
False
)
def
save
(
self
,
it
e
=
None
):
def
save
(
self
,
it
r
):
stats
=
{}
stats
[
"net_param"
]
=
self
.
network
.
state_dict
()
torch
.
save
(
stats
,
self
.
configs
.
save_dir
)
print
(
f
"Saving model to
{
self
.
configs
.
save_dir
}
"
)
save_path
=
Path
(
"/"
)
/
"data1"
/
"IDA_LSTM_checkpoints"
save_path
.
mkdir
(
exist_ok
=
True
)
save_path
=
save_path
/
f
"
{
itr
}
.pth"
torch
.
save
(
stats
,
save_path
)
# print(f"Saving model to {save_path}")
def
load
(
self
):
if
os
.
path
.
exists
(
self
.
configs
.
save_dir
):
stats
=
torch
.
load
(
self
.
configs
.
save_dir
)
self
.
network
.
load_state_dict
(
stats
[
"net_param"
])
print
(
"Model loaded"
)
else
:
print
(
"Training from scratch"
)
def
load
(
self
,
path
):
assert
os
.
path
.
exists
(
path
),
"Weights dir does not exist"
stats
=
torch
.
load
(
path
)
self
.
network
.
load_state_dict
(
stats
[
"net_param"
])
print
(
"Model loaded"
)
def
train
(
self
,
frames
,
mask
):
frames_tensor
=
torch
.
FloatTensor
(
frames
).
to
(
self
.
configs
.
device
)
...
...
core/models/predict.py
View file @
0812e9fd
...
...
@@ -929,7 +929,6 @@ class CST_PredRNN(nn.Module):
class
InteractionDST_PredRNN
(
nn
.
Module
):
def
__init__
(
self
,
num_layers
,
num_hidden
,
configs
):
print
(
"IN NET"
)
super
(
InteractionDST_PredRNN
,
self
).
__init__
()
self
.
configs
=
configs
self
.
frame_channel
=
(
...
...
requirements.txt
View file @
0812e9fd
...
...
@@ -4,4 +4,6 @@ jpype1
scipy
imageio
pandas
h5py
\ No newline at end of file
h5py
tqdm
matplotlib
\ No newline at end of file
test.ipynb
deleted
100644 → 0
View file @
a5b77efb
This diff is collapsed.
Click to expand it.
This diff is collapsed.
Click to expand it.
train.ipynb
deleted
100644 → 0
View file @
a5b77efb
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"source": [
"def wrapper_train(model):\n",
" if args.pretrained_model:\n",
" model.load(args.pretrained_model)\n",
" # load data\n",
" # train_input_handle, test_input_handle = datasets_factory.data_provider(\n",
" # args.dataset_name, args.train_data_paths, args.valid_data_paths, args.batch_size, args.img_width,\n",
" # seq_length=args.total_length, is_training=True)\n",
"\n",
" eta = args.sampling_start_value\n",
" best_mse = math.inf\n",
" tolerate = 0\n",
" limit = 3\n",
" best_iter = None\n",
" for itr in range(1, args.max_iterations + 1):\n",
"\n",
" ims = sample(batch_size=batch_size)\n",
" ims = padding_CIKM_data(ims)\n",
"\n",
" ims = preprocess.reshape_patch(ims, args.patch_size)\n",
" ims = nor(ims)\n",
" eta, real_input_flag = schedule_sampling(eta, itr)\n",
"\n",
" cost = trainer.train(model, ims, real_input_flag, args, itr)\n",
"\n",
" if itr % args.display_interval == 0:\n",
" print(\"itr: \" + str(itr))\n",
" print(\"training loss: \" + str(cost))\n",
"\n",
" if itr % args.test_interval == 0:\n",
" print(\"validation one \")\n",
" valid_mse = wrapper_valid(model)\n",
" print(\"validation mse is:\", str(valid_mse))\n",
"\n",
" if valid_mse < best_mse:\n",
" best_mse = valid_mse\n",
" best_iter = itr\n",
" tolerate = 0\n",
" model.save()\n",
" else:\n",
" tolerate = tolerate + 1\n",
"\n",
" if tolerate == limit:\n",
" model.load()\n",
" test_mse = wrapper_test(model)\n",
" print(\"the best valid mse is:\", str(best_mse))\n",
" print(\"the test mse is \", str(test_mse))\n",
" break\n",
"\n",
"\n",
"def wrapper_valid(model):\n",
" loss = 0\n",
" count = 0\n",
" index = 1\n",
" flag = True\n",
" # img_mse, ssim = [], []\n",
"\n",
" # for i in range(args.total_length - args.input_length):\n",
" # img_mse.append(0)\n",
" # ssim.append(0)\n",
"\n",
" real_input_flag = np.zeros(\n",
" (\n",
" args.batch_size,\n",
" args.total_length - args.input_length - 1,\n",
" args.img_width // args.patch_size,\n",
" args.img_width // args.patch_size,\n",
" args.patch_size ** 2 * args.img_channel,\n",
" )\n",
" )\n",
" output_length = args.total_length - args.input_length\n",
" while flag:\n",
"\n",
" dat, (index, b_cup) = sample(batch_size, data_type=\"validation\", index=index)\n",
" dat = nor(dat)\n",
" tars = dat[:, -output_length:]\n",
" ims = padding_CIKM_data(dat)\n",
"\n",
" ims = preprocess.reshape_patch(ims, args.patch_size)\n",
" img_gen, _ = model.test(ims, real_input_flag)\n",
" img_gen = preprocess.reshape_patch_back(img_gen, args.patch_size)\n",
" img_out = unpadding_CIKM_data(img_gen[:, -output_length:])\n",
"\n",
" mse = np.mean(np.square(tars - img_out))\n",
" loss = loss + mse\n",
" count = count + 1\n",
" if b_cup == args.batch_size - 1:\n",
" pass\n",
" else:\n",
" flag = False\n",
"\n",
" return loss / count\n",
"\n"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"# if args.is_training:\n",
"# wrapper_train(model)\n",
"# else:\n",
"# wrapper_test(model)"
],
"outputs": [],
"metadata": {}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
\ No newline at end of file
%% Cell type:code id: tags:
```
python
def
wrapper_train
(
model
):
if
args
.
pretrained_model
:
model
.
load
(
args
.
pretrained_model
)
# load data
# train_input_handle, test_input_handle = datasets_factory.data_provider(
# args.dataset_name, args.train_data_paths, args.valid_data_paths, args.batch_size, args.img_width,
# seq_length=args.total_length, is_training=True)
eta
=
args
.
sampling_start_value
best_mse
=
math
.
inf
tolerate
=
0
limit
=
3
best_iter
=
None
for
itr
in
range
(
1
,
args
.
max_iterations
+
1
):
ims
=
sample
(
batch_size
=
batch_size
)
ims
=
padding_CIKM_data
(
ims
)
ims
=
preprocess
.
reshape_patch
(
ims
,
args
.
patch_size
)
ims
=
nor
(
ims
)
eta
,
real_input_flag
=
schedule_sampling
(
eta
,
itr
)
cost
=
trainer
.
train
(
model
,
ims
,
real_input_flag
,
args
,
itr
)
if
itr
%
args
.
display_interval
==
0
:
print
(
"itr: "
+
str
(
itr
))
print
(
"training loss: "
+
str
(
cost
))
if
itr
%
args
.
test_interval
==
0
:
print
(
"validation one "
)
valid_mse
=
wrapper_valid
(
model
)
print
(
"validation mse is:"
,
str
(
valid_mse
))
if
valid_mse
<
best_mse
: