Update app.py
Browse files
app.py
CHANGED
@@ -5,11 +5,10 @@ import plotly.express as px
|
|
5 |
import gradio as gr
|
6 |
|
7 |
######################################
|
8 |
-
# 1)
|
9 |
######################################
|
10 |
-
df = pd.read_csv("X_train_Y_Train_merged_train.csv")
|
11 |
|
12 |
-
# List of model filenames (adjust if needed)
|
13 |
model_filenames = [
|
14 |
"YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl",
|
15 |
"YODPPROB.pkl", "YOWRPROB.pkl", "YODPR2WK.pkl", "YOWRDEPR.pkl",
|
@@ -18,70 +17,60 @@ model_filenames = [
|
|
18 |
]
|
19 |
model_path = "models/"
|
20 |
|
21 |
-
|
22 |
-
######################################
|
23 |
-
# 2) Model Predictor
|
24 |
-
######################################
|
25 |
class ModelPredictor:
|
26 |
def __init__(self, model_path, model_filenames):
|
27 |
self.model_path = model_path
|
28 |
self.model_filenames = model_filenames
|
29 |
self.models = self.load_models()
|
30 |
-
# Mapping from label column to
|
31 |
self.prediction_map = {
|
32 |
-
"YOWRCONC": ["
|
33 |
-
"YOSEEDOC": ["Did not feel
|
34 |
-
"YOWRHRS": ["
|
35 |
-
"YO_MDEA5": ["
|
36 |
-
"YOWRCHR": ["Did not feel so sad", "Felt so sad nothing
|
37 |
-
"YOWRLSIN": ["
|
38 |
-
"YODPPROB": ["No other
|
39 |
-
"YOWRPROB": ["Did not have
|
40 |
-
"YODPR2WK": ["No
|
41 |
"YOWRDEPR": ["Did not feel depressed mostly everyday", "Felt depressed mostly everyday"],
|
42 |
-
"YODPDISC": ["Mood not depressed overall", "Mood depressed overall
|
43 |
-
"YOLOSEV": ["
|
44 |
"YOWRDCSN": ["Was able to make decisions", "Was unable to make decisions"],
|
45 |
-
"YODSMMDE": ["
|
46 |
"YO_MDEA3": ["No appetite/weight changes", "Had appetite/weight changes"],
|
47 |
"YODPLSIN": ["Never bored/lost interest", "Felt bored/lost interest"],
|
48 |
"YOWRELES": ["Did not eat less than usual", "Ate less than usual"],
|
49 |
"YODSCEV": ["Fewer severe symptoms", "More severe symptoms"],
|
50 |
-
"YOPB2WK": ["No uneasy feelings 2+ weeks", "Had uneasy feelings 2+ weeks"],
|
51 |
-
"YO_MDEA2": ["No
|
52 |
}
|
53 |
|
54 |
def load_models(self):
|
55 |
-
|
56 |
-
for
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
def make_predictions(self, user_input):
|
64 |
"""
|
65 |
-
|
66 |
-
The i-th array corresponds to the i-th model in self.models.
|
67 |
"""
|
68 |
predictions = []
|
69 |
for model in self.models:
|
70 |
-
|
71 |
-
predictions.append(
|
72 |
return predictions
|
73 |
|
74 |
def get_majority_vote(self, predictions):
|
75 |
-
"""
|
76 |
-
Flatten all predictions from all models, combine them,
|
77 |
-
then find the majority class (0 or 1).
|
78 |
-
"""
|
79 |
combined = np.concatenate(predictions)
|
80 |
-
|
81 |
-
return
|
82 |
|
83 |
-
|
84 |
-
|
85 |
if majority_vote_count >= 13:
|
86 |
return "Mental Health Severity: Severe"
|
87 |
elif majority_vote_count >= 9:
|
@@ -91,22 +80,52 @@ class ModelPredictor:
|
|
91 |
else:
|
92 |
return "Mental Health Severity: Very Low"
|
93 |
|
|
|
94 |
|
95 |
######################################
|
96 |
-
#
|
97 |
######################################
|
98 |
def validate_inputs(*args):
|
99 |
for arg in args:
|
100 |
-
if arg
|
101 |
return False
|
102 |
return True
|
103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
######################################
|
106 |
-
#
|
107 |
######################################
|
108 |
-
predictor = ModelPredictor(model_path, model_filenames)
|
109 |
-
|
110 |
def predict(
|
111 |
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
|
112 |
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
|
@@ -114,7 +133,7 @@ def predict(
|
|
114 |
YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
|
115 |
YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
|
116 |
):
|
117 |
-
# Validate
|
118 |
if not validate_inputs(
|
119 |
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
|
120 |
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
|
@@ -133,53 +152,50 @@ def predict(
|
|
133 |
None
|
134 |
)
|
135 |
|
136 |
-
#
|
137 |
-
|
138 |
-
'YNURSMDE': [
|
139 |
-
'YMDEYR': [
|
140 |
-
'YSOCMDE': [
|
141 |
-
'YMDESUD5ANYO': [
|
142 |
-
'YMSUD5YANY': [
|
143 |
-
'YUSUITHK': [
|
144 |
-
'YMDETXRX': [
|
145 |
-
'YUSUITHKYR': [
|
146 |
-
'YMDERSUD5ANY': [
|
147 |
-
'YUSUIPLNYR': [
|
148 |
-
'YCOUNMDE': [
|
149 |
-
'YPSY1MDE': [
|
150 |
-
'YHLTMDE': [
|
151 |
-
'YDOCMDE': [
|
152 |
-
'YPSY2MDE': [
|
153 |
-
'YMDEHARX': [
|
154 |
-
'LVLDIFMEM2': [
|
155 |
-
'MDEIMPY': [
|
156 |
-
'YMDEHPO': [
|
157 |
-
'YMIMS5YANY': [
|
158 |
-
'YMDEIMAD5YR': [
|
159 |
-
'YMIUD5YANY': [
|
160 |
-
'YMDEHPRX': [
|
161 |
-
'YMIMI5YANY': [
|
162 |
-
'YUSUIPLN': [
|
163 |
-
'YTXMDEYR': [
|
164 |
-
'YMDEAUD5YR': [
|
165 |
-
'YRXMDEYR': [
|
166 |
-
'YMDELT': [
|
167 |
}
|
168 |
-
|
169 |
|
170 |
-
#
|
171 |
-
predictions = predictor.make_predictions(
|
172 |
-
|
173 |
-
# 2) Majority vote
|
174 |
majority_vote = predictor.get_majority_vote(predictions)
|
|
|
|
|
|
|
|
|
175 |
|
176 |
-
#
|
177 |
-
num_ones = sum(np.concatenate(predictions) == 1)
|
178 |
-
|
179 |
-
# 4) Severity
|
180 |
-
severity = predictor.evaluate_severity(num_ones)
|
181 |
-
|
182 |
-
# 5) Group textual results
|
183 |
groups = {
|
184 |
"Concentration_and_Decision_Making": ["YOWRCONC", "YOWRDCSN"],
|
185 |
"Sleep_and_Energy_Levels": ["YOWRHRS", "YO_MDEA5", "YOWRELES", "YO_MDEA2"],
|
@@ -191,186 +207,145 @@ def predict(
|
|
191 |
"YOPB2WK"]
|
192 |
}
|
193 |
|
194 |
-
|
195 |
for i, arr in enumerate(predictions):
|
196 |
-
|
197 |
-
|
198 |
-
if
|
199 |
-
|
200 |
else:
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
for gname, gcols in groups.items():
|
205 |
-
if
|
206 |
-
|
207 |
-
|
208 |
break
|
209 |
-
# If not found_group, we do nothing (skip or put in a "misc" group)
|
210 |
-
|
211 |
-
final_str = []
|
212 |
-
for gname, items in grouped_text.items():
|
213 |
-
if items:
|
214 |
-
final_str.append(f"**{gname.replace('_',' ')}**")
|
215 |
-
final_str.append("\n".join(items))
|
216 |
-
final_str.append("\n")
|
217 |
-
final_str = "\n".join(final_str).strip()
|
218 |
-
if not final_str:
|
219 |
-
final_str = "No predictions made. Please check inputs."
|
220 |
-
|
221 |
-
# Additional info
|
222 |
-
total_patients = len(df)
|
223 |
-
total_patient_markdown = (
|
224 |
-
f"### Total Patient Count\nThere are **{total_patients}** patients in the dataset."
|
225 |
-
)
|
226 |
-
|
227 |
-
# A) Bar chart for input features
|
228 |
-
same_val_counts = {}
|
229 |
-
for col, val_list in user_input_data.items():
|
230 |
-
val_ = val_list[0]
|
231 |
-
same_val_counts[col] = len(df[df[col] == val_])
|
232 |
-
bar_input_df = pd.DataFrame({"Feature": list(same_val_counts.keys()),
|
233 |
-
"Count": list(same_val_counts.values())})
|
234 |
-
fig_bar_input = px.bar(
|
235 |
-
bar_input_df, x="Feature", y="Count",
|
236 |
-
title="Number of Patients with Same Input Feature Values"
|
237 |
-
)
|
238 |
-
fig_bar_input.update_layout(width=800, height=500)
|
239 |
|
240 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
label_counts = {}
|
242 |
for i, arr in enumerate(predictions):
|
243 |
-
|
244 |
pred_val = arr[0]
|
245 |
if pred_val in [0,1]:
|
246 |
-
|
247 |
-
|
248 |
if label_counts:
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
else:
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
#
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
for
|
263 |
-
if
|
264 |
continue
|
265 |
-
for
|
266 |
-
if
|
267 |
continue
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
if
|
273 |
-
|
274 |
fig_dist = px.bar(
|
275 |
-
|
276 |
-
x=
|
277 |
y="count",
|
278 |
-
color=
|
279 |
facet_row="feature",
|
280 |
facet_col="label",
|
281 |
-
title="Distribution
|
282 |
)
|
283 |
-
fig_dist.update_layout(width=
|
284 |
else:
|
285 |
fig_dist = px.bar(title="Distribution plot not generated.")
|
286 |
|
287 |
-
#
|
288 |
-
|
289 |
-
|
290 |
-
# We won't produce a co-occurrence plot by default here, so set to None
|
291 |
-
co_occurrence_placeholder = None
|
292 |
|
293 |
-
# Return the 8 outputs
|
294 |
return (
|
295 |
-
final_str,
|
296 |
-
|
297 |
-
|
298 |
-
fig_dist,
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
)
|
304 |
|
305 |
-
|
306 |
-
######################################
|
307 |
-
# 5) Input Mapping
|
308 |
-
######################################
|
309 |
-
input_mapping = {
|
310 |
-
'YNURSMDE': {"Yes": 1, "No": 0},
|
311 |
-
'YMDEYR': {"Yes": 1, "No": 2},
|
312 |
-
'YSOCMDE': {"Yes": 1, "No": 0},
|
313 |
-
'YMDESUD5ANYO': {"SUD only, no MDE": 1, "MDE only, no SUD": 2, "SUD and MDE": 3, "Neither SUD or MDE": 4},
|
314 |
-
'YMSUD5YANY': {"Yes": 1, "No": 0},
|
315 |
-
'YUSUITHK': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
316 |
-
'YMDETXRX': {"Yes": 1, "No": 0},
|
317 |
-
'YUSUITHKYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
318 |
-
'YMDERSUD5ANY': {"Yes": 1, "No": 0},
|
319 |
-
'YUSUIPLNYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
320 |
-
'YCOUNMDE': {"Yes": 1, "No": 0},
|
321 |
-
'YPSY1MDE': {"Yes": 1, "No": 0},
|
322 |
-
'YHLTMDE': {"Yes": 1, "No": 0},
|
323 |
-
'YDOCMDE': {"Yes": 1, "No": 0},
|
324 |
-
'YPSY2MDE': {"Yes": 1, "No": 0},
|
325 |
-
'YMDEHARX': {"Yes": 1, "No": 0},
|
326 |
-
'LVLDIFMEM2': {"No Difficulty": 1, "Some difficulty": 2, "A lot of difficulty or cannot do at all": 3},
|
327 |
-
'MDEIMPY': {"Yes": 1, "No": 2},
|
328 |
-
'YMDEHPO': {"Yes": 1, "No": 0},
|
329 |
-
'YMIMS5YANY': {"Yes": 1, "No": 0},
|
330 |
-
'YMDEIMAD5YR': {"Yes": 1, "No": 0},
|
331 |
-
'YMIUD5YANY': {"Yes": 1, "No": 0},
|
332 |
-
'YMDEHPRX': {"Yes": 1, "No": 0},
|
333 |
-
'YMIMI5YANY': {"Yes": 1, "No": 0},
|
334 |
-
'YUSUIPLN': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
335 |
-
'YTXMDEYR': {"Yes": 1, "No": 0},
|
336 |
-
'YMDEAUD5YR': {"Yes": 1, "No": 0},
|
337 |
-
'YRXMDEYR': {"Yes": 1, "No": 0},
|
338 |
-
'YMDELT': {"Yes": 1, "No": 2}
|
339 |
-
}
|
340 |
-
|
341 |
-
|
342 |
######################################
|
343 |
-
#
|
344 |
######################################
|
345 |
def co_occurrence_plot(feature1, feature2, label_col):
|
346 |
"""
|
347 |
-
|
348 |
"""
|
349 |
-
if not feature1 or not feature2 or not label_col:
|
350 |
return px.bar(title="Please select all three fields.")
|
351 |
if feature1 not in df.columns or feature2 not in df.columns or label_col not in df.columns:
|
352 |
return px.bar(title="Selected columns not found in the dataset.")
|
353 |
|
354 |
-
|
355 |
fig = px.bar(
|
356 |
-
|
357 |
x=feature1,
|
358 |
y="count",
|
359 |
color=label_col,
|
360 |
facet_col=feature2,
|
361 |
-
title=f"Co-
|
362 |
)
|
363 |
-
fig.update_layout(width=
|
364 |
return fig
|
365 |
|
366 |
-
|
367 |
######################################
|
368 |
-
#
|
369 |
######################################
|
370 |
-
with gr.Blocks(css=".gradio-container {max-width:
|
371 |
-
|
372 |
with gr.Tab("Prediction"):
|
373 |
-
#
|
374 |
YMDEYR_dd = gr.Dropdown(list(input_mapping['YMDEYR'].keys()), label="YMDEYR")
|
375 |
YMDERSUD5ANY_dd = gr.Dropdown(list(input_mapping['YMDERSUD5ANY'].keys()), label="YMDERSUD5ANY")
|
376 |
YMDEIMAD5YR_dd = gr.Dropdown(list(input_mapping['YMDEIMAD5YR'].keys()), label="YMDEIMAD5YR")
|
@@ -395,7 +370,7 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
|
|
395 |
YDOCMDE_dd = gr.Dropdown(list(input_mapping['YDOCMDE'].keys()), label="YDOCMDE")
|
396 |
YTXMDEYR_dd = gr.Dropdown(list(input_mapping['YTXMDEYR'].keys()), label="YTXMDEYR")
|
397 |
|
398 |
-
# Suicidal
|
399 |
YUSUITHKYR_dd = gr.Dropdown(list(input_mapping['YUSUITHKYR'].keys()), label="YUSUITHKYR")
|
400 |
YUSUIPLNYR_dd = gr.Dropdown(list(input_mapping['YUSUIPLNYR'].keys()), label="YUSUIPLNYR")
|
401 |
YUSUITHK_dd = gr.Dropdown(list(input_mapping['YUSUITHK'].keys()), label="YUSUITHK")
|
@@ -407,10 +382,10 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
|
|
407 |
YMSUD5YANY_dd = gr.Dropdown(list(input_mapping['YMSUD5YANY'].keys()), label="YMSUD5YANY")
|
408 |
YRXMDEYR_dd = gr.Dropdown(list(input_mapping['YRXMDEYR'].keys()), label="YRXMDEYR")
|
409 |
|
410 |
-
#
|
411 |
predict_btn = gr.Button("Predict")
|
412 |
|
413 |
-
#
|
414 |
out_pred_res = gr.Textbox(label="Prediction Results", lines=8)
|
415 |
out_sev = gr.Textbox(label="Mental Health Severity", lines=2)
|
416 |
out_count = gr.Markdown(label="Total Patient Count")
|
@@ -420,7 +395,7 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
|
|
420 |
out_bar_input = gr.Plot(label="Input Feature Counts")
|
421 |
out_bar_labels = gr.Plot(label="Predicted Label Counts")
|
422 |
|
423 |
-
#
|
424 |
predict_btn.click(
|
425 |
fn=predict,
|
426 |
inputs=[
|
@@ -436,21 +411,20 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
|
|
436 |
]
|
437 |
)
|
438 |
|
439 |
-
# ------------- SECOND TAB (CO-OCCURRENCE) -------------
|
440 |
with gr.Tab("Co-occurrence"):
|
441 |
-
gr.Markdown("##
|
442 |
with gr.Row():
|
443 |
-
|
444 |
-
|
445 |
label_dd = gr.Dropdown(sorted(df.columns), label="Label Column")
|
446 |
-
|
|
|
447 |
|
448 |
-
|
449 |
-
co_occ_btn.click(
|
450 |
fn=co_occurrence_plot,
|
451 |
-
inputs=[
|
452 |
-
outputs=
|
453 |
)
|
454 |
|
455 |
-
#
|
456 |
-
demo.launch(
|
|
|
5 |
import gradio as gr
|
6 |
|
7 |
######################################
|
8 |
+
# 1) LOAD DATA & MODELS
|
9 |
######################################
|
10 |
+
df = pd.read_csv("X_train_Y_Train_merged_train.csv") # Make sure the CSV is present
|
11 |
|
|
|
12 |
model_filenames = [
|
13 |
"YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl",
|
14 |
"YODPPROB.pkl", "YOWRPROB.pkl", "YODPR2WK.pkl", "YOWRDEPR.pkl",
|
|
|
17 |
]
|
18 |
model_path = "models/"
|
19 |
|
|
|
|
|
|
|
|
|
20 |
class ModelPredictor:
|
21 |
def __init__(self, model_path, model_filenames):
|
22 |
self.model_path = model_path
|
23 |
self.model_filenames = model_filenames
|
24 |
self.models = self.load_models()
|
25 |
+
# Mapping from each label column to a list: [meaning_of_0, meaning_of_1]
|
26 |
self.prediction_map = {
|
27 |
+
"YOWRCONC": ["No difficulty concentrating", "Had difficulty concentrating"],
|
28 |
+
"YOSEEDOC": ["Did not feel need for doctor", "Felt need for doctor"],
|
29 |
+
"YOWRHRS": ["No trouble sleeping", "Had trouble sleeping"],
|
30 |
+
"YO_MDEA5": ["No restlessness/lethargy noted", "Others noticed restlessness/lethargy"],
|
31 |
+
"YOWRCHR": ["Did not feel so sad", "Felt so sad that nothing cheered up"],
|
32 |
+
"YOWRLSIN": ["No boredom/loss of interest", "Bored/lost interest in everything"],
|
33 |
+
"YODPPROB": ["No other 2+ week problems", "Had other 2+ week problems"],
|
34 |
+
"YOWRPROB": ["Did not have worst feeling ever", "Had worst time feeling"],
|
35 |
+
"YODPR2WK": ["No 2+ weeks of these feelings", "Had 2+ weeks of these feelings"],
|
36 |
"YOWRDEPR": ["Did not feel depressed mostly everyday", "Felt depressed mostly everyday"],
|
37 |
+
"YODPDISC": ["Mood not depressed overall", "Mood depressed overall discrepancy"],
|
38 |
+
"YOLOSEV": ["No loss of interest in enjoyable things", "Lost interest in enjoyable things"],
|
39 |
"YOWRDCSN": ["Was able to make decisions", "Was unable to make decisions"],
|
40 |
+
"YODSMMDE": ["No 2+ week depression episodes", "Had 2+ week depression episodes"],
|
41 |
"YO_MDEA3": ["No appetite/weight changes", "Had appetite/weight changes"],
|
42 |
"YODPLSIN": ["Never bored/lost interest", "Felt bored/lost interest"],
|
43 |
"YOWRELES": ["Did not eat less than usual", "Ate less than usual"],
|
44 |
"YODSCEV": ["Fewer severe symptoms", "More severe symptoms"],
|
45 |
+
"YOPB2WK": ["No uneasy feelings for 2+ weeks", "Had uneasy feelings for 2+ weeks"],
|
46 |
+
"YO_MDEA2": ["No daily well-being issues", "Daily well-being issues for 2+ weeks"]
|
47 |
}
|
48 |
|
49 |
def load_models(self):
|
50 |
+
loaded = []
|
51 |
+
for fname in self.model_filenames:
|
52 |
+
with open(self.model_path + fname, "rb") as f:
|
53 |
+
model = pickle.load(f)
|
54 |
+
loaded.append(model)
|
55 |
+
return loaded
|
56 |
+
|
57 |
+
def make_predictions(self, user_input: pd.DataFrame):
|
|
|
58 |
"""
|
59 |
+
Return list of arrays, each array is [0] or [1].
|
|
|
60 |
"""
|
61 |
predictions = []
|
62 |
for model in self.models:
|
63 |
+
out = model.predict(user_input)
|
64 |
+
predictions.append(out.flatten())
|
65 |
return predictions
|
66 |
|
67 |
def get_majority_vote(self, predictions):
|
|
|
|
|
|
|
|
|
68 |
combined = np.concatenate(predictions)
|
69 |
+
# find 0 or 1 that is most frequent
|
70 |
+
return np.bincount(combined).argmax()
|
71 |
|
72 |
+
def evaluate_severity(self, majority_vote_count: int) -> str:
|
73 |
+
# Simple thresholds
|
74 |
if majority_vote_count >= 13:
|
75 |
return "Mental Health Severity: Severe"
|
76 |
elif majority_vote_count >= 9:
|
|
|
80 |
else:
|
81 |
return "Mental Health Severity: Very Low"
|
82 |
|
83 |
+
predictor = ModelPredictor(model_path, model_filenames)
|
84 |
|
85 |
######################################
|
86 |
+
# 2) VALIDATION, INPUT MAPPING
|
87 |
######################################
|
88 |
def validate_inputs(*args):
|
89 |
for arg in args:
|
90 |
+
if not arg: # empty or None
|
91 |
return False
|
92 |
return True
|
93 |
|
94 |
+
input_mapping = {
|
95 |
+
'YNURSMDE': {"Yes": 1, "No": 0},
|
96 |
+
'YMDEYR': {"Yes": 1, "No": 2},
|
97 |
+
'YSOCMDE': {"Yes": 1, "No": 0},
|
98 |
+
'YMDESUD5ANYO': {"SUD only, no MDE": 1, "MDE only, no SUD": 2, "SUD and MDE": 3, "Neither SUD or MDE": 4},
|
99 |
+
'YMSUD5YANY': {"Yes": 1, "No": 0},
|
100 |
+
'YUSUITHK': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
101 |
+
'YMDETXRX': {"Yes": 1, "No": 0},
|
102 |
+
'YUSUITHKYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
103 |
+
'YMDERSUD5ANY': {"Yes": 1, "No": 0},
|
104 |
+
'YUSUIPLNYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
105 |
+
'YCOUNMDE': {"Yes": 1, "No": 0},
|
106 |
+
'YPSY1MDE': {"Yes": 1, "No": 0},
|
107 |
+
'YHLTMDE': {"Yes": 1, "No": 0},
|
108 |
+
'YDOCMDE': {"Yes": 1, "No": 0},
|
109 |
+
'YPSY2MDE': {"Yes": 1, "No": 0},
|
110 |
+
'YMDEHARX': {"Yes": 1, "No": 0},
|
111 |
+
'LVLDIFMEM2': {"No Difficulty": 1, "Some difficulty": 2, "A lot of difficulty or cannot do at all": 3},
|
112 |
+
'MDEIMPY': {"Yes": 1, "No": 2},
|
113 |
+
'YMDEHPO': {"Yes": 1, "No": 0},
|
114 |
+
'YMIMS5YANY': {"Yes": 1, "No": 0},
|
115 |
+
'YMDEIMAD5YR': {"Yes": 1, "No": 0},
|
116 |
+
'YMIUD5YANY': {"Yes": 1, "No": 0},
|
117 |
+
'YMDEHPRX': {"Yes": 1, "No": 0},
|
118 |
+
'YMIMI5YANY': {"Yes": 1, "No": 0},
|
119 |
+
'YUSUIPLN': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
120 |
+
'YTXMDEYR': {"Yes": 1, "No": 0},
|
121 |
+
'YMDEAUD5YR': {"Yes": 1, "No": 0},
|
122 |
+
'YRXMDEYR': {"Yes": 1, "No": 0},
|
123 |
+
'YMDELT': {"Yes": 1, "No": 2}
|
124 |
+
}
|
125 |
|
126 |
######################################
|
127 |
+
# 3) PREDICT FUNCTION
|
128 |
######################################
|
|
|
|
|
129 |
def predict(
|
130 |
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
|
131 |
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
|
|
|
133 |
YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
|
134 |
YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
|
135 |
):
|
136 |
+
# 1) Validate
|
137 |
if not validate_inputs(
|
138 |
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
|
139 |
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
|
|
|
152 |
None
|
153 |
)
|
154 |
|
155 |
+
# 2) Map user-friendly -> numeric
|
156 |
+
user_input_dict = {
|
157 |
+
'YNURSMDE': input_mapping['YNURSMDE'][YNURSMDE],
|
158 |
+
'YMDEYR': input_mapping['YMDEYR'][YMDEYR],
|
159 |
+
'YSOCMDE': input_mapping['YSOCMDE'][YSOCMDE],
|
160 |
+
'YMDESUD5ANYO': input_mapping['YMDESUD5ANYO'][YMDESUD5ANYO],
|
161 |
+
'YMSUD5YANY': input_mapping['YMSUD5YANY'][YMSUD5YANY],
|
162 |
+
'YUSUITHK': input_mapping['YUSUITHK'][YUSUITHK],
|
163 |
+
'YMDETXRX': input_mapping['YMDETXRX'][YMDETXRX],
|
164 |
+
'YUSUITHKYR': input_mapping['YUSUITHKYR'][YUSUITHKYR],
|
165 |
+
'YMDERSUD5ANY': input_mapping['YMDERSUD5ANY'][YMDERSUD5ANY],
|
166 |
+
'YUSUIPLNYR': input_mapping['YUSUIPLNYR'][YUSUIPLNYR],
|
167 |
+
'YCOUNMDE': input_mapping['YCOUNMDE'][YCOUNMDE],
|
168 |
+
'YPSY1MDE': input_mapping['YPSY1MDE'][YPSY1MDE],
|
169 |
+
'YHLTMDE': input_mapping['YHLTMDE'][YHLTMDE],
|
170 |
+
'YDOCMDE': input_mapping['YDOCMDE'][YDOCMDE],
|
171 |
+
'YPSY2MDE': input_mapping['YPSY2MDE'][YPSY2MDE],
|
172 |
+
'YMDEHARX': input_mapping['YMDEHARX'][YMDEHARX],
|
173 |
+
'LVLDIFMEM2': input_mapping['LVLDIFMEM2'][LVLDIFMEM2],
|
174 |
+
'MDEIMPY': input_mapping['MDEIMPY'][MDEIMPY],
|
175 |
+
'YMDEHPO': input_mapping['YMDEHPO'][YMDEHPO],
|
176 |
+
'YMIMS5YANY': input_mapping['YMIMS5YANY'][YMIMS5YANY],
|
177 |
+
'YMDEIMAD5YR': input_mapping['YMDEIMAD5YR'][YMDEIMAD5YR],
|
178 |
+
'YMIUD5YANY': input_mapping['YMIUD5YANY'][YMIUD5YANY],
|
179 |
+
'YMDEHPRX': input_mapping['YMDEHPRX'][YMDEHPRX],
|
180 |
+
'YMIMI5YANY': input_mapping['YMIMI5YANY'][YMIMI5YANY],
|
181 |
+
'YUSUIPLN': input_mapping['YUSUIPLN'][YUSUIPLN],
|
182 |
+
'YTXMDEYR': input_mapping['YTXMDEYR'][YTXMDEYR],
|
183 |
+
'YMDEAUD5YR': input_mapping['YMDEAUD5YR'][YMDEAUD5YR],
|
184 |
+
'YRXMDEYR': input_mapping['YRXMDEYR'][YRXMDEYR],
|
185 |
+
'YMDELT': input_mapping['YMDELT'][YMDELT]
|
186 |
}
|
187 |
+
user_df = pd.DataFrame(user_input_dict, index=[0])
|
188 |
|
189 |
+
# 3) Make predictions
|
190 |
+
predictions = predictor.make_predictions(user_df)
|
191 |
+
# majority
|
|
|
192 |
majority_vote = predictor.get_majority_vote(predictions)
|
193 |
+
# how many are '1'
|
194 |
+
count_ones = sum(np.concatenate(predictions) == 1)
|
195 |
+
# severity
|
196 |
+
severity_msg = predictor.evaluate_severity(count_ones)
|
197 |
|
198 |
+
# 4) Format textual results for each group (just as an example)
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
groups = {
|
200 |
"Concentration_and_Decision_Making": ["YOWRCONC", "YOWRDCSN"],
|
201 |
"Sleep_and_Energy_Levels": ["YOWRHRS", "YO_MDEA5", "YOWRELES", "YO_MDEA2"],
|
|
|
207 |
"YOPB2WK"]
|
208 |
}
|
209 |
|
210 |
+
group_text = {g: [] for g in groups}
|
211 |
for i, arr in enumerate(predictions):
|
212 |
+
label_col = model_filenames[i].split('.')[0] # e.g. 'YOWRCONC'
|
213 |
+
val = arr[0]
|
214 |
+
if label_col in predictor.prediction_map and val in [0,1]:
|
215 |
+
text_label = predictor.prediction_map[label_col][val]
|
216 |
else:
|
217 |
+
text_label = f"Prediction={val}"
|
218 |
+
# see which group
|
219 |
+
found = False
|
220 |
for gname, gcols in groups.items():
|
221 |
+
if label_col in gcols:
|
222 |
+
group_text[gname].append(f"{label_col} => {text_label}")
|
223 |
+
found = True
|
224 |
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
|
226 |
+
# build final results
|
227 |
+
final_str_parts = []
|
228 |
+
for gname, lines in group_text.items():
|
229 |
+
if lines:
|
230 |
+
final_str_parts.append(f"**{gname.replace('_',' ')}**")
|
231 |
+
final_str_parts.append("\n".join(lines))
|
232 |
+
final_str_parts.append("")
|
233 |
+
if not final_str_parts:
|
234 |
+
final_str = "No predictions made or no matching group columns."
|
235 |
+
else:
|
236 |
+
final_str = "\n".join(final_str_parts)
|
237 |
+
|
238 |
+
# 5) Additional features
|
239 |
+
# total patients
|
240 |
+
total_count = len(df)
|
241 |
+
total_count_md = f"### Total Patient Count\nWe have **{total_count}** patients in the dataset."
|
242 |
+
|
243 |
+
# bar chart for input features
|
244 |
+
input_counts = {}
|
245 |
+
for col, val_ in user_input_dict.items():
|
246 |
+
# only 1 item
|
247 |
+
v = val_
|
248 |
+
# how many have that value?
|
249 |
+
matched = len(df[df[col] == v])
|
250 |
+
input_counts[col] = matched
|
251 |
+
bar_in_df = pd.DataFrame({"Feature": list(input_counts.keys()),
|
252 |
+
"Count": list(input_counts.values())})
|
253 |
+
fig_in = px.bar(bar_in_df, x="Feature", y="Count",
|
254 |
+
title="Number of Patients with Same Input Feature Values")
|
255 |
+
fig_in.update_layout(width=700, height=400)
|
256 |
+
|
257 |
+
# bar chart for predicted labels
|
258 |
label_counts = {}
|
259 |
for i, arr in enumerate(predictions):
|
260 |
+
lblcol = model_filenames[i].split('.')[0]
|
261 |
pred_val = arr[0]
|
262 |
if pred_val in [0,1]:
|
263 |
+
# how many in df have that label?
|
264 |
+
label_counts[lblcol] = len(df[df[lblcol] == pred_val])
|
265 |
if label_counts:
|
266 |
+
bar_lbl_df = pd.DataFrame({"Label": list(label_counts.keys()),
|
267 |
+
"Count": list(label_counts.values())})
|
268 |
+
fig_lbl = px.bar(bar_lbl_df, x="Label", y="Count",
|
269 |
+
title="Number of Patients with the Same Predicted Label")
|
270 |
+
fig_lbl.update_layout(width=700, height=400)
|
271 |
else:
|
272 |
+
fig_lbl = px.bar(title="No valid predicted labels to display.")
|
273 |
+
fig_lbl.update_layout(width=700, height=400)
|
274 |
+
|
275 |
+
# distribution plot (just a small sample)
|
276 |
+
feat_sample = list(user_input_dict.keys())[:3]
|
277 |
+
label_sample = [mf.split('.')[0] for mf in model_filenames[:2]]
|
278 |
+
rows = []
|
279 |
+
for f_ in feat_sample:
|
280 |
+
if f_ not in df.columns:
|
281 |
continue
|
282 |
+
for l_ in label_sample:
|
283 |
+
if l_ not in df.columns:
|
284 |
continue
|
285 |
+
sub_g = df.groupby([f_, l_]).size().reset_index(name="count")
|
286 |
+
sub_g["feature"] = f_
|
287 |
+
sub_g["label"] = l_
|
288 |
+
rows.append(sub_g)
|
289 |
+
if rows:
|
290 |
+
big_df = pd.concat(rows, ignore_index=True)
|
291 |
fig_dist = px.bar(
|
292 |
+
big_df,
|
293 |
+
x=big_df.columns[0], # feature value
|
294 |
y="count",
|
295 |
+
color=big_df.columns[1], # label value
|
296 |
facet_row="feature",
|
297 |
facet_col="label",
|
298 |
+
title="Distribution (Sample Input Features vs Sample Labels)"
|
299 |
)
|
300 |
+
fig_dist.update_layout(width=900, height=600)
|
301 |
else:
|
302 |
fig_dist = px.bar(title="Distribution plot not generated.")
|
303 |
|
304 |
+
# nearest neighbors or co-occ placeholder
|
305 |
+
nn_md = "Nearest neighbors / advanced metrics not implemented in this version."
|
306 |
+
co_occ_placeholder = None
|
|
|
|
|
307 |
|
|
|
308 |
return (
|
309 |
+
final_str, # 1) Prediction Results
|
310 |
+
severity_msg, # 2) Mental Health Severity
|
311 |
+
total_count_md, # 3) Total Patient Count
|
312 |
+
fig_dist, # 4) Distribution Plot
|
313 |
+
nn_md, # 5) Nearest Neighbors (Markdown)
|
314 |
+
co_occ_placeholder, # 6) Co-occurrence Plot
|
315 |
+
fig_in, # 7) Bar Chart for input features
|
316 |
+
fig_lbl # 8) Bar Chart for predicted labels
|
317 |
)
|
318 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
319 |
######################################
|
320 |
+
# 4) CO-OCCURRENCE FUNCTION
|
321 |
######################################
|
322 |
def co_occurrence_plot(feature1, feature2, label_col):
|
323 |
"""
|
324 |
+
Create a bar chart for co-occurrence among feature1, feature2, and label_col.
|
325 |
"""
|
326 |
+
if (not feature1) or (not feature2) or (not label_col):
|
327 |
return px.bar(title="Please select all three fields.")
|
328 |
if feature1 not in df.columns or feature2 not in df.columns or label_col not in df.columns:
|
329 |
return px.bar(title="Selected columns not found in the dataset.")
|
330 |
|
331 |
+
grouped = df.groupby([feature1, feature2, label_col]).size().reset_index(name="count")
|
332 |
fig = px.bar(
|
333 |
+
grouped,
|
334 |
x=feature1,
|
335 |
y="count",
|
336 |
color=label_col,
|
337 |
facet_col=feature2,
|
338 |
+
title=f"Co-occurrence: {feature1}, {feature2} vs {label_col}"
|
339 |
)
|
340 |
+
fig.update_layout(width=900, height=600)
|
341 |
return fig
|
342 |
|
|
|
343 |
######################################
|
344 |
+
# 5) BUILD GRADIO UI
|
345 |
######################################
|
346 |
+
with gr.Blocks(css=".gradio-container {max-width: 1100px;}") as demo:
|
|
|
347 |
with gr.Tab("Prediction"):
|
348 |
+
# Input fields in the same order as predict(...)
|
349 |
YMDEYR_dd = gr.Dropdown(list(input_mapping['YMDEYR'].keys()), label="YMDEYR")
|
350 |
YMDERSUD5ANY_dd = gr.Dropdown(list(input_mapping['YMDERSUD5ANY'].keys()), label="YMDERSUD5ANY")
|
351 |
YMDEIMAD5YR_dd = gr.Dropdown(list(input_mapping['YMDEIMAD5YR'].keys()), label="YMDEIMAD5YR")
|
|
|
370 |
YDOCMDE_dd = gr.Dropdown(list(input_mapping['YDOCMDE'].keys()), label="YDOCMDE")
|
371 |
YTXMDEYR_dd = gr.Dropdown(list(input_mapping['YTXMDEYR'].keys()), label="YTXMDEYR")
|
372 |
|
373 |
+
# Suicidal
|
374 |
YUSUITHKYR_dd = gr.Dropdown(list(input_mapping['YUSUITHKYR'].keys()), label="YUSUITHKYR")
|
375 |
YUSUIPLNYR_dd = gr.Dropdown(list(input_mapping['YUSUIPLNYR'].keys()), label="YUSUIPLNYR")
|
376 |
YUSUITHK_dd = gr.Dropdown(list(input_mapping['YUSUITHK'].keys()), label="YUSUITHK")
|
|
|
382 |
YMSUD5YANY_dd = gr.Dropdown(list(input_mapping['YMSUD5YANY'].keys()), label="YMSUD5YANY")
|
383 |
YRXMDEYR_dd = gr.Dropdown(list(input_mapping['YRXMDEYR'].keys()), label="YRXMDEYR")
|
384 |
|
385 |
+
# Button
|
386 |
predict_btn = gr.Button("Predict")
|
387 |
|
388 |
+
# 8 outputs
|
389 |
out_pred_res = gr.Textbox(label="Prediction Results", lines=8)
|
390 |
out_sev = gr.Textbox(label="Mental Health Severity", lines=2)
|
391 |
out_count = gr.Markdown(label="Total Patient Count")
|
|
|
395 |
out_bar_input = gr.Plot(label="Input Feature Counts")
|
396 |
out_bar_labels = gr.Plot(label="Predicted Label Counts")
|
397 |
|
398 |
+
# Connect
|
399 |
predict_btn.click(
|
400 |
fn=predict,
|
401 |
inputs=[
|
|
|
411 |
]
|
412 |
)
|
413 |
|
|
|
414 |
with gr.Tab("Co-occurrence"):
|
415 |
+
gr.Markdown("## Co-Occurrence Plot\nSelect two features + one label to see a distribution.")
|
416 |
with gr.Row():
|
417 |
+
feat1_dd = gr.Dropdown(sorted(df.columns), label="Feature 1")
|
418 |
+
feat2_dd = gr.Dropdown(sorted(df.columns), label="Feature 2")
|
419 |
label_dd = gr.Dropdown(sorted(df.columns), label="Label Column")
|
420 |
+
generate_btn = gr.Button("Generate Plot")
|
421 |
+
co_occ_output = gr.Plot()
|
422 |
|
423 |
+
generate_btn.click(
|
|
|
424 |
fn=co_occurrence_plot,
|
425 |
+
inputs=[feat1_dd, feat2_dd, label_dd],
|
426 |
+
outputs=co_occ_output
|
427 |
)
|
428 |
|
429 |
+
# Launch
|
430 |
+
demo.launch()
|