Commit 39abf503 authored by Alessia Marcolini's avatar Alessia Marcolini
Browse files

Specify seed in dataset

parent 326ca2d5
......@@ -54,12 +54,15 @@ def augment_3D_HN(image, mode, size):
return image
class NumpyCSVDataset(Dataset):
def __init__(self, data_dir, label_file, label_col, size, augmentation_function=augment_3D, mode='train'):
def __init__(self, data_dir, label_file, label_col, size, transforms=augment_3D, mode='train', seed=1234):
super(NumpyCSVDataset, self).__init__()
self.data_dir = data_dir
self.size = size
self.augmentation = augmentation_function
self.transforms = transforms
self.mode = mode
self.seed = seed
np.random.seed(self.seed)
clinical_file = pd.read_csv(label_file, sep=',', dtype=str).sort_values(by=['Patient #'])
......@@ -82,7 +85,7 @@ class NumpyCSVDataset(Dataset):
data_file = f'{self.data_dir}/{self._files[self._indexes[idx]]}'
data = np.load(data_file)
data = self.augmentation(data, self.mode, self.size)
data = self.transforms(data, self.mode, self.size)
data = torch.Tensor(data)
output = {'data': data, 'target': label, 'filename': file}
return output
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment