Commit 9bb0ecb7 authored by Alessia Marcolini's avatar Alessia Marcolini
Browse files

Move digital pathology pipeline to MPBA/digital-pathology-classification

parent a17fa855
from pathlib import Path
DATA_PATH = Path('data')
TENSORBOARD_LOG_DIR = Path('tb-runs')
DATASET_SOURCE = 'tcga_breast'
SVS_DIR = DATA_PATH / DATASET_SOURCE / 'svs'
EXPERIMENT_DATA_DIR = DATA_PATH / DATASET_SOURCE / 'ER'
TILES_DIR = EXPERIMENT_DATA_DIR / 'ER_balanced_tiles_level_0'
ORIGINAL_LABEL_COLUMN = 'patient.breast_carcinoma_estrogen_receptor_status'
N_SPLITS = 10
"""This package includes all the modules related to data loading and preprocessing.
To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
"""
import importlib
from torch.utils import data
from DP_classification.datasets.base_dataset import BaseDataset
from DP_classification.datasets.tile_dataset import TileDataset
def find_dataset_using_name(dataset_name):
"""Import the module "data/[dataset_name]_dataset.py".
In the file, the class called DatasetNameDataset() will
be instantiated. It has to be a subclass of BaseDataset,
and it is case-insensitive.
"""
dataset_filename = "DP_classification.datasets." + dataset_name + "_dataset"
datasetlib = importlib.import_module(dataset_filename)
dataset = None
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
for name, cls in datasetlib.__dict__.items():
if name.lower() == target_dataset_name.lower() and issubclass(cls, BaseDataset):
dataset = cls
if dataset is None:
raise NotImplementedError(
'In {0}.py, there should be a subclass of BaseDataset with class name that matches {1} in lowercase.'.format(
dataset_filename, target_dataset_name
)
)
return dataset
def create_dataset(configuration, *args, **kwargs):
"""Create a dataset given the configuration (loaded from the json file).
This function wraps the class CustomDatasetDataLoader.
This is the main interface between this package and train.py/validate.py
Example:
from datasets import create_dataset
dataset = create_dataset(configuration, *args, **kwargs)
"""
data_loader = CustomDatasetDataLoader(configuration, *args, **kwargs)
dataset = data_loader.load_data()
return dataset
class CustomDatasetDataLoader:
"""Wrapper class of Dataset class that performs multi-threaded data loading
according to the configuration.
"""
def __init__(self, configuration, *args, **kwargs):
self.configuration = configuration
dataset_class = find_dataset_using_name(configuration['dataset_name'])
self.dataset = dataset_class(configuration, *args, **kwargs)
print("dataset [{0}] was created".format(type(self.dataset).__name__))
# if we use custom collation, define it as a staticmethod in the dataset class
custom_collate_fn = getattr(self.dataset, "collate_fn", None)
if callable(custom_collate_fn):
self.dataloader = data.DataLoader(
self.dataset,
**configuration['loader_params'],
collate_fn=custom_collate_fn
)
else:
self.dataloader = data.DataLoader(
self.dataset, **configuration['loader_params']
)
def load_data(self):
return self
def get_custom_dataloader(self, custom_configuration):
"""Get a custom dataloader (e.g. for exporting the model).
This dataloader may use different configurations than the
default train_dataloader and val_dataloader.
"""
custom_collate_fn = getattr(self.dataset, "collate_fn", None)
if callable(custom_collate_fn):
custom_dataloader = data.DataLoader(
self.dataset,
**self.configuration['loader_params'],
collate_fn=custom_collate_fn
)
else:
custom_dataloader = data.DataLoader(
self.dataset, **self.configuration['loader_params']
)
return custom_dataloader
def __len__(self):
"""Return the number of data in the dataset.
"""
return len(self.dataset)
def __iter__(self):
"""Return a batch of data.
"""
for data in self.dataloader:
yield data
"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets. Also
includes some transformation functions.
"""
from abc import ABC, abstractmethod
import numpy as np
import torch.utils.data as data
class BaseDataset(data.Dataset, ABC):
"""This class is an abstract base class (ABC) for datasets.
"""
def __init__(self, configuration):
"""Initialize the class; save the configuration in the class.
"""
self.configuration = configuration
@abstractmethod
def __len__(self):
"""Return the total number of images in the dataset."""
return 0
@abstractmethod
def __getitem__(self, index):
"""Return a data point (usually data and labels in
a supervised setting).
"""
pass
def pre_epoch_callback(self, epoch):
"""Callback to be called before every epoch.
"""
pass
def post_epoch_callback(self, epoch):
"""Callback to be called after every epoch.
"""
pass
from pathlib import Path
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
from DP_classification.datasets.base_dataset import BaseDataset
class TileDataset(BaseDataset):
def __init__(
self, configuration, dataset_dir, metadata_tiles, mode='train', seed=0
):
assert mode in ['train', 'test']
super().__init__(configuration)
self.dataset_dir = Path(dataset_dir)
self.mode = mode
self.seed = seed
np.random.seed(self.seed)
if 'split_id' in configuration:
metadata_tiles = metadata_tiles[
metadata_tiles['split_id'] == configuration['split_id']
]
# Keep only 'train' ('test') tiles
metadata_tiles = metadata_tiles[metadata_tiles['mode'] == mode]
self.filenames = metadata_tiles['tile_filename'].values
self.labels = metadata_tiles['label'].values
self.patients = metadata_tiles['patient'].values
def __getitem__(self, i):
filename = self.filenames[i]
img_path = self.dataset_dir / filename
image = Image.open(img_path).convert('RGB')
image_out = self.__transform__(image)
image.close()
label = self.labels[i]
patient = self.patients[i]
sample = {
'image': image_out,
'label': label,
'filename': filename,
'patient': patient,
}
return sample
def __shuffle__(self):
idx_permut = np.random.permutation(self.__len__())
self.filenames = self.filenames[idx_permut]
self.labels = self.labels[idx_permut]
self.patients = self.patients[idx_permut]
def __len__(self):
return len(self.filenames)
def __transform__(self, image):
# TODO: compute mean and std of GTEX dataset
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
if self.mode == 'train':
transformations = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
# transforms.RandomRotation(90), #TODO avoid borders
transforms.RandomResizedCrop(self.configuration['input_size'][0]),
transforms.ToTensor(),
normalize,
]
)
image = transformations(image)
else:
transformations = transforms.Compose(
[
transforms.CenterCrop(size=self.configuration['input_size']),
transforms.ToTensor(),
normalize,
]
)
image = transformations(image)
return image
"""This package contains modules related to objective functions, optimizations, and network architectures.
To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
You need to implement the following five functions:
-- <__init__>: initialize the class; first call BaseModel.__init__(self, configuration).
-- <set_input>: unpack data from dataset and apply preprocessing.
-- <forward>: produce intermediate results.
-- <optimize_parameters>: calculate loss, gradients, and update network weights.
In the function <__init__>, you need to define four lists:
-- self.network_names (str list): define networks used in our training.
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them.
"""
import importlib
from sys import exit
from torch.optim import lr_scheduler
from DP_classification.models.base_model import BaseModel
def find_model_using_name(model_name):
"""Import the module "models/[model_name]_model.py".
In the file, the class called DatasetNameModel() will
be instantiated. It has to be a subclass of BaseModel,
and it is case-insensitive.
"""
model_filename = "DP_classification.models." + model_name + "_model"
modellib = importlib.import_module(model_filename)
model = None
target_model_name = model_name.replace('_', '') + 'model'
for name, cls in modellib.__dict__.items():
if name.lower() == target_model_name.lower() and issubclass(cls, BaseModel):
model = cls
if model is None:
print(
"In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase."
% (model_filename, target_model_name)
)
exit(0)
return model
def create_model(configuration):
"""Create a model given the configuration.
This is the main interface between this package and train.py/validate.py
"""
model = find_model_using_name(configuration['model_name'])
instance = model(configuration)
print("model [{0}] was created".format(type(instance).__name__))
return instance
import os
from abc import ABC, abstractmethod
from collections import OrderedDict
from pathlib import Path
import torch
from DP_classification.utils import get_scheduler, transfer_to_device
class BaseModel(ABC):
"""This class is an abstract base class (ABC) for models.
"""
def __init__(self, configuration):
"""Initialize the BaseModel class.
Parameters:
configuration: Configuration dictionary.
When creating your custom class, you need to implement your own initialization.
In this fucntion, you should first call <BaseModel.__init__(self, opt)>
Then, you need to define these lists:
-- self.network_names (str list): define networks used in our training.
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network.
If two networks are updated at the same time, you can use itertools.chain to group
them. See cycle_gan_model.py for an example.
"""
self.configuration = configuration
self.is_train = configuration['is_train']
self.use_cuda = torch.cuda.is_available()
self.device = torch.device('cuda:0') if self.use_cuda else torch.device('cpu')
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.enabled = False
self.save_dir = configuration['checkpoint_path']
self.network_names = []
self.loss_names = []
self.optimizers = []
self.visual_names = []
def __setattr__(self, name, value):
if name == 'save_dir':
if not isinstance(value, Path):
value = Path(value)
super().__setattr__(name, value)
if not value.exists():
os.makedirs(value)
else:
super().__setattr__(name, value)
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
The implementation here is just a basic setting of input and label. You may implement
other functionality in your own model.
"""
self.input = transfer_to_device(input[0], self.device)
self.label = transfer_to_device(input[1], self.device)
@abstractmethod
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
pass
@abstractmethod
def optimize_parameters(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
pass
def setup(self, verbose=False):
"""Load and print networks; create schedulers.
"""
if self.configuration['load_checkpoint'] >= 0:
last_checkpoint = self.configuration['load_checkpoint']
else:
last_checkpoint = -1
if last_checkpoint >= 0:
# enable restarting training
self.load_networks(last_checkpoint)
if self.is_train:
self.load_optimizers(last_checkpoint)
for o in self.optimizers:
o.param_groups[0]['lr'] = o.param_groups[0][
'initial_lr'
] # reset learning rate
self.schedulers = [
get_scheduler(optimizer, self.configuration)
for optimizer in self.optimizers
]
if last_checkpoint > 0:
for s in self.schedulers:
for _ in range(last_checkpoint):
s.step()
if verbose:
self.print_networks()
def train(self):
"""Make models train mode during test time."""
for name in self.network_names:
if isinstance(name, str):
net = getattr(self, name.lower())
net.train()
def eval(self):
"""Make models eval mode during test time."""
for name in self.network_names:
if isinstance(name, str):
net = getattr(self, name.lower())
net.eval()
def test(self):
"""Forward function used in test time.
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
"""
with torch.no_grad():
self.forward()
def update_learning_rate(self):
"""Update learning rates for all the networks; called at the end of every epoch"""
for scheduler in self.schedulers:
scheduler.step()
lr = self.optimizers[0].param_groups[0]['lr']
print('learning rate = {0:.7f}'.format(lr))
def save_networks(self, epoch, filename=None):
"""Save all the networks to the disk.
"""
for name in self.network_names:
if isinstance(name, str):
if filename is None:
if filename is None:
save_filename = '{0}_net_{1}.pth'.format(epoch, name)
else:
save_filename = filename
save_path = self.save_dir / 'weights' / save_filename
net = getattr(self, name.lower())
if self.use_cuda:
torch.save(net.cpu().state_dict(), save_path)
net.to(self.device)
else:
torch.save(net.cpu().state_dict(), save_path)
def load_networks(self, epoch):
"""Load all the networks from the disk.
"""
for name in self.network_names:
if isinstance(name, str):
load_filename = '{0}_net_{1}.pth'.format(epoch, name)
load_path = self.save_dir / 'weights' / load_filename
net = getattr(self, name.lower())
if isinstance(net, torch.nn.DataParallel):
net = net.module
print('loading the model from {0}'.format(load_path))
state_dict = torch.load(load_path, map_location=self.device)
if hasattr(state_dict, '_metadata'):
del state_dict._metadata
net.load_state_dict(state_dict)
def save_optimizers(self, epoch):
"""Save all the optimizers to the disk for restarting training.
"""
for i, optimizer in enumerate(self.optimizers):
save_filename = '{0}_optimizer_{1}.pth'.format(epoch, i)
save_path = self.save_dir / 'weights' / save_filename
torch.save(optimizer.state_dict(), save_path)
def load_optimizers(self, epoch):
"""Load all the optimizers from the disk.
"""
for i, optimizer in enumerate(self.optimizers):
load_filename = '{0}_optimizer_{1}.pth'.format(epoch, i)
load_path = self.save_dir / 'weights' / load_filename
print('loading the optimizer from {0}'.format(load_path))
state_dict = torch.load(load_path)
if hasattr(state_dict, '_metadata'):
del state_dict._metadata
optimizer.load_state_dict(state_dict)
def print_networks(self):
"""Print the total number of parameters in the network and network architecture.
"""
print('Networks initialized')
for name in self.network_names:
if isinstance(name, str):
net = self.network
num_params = 0
for param in net.parameters():
num_params += param.numel()
print(net)
print(
'[Network {0}] Total number of parameters : {1:.3f} M'.format(
name, num_params / 1e6
)
)
def set_requires_grad(self, requires_grad=False):
"""Set requies_grad for all the networks to avoid unnecessary computations.
"""
for name in self.network_names:
if isinstance(name, str):
net = getattr(self, name.lower())
for param in net.parameters():
param.requires_grad = requires_grad
def get_current_losses(self):
"""Return traning losses / errors. train.py will print out these errors on console"""
errors_ret = dict()
for name in self.loss_names:
if isinstance(name, str):
errors_ret[name] = float(
getattr(self, 'loss_' + name)
) # float(...) works for both scalar tensor and float number
return errors_ret
def pre_epoch_callback(self, epoch):
pass
def post_epoch_callback(self, epoch, visualizer):
pass
def get_hyperparam_result(self):
"""Returns the final training result for hyperparameter tuning (e.g. best
validation loss).
"""
pass
def export(self):
"""Exports all the networks of the model using JIT tracing. Requires that the
input is set.
"""
for name in self.network_names:
if isinstance(name, str):
net = getattr(self, name.lower())
export_path = os.path.join(
self.configuration['export_path'],
'exported_net_{}.pth'.format(name),
)
if isinstance(
self.input, list
): # we have to modify the input for tracing
self.input = [tuple(self.input)]
traced_script_module = torch.jit.trace(net, self.input)
traced_script_module.save(export_path)
def get_current_visuals(self):
"""Return visualization images. train.py will display these images."""
visual_ret = OrderedDict()
for name in self.visual_names:
if isinstance(name, str):
visual_ret[name] = getattr(self, name.lower())
return visual_ret
import os
from collections import OrderedDict
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score
from sklearn.metrics import matthews_corrcoef as mcor
from sklearn.metrics import roc_auc_score
from torch.utils.tensorboard import SummaryWriter
from DP_classification.models.base_model import BaseModel
from DP_classification.models.dapper_models import ResNet_model
from DP_classification.utils import transfer_to_device
class ClassificationTileModel(BaseModel):
def __init__(self, configuration):
super().__init__(configuration)
</