Commit 1573ec4a authored by Alessia Marcolini's avatar Alessia Marcolini
Browse files

Handle subdatasets with property decorators

parent 6f3c869c
...@@ -95,17 +95,17 @@ class NumpyCSVDataset(Dataset): ...@@ -95,17 +95,17 @@ class NumpyCSVDataset(Dataset):
# filter the clinical file in order to keep files that are really on disk # filter the clinical file in order to keep files that are really on disk
clinical = clinical.loc[clinical['filename'].isin(available_files)] clinical = clinical.loc[clinical['filename'].isin(available_files)]
self._filenames_full = clinical['filename'].values self._filenames = clinical['filename'].values
self._patients_full = clinical['patient'].values self._patients = clinical['patient'].values
self._labels_full = clinical[label_col].values self._labels = clinical[label_col].values
self.indices = np.arange(len(self._filenames_full)) self.indices = np.arange(len(self._filenames))
def __getitem__(self, idx): def __getitem__(self, idx):
label = self._labels_full[self.indices[idx]] label = self._labels[self.indices[idx]]
filename = self._filenames_full[self.indices[idx]] filename = self._filenames[self.indices[idx]]
patient = self._patients_full[self.indices[idx]] patient = self._patients[self.indices[idx]]
data_file = self.data_dir / self._filenames_full[self.indices[idx]] data_file = self.data_dir / self._filenames[self.indices[idx]]
data = np.load(data_file) data = np.load(data_file)
data = self.transforms(data, self.mode, self.size) data = self.transforms(data, self.mode, self.size)
...@@ -118,21 +118,24 @@ class NumpyCSVDataset(Dataset): ...@@ -118,21 +118,24 @@ class NumpyCSVDataset(Dataset):
} }
return output return output
def get_labels(self): @property
return self._labels_full[self.indices] def labels(self):
return self._labels[self.indices]
def get_files(self): @property
return self._filenames_full[self.indices] def filenames(self):
return self._filenames[self.indices]
def get_patients(self): @property
return self._patients_full[self.indices] def patients(self):
return self._patients[self.indices]
def __len__(self): def __len__(self):
return len(self.indices) return len(self.indices)
def __shuffle__(self): def __shuffle__(self):
idx_permut = np.random.permutation(self.__len__()) idx_permut = np.random.permutation(self.__len__())
self._filenames_full = self._filenames_full[idx_permut] self._filenames = self._filenames[idx_permut]
self._labels_full = self._labels_full[idx_permut] self._labels = self._labels[idx_permut]
self._patients_full = self._patients_full[idx_permut] self._patients = self._patients[idx_permut]
self.indices = self.indices[idx_permut] self.indices = self.indices[idx_permut]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment