Commit 6c5fe9b2 authored by Alessia Marcolini's avatar Alessia Marcolini
Browse files

Modify train_test_indexes_patient_wise according to new dataset methods

parent 1573ec4a
import numpy as np
from sklearn.model_selection import train_test_split
def train_test_indexes_patient_wise(dataset, test_size=0.2, stratify=True):
patients = dataset.get_patients()
def train_test_indexes_patient_wise(dataset, test_size=0.2, stratify=True, seed=1234):
patients = dataset.patients
unique_patients = np.unique(patients)
# print(len(files), len(patients))
......@@ -11,28 +12,38 @@ def train_test_indexes_patient_wise(dataset, test_size=0.2, stratify=True):
patients_labels = []
for patient in unique_patients:
idx = np.where(patient == np.array(patients))[0][0] # index of the first im belonging to the patient
label = dataset.get_labels()[idx]
idx = np.where(patient == np.array(patients))[0][
0
] # index of the first im belonging to the patient
label = dataset.labels[idx]
patients_labels.append(label)
# print(len(patients_labels))
# print(train_patients)
if stratify:
train_patients, test_patients = train_test_split(unique_patients, test_size=test_size, random_state=dataset.seed, stratify=patients_labels)
train_patients, test_patients = train_test_split(
unique_patients,
test_size=test_size,
random_state=seed,
stratify=patients_labels,
)
else:
train_patients, test_patients = train_test_split(unique_patients, test_size=test_size, random_state=dataset.seed)
train_patients, test_patients = train_test_split(
unique_patients, test_size=test_size, random_state=seed
)
train_indexes = []
test_indexes = []
for train_patient in train_patients:
idxs = np.where(train_patient == np.array(patients))[0].tolist()
train_indexes.extend(idxs)
for test_patient in test_patients:
idxs = np.where(test_patient == np.array(patients))[0].tolist()
test_indexes.extend(idxs)
return train_indexes, test_indexes
\ No newline at end of file
return train_indexes, test_indexes
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment