Commit b5d5fd43 authored by Alessia Marcolini's avatar Alessia Marcolini
Browse files

Handle subdatasets directly in NumpyCSVDataset

parent 39abf503
......@@ -69,20 +69,20 @@ class NumpyCSVDataset(Dataset):
unique_patients = list(clinical_file['Patient #'])
# example filename with augmentation: HN-HGJ-032_8.npy
patients = list(filter(lambda patient: patient in unique_patients, [f.split('.')[0].split('_')[0] for f in sorted(os.listdir(data_dir))]))
self.patients = patients
self._patients_full = np.array(patients)
self._files = np.array([f for f in sorted(os.listdir(data_dir)) if f.split('.')[0].split('_')[0] in patients]) # select patients in clinical file
self._files_full = np.array([f for f in sorted(os.listdir(data_dir)) if f.split('.')[0].split('_')[0] in patients]) # select patients in clinical file
unique_patients_labels = clinical_file[label_col].values
labels = np.array([clinical_file[clinical_file['Patient #']==patient][label_col].values[0] for patient in patients])
# labels = labels[self._samples]
self._labels = labels
self._indexes = np.arange(len(self._files))
self._labels_full = labels
self.indexes = np.arange(len(self._files_full))
def __getitem__(self, idx, no_data=False):
label = self._labels[self._indexes[idx]]
file = self._files[self._indexes[idx]]
data_file = f'{self.data_dir}/{self._files[self._indexes[idx]]}'
label = self._labels_full[self.indexes[idx]]
file = self._files_full[self.indexes[idx]]
data_file = f'{self.data_dir}/{self._files_full[self.indexes[idx]]}'
data = np.load(data_file)
data = self.transforms(data, self.mode, self.size)
......@@ -91,7 +91,21 @@ class NumpyCSVDataset(Dataset):
return output
def get_labels(self):
return self._labels[self._indexes]
return self._labels_full[self.indexes]
def get_files(self):
return self._files_full[self.indexes]
def get_patients(self):
return self._patients_full[self.indexes]
def __len__(self):
return len(self._indexes)
return len(self.indexes)
def __shuffle__(self):
idx_permut = np.random.permutation(self.__len__())
self._files_full = self._files_full[idx_permut]
self._labels_full = self._labels_full[idx_permut]
self._patients_full = self._patients_full[idx_permut]
self.indexes = self.indexes[idx_permut]
\ No newline at end of file
import numpy as np
from sklearn.model_selection import train_test_split
def train_test_indexes_patient_wise(dataset, test_size=0.2, seed=0, stratify=True):
def train_test_indexes_patient_wise(dataset, test_size=0.2, stratify=True):
files = dataset._files
patients = np.unique(dataset.patients)
files = dataset.get_files()
patients = dataset.get_patients()
unique_patients = np.unique(patients)
# print(len(files), len(patients))
patients_labels = []
for patient in patients:
idx = np.where(patient == np.array(dataset.patients))[0][0] # index of the first im belonging to the patient
for patient in unique_patients:
idx = np.where(patient == np.array(patients))[0][0] # index of the first im belonging to the patient
label = dataset.get_labels()[idx]
patients_labels.append(label)
......@@ -28,11 +29,11 @@ def train_test_indexes_patient_wise(dataset, test_size=0.2, seed=0, stratify=Tru
for train_patient in train_patients:
idxs = np.where(train_patient == np.array(dataset.patients))[0].tolist()
idxs = np.where(train_patient == np.array(patients))[0].tolist()
train_indexes.extend(idxs)
for test_patient in test_patients:
idxs = np.where(test_patient == np.array(dataset.patients))[0].tolist()
idxs = np.where(test_patient == np.array(patients))[0].tolist()
test_indexes.extend(idxs)
return train_indexes, test_indexes
return train_indexes, test_indexes
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment