Commit e0d98520 authored by Marco Di Francesco's avatar Marco Di Francesco 🍉
Browse files

Add support for single array loading in taasss

parent 104bdaa0
This diff is collapsed.
This diff is collapsed.
......@@ -25,11 +25,9 @@ def get_batcher(args):
sort_meta,
outlier_mask,
shuffle=False,
batch_size=args.batch_size, # TODO: UPDATE FROM 1 TO 4
batch_size=args.batch_size,
filter_threshold=0,
)
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# Filter images in the ROI
for imgs, _, masks in batcher:
yield imgs * ~masks
......@@ -65,3 +63,4 @@ def unpadding_taasss(array: np.ndarray, args) -> np.ndarray:
return array[:, :, 8:248, 8:248, :]
elif args.img_width == 128:
return array[:, :, 4:124, 4:124, :]
raise ValueError
......@@ -5,6 +5,7 @@ from torch.optim import Adam
from core.models import predict
from pathlib import Path
class Model(object):
def __init__(self, configs):
self.configs = configs
......@@ -49,7 +50,7 @@ class Model(object):
def load(self, path):
assert os.path.exists(path), "Weights dir does not exist"
stats = torch.load(path)
stats = torch.load(path, map_location=torch.device(self.configs.device))
self.network.load_state_dict(stats["net_param"])
print("Model loaded")
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment