Commit 4b221185 authored by Alessia Marcolini's avatar Alessia Marcolini
Browse files

Add BZ dataset

parent 0b3fd442
import torch import os
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import os
import SimpleITK as sitk import SimpleITK as sitk
from dicom_utils.dicom_utils import augmentation as aug, processing as dup import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from dicom_utils.dicom_utils import augmentation as aug
from dicom_utils.dicom_utils import processing as dup
def augment_3D(image, mode, size): def augment_3D(image, mode, size):
N_CHANNELS = image.shape[0] N_CHANNELS = image.shape[0]
image_seq = [sitk.GetImageFromArray(image[i,:,:,:]) for i in range(N_CHANNELS)] image_seq = [sitk.GetImageFromArray(image[i, :, :, :]) for i in range(N_CHANNELS)]
if mode == 'train': if mode == 'train':
# morphological augmentation # morphological augmentation
image_seq = aug.augment_morph(image_seq) image_seq = aug.augment_morph(image_seq)
# add gaussian noise # add gaussian noise
fg = aug.get_gauss_noise() fg = aug.get_gauss_noise()
image_seq = [fg.Execute(image) for image in image_seq] image_seq = [fg.Execute(image) for image in image_seq]
image_seq = [dup.resample(image, size = (size, size, size)) for image in image_seq] image_seq = [dup.resample(image, size=(size, size, size)) for image in image_seq]
image_seq = [sitk.GetArrayFromImage(image) for image in image_seq] image_seq = [sitk.GetArrayFromImage(image) for image in image_seq]
image = np.stack(image_seq, axis=0) image = np.stack(image_seq, axis=0)
return image return image
def augment_3D_HN(image, mode, size): def augment_3D_HN(image, mode, size):
def normalize_range(image, range_pixel): def normalize_range(image, range_pixel):
image = sitk.Threshold(image, lower = -5000, upper=range_pixel[1], outsideValue=range_pixel[1]) image = sitk.Threshold(
image = sitk.Threshold(image, lower=range_pixel[0], upper=5000, outsideValue=range_pixel[0]) image, lower=-5000, upper=range_pixel[1], outsideValue=range_pixel[1]
image = (image - range_pixel[0])/(range_pixel[1] - range_pixel[0]) )
image = sitk.Threshold(
image, lower=range_pixel[0], upper=5000, outsideValue=range_pixel[0]
)
image = (image - range_pixel[0]) / (range_pixel[1] - range_pixel[0])
return image return image
image_CT = sitk.GetImageFromArray(image[0,:,:,:]) image_CT = sitk.GetImageFromArray(image[0, :, :, :])
image_PT = sitk.GetImageFromArray(image[1,:,:,:]) image_PT = sitk.GetImageFromArray(image[1, :, :, :])
# normalize_range # normalize_range
image_CT = normalize_range(image_CT, [-1000, 3000]) image_CT = normalize_range(image_CT, [-1000, 3000])
image_PT = normalize_range(image_PT, [0, 50]) image_PT = normalize_range(image_PT, [0, 50])
image_seq = [image_CT, image_PT] image_seq = [image_CT, image_PT]
if mode == 'train': if mode == 'train':
# morphological augmentation # morphological augmentation
image_seq = aug.augment_morph(image_seq) image_seq = aug.augment_morph(image_seq)
# add gaussian noise # add gaussian noise
# fg = aug.get_gauss_noise() # fg = aug.get_gauss_noise()
# image_seq = [fg.Execute(image) for image in image_seq] # image_seq = [fg.Execute(image) for image in image_seq]
image_seq = [dup.resample(image, size = (size, size, size)) for image in image_seq] image_seq = [dup.resample(image, size=(size, size, size)) for image in image_seq]
image_seq = [sitk.GetArrayFromImage(image) for image in image_seq] image_seq = [sitk.GetArrayFromImage(image) for image in image_seq]
image = np.stack(image_seq, axis=0) image = np.stack(image_seq, axis=0)
return image return image
class NumpyCSVDataset(Dataset): class NumpyCSVDataset(Dataset):
def __init__(self, data_dir, label_file, label_col, size, transforms=augment_3D, mode='train', seed=1234): def __init__(
self,
data_dir,
label_file,
label_col,
size,
transforms=augment_3D,
mode='train',
seed=1234,
):
super(NumpyCSVDataset, self).__init__() super(NumpyCSVDataset, self).__init__()
self.data_dir = data_dir self.data_dir = data_dir
self.size = size self.size = size
self.transforms = transforms self.transforms = transforms
self.mode = mode self.mode = mode
self.seed = seed self.seed = seed
np.random.seed(self.seed)
clinical_file = pd.read_csv(label_file, sep=',', dtype=str).sort_values(
by=['patient']
)
unique_patients = list(clinical_file['patient'])
# example filename with augmentation: HN-HGJ-032_8.npy
patients = list(
filter(
lambda patient: patient in unique_patients,
[f.split('.')[0].split('_')[0] for f in sorted(os.listdir(data_dir))],
)
)
self._patients_full = np.array(patients)
self._files_full = np.array(
[
f
for f in sorted(os.listdir(data_dir))
if f.split('.')[0].split('_')[0] in patients
]
) # select patients in clinical file
unique_patients_labels = clinical_file[label_col].values
labels = np.array(
[
clinical_file[clinical_file['patient'] == patient][label_col].values[0]
for patient in patients
]
)
# labels = labels[self._samples]
self._labels_full = labels
self.indexes = np.arange(len(self._files_full))
def __getitem__(self, idx, no_data=False):
label = self._labels_full[self.indexes[idx]]
file = self._files_full[self.indexes[idx]]
patient = self._patients_full[self.indexes[idx]]
data_file = f'{self.data_dir}/{self._files_full[self.indexes[idx]]}'
data = np.load(data_file)
data = self.transforms(data, self.mode, self.size)
data = torch.Tensor(data)
output = {'data': data, 'target': label, 'filename': file, 'patient': patient}
return output
def get_labels(self):
return self._labels_full[self.indexes]
def get_files(self):
return self._files_full[self.indexes]
def get_patients(self):
return self._patients_full[self.indexes]
def __len__(self):
return len(self.indexes)
def __shuffle__(self):
idx_permut = np.random.permutation(self.__len__())
self._files_full = self._files_full[idx_permut]
self._labels_full = self._labels_full[idx_permut]
self._patients_full = self._patients_full[idx_permut]
self.indexes = self.indexes[idx_permut]
class NumpyCSVDataset_BZ(Dataset):
def __init__(
self,
data_dir,
label_file,
label_col,
size,
transforms=augment_3D,
mode='train',
seed=1234,
):
super(NumpyCSVDataset_BZ, self).__init__()
self.data_dir = data_dir
self.size = size
self.transforms = transforms
self.mode = mode
self.seed = seed
np.random.seed(self.seed) np.random.seed(self.seed)
clinical_file = pd.read_csv(label_file, sep=',', dtype=str).sort_values(by=['Patient #']) clinical_file = pd.read_csv(label_file, sep=';', dtype=str).sort_values(
by=['patient']
unique_patients = list(clinical_file['Patient #']) )
unique_patients = list(clinical_file['patient'])
# example filename with augmentation: HN-HGJ-032_8.npy # example filename with augmentation: HN-HGJ-032_8.npy
patients = list(filter(lambda patient: patient in unique_patients, [f.split('.')[0].split('_')[0] for f in sorted(os.listdir(data_dir))])) patients = list(
filter(
lambda patient: patient in unique_patients,
[f.split('.')[0].split('_')[0] for f in sorted(os.listdir(data_dir))],
)
)
self._patients_full = np.array(patients) self._patients_full = np.array(patients)
self._files_full = np.array([f for f in sorted(os.listdir(data_dir)) if f.split('.')[0].split('_')[0] in patients]) # select patients in clinical file self._files_full = np.array(
[
f
for f in sorted(os.listdir(data_dir))
if f.split('.')[0].split('_')[0] in patients
]
) # select patients in clinical file
unique_patients_labels = clinical_file[label_col].values unique_patients_labels = clinical_file[label_col].values
labels = np.array([clinical_file[clinical_file['Patient #']==patient][label_col].values[0] for patient in patients]) labels = np.array(
# labels = labels[self._samples] [
clinical_file[clinical_file['patient'] == patient][label_col].values[0]
for patient in patients
]
)
# labels = labels[self._samples]
self._labels_full = labels self._labels_full = labels
self.indexes = np.arange(len(self._files_full)) self.indexes = np.arange(len(self._files_full))
...@@ -85,28 +206,27 @@ class NumpyCSVDataset(Dataset): ...@@ -85,28 +206,27 @@ class NumpyCSVDataset(Dataset):
patient = self._patients_full[self.indexes[idx]] patient = self._patients_full[self.indexes[idx]]
data_file = f'{self.data_dir}/{self._files_full[self.indexes[idx]]}' data_file = f'{self.data_dir}/{self._files_full[self.indexes[idx]]}'
data = np.load(data_file) data = np.load(data_file)
data = self.transforms(data, self.mode, self.size) data = self.transforms(data, self.mode, self.size)
data = torch.Tensor(data) data = torch.Tensor(data)
output = {'data': data, 'target': label, 'filename': file, 'patient': patient} output = {'data': data, 'target': label, 'filename': file, 'patient': patient}
return output return output
def get_labels(self): def get_labels(self):
return self._labels_full[self.indexes] return self._labels_full[self.indexes]
def get_files(self): def get_files(self):
return self._files_full[self.indexes] return self._files_full[self.indexes]
def get_patients(self): def get_patients(self):
return self._patients_full[self.indexes] return self._patients_full[self.indexes]
def __len__(self): def __len__(self):
return len(self.indexes) return len(self.indexes)
def __shuffle__(self): def __shuffle__(self):
idx_permut = np.random.permutation(self.__len__()) idx_permut = np.random.permutation(self.__len__())
self._files_full = self._files_full[idx_permut] self._files_full = self._files_full[idx_permut]
self._labels_full = self._labels_full[idx_permut] self._labels_full = self._labels_full[idx_permut]
self._patients_full = self._patients_full[idx_permut] self._patients_full = self._patients_full[idx_permut]
self.indexes = self.indexes[idx_permut] self.indexes = self.indexes[idx_permut]
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment