Commit 326ca2d5 authored by Alessia Marcolini's avatar Alessia Marcolini
Browse files

Fix class stratification of split

parent a714181f
......@@ -17,8 +17,11 @@ def train_test_indexes_patient_wise(dataset, test_size=0.2, seed=0, stratify=Tru
# print(len(patients_labels))
#print(train_patients)
train_patients, test_patients = train_test_split(patients, test_size=test_size, random_state=seed)
# print(train_patients)
if stratify:
train_patients, test_patients = train_test_split(unique_patients, test_size=test_size, random_state=dataset.seed, stratify=patients_labels)
else:
train_patients, test_patients = train_test_split(unique_patients, test_size=test_size, random_state=dataset.seed)
train_indexes = []
test_indexes = []
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment