Commit 65b19d15 authored by Alessia Marcolini's avatar Alessia Marcolini
Browse files

Implementation of __getitem__, __shuffle__ and __transform__

parent e4c2c3c0
from pathlib import Path
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
from DP_classification.datasets.base_dataset import BaseDataset
class TileDataset(BaseDataset):
def __init__(self, dataset_dir, metadata_tiles, configuration, mode, seed=0):
def __init__(
self, configuration, dataset_dir, metadata_tiles, mode='train', seed=0
):
assert mode in ['train', 'test']
super().__init__(configuration)
self.dataset_dir = dataset_dir
self.dataset_dir = Path(dataset_dir)
self.mode = mode
self.seed = seed
np.random.seed(self.seed)
if 'split_id' in configuration:
metadata_tiles = metadata_tiles[metadata_tiles['split_id'] == configuration['split_id']]
metadata_tiles = metadata_tiles[
metadata_tiles['split_id'] == configuration['split_id']
]
# Keep only 'train' ('test') tiles
metadata_tiles = metadata_tiles[metadata_tiles['mode'] == mode]
self.filenames = metadata_tiles['tile_filename'].values
self.labels = metadata_tiles['label'].values
self.labels = metadata_tiles['label'].astype('str').values
self.patients = metadata_tiles['patient'].values
def __getitem__(self, index):
pass
def __getitem__(self, i):
filename = self.filenames[i]
img_path = self.dataset_dir / filename
image = Image.open(img_path).convert('RGB')
image_out = self.__transform__(image)
image.close()
label = self.labels[i]
sample = {'image': image_out, 'label': label, 'filename': filename}
return sample
def __shuffle__(self):
idx_permut = np.random.permutation(self.__len__())
self.filenames = self.filenames[idx_permut]
self.labels = self.labels[idx_permut]
self.patients = self.patients[idx_permut]
def __len__(self):
return len(self.filenames)
def __transform__(self, image):
# TODO: compute mean and std of GTEX dataset
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
if self.mode == 'train':
transformations = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
# transforms.RandomRotation(90), #TODO avoid borders
transforms.RandomResizedCrop(self.configuration['input_size'][0]),
transforms.ToTensor(),
normalize,
]
)
image = transformations(image)
else:
transformations = transforms.Compose(
[
transforms.CenterCrop(size=self.configuration['input_size']),
transforms.ToTensor(),
normalize,
]
)
image = transformations(image)
return image
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment