Commit 0ae27f32 authored by Alessia Marcolini's avatar Alessia Marcolini
Browse files

Add TileDataset - to be continued

parent 8233abec
......@@ -3,8 +3,11 @@
To add a custom dataset class called 'dummy', you need to add a file called '' and define a subclass 'DummyDataset' inherited from BaseDataset.
import importlib
from torch.utils import data
from datasets.base_dataset import BaseDataset
from DP_classification.datasets.base_dataset import BaseDataset
from DP_classification.datasets.tile_dataset import TileDataset
def find_dataset_using_name(dataset_name):
......@@ -14,7 +17,7 @@ def find_dataset_using_name(dataset_name):
be instantiated. It has to be a subclass of BaseDataset,
and it is case-insensitive.
dataset_filename = "datasets." + dataset_name + "_dataset"
dataset_filename = "DP_classification.datasets." + dataset_name + "_dataset"
datasetlib = importlib.import_module(dataset_filename)
dataset = None
from DP_classification.datasets.base_dataset import BaseDataset
class TileDataset(BaseDataset):
def __init__(self, dataset_dir, metadata_tiles, configuration, mode, seed=0):
assert mode in ['train', 'test']
self.dataset_dir = dataset_dir
self.mode = mode
if 'split_id' in configuration:
metadata_tiles = metadata_tiles[metadata_tiles['split_id'] == configuration['split_id']]
# Keep only 'train' ('test') tiles
metadata_tiles = metadata_tiles[metadata_tiles['mode'] == mode]
self.filenames = metadata_tiles['tile_filename'].values
self.labels = metadata_tiles['label'].values
self.patients = metadata_tiles['patient'].values
def __getitem__(self, index):
def __len__(self):
return len(self.filenames)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment