Commit 86090f01 authored by Alessia Marcolini's avatar Alessia Marcolini
Browse files

Store training and validation loss

parent d87485ec
......@@ -16,7 +16,10 @@ class ClassificationTileModel(BaseModel):
def __init__(self, configuration):
super().__init__(configuration)
self.loss_names = ['classification']
self.loss_names = ['train_classification']
self.val_losses = []
self.n_batches_val = 0
self.network_names = ['ResNet']
self.resnet = ResNet_model(
......@@ -48,25 +51,35 @@ class ClassificationTileModel(BaseModel):
self.optimizer = opt(self.resnet.parameters(), lr=configuration['lr'])
self.optimizers = [self.optimizer]
# storing predictions and labels during training
self.train_predictions = []
self.train_labels = []
# self.train_images = []
# storing predictions and labels for validation
self.val_predictions = []
self.val_labels = []
self.val_images = []
# self.val_images = []
def forward(self):
"""Run forward pass.
"""
self.output = self.resnet(self.input)
if self.resnet.training:
# self.train_images.append(self.input)
self.train_predictions.append(self.output)
self.train_labels.append(self.label)
def backward(self):
"""Calculate losses; called in every training iteration.
"""
self.loss_classification = self.criterion_loss(self.output, self.label)
self.loss_train_classification = self.criterion_loss(self.output, self.label)
def optimize_parameters(self):
"""Calculate gradients and update network weights.
"""
self.loss_classification.backward() # calculate gradients
self.loss_train_classification.backward() # calculate gradients
self.optimizer.step()
self.optimizer.zero_grad()
torch.cuda.empty_cache()
......@@ -79,6 +92,9 @@ class ClassificationTileModel(BaseModel):
self.val_predictions.append(self.output)
self.val_labels.append(self.label)
self.val_losses.append(self.criterion_loss(self.output, self.label))
self.n_batches_val += 1
def post_epoch_callback(self, epoch):
if self.resnet.training:
......@@ -107,7 +123,18 @@ class ClassificationTileModel(BaseModel):
# Here you may do something else with the validation data such as
# displaying the validation images or calculating the ROC curve
self.val_images = []
val_losses_sum = sum(self.val_losses).item()
loss_val_avg = val_losses_sum / self.n_batches_val
print(val_losses_sum)
self.val_losses = []
self.n_batches_val = 0
if self.resnet.training:
# self.train_images = []
self.train_predictions = []
self.train_labels = []
# self.val_images = []
self.val_predictions = []
self.val_labels = []
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment