classificationTile_model.py 5.51 KB
Newer Older
1
2
from collections import OrderedDict

Alessia Marcolini's avatar
Alessia Marcolini committed
3
4
5
import torch
import torch.nn as nn
import torch.optim as optim
6
from sklearn.metrics import accuracy_score
7
8
from sklearn.metrics import matthews_corrcoef as mcor
from sklearn.metrics import roc_auc_score
Alessia Marcolini's avatar
Alessia Marcolini committed
9
10
11

from DP_classification.models.base_model import BaseModel
from DP_classification.models.dapper_models import ResNet_model
12
from DP_classification.utils import transfer_to_device
Alessia Marcolini's avatar
Alessia Marcolini committed
13
14
15
16
17
18


class ClassificationTileModel(BaseModel):
    def __init__(self, configuration):
        super().__init__(configuration)

19
20
21
22
        self.loss_names = ['train_classification']
        self.val_losses = []
        self.n_batches_val = 0

Alessia Marcolini's avatar
Alessia Marcolini committed
23
24
        self.network_names = ['ResNet']

25
26
27
28
29
30
        self.resnet = ResNet_model(
            configuration['version'],
            configuration['pretrained_imagenet'],
            configuration['n_classes'],
            configuration['dropout'],
        )
Alessia Marcolini's avatar
Alessia Marcolini committed
31
32
33
34
35
36
37
38
        if self.use_cuda:
            self.multigpu = True

        if self.multigpu:
            if torch.cuda.device_count() > 1:
                print("Let's use", torch.cuda.device_count(), "GPUs!")
                self.resnet = nn.DataParallel(self.resnet)

39
        self.resnet = self.resnet.to(self.device)
Alessia Marcolini's avatar
Alessia Marcolini committed
40
41
42
43
44
45
46
47
48
49

        if self.is_train:  # only defined during training time
            self.criterion_loss = torch.nn.CrossEntropyLoss()

            try:
                opt = getattr(optim, configuration['optimizer'])
            except AttributeError as e:
                print(e)
                print('Using Adam optimizer instead.')
                opt = optim.Adam
50
51

            self.optimizer = opt(self.resnet.parameters(), lr=configuration['lr'])
Alessia Marcolini's avatar
Alessia Marcolini committed
52
53
            self.optimizers = [self.optimizer]

54
55
56
57
58
        # storing predictions and labels during training
        self.train_predictions = []
        self.train_labels = []
        # self.train_images = []

Alessia Marcolini's avatar
Alessia Marcolini committed
59
60
61
        # storing predictions and labels for validation
        self.val_predictions = []
        self.val_labels = []
62
        # self.val_images = []
Alessia Marcolini's avatar
Alessia Marcolini committed
63
64
65
66

    def forward(self):
        """Run forward pass.
        """
67
        self.output = self.resnet(self.input)
Alessia Marcolini's avatar
Alessia Marcolini committed
68

69
70
71
72
73
        if self.resnet.training:
            # self.train_images.append(self.input)
            self.train_predictions.append(self.output)
            self.train_labels.append(self.label)

Alessia Marcolini's avatar
Alessia Marcolini committed
74
75
76
    def backward(self):
        """Calculate losses; called in every training iteration.
        """
77
        self.loss_train_classification = self.criterion_loss(self.output, self.label)
78

Alessia Marcolini's avatar
Alessia Marcolini committed
79
80
81
    def optimize_parameters(self):
        """Calculate gradients and update network weights.
        """
82
        self.loss_train_classification.backward()  # calculate gradients
Alessia Marcolini's avatar
Alessia Marcolini committed
83
84
85
86
87
        self.optimizer.step()
        self.optimizer.zero_grad()
        torch.cuda.empty_cache()

    def test(self):
88
        super().test()  # run the forward pass
Alessia Marcolini's avatar
Alessia Marcolini committed
89
90
91
92
93
94

        # save predictions and labels as flat tensors
        self.val_images.append(self.input)
        self.val_predictions.append(self.output)
        self.val_labels.append(self.label)

95
96
97
        self.val_losses.append(self.criterion_loss(self.output, self.label))
        self.n_batches_val += 1

98
    def post_epoch_callback(self, epoch):
Alessia Marcolini's avatar
Alessia Marcolini committed
99

100
        if self.resnet.training:
Alessia Marcolini's avatar
Alessia Marcolini committed
101

102
103
104
            self.train_predictions = torch.cat(self.train_predictions, dim=0)
            train_predictions = torch.argmax(self.train_predictions, dim=1)
            train_predictions = torch.flatten(train_predictions).cpu()
Alessia Marcolini's avatar
Alessia Marcolini committed
105

106
107
            self.train_labels = torch.cat(self.train_labels, dim=0)
            train_labels = torch.flatten(self.train_labels).cpu()
Alessia Marcolini's avatar
Alessia Marcolini committed
108

109
110
111
112
113
114
115
            # self.train_images = torch.squeeze(torch.cat(self.train_images, dim=0)).cpu()

            self.compute_metrics(train_labels, train_predictions, verbose=True)

        self.val_predictions = torch.cat(self.val_predictions, dim=0)
        val_predictions = torch.argmax(self.val_predictions, dim=1)
        val_predictions = torch.flatten(val_predictions).cpu()
Alessia Marcolini's avatar
Alessia Marcolini committed
116

117
118
119
120
        self.val_labels = torch.cat(self.val_labels, dim=0)
        val_labels = torch.flatten(self.val_labels).cpu()

        # self.val_images = torch.squeeze(torch.cat(self.val_images, dim=0)).cpu()
Alessia Marcolini's avatar
Alessia Marcolini committed
121

122
        self.compute_metrics(val_labels, val_predictions, verbose=True)
Alessia Marcolini's avatar
Alessia Marcolini committed
123
124
125
        # Here you may do something else with the validation data such as
        # displaying the validation images or calculating the ROC curve

126
127
128
129
130
131
132
133
134
135
136
137
        val_losses_sum = sum(self.val_losses).item()
        loss_val_avg = val_losses_sum / self.n_batches_val

        print(val_losses_sum)
        self.val_losses = []
        self.n_batches_val = 0
        if self.resnet.training:
            # self.train_images = []
            self.train_predictions = []
            self.train_labels = []

        # self.val_images = []
Alessia Marcolini's avatar
Alessia Marcolini committed
138
139
        self.val_predictions = []
        self.val_labels = []
140
141
142
143
144
145

    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.
        """
        self.input = transfer_to_device(input['image'], self.device)
        self.label = transfer_to_device(input['label'], self.device)
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160

    def compute_metrics(self, labels, predictions, verbose=False):
        # Calculate and show accuracy
        val_accuracy = accuracy_score(labels, predictions)
        MCC = mcor(labels, predictions)
        # auc = roc_auc_score(labels, predictions, average="micro")

        metrics = OrderedDict()
        metrics['accuracy'] = val_accuracy
        metrics['MCC'] = MCC
        # metrics['auc'] = auc

        if verbose:
            for metric in metrics:
                print(f'Validation {metric}: {metrics[metric]}')