Commit d87485ec authored by Alessia Marcolini's avatar Alessia Marcolini
Browse files

Add compute_metrics function

parent 80325d10
......@@ -4,6 +4,8 @@ import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score
from sklearn.metrics import matthews_corrcoef as mcor
from sklearn.metrics import roc_auc_score
from DP_classification.models.base_model import BaseModel
from DP_classification.models.dapper_models import ResNet_model
......@@ -78,23 +80,30 @@ class ClassificationTileModel(BaseModel):
self.val_labels.append(self.label)
def post_epoch_callback(self, epoch):
self.val_predictions = torch.cat(self.val_predictions, dim=0)
predictions = torch.argmax(self.val_predictions, dim=1)
predictions = torch.flatten(predictions).cpu()
self.val_labels = torch.cat(self.val_labels, dim=0)
labels = torch.flatten(self.val_labels).cpu()
if self.resnet.training:
self.val_images = torch.squeeze(torch.cat(self.val_images, dim=0)).cpu()
self.train_predictions = torch.cat(self.train_predictions, dim=0)
train_predictions = torch.argmax(self.train_predictions, dim=1)
train_predictions = torch.flatten(train_predictions).cpu()
# Calculate and show accuracy
val_accuracy = accuracy_score(labels, predictions)
self.train_labels = torch.cat(self.train_labels, dim=0)
train_labels = torch.flatten(self.train_labels).cpu()
metrics = OrderedDict()
metrics['accuracy'] = val_accuracy
# self.train_images = torch.squeeze(torch.cat(self.train_images, dim=0)).cpu()
self.compute_metrics(train_labels, train_predictions, verbose=True)
self.val_predictions = torch.cat(self.val_predictions, dim=0)
val_predictions = torch.argmax(self.val_predictions, dim=1)
val_predictions = torch.flatten(val_predictions).cpu()
print('Validation accuracy: {0:.3f}'.format(val_accuracy))
self.val_labels = torch.cat(self.val_labels, dim=0)
val_labels = torch.flatten(self.val_labels).cpu()
# self.val_images = torch.squeeze(torch.cat(self.val_images, dim=0)).cpu()
self.compute_metrics(val_labels, val_predictions, verbose=True)
# Here you may do something else with the validation data such as
# displaying the validation images or calculating the ROC curve
......@@ -107,3 +116,18 @@ class ClassificationTileModel(BaseModel):
"""
self.input = transfer_to_device(input['image'], self.device)
self.label = transfer_to_device(input['label'], self.device)
def compute_metrics(self, labels, predictions, verbose=False):
# Calculate and show accuracy
val_accuracy = accuracy_score(labels, predictions)
MCC = mcor(labels, predictions)
# auc = roc_auc_score(labels, predictions, average="micro")
metrics = OrderedDict()
metrics['accuracy'] = val_accuracy
metrics['MCC'] = MCC
# metrics['auc'] = auc
if verbose:
for metric in metrics:
print(f'Validation {metric}: {metrics[metric]}')
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment