Commit 43c5b9fc authored by Alessia Marcolini's avatar Alessia Marcolini
Browse files

Extract features deep

parent 69401bad
import os
from collections import OrderedDict
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
......@@ -65,10 +67,10 @@ class ClassificationTileModel(BaseModel):
self.val_labels = []
# self.val_images = []
def forward(self):
def forward(self, features_extraction_only=False):
"""Run forward pass.
"""
self.output = self.resnet(self.input)
self.output = self.resnet(self.input, features_extraction_only)
if self.resnet.training:
# self.train_images.append(self.input)
......@@ -179,3 +181,45 @@ class ClassificationTileModel(BaseModel):
print(f'Tensorboard folder: {log_dir}')
self.tensorboard_writer = SummaryWriter(log_dir)
def extract_features(self, dataset):
batch_size = dataset.dataloader.batch_size
features = torch.zeros(
[len(dataset) // batch_size, self.resnet.adapter_out_features * batch_size],
dtype=torch.float64,
)
features.to(self.device)
filenames = []
patients = []
labels = []
with torch.no_grad():
for i, data in enumerate(dataset):
self.set_input(data)
self.forward(features_extraction_only=True)
features[i, :] = torch.flatten(self.output)
filenames.extend(data['filename'])
patients.extend(data['patient'])
labels.extend(data['label'])
features = np.array(features)
features_reshaped = features.reshape(
features.shape[0] * batch_size, features.shape[1] // batch_size
)
features_pd = pd.DataFrame(features_reshaped)
features_pd['filename'] = filenames
features_pd['patient'] = patients
features_pd['label'] = labels
features_pd['label'] = features_pd['label'].astype(np.uint8)
# reorder columns to have metadata first
columns = list(features_pd.columns)
columns = columns[-3:] + columns[:-3]
features_pd = features_pd[columns]
return features_pd
......@@ -17,6 +17,8 @@ class DAPPER_model(nn.Module):
):
super(DAPPER_model, self).__init__()
self.adapter_out_features = adapter_out_features
self.adapter = nn.Sequential(
nn.Linear(n_features_out, adapter_out_features),
nn.ReLU(True),
......@@ -36,11 +38,13 @@ class DAPPER_model(nn.Module):
self.fc_final = nn.Sequential(nn.BatchNorm1d(256), nn.Linear(256, n_classes))
def forward(self, x):
features = self.adapter(x)
features = self.fc_moba(features)
out = self.fc_final(features)
return out
def forward(self, x, features_extraction_only=False):
x = self.adapter(x)
if not features_extraction_only:
x = self.fc_moba(x)
x = self.fc_final(x)
return x
def _get_conv_layers(self, version='', pretrained=False):
raise NotImplementedError
......@@ -84,11 +88,11 @@ class ResNet_model(DAPPER_model):
return features_extractors, resnet_n_features_out
def forward(self, x):
def forward(self, x, features_extraction_only=False):
features = self.features_extractors(x)
features = features.view(features.size(0), -1)
out = super().forward(features)
out = super().forward(features, features_extraction_only)
return out
def test(self):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment