Commit 6c5fe9b2 authored by Alessia Marcolini's avatar Alessia Marcolini
Browse files

Modify train_test_indexes_patient_wise according to new dataset methods

parent 1573ec4a
import numpy as np import numpy as np
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
def train_test_indexes_patient_wise(dataset, test_size=0.2, stratify=True):
def train_test_indexes_patient_wise(dataset, test_size=0.2, stratify=True, seed=1234):
patients = dataset.get_patients()
patients = dataset.patients
unique_patients = np.unique(patients) unique_patients = np.unique(patients)
# print(len(files), len(patients)) # print(len(files), len(patients))
...@@ -11,28 +12,38 @@ def train_test_indexes_patient_wise(dataset, test_size=0.2, stratify=True): ...@@ -11,28 +12,38 @@ def train_test_indexes_patient_wise(dataset, test_size=0.2, stratify=True):
patients_labels = [] patients_labels = []
for patient in unique_patients: for patient in unique_patients:
idx = np.where(patient == np.array(patients))[0][0] # index of the first im belonging to the patient idx = np.where(patient == np.array(patients))[0][
label = dataset.get_labels()[idx] 0
] # index of the first im belonging to the patient
label = dataset.labels[idx]
patients_labels.append(label) patients_labels.append(label)
# print(len(patients_labels)) # print(len(patients_labels))
# print(train_patients) # print(train_patients)
if stratify: if stratify:
train_patients, test_patients = train_test_split(unique_patients, test_size=test_size, random_state=dataset.seed, stratify=patients_labels) train_patients, test_patients = train_test_split(
unique_patients,
test_size=test_size,
random_state=seed,
stratify=patients_labels,
)
else: else:
train_patients, test_patients = train_test_split(unique_patients, test_size=test_size, random_state=dataset.seed) train_patients, test_patients = train_test_split(
unique_patients, test_size=test_size, random_state=seed
)
train_indexes = [] train_indexes = []
test_indexes = [] test_indexes = []
for train_patient in train_patients: for train_patient in train_patients:
idxs = np.where(train_patient == np.array(patients))[0].tolist() idxs = np.where(train_patient == np.array(patients))[0].tolist()
train_indexes.extend(idxs) train_indexes.extend(idxs)
for test_patient in test_patients: for test_patient in test_patients:
idxs = np.where(test_patient == np.array(patients))[0].tolist() idxs = np.where(test_patient == np.array(patients))[0].tolist()
test_indexes.extend(idxs) test_indexes.extend(idxs)
return train_indexes, test_indexes return train_indexes, test_indexes
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment