Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
Marco Di Francesco
IDA LSTM
Commits
9988109c
Commit
9988109c
authored
Jun 16, 2021
by
Marco Di Francesco
🍉
Browse files
Fix test model for taasss
parent
dee581de
Changes
3
Hide whitespace changes
Inline
Side-by-side
.gitignore
View file @
9988109c
dataset/
checkpoints/
dataset_generated/
.venv
\ No newline at end of file
.venv
# Python files
__pycache__/
.pyc
.ipynb_checkpoints
cikm_inter_dst_predrnn_run.py
View file @
9988109c
...
...
@@ -5,6 +5,7 @@ import math
import
shutil
import
numpy
as
np
import
torch
import
core.trainer
as
trainer
from
core.models.model_factory
import
Model
...
...
@@ -263,11 +264,11 @@ def wrapper_test(model):
tars
=
dat
[:,
-
output_length
:]
ims
=
padding_CIKM_data
(
dat
)
ims
=
preprocess
.
reshape_patch
(
ims
,
args
.
patch_size
)
print
(
ims
.
shape
,
real_input_flag
.
shape
)
img_gen
,
_
=
model
.
test
(
ims
,
real_input_flag
)
img_gen
=
preprocess
.
reshape_patch_back
(
img_gen
,
args
.
patch_size
)
img_out
=
unpadding_CIKM_data
(
img_gen
[:,
-
output_length
:])
mse
=
np
.
mean
(
np
.
square
(
tars
-
img_out
))
print
(
"MSE"
,
mse
)
img_out
=
de_nor
(
img_out
)
loss
+=
mse
...
...
@@ -275,6 +276,7 @@ def wrapper_test(model):
bat_ind
=
0
print
(
"index is:"
,
index
)
print
(
"INDEX"
,
index
,
b_cup
)
for
ind
in
range
(
index
-
batch_size
,
index
,
1
):
save_fold
=
test_save_root
+
"sample_"
+
str
(
ind
)
+
"/"
...
...
cikm_inter_dst_predrnn_run_taasss.py
View file @
9988109c
...
...
@@ -34,7 +34,7 @@ parser.add_argument("--is_parallel", type=bool, default=False)
parser
.
add_argument
(
"--save_dir"
,
type
=
str
,
default
=
"checkpoints/model.ckpt"
)
parser
.
add_argument
(
"--gen_frm_dir"
,
type
=
str
,
default
=
"dataset_generated/"
)
parser
.
add_argument
(
"--input_length"
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
"--total_length"
,
type
=
int
,
default
=
15
)
# 25
parser
.
add_argument
(
"--total_length"
,
type
=
int
,
default
=
15
)
# 25
parser
.
add_argument
(
"--img_width"
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
"--img_channel"
,
type
=
int
,
default
=
1
)
...
...
@@ -264,6 +264,13 @@ def crop_taasss(array):
return
array
[:,
:
15
,
:
128
,
:
128
,
:]
def
uncrop_taasss
(
array
):
# They had to go from 101 to 128, we did 480 to 480
zeros
=
np
.
zeros
((
1
,
10
,
480
,
480
,
1
))
zeros
[:,
:,
:
128
,
:
128
,
:]
=
array
return
zeros
def
wrapper_test
(
model
):
test_save_root
=
args
.
gen_frm_dir
clean_fold
(
test_save_root
)
...
...
@@ -288,10 +295,14 @@ def wrapper_test(model):
)
)
output_length
=
args
.
total_length
-
args
.
input_length
train_model_iter
=
get_batcher
()
index
=
1
b_cup
=
0
# ?????????
while
flag
:
index
+=
1
# print("Sample is:", index)
# dat, (index, b_cup) = sample(batch_size, data_type="test", index=index)
train_model_iter
=
get_batcher
()
train_batch
,
sample_datetimes
,
train_mask
=
next
(
train_model_iter
)
# Just because they called it this way
dat
=
train_batch
...
...
@@ -305,13 +316,13 @@ def wrapper_test(model):
tars
=
dat
[:,
-
output_length
:]
ims
=
crop_taasss
(
dat
)
ims
=
preprocess
.
reshape_patch
(
ims
,
args
.
patch_size
)
print
(
ims
.
shape
,
real_input_flag
.
shape
)
img_gen
,
_
=
model
.
test
(
train_batch
,
real_input_flag
)
#
img_gen = preprocess.reshape_patch_back(img_gen, args.patch_size)
#
img_out = un
padding_CIKM_data
(img_gen[:, -output_length:])
ims
=
ims
.
astype
(
np
.
float64
)
img_gen
,
_
=
model
.
test
(
ims
,
real_input_flag
)
img_gen
=
preprocess
.
reshape_patch_back
(
img_gen
,
args
.
patch_size
)
img_out
=
un
crop_taasss
(
img_gen
[:,
-
output_length
:])
mse
=
np
.
mean
(
np
.
square
(
tars
-
img_out
))
print
(
index
,
"MSE"
,
mse
)
img_out
=
de_nor
(
img_out
)
loss
+=
mse
count
+=
1
...
...
@@ -331,7 +342,6 @@ def wrapper_test(model):
pass
else
:
flag
=
False
return
loss
/
count
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment