split.py 1.32 KB
Newer Older
Nicole Bussola's avatar
Nicole Bussola committed
1
2
3
import numpy as np
from sklearn.model_selection import train_test_split

4
5
6
7

def train_test_indexes_patient_wise(dataset, test_size=0.2, stratify=True, seed=1234):

    patients = dataset.patients
8
    unique_patients = np.unique(patients)
Nicole Bussola's avatar
Nicole Bussola committed
9
10
11
12
13

    # print(len(files), len(patients))

    patients_labels = []

14
    for patient in unique_patients:
15
16
17
18
        idx = np.where(patient == np.array(patients))[0][
            0
        ]  # index of the first im belonging to the patient
        label = dataset.labels[idx]
Nicole Bussola's avatar
Nicole Bussola committed
19
20
21
        patients_labels.append(label)

    # print(len(patients_labels))
22

23
24
    # print(train_patients)
    if stratify:
25
26
27
28
29
30
        train_patients, test_patients = train_test_split(
            unique_patients,
            test_size=test_size,
            random_state=seed,
            stratify=patients_labels,
        )
31
    else:
32
33
34
35
        train_patients, test_patients = train_test_split(
            unique_patients, test_size=test_size, random_state=seed
        )

Nicole Bussola's avatar
Nicole Bussola committed
36
37
    train_indexes = []
    test_indexes = []
38

Nicole Bussola's avatar
Nicole Bussola committed
39
    for train_patient in train_patients:
40
        idxs = np.where(train_patient == np.array(patients))[0].tolist()
Nicole Bussola's avatar
Nicole Bussola committed
41
        train_indexes.extend(idxs)
42

Nicole Bussola's avatar
Nicole Bussola committed
43
    for test_patient in test_patients:
44
        idxs = np.where(test_patient == np.array(patients))[0].tolist()
Nicole Bussola's avatar
Nicole Bussola committed
45
        test_indexes.extend(idxs)
46
47
48
49

    return train_indexes, test_indexes