Commit 6f3c869c authored by Alessia Marcolini's avatar Alessia Marcolini
Browse files

Unified dataset according to new data structure

parent c88bb9d0
import os
from pathlib import Path
import numpy as np
import pandas as pd
......@@ -6,6 +7,7 @@ import SimpleITK as sitk
import torch
from torch.utils.data import Dataset
from dicom_utils.dicom_utils import augmentation as aug
from dicom_utils.dicom_utils import processing as dup
......@@ -74,6 +76,10 @@ class NumpyCSVDataset(Dataset):
seed=1234,
):
super(NumpyCSVDataset, self).__init__()
if not isinstance(data_dir, Path):
data_dir = Path(data_dir)
self.data_dir = data_dir
self.size = size
self.transforms = transforms
......@@ -82,151 +88,51 @@ class NumpyCSVDataset(Dataset):
np.random.seed(self.seed)
clinical_file = pd.read_csv(label_file, sep=',', dtype=str).sort_values(
by=['patient']
)
unique_patients = list(clinical_file['patient'])
# example filename with augmentation: HN-HGJ-032_8.npy
patients = list(
filter(
lambda patient: patient in unique_patients,
[f.split('.')[0].split('_')[0] for f in sorted(os.listdir(data_dir))],
)
)
self._patients_full = np.array(patients)
self._files_full = np.array(
[
f
for f in sorted(os.listdir(data_dir))
if f.split('.')[0].split('_')[0] in patients
]
) # select patients in clinical file
unique_patients_labels = clinical_file[label_col].values
labels = np.array(
[
clinical_file[clinical_file['patient'] == patient][label_col].values[0]
for patient in patients
]
)
# labels = labels[self._samples]
self._labels_full = labels
self.indexes = np.arange(len(self._files_full))
def __getitem__(self, idx, no_data=False):
label = self._labels_full[self.indexes[idx]]
file = self._files_full[self.indexes[idx]]
patient = self._patients_full[self.indexes[idx]]
data_file = f'{self.data_dir}/{self._files_full[self.indexes[idx]]}'
data = np.load(data_file)
data = self.transforms(data, self.mode, self.size)
data = torch.Tensor(data)
output = {'data': data, 'target': label, 'filename': file, 'patient': patient}
return output
def get_labels(self):
return self._labels_full[self.indexes]
def get_files(self):
return self._files_full[self.indexes]
def get_patients(self):
return self._patients_full[self.indexes]
def __len__(self):
return len(self.indexes)
def __shuffle__(self):
idx_permut = np.random.permutation(self.__len__())
self._files_full = self._files_full[idx_permut]
self._labels_full = self._labels_full[idx_permut]
self._patients_full = self._patients_full[idx_permut]
self.indexes = self.indexes[idx_permut]
clinical = pd.read_csv(label_file, dtype=str).sort_values(by=['patient'])
available_files = [f for f in os.listdir(data_dir) if f.endswith('.npy')]
class NumpyCSVDataset_BZ(Dataset):
def __init__(
self,
data_dir,
label_file,
label_col,
size,
transforms=augment_3D,
mode='train',
seed=1234,
):
super(NumpyCSVDataset_BZ, self).__init__()
self.data_dir = data_dir
self.size = size
self.transforms = transforms
self.mode = mode
self.seed = seed
# filter the clinical file in order to keep files that are really on disk
clinical = clinical.loc[clinical['filename'].isin(available_files)]
np.random.seed(self.seed)
self._filenames_full = clinical['filename'].values
self._patients_full = clinical['patient'].values
self._labels_full = clinical[label_col].values
clinical_file = pd.read_csv(label_file, sep=';', dtype=str).sort_values(
by=['patient']
)
self.indices = np.arange(len(self._filenames_full))
unique_patients = list(clinical_file['patient'])
# example filename with augmentation: HN-HGJ-032_8.npy
patients = list(
filter(
lambda patient: patient in unique_patients,
[f.split('.')[0].split('_')[0] for f in sorted(os.listdir(data_dir))],
)
)
self._patients_full = np.array(patients)
self._files_full = np.array(
[
f
for f in sorted(os.listdir(data_dir))
if f.split('.')[0].split('_')[0] in patients
]
) # select patients in clinical file
unique_patients_labels = clinical_file[label_col].values
labels = np.array(
[
clinical_file[clinical_file['patient'] == patient][label_col].values[0]
for patient in patients
]
)
# labels = labels[self._samples]
self._labels_full = labels
self.indexes = np.arange(len(self._files_full))
def __getitem__(self, idx, no_data=False):
label = self._labels_full[self.indexes[idx]]
file = self._files_full[self.indexes[idx]]
patient = self._patients_full[self.indexes[idx]]
data_file = f'{self.data_dir}/{self._files_full[self.indexes[idx]]}'
def __getitem__(self, idx):
label = self._labels_full[self.indices[idx]]
filename = self._filenames_full[self.indices[idx]]
patient = self._patients_full[self.indices[idx]]
data_file = self.data_dir / self._filenames_full[self.indices[idx]]
data = np.load(data_file)
data = self.transforms(data, self.mode, self.size)
data = torch.Tensor(data)
output = {'data': data, 'target': label, 'filename': file, 'patient': patient}
output = {
'data': data,
'target': label,
'filename': filename,
'patient': patient,
}
return output
def get_labels(self):
return self._labels_full[self.indexes]
return self._labels_full[self.indices]
def get_files(self):
return self._files_full[self.indexes]
return self._filenames_full[self.indices]
def get_patients(self):
return self._patients_full[self.indexes]
return self._patients_full[self.indices]
def __len__(self):
return len(self.indexes)
return len(self.indices)
def __shuffle__(self):
idx_permut = np.random.permutation(self.__len__())
self._files_full = self._files_full[idx_permut]
self._filenames_full = self._filenames_full[idx_permut]
self._labels_full = self._labels_full[idx_permut]
self._patients_full = self._patients_full[idx_permut]
self.indexes = self.indexes[idx_permut]
self.indices = self.indices[idx_permut]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment