Commit 15da5952 authored by Alessia Marcolini's avatar Alessia Marcolini
Browse files

Change loss naming convention + black formatting

parent e8cbaafc
......@@ -4,6 +4,7 @@ import torch.optim as optim
from DP_classification.models.base_model import BaseModel
from DP_classification.models.dapper_models import ResNet_model
from DP_classification.utils import transfer_to_device
class ClassificationTileModel(BaseModel):
......@@ -13,8 +14,13 @@ class ClassificationTileModel(BaseModel):
self.loss_names = ['classification']
self.network_names = ['ResNet']
self.network = ResNet_model(configuration['version'], configuration['pretrained_imagenet'], configuration['n_classes'], configuration['dropout'])
self.network = self.network.to(self.device)
self.resnet = ResNet_model(
configuration['version'],
configuration['pretrained_imagenet'],
configuration['n_classes'],
configuration['dropout'],
)
self.resnet = self.resnet.to(self.device)
if self.is_train: # only defined during training time
self.criterion_loss = torch.nn.CrossEntropyLoss()
......@@ -25,7 +31,8 @@ class ClassificationTileModel(BaseModel):
print(e)
print('Using Adam optimizer instead.')
opt = optim.Adam
self.optimizer = opt(self.network.parameters(), lr=configuration['lr'])
self.optimizer = opt(self.resnet.parameters(), lr=configuration['lr'])
self.optimizers = [self.optimizer]
# storing predictions and labels for validation
......@@ -36,30 +43,29 @@ class ClassificationTileModel(BaseModel):
def forward(self):
"""Run forward pass.
"""
self.output = self.network(self.input)
self.output = self.resnet(self.input)
def backward(self):
"""Calculate losses; called in every training iteration.
"""
self.loss = self.criterion_loss(self.output, self.label)
self.loss_classification = self.criterion_loss(self.output, self.label)
def optimize_parameters(self):
"""Calculate gradients and update network weights.
"""
self.loss.backward() # calculate gradients
self.loss_classification.backward() # calculate gradients
self.optimizer.step()
self.optimizer.zero_grad()
torch.cuda.empty_cache()
def test(self):
super().test() # run the forward pass
super().test() # run the forward pass
# save predictions and labels as flat tensors
self.val_images.append(self.input)
self.val_predictions.append(self.output)
self.val_labels.append(self.label)
def post_epoch_callback(self, epoch, visualizer):
self.val_predictions = torch.cat(self.val_predictions, dim=0)
predictions = torch.argmax(self.val_predictions, dim=1)
......@@ -85,3 +91,9 @@ class ClassificationTileModel(BaseModel):
self.val_images = []
self.val_predictions = []
self.val_labels = []
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
"""
self.input = transfer_to_device(input['image'], self.device)
self.label = transfer_to_device(input['label'], self.device)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment