Commit 6f3c869c authored by Alessia Marcolini's avatar Alessia Marcolini
Browse files

Unified dataset according to new data structure

parent c88bb9d0
import os import os
from pathlib import Path
import numpy as np import numpy as np
import pandas as pd import pandas as pd
...@@ -6,6 +7,7 @@ import SimpleITK as sitk ...@@ -6,6 +7,7 @@ import SimpleITK as sitk
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from dicom_utils.dicom_utils import augmentation as aug from dicom_utils.dicom_utils import augmentation as aug
from dicom_utils.dicom_utils import processing as dup from dicom_utils.dicom_utils import processing as dup
...@@ -74,6 +76,10 @@ class NumpyCSVDataset(Dataset): ...@@ -74,6 +76,10 @@ class NumpyCSVDataset(Dataset):
seed=1234, seed=1234,
): ):
super(NumpyCSVDataset, self).__init__() super(NumpyCSVDataset, self).__init__()
if not isinstance(data_dir, Path):
data_dir = Path(data_dir)
self.data_dir = data_dir self.data_dir = data_dir
self.size = size self.size = size
self.transforms = transforms self.transforms = transforms
...@@ -82,151 +88,51 @@ class NumpyCSVDataset(Dataset): ...@@ -82,151 +88,51 @@ class NumpyCSVDataset(Dataset):
np.random.seed(self.seed) np.random.seed(self.seed)
clinical_file = pd.read_csv(label_file, sep=',', dtype=str).sort_values( clinical = pd.read_csv(label_file, dtype=str).sort_values(by=['patient'])
by=['patient']
)
unique_patients = list(clinical_file['patient'])
# example filename with augmentation: HN-HGJ-032_8.npy
patients = list(
filter(
lambda patient: patient in unique_patients,
[f.split('.')[0].split('_')[0] for f in sorted(os.listdir(data_dir))],
)
)
self._patients_full = np.array(patients)
self._files_full = np.array(
[
f
for f in sorted(os.listdir(data_dir))
if f.split('.')[0].split('_')[0] in patients
]
) # select patients in clinical file
unique_patients_labels = clinical_file[label_col].values
labels = np.array(
[
clinical_file[clinical_file['patient'] == patient][label_col].values[0]
for patient in patients
]
)
# labels = labels[self._samples]
self._labels_full = labels
self.indexes = np.arange(len(self._files_full))
def __getitem__(self, idx, no_data=False):
label = self._labels_full[self.indexes[idx]]
file = self._files_full[self.indexes[idx]]
patient = self._patients_full[self.indexes[idx]]
data_file = f'{self.data_dir}/{self._files_full[self.indexes[idx]]}'
data = np.load(data_file)
data = self.transforms(data, self.mode, self.size)
data = torch.Tensor(data)
output = {'data': data, 'target': label, 'filename': file, 'patient': patient}
return output
def get_labels(self):
return self._labels_full[self.indexes]
def get_files(self):
return self._files_full[self.indexes]
def get_patients(self):
return self._patients_full[self.indexes]
def __len__(self):
return len(self.indexes)
def __shuffle__(self):
idx_permut = np.random.permutation(self.__len__())
self._files_full = self._files_full[idx_permut]
self._labels_full = self._labels_full[idx_permut]
self._patients_full = self._patients_full[idx_permut]
self.indexes = self.indexes[idx_permut]
available_files = [f for f in os.listdir(data_dir) if f.endswith('.npy')]
class NumpyCSVDataset_BZ(Dataset): # filter the clinical file in order to keep files that are really on disk
def __init__( clinical = clinical.loc[clinical['filename'].isin(available_files)]
self,
data_dir,
label_file,
label_col,
size,
transforms=augment_3D,
mode='train',
seed=1234,
):
super(NumpyCSVDataset_BZ, self).__init__()
self.data_dir = data_dir
self.size = size
self.transforms = transforms
self.mode = mode
self.seed = seed
np.random.seed(self.seed) self._filenames_full = clinical['filename'].values
self._patients_full = clinical['patient'].values
self._labels_full = clinical[label_col].values
clinical_file = pd.read_csv(label_file, sep=';', dtype=str).sort_values( self.indices = np.arange(len(self._filenames_full))
by=['patient']
)
unique_patients = list(clinical_file['patient']) def __getitem__(self, idx):
# example filename with augmentation: HN-HGJ-032_8.npy label = self._labels_full[self.indices[idx]]
patients = list( filename = self._filenames_full[self.indices[idx]]
filter( patient = self._patients_full[self.indices[idx]]
lambda patient: patient in unique_patients, data_file = self.data_dir / self._filenames_full[self.indices[idx]]
[f.split('.')[0].split('_')[0] for f in sorted(os.listdir(data_dir))],
)
)
self._patients_full = np.array(patients)
self._files_full = np.array(
[
f
for f in sorted(os.listdir(data_dir))
if f.split('.')[0].split('_')[0] in patients
]
) # select patients in clinical file
unique_patients_labels = clinical_file[label_col].values
labels = np.array(
[
clinical_file[clinical_file['patient'] == patient][label_col].values[0]
for patient in patients
]
)
# labels = labels[self._samples]
self._labels_full = labels
self.indexes = np.arange(len(self._files_full))
def __getitem__(self, idx, no_data=False):
label = self._labels_full[self.indexes[idx]]
file = self._files_full[self.indexes[idx]]
patient = self._patients_full[self.indexes[idx]]
data_file = f'{self.data_dir}/{self._files_full[self.indexes[idx]]}'
data = np.load(data_file) data = np.load(data_file)
data = self.transforms(data, self.mode, self.size) data = self.transforms(data, self.mode, self.size)
data = torch.Tensor(data) data = torch.Tensor(data)
output = {'data': data, 'target': label, 'filename': file, 'patient': patient} output = {
'data': data,
'target': label,
'filename': filename,
'patient': patient,
}
return output return output
def get_labels(self): def get_labels(self):
return self._labels_full[self.indexes] return self._labels_full[self.indices]
def get_files(self): def get_files(self):
return self._files_full[self.indexes] return self._filenames_full[self.indices]
def get_patients(self): def get_patients(self):
return self._patients_full[self.indexes] return self._patients_full[self.indices]
def __len__(self): def __len__(self):
return len(self.indexes) return len(self.indices)
def __shuffle__(self): def __shuffle__(self):
idx_permut = np.random.permutation(self.__len__()) idx_permut = np.random.permutation(self.__len__())
self._files_full = self._files_full[idx_permut] self._filenames_full = self._filenames_full[idx_permut]
self._labels_full = self._labels_full[idx_permut] self._labels_full = self._labels_full[idx_permut]
self._patients_full = self._patients_full[idx_permut] self._patients_full = self._patients_full[idx_permut]
self.indexes = self.indexes[idx_permut] self.indices = self.indices[idx_permut]
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment