Commit 1dd689cc authored by Alessia Marcolini's avatar Alessia Marcolini
Browse files

Fix missing imports, remove visualizer

parent 4ec53493
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score
from DP_classification.models.base_model import BaseModel
from DP_classification.models.dapper_models import ResNet_model
......@@ -74,7 +77,7 @@ class ClassificationTileModel(BaseModel):
def post_epoch_callback(self, epoch, visualizer):
def post_epoch_callback(self, epoch):
self.val_predictions =, dim=0)
predictions = torch.argmax(self.val_predictions, dim=1)
predictions = torch.flatten(predictions).cpu()
......@@ -90,7 +93,6 @@ class ClassificationTileModel(BaseModel):
metrics = OrderedDict()
metrics['accuracy'] = val_accuracy
visualizer.plot_current_validation_metrics(epoch, metrics)
print('Validation accuracy: {0:.3f}'.format(val_accuracy))
# Here you may do something else with the validation data such as
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment