Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
Marco Di Francesco
IDA LSTM
Commits
e0d98520
Commit
e0d98520
authored
Jul 29, 2021
by
Marco Di Francesco
🍉
Browse files
Add support for single array loading in taasss
parent
104bdaa0
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
cikm_inter_dst_predrnn_run_taasss_test.ipynb
View file @
e0d98520
This diff is collapsed.
Click to expand it.
This diff is collapsed.
Click to expand it.
cikm_inter_dst_predrnn_run_taasss_utils.py
View file @
e0d98520
...
...
@@ -25,11 +25,9 @@ def get_batcher(args):
sort_meta
,
outlier_mask
,
shuffle
=
False
,
batch_size
=
args
.
batch_size
,
# TODO: UPDATE FROM 1 TO 4
batch_size
=
args
.
batch_size
,
filter_threshold
=
0
,
)
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# Filter images in the ROI
for
imgs
,
_
,
masks
in
batcher
:
yield
imgs
*
~
masks
...
...
@@ -65,3 +63,4 @@ def unpadding_taasss(array: np.ndarray, args) -> np.ndarray:
return
array
[:,
:,
8
:
248
,
8
:
248
,
:]
elif
args
.
img_width
==
128
:
return
array
[:,
:,
4
:
124
,
4
:
124
,
:]
raise
ValueError
core/models/model_factory.py
View file @
e0d98520
...
...
@@ -5,6 +5,7 @@ from torch.optim import Adam
from
core.models
import
predict
from
pathlib
import
Path
class
Model
(
object
):
def
__init__
(
self
,
configs
):
self
.
configs
=
configs
...
...
@@ -49,7 +50,7 @@ class Model(object):
def
load
(
self
,
path
):
assert
os
.
path
.
exists
(
path
),
"Weights dir does not exist"
stats
=
torch
.
load
(
path
)
stats
=
torch
.
load
(
path
,
map_location
=
torch
.
device
(
self
.
configs
.
device
)
)
self
.
network
.
load_state_dict
(
stats
[
"net_param"
])
print
(
"Model loaded"
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment