Commit 83317576 authored by Alessia Marcolini's avatar Alessia Marcolini
Browse files

Now use tensorboard

parent 86090f01
......@@ -17,6 +17,7 @@ data
# Results folders
results*
tb-runs
# Archived scripts
archive
......
from pathlib import Path
DATA_PATH = Path('data')
TENSORBOARD_LOG_DIR = Path('tb-runs')
DATASET_SOURCE = 'tcga_breast'
......
import os
from collections import OrderedDict
import torch
......@@ -6,6 +7,7 @@ import torch.optim as optim
from sklearn.metrics import accuracy_score
from sklearn.metrics import matthews_corrcoef as mcor
from sklearn.metrics import roc_auc_score
from torch.utils.tensorboard import SummaryWriter
from DP_classification.models.base_model import BaseModel
from DP_classification.models.dapper_models import ResNet_model
......@@ -16,6 +18,8 @@ class ClassificationTileModel(BaseModel):
def __init__(self, configuration):
super().__init__(configuration)
self.global_iteration = -1
self.loss_names = ['train_classification']
self.val_losses = []
self.n_batches_val = 0
......@@ -76,6 +80,11 @@ class ClassificationTileModel(BaseModel):
"""
self.loss_train_classification = self.criterion_loss(self.output, self.label)
if self.tensorboard_writer is not None:
self.tensorboard_writer.add_scalar(
f'train_loss', self.loss_train_classification, self.global_iteration
)
def optimize_parameters(self):
"""Calculate gradients and update network weights.
"""
......@@ -88,7 +97,7 @@ class ClassificationTileModel(BaseModel):
super().test() # run the forward pass
# save predictions and labels as flat tensors
self.val_images.append(self.input)
# self.val_images.append(self.input)
self.val_predictions.append(self.output)
self.val_labels.append(self.label)
......@@ -126,6 +135,10 @@ class ClassificationTileModel(BaseModel):
val_losses_sum = sum(self.val_losses).item()
loss_val_avg = val_losses_sum / self.n_batches_val
if self.tensorboard_writer is not None:
self.tensorboard_writer.add_scalar(
f'val_loss', loss_val_avg, self.global_iteration
)
print(val_losses_sum)
self.val_losses = []
self.n_batches_val = 0
......@@ -158,3 +171,11 @@ class ClassificationTileModel(BaseModel):
if verbose:
for metric in metrics:
print(f'Validation {metric}: {metrics[metric]}')
def configure_tensorboard(self, log_dir):
if not os.path.exists(log_dir):
os.makedirs(log_dir)
print(f'Tensorboard folder: {log_dir}')
self.tensorboard_writer = SummaryWriter(log_dir)
......@@ -26,5 +26,6 @@ dependencies:
- pytorch
- torchvision
- black
- tensorboard
prefix: //anaconda/envs/inf
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment