Commit e751d783 authored by Nicole Bussola's avatar Nicole Bussola Committed by Alessia Marcolini
Browse files

add patient-wise splitting

parent ab5c881d
import numpy as np
from sklearn.model_selection import train_test_split
def train_test_indexes_patient_wise(dataset, test_size=0.2, seed=0, stratify=True):
files = dataset._files
patients = np.unique(dataset.patients)
# print(len(files), len(patients))
patients_labels = []
for patient in patients:
idx = np.where(patient == np.array(dataset.patients))[0][0] # index of the first im belonging to the patient
label = dataset.get_labels()[idx]
# print(len(patients_labels))
train_patients, test_patients = train_test_split(patients, test_size=test_size, random_state=seed)
train_indexes = []
test_indexes = []
for train_patient in train_patients:
idxs = np.where(train_patient == np.array(dataset.patients))[0].tolist()
for test_patient in test_patients:
idxs = np.where(test_patient == np.array(dataset.patients))[0].tolist()
return train_indexes, test_indexes
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment