Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
MPBA Radiomics
RADLER
Commits
6c5fe9b2
Commit
6c5fe9b2
authored
Mar 11, 2020
by
Alessia Marcolini
Browse files
Modify train_test_indexes_patient_wise according to new dataset methods
parent
1573ec4a
Changes
1
Hide whitespace changes
Inline
Side-by-side
split.py
View file @
6c5fe9b2
import
numpy
as
np
from
sklearn.model_selection
import
train_test_split
def
train_test_indexes_patient_wise
(
dataset
,
test_size
=
0.2
,
stratify
=
True
):
patients
=
dataset
.
get_patients
()
def
train_test_indexes_patient_wise
(
dataset
,
test_size
=
0.2
,
stratify
=
True
,
seed
=
1234
):
patients
=
dataset
.
patients
unique_patients
=
np
.
unique
(
patients
)
# print(len(files), len(patients))
...
...
@@ -11,28 +12,38 @@ def train_test_indexes_patient_wise(dataset, test_size=0.2, stratify=True):
patients_labels
=
[]
for
patient
in
unique_patients
:
idx
=
np
.
where
(
patient
==
np
.
array
(
patients
))[
0
][
0
]
# index of the first im belonging to the patient
label
=
dataset
.
get_labels
()[
idx
]
idx
=
np
.
where
(
patient
==
np
.
array
(
patients
))[
0
][
0
]
# index of the first im belonging to the patient
label
=
dataset
.
labels
[
idx
]
patients_labels
.
append
(
label
)
# print(len(patients_labels))
# print(train_patients)
if
stratify
:
train_patients
,
test_patients
=
train_test_split
(
unique_patients
,
test_size
=
test_size
,
random_state
=
dataset
.
seed
,
stratify
=
patients_labels
)
train_patients
,
test_patients
=
train_test_split
(
unique_patients
,
test_size
=
test_size
,
random_state
=
seed
,
stratify
=
patients_labels
,
)
else
:
train_patients
,
test_patients
=
train_test_split
(
unique_patients
,
test_size
=
test_size
,
random_state
=
dataset
.
seed
)
train_patients
,
test_patients
=
train_test_split
(
unique_patients
,
test_size
=
test_size
,
random_state
=
seed
)
train_indexes
=
[]
test_indexes
=
[]
for
train_patient
in
train_patients
:
idxs
=
np
.
where
(
train_patient
==
np
.
array
(
patients
))[
0
].
tolist
()
train_indexes
.
extend
(
idxs
)
for
test_patient
in
test_patients
:
idxs
=
np
.
where
(
test_patient
==
np
.
array
(
patients
))[
0
].
tolist
()
test_indexes
.
extend
(
idxs
)
return
train_indexes
,
test_indexes
\ No newline at end of file
return
train_indexes
,
test_indexes
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment