dataset.py 4.16 KB
Newer Older
Alessia Marcolini's avatar
Alessia Marcolini committed
1
import os
2
from pathlib import Path
Alessia Marcolini's avatar
Alessia Marcolini committed
3

damiana s's avatar
damiana s committed
4
5
6
import numpy as np
import pandas as pd
import SimpleITK as sitk
Alessia Marcolini's avatar
Alessia Marcolini committed
7
import torch
damiana s's avatar
damiana s committed
8
9
from torch.utils.data import Dataset

10

Alessia Marcolini's avatar
Alessia Marcolini committed
11
12
13
14
from dicom_utils.dicom_utils import augmentation as aug
from dicom_utils.dicom_utils import processing as dup


damiana s's avatar
damiana s committed
15
16
def augment_3D(image, mode, size):
    N_CHANNELS = image.shape[0]
Alessia Marcolini's avatar
Alessia Marcolini committed
17
18
    image_seq = [sitk.GetImageFromArray(image[i, :, :, :]) for i in range(N_CHANNELS)]

19
    if mode == 'train':
20
        # morphological augmentation
21
        image_seq = aug.augment_morph(image_seq)
Alessia Marcolini's avatar
Alessia Marcolini committed
22

23
        # add gaussian noise
24
        fg = aug.get_gauss_noise()
25
        image_seq = [fg.Execute(image) for image in image_seq]
Alessia Marcolini's avatar
Alessia Marcolini committed
26
27

    image_seq = [dup.resample(image, size=(size, size, size)) for image in image_seq]
28
29
    image_seq = [sitk.GetArrayFromImage(image) for image in image_seq]
    image = np.stack(image_seq, axis=0)
30
    return image
Andrea Bizzego's avatar
Andrea Bizzego committed
31

Alessia Marcolini's avatar
Alessia Marcolini committed
32

Andrea Bizzego's avatar
Andrea Bizzego committed
33
34
def augment_3D_HN(image, mode, size):
    def normalize_range(image, range_pixel):
Alessia Marcolini's avatar
Alessia Marcolini committed
35
36
37
38
39
40
41
        image = sitk.Threshold(
            image, lower=-5000, upper=range_pixel[1], outsideValue=range_pixel[1]
        )
        image = sitk.Threshold(
            image, lower=range_pixel[0], upper=5000, outsideValue=range_pixel[0]
        )
        image = (image - range_pixel[0]) / (range_pixel[1] - range_pixel[0])
42
        return image
Andrea Bizzego's avatar
Andrea Bizzego committed
43

Alessia Marcolini's avatar
Alessia Marcolini committed
44
45
46
    image_CT = sitk.GetImageFromArray(image[0, :, :, :])
    image_PT = sitk.GetImageFromArray(image[1, :, :, :])

Andrea Bizzego's avatar
Andrea Bizzego committed
47
48
49
    # normalize_range
    image_CT = normalize_range(image_CT, [-1000, 3000])
    image_PT = normalize_range(image_PT, [0, 50])
Alessia Marcolini's avatar
Alessia Marcolini committed
50

Andrea Bizzego's avatar
Andrea Bizzego committed
51
    image_seq = [image_CT, image_PT]
Alessia Marcolini's avatar
Alessia Marcolini committed
52

Andrea Bizzego's avatar
Andrea Bizzego committed
53
    if mode == 'train':
54
        # morphological augmentation
Andrea Bizzego's avatar
Andrea Bizzego committed
55
        image_seq = aug.augment_morph(image_seq)
Alessia Marcolini's avatar
Alessia Marcolini committed
56

57
58
59
        # add gaussian noise
        # fg = aug.get_gauss_noise()
        # image_seq = [fg.Execute(image) for image in image_seq]
Alessia Marcolini's avatar
Alessia Marcolini committed
60
61

    image_seq = [dup.resample(image, size=(size, size, size)) for image in image_seq]
Andrea Bizzego's avatar
Andrea Bizzego committed
62
63
    image_seq = [sitk.GetArrayFromImage(image) for image in image_seq]
    image = np.stack(image_seq, axis=0)
64
    return image
Alessia Marcolini's avatar
Alessia Marcolini committed
65
66


damiana s's avatar
damiana s committed
67
class NumpyCSVDataset(Dataset):
Alessia Marcolini's avatar
Alessia Marcolini committed
68
69
70
    def __init__(
        self,
        data_dir,
71
        clinical_file,
Alessia Marcolini's avatar
Alessia Marcolini committed
72
73
74
75
76
77
        label_col,
        size,
        transforms=augment_3D,
        mode='train',
        seed=1234,
    ):
damiana s's avatar
damiana s committed
78
        super(NumpyCSVDataset, self).__init__()
79
80
81
82

        if not isinstance(data_dir, Path):
            data_dir = Path(data_dir)

damiana s's avatar
damiana s committed
83
84
        self.data_dir = data_dir
        self.size = size
Alessia Marcolini's avatar
Alessia Marcolini committed
85
        self.transforms = transforms
damiana s's avatar
damiana s committed
86
        self.mode = mode
Alessia Marcolini's avatar
Alessia Marcolini committed
87
        self.seed = seed
Alessia Marcolini's avatar
Alessia Marcolini committed
88
89
90

        np.random.seed(self.seed)

91
        clinical = pd.read_csv(clinical_file, dtype=str).sort_values(by=['patient'])
Alessia Marcolini's avatar
Alessia Marcolini committed
92

93
        available_files = [f for f in os.listdir(data_dir) if f.endswith('.npy')]
Alessia Marcolini's avatar
Alessia Marcolini committed
94

95
96
        # filter the clinical file in order to keep files that are really on disk
        clinical = clinical.loc[clinical['filename'].isin(available_files)]
Alessia Marcolini's avatar
Alessia Marcolini committed
97

98
99
100
        self._filenames = clinical['filename'].values
        self._patients = clinical['patient'].values
        self._labels = clinical[label_col].values
Alessia Marcolini's avatar
Alessia Marcolini committed
101

102
        self.indices = np.arange(len(self._filenames))
Alessia Marcolini's avatar
Alessia Marcolini committed
103

104
    def __getitem__(self, idx):
105
106
107
108
        label = self._labels[self.indices[idx]]
        filename = self._filenames[self.indices[idx]]
        patient = self._patients[self.indices[idx]]
        data_file = self.data_dir / self._filenames[self.indices[idx]]
damiana s's avatar
damiana s committed
109
        data = np.load(data_file)
Alessia Marcolini's avatar
Alessia Marcolini committed
110
111
112

        data = self.transforms(data, self.mode, self.size)
        data = torch.Tensor(data)
113
114
115
116
117
118
        output = {
            'data': data,
            'target': label,
            'filename': filename,
            'patient': patient,
        }
119
        return output
Alessia Marcolini's avatar
Alessia Marcolini committed
120

121
122
123
    @property
    def labels(self):
        return self._labels[self.indices]
Alessia Marcolini's avatar
Alessia Marcolini committed
124

125
126
127
    @property
    def filenames(self):
        return self._filenames[self.indices]
Alessia Marcolini's avatar
Alessia Marcolini committed
128

129
130
131
    @property
    def patients(self):
        return self._patients[self.indices]
Alessia Marcolini's avatar
Alessia Marcolini committed
132

damiana s's avatar
damiana s committed
133
    def __len__(self):
134
        return len(self.indices)
Alessia Marcolini's avatar
Alessia Marcolini committed
135

136
137
    def __shuffle__(self):
        idx_permut = np.random.permutation(self.__len__())
138
139
140
        self._filenames = self._filenames[idx_permut]
        self._labels = self._labels[idx_permut]
        self._patients = self._patients[idx_permut]
141
        self.indices = self.indices[idx_permut]