Commit efb4467f authored by Alessia Marcolini's avatar Alessia Marcolini
Browse files

Add ResNet as DAPPER model

parent 1f7c07f7
import torch
import torch.nn as nn
import torch.optim as optim
from DP_classification.models.base_model import BaseModel
from DP_classification.models.dapper_models import ResNet_model
class ClassificationTileModel(BaseModel):
def __init__(self, configuration):
super().__init__(configuration)
self.loss_names = ['classification']
self.network_names = ['ResNet']
self.network = ResNet_model(configuration['version'], configuration['pretrained_imagenet'], configuration['n_classes'], configuration['dropout'])
self.network = self.network.to(self.device)
if self.is_train: # only defined during training time
self.criterion_loss = torch.nn.CrossEntropyLoss()
try:
opt = getattr(optim, configuration['optimizer'])
except AttributeError as e:
print(e)
print('Using Adam optimizer instead.')
opt = optim.Adam
self.optimizer = opt(self.network.parameters(), lr=configuration['lr'])
self.optimizers = [self.optimizer]
# storing predictions and labels for validation
self.val_predictions = []
self.val_labels = []
self.val_images = []
def forward(self):
"""Run forward pass.
"""
self.output = self.network(self.input)
def backward(self):
"""Calculate losses; called in every training iteration.
"""
self.loss = self.criterion_loss(self.output, self.label)
def optimize_parameters(self):
"""Calculate gradients and update network weights.
"""
self.loss.backward() # calculate gradients
self.optimizer.step()
self.optimizer.zero_grad()
torch.cuda.empty_cache()
def test(self):
super().test() # run the forward pass
# save predictions and labels as flat tensors
self.val_images.append(self.input)
self.val_predictions.append(self.output)
self.val_labels.append(self.label)
def post_epoch_callback(self, epoch, visualizer):
self.val_predictions = torch.cat(self.val_predictions, dim=0)
predictions = torch.argmax(self.val_predictions, dim=1)
predictions = torch.flatten(predictions).cpu()
self.val_labels = torch.cat(self.val_labels, dim=0)
labels = torch.flatten(self.val_labels).cpu()
self.val_images = torch.squeeze(torch.cat(self.val_images, dim=0)).cpu()
# Calculate and show accuracy
val_accuracy = accuracy_score(labels, predictions)
metrics = OrderedDict()
metrics['accuracy'] = val_accuracy
visualizer.plot_current_validation_metrics(epoch, metrics)
print('Validation accuracy: {0:.3f}'.format(val_accuracy))
# Here you may do something else with the validation data such as
# displaying the validation images or calculating the ROC curve
self.val_images = []
self.val_predictions = []
self.val_labels = []
import torch
import torch.nn as nn
import torchvision.models as models
from DP_classification.models.base_model import BaseModel
class DAPPER_model(nn.Module):
def __init__(self, n_features_out, adapter_out_features=1000, version='', pretrained=False, n_classes=2, dropout=0.5):
super(DAPPER_model, self).__init__()
self.adapter = nn.Sequential(nn.Linear(n_features_out, adapter_out_features),
nn.ReLU(True),
nn.Dropout(dropout))
self.fc_moba = nn.Sequential(nn.BatchNorm1d(1000),
nn.Linear(1000, 1000),
nn.ReLU(True),
nn.Dropout(dropout),
nn.BatchNorm1d(1000),
nn.Linear(1000, 256),
nn.ReLU(True),
nn.Dropout(dropout))
self.fc_final = nn.Sequential(nn.BatchNorm1d(256),
nn.Linear(256, n_classes))
def forward(self, x):
features = self.adapter(features)
features = self.fc_moba(features)
out = self.fc_final(features)
return out
def _get_conv_layers(self, version='', pretrained=False):
raise NotImplementedError
class ResNet_model(DAPPER_model):
def __init__(self, version='152', pretrained=False, n_classes=2, dropout=0.5):
assert version in ['101', '152', '18', '34', '50'], "ResNet version not supported"
features_extractors, resnet_n_features_out = self._get_conv_layers(version, pretrained)
super(ResNet_model, self).__init__(resnet_n_features_out, 1000, version, pretrained, n_classes, dropout)
self.features_extractors = features_extractors
def _get_conv_layers(self, version='152', pretrained=False):
resnet_init = getattr(models, f'resnet{version}')
resnet = resnet_init(pretrained)
features_extractors = nn.Sequential(resnet.conv1, resnet.bn1,resnet.relu, resnet.maxpool,
resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4,
resnet.avgpool)
resnet_n_features_out = resnet.layer1[0].expansion*512
return features_extractors, resnet_n_features_out
def forward(self, x):
features = self.features_extractors(x)
features = features.view(features.size(0), -1)
out = super().forward(x)
def test(self):
super().test() # run the forward pass
{
"train_dataset_params": {
"dataset_name": "tcga_breast",
"dataset_path": "",
"loader_params": {
"batch_size": 32,
"shuffle": true,
"num_workers": 4,
"pin_memory": true
},
"input_size": [200, 200]
},
"val_dataset_params": {
"dataset_name": "tcga_breast",
"dataset_path": "",
"loader_params": {
"batch_size": 24,
"shuffle": false,
"num_workers": 4,
"pin_memory": true
},
"input_size": [200, 200]
},
"model_params": {
"model_name": "classificationTile",
"version" : "152",
"pretrained_imagenet" : true,
"n_classes" : 2,
"dropout" : 0.5,
"optimizer" : "Adam",
"is_train": true,
"max_epochs": 40,
"lr": 0.01,
"export_path": "",
"checkpoint_path": "",
"load_checkpoint": -1,
"lr_policy": "step",
"lr_decay_iters": 10
},
"visualization_params": {
"name": "2d segmentation"
},
"printout_freq": 10,
"model_update_freq": 1
}
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment