scatter_interact.py 8.61 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
"""
This python script take as input a csv with patient information,
data from medical imaging,
plot a 2-dim plot of the raw data reduced using umap.
It also provide a tool for visualizing the actual image,
selecting a point on the scatter plot
RUN: run with bokeh server: 'bokeh serve <script path>,
go to the specified local url
LOCATION: csv a upper level folder callend 'vendor_plot',
data in upper level folder called 'bbox64'
REQUIREMENTS:
numpy
pandas (sort of)
bokeh==1.4.0
holoviews==1.12.7
umap
PARAMETER: (important ones)
modality 'PT' /'CT
"""
import os
import numpy as np
import pandas as pd
from holoviews import opts
import holoviews as hv
import holoviews.plotting.bokeh
from holoviews.streams import Selection1D
import umap
from pathlib import Path
hv.extension('bokeh')
renderer = hv.renderer('bokeh')


def scan_explore(datapath_val,
                 datapath_bz,
                 subject_id,
                 slice_id=0,
                 modality='CT'):
    print(f"Subject ID:{subject_id} Slice(x):{slice_id} Modality:{modality}")
    slice_id = int(slice_id)
    data_ext = ".npy"

    if modality == 'PT':
        mod_id = 1
    elif modality == 'CT':
        mod_id = 0
    else:
        print("Invalid scan modality.")
    # extract slice from data
    object_path_val = datapath_val/(subject_id + data_ext)  # path op
    object_path_bz = datapath_bz/(subject_id + data_ext)
    try:
        img_slice = np.load(object_path_val)[mod_id, slice_id, :, :]
    except FileNotFoundError:
        img_slice = np.load(object_path_bz)[mod_id, slice_id, :, :]

    return hv.Image(img_slice).opts(title=f"id:{subject_id}  mod:{modality}",
                                    fontsize={'title': 10})


def scan_wrap(datapath_val, datapath_bz, points_data,
              index, slice_id=0, modality='CT'):
    print("Index supplied ", index)
    # the supplied index is always a list with one element
    if (index is None) | (index == []):
        pt_index = 0
    else:
        pt_index = index[0]

    subject_id = (points_data.data.iloc[pt_index])['Subject']
    res = scan_explore(datapath_val, datapath_bz,
                       subject_id, slice_id, modality)
    return res


def get_data_lbl(path, file_list, path_orig_df, modality):
    """
    Extract data and manufacturer label from data folder and info file
    IN:
    path: data folder
    file_list: list of files with extension
    path_orig_df: our dataset with file info
    modality: either PT (for PET) or CT
    OUT:
    pandas dataframe with cols:
    data: list of 3d-numpy arrays
    manufacturer: scan device manufacturer
    dataset: dataset id
    subject_id: scanned subject id
    """
    if modality == "CT":
        mod_id = 0
    elif modality == "PT":
        mod_id = 1
    else:
        print("Please enter a valid modality parameter")
    #
    df_sel_mod = path_orig_df.loc[path_orig_df["Modality"] == modality]
    data = []
    lbl = []
    dataset_id = []
    subject_id = []
    for elem in file_list:
        # subj id removing extension
        subj_id = elem.split(".")[0]
        ds_split = subj_id.split("-")
        if len(ds_split) == 3:
            dataset_tmp = subj_id.split("-")[1]
            # save manufacturer
            vendor = df_sel_mod.loc[df_sel_mod["Subject ID"] == subj_id,
                                    "Manufacturer"].values[0]
        elif len(ds_split) == 1:  # bz unique with different format
            dataset_tmp = "BZ_0"
            vendor = "Philips"  # bolzano manufacturer
        # load patient data
        data_tmp = np.load(path/elem)[mod_id, :, :, :]
        # save to lists
        data.append(data_tmp)
        lbl.append(vendor)
        dataset_id.append(dataset_tmp)
        subject_id.append(subj_id)
    return pd.DataFrame(data={'data': data, 'manufacturer': lbl,
                              'dataset': dataset_id, 'subject_id': subject_id})


def get_scatter(df, group, ds):
    keydims = ['x_u', 'y_u']
    valdims = ['Manufacturer', 'Dataset', 'Subject']
    df_sel = df.loc[(df['manufacturer'] == group) & (df['dataset'] == ds)]
    red_x = df_sel['red_x']
    red_y = df_sel['red_y']
    manufacturer = df_sel['manufacturer']
    dataset = df_sel['dataset']
    subject_id = df_sel['subject_id']
    plot = hv.Scatter((red_x, red_y, manufacturer, dataset, subject_id),
                      kdims=keydims, vdims=valdims)
    return plot


def get_scatter_full(df):
    keydims = ['x_u', 'y_u']
    valdims = ['Manufacturer', 'Dataset', 'Subject']
    df_sel = df
    red_x = df_sel['red_x']
    red_y = df_sel['red_y']
    manufacturer = df_sel['manufacturer']
    dataset = df_sel['dataset']
    subject_id = df_sel['subject_id']
    plot = hv.Scatter((red_x, red_y, manufacturer, dataset, subject_id),
                      kdims=keydims, vdims=valdims)
    return plot
#


# set script parameters
modality = "CT"   # choose a modality for the 2d rim red plot
data_folder = Path().cwd()/'..'/'..'/'data'
info_df_path = data_folder/'HN_val'/'processed'/'path_original_data.csv'
DATAPATH = data_folder/'HN_val'/'processed'/'bbox'/'bbox_64'
DATAPATH_BZ = data_folder/'HN_BZ'/'processed'/'bbox'/'bbox_64'
file_list = os.listdir(DATAPATH)
file_list_bz = os.listdir(DATAPATH_BZ)
file_list.sort()
file_list_bz.sort()
path_orig_df = pd.read_csv(info_df_path)
print("Environment set.")

# prepare data and dim red
# actually get the data using pandas dataframes
# collect df from multiple folder for inputs and merge vertically
df_val = get_data_lbl(DATAPATH, file_list, path_orig_df, modality)
df_bz = get_data_lbl(DATAPATH_BZ, file_list_bz, path_orig_df, modality)
df = df_val.append(df_bz, ignore_index=True)
# flatten the 3d array
df['flat_data'] = [x.flatten('C') for x in df['data']]
df.drop(columns=['data'], inplace=True)  # remove data col to save memory
umap_euclid = umap.UMAP(metric="euclidean", n_components=2,
                        n_neighbors=20, min_dist=.2)
reduced_data = umap_euclid.fit_transform(df['flat_data'].tolist())
df['red_x'] = reduced_data[:, 0]
df['red_y'] = reduced_data[:, 1]
df.drop(columns=['flat_data'], inplace=True)  # remove flat_data col

print("Dim red completed")

# set up the layered scatterplot
# compute possible combination in the dataset
# use them to generate all the relevant sub-plot to be stacked

# compute unique co occurrences of manufacturer and
# dataset id actually present in the data
comb_vendor_ds = np.unique(df['manufacturer'] + '-' + df['dataset'])
comb_vendor_ds = [x.split('-') for x in comb_vendor_ds]
scatter_dict = {(group, ds): get_scatter(df, group, ds)
                for group, ds in comb_vendor_ds}


# overlayed plots with different group, give color/legend can be hovered over
overlay = hv.NdOverlay(scatter_dict, kdims=['Manufacturer', 'dataset'])
overlay.opts(opts.Scatter(alpha=0.75, size=7, width=650,
                          height=650, tools=['hover']))

# one transparent layer just for point index information with tapping widgets
scatter_title = f"Umap 2d plot HN mod: {modality}"
p_ref = get_scatter_full(df).opts(opts.Scatter(title=scatter_title,
                                              alpha=0.0, size=7,
                                              width=650, height=650,
                                              tools=['tap']))
s = overlay * p_ref
# attach stream to overlayed plot
# stream depend on tap widget, unique selection widget and
# output a list with its index
# in this case it will account for index over the last layered
# scatterplot (in this case the complete one)
sel_stream = Selection1D(source=s, index=[0])
print("overlay plot and stream have been set.")

# define dynamic map taking subj id from stream and points data,
# slice id and modality from widgets
# to use already there the index variable and the kdim from
# sliders explicitly in the plot func
# it required that lambda func otherwise only the func has to be passed.
# prob there is a better way...
scan_dmap = hv.DynamicMap(lambda index, slice_id,
                          modality: scan_wrap(datapath_val=DATAPATH, 
                                              datapath_bz=DATAPATH_BZ,
                                              points_data=p_ref,
                                              index=index,
                                              slice_id=slice_id,
                                              modality=modality),
                        streams=[sel_stream],
                        kdims=['slice_id', 'modality'])
# # define value domains for key dimensions not in streams
scan_dmap = scan_dmap.redim.range(slice_id=(0, 64))
scan_dmap = scan_dmap.redim.values(modality=['CT', 'PT'])
# rendering
layout = s + scan_dmap  # build layout
doc = renderer.server_doc(layout)  # instantiate renderer
doc.title = 'HN Scan explorer'  # set title