pantdipendra
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -1,11 +1,10 @@
|
|
1 |
import pickle
|
2 |
-
|
3 |
import gradio as gr
|
4 |
import numpy as np
|
5 |
import pandas as pd
|
6 |
import plotly.express as px
|
7 |
|
8 |
-
# Load the training CSV once
|
9 |
df = pd.read_csv("X_train_Y_Train_merged_train.csv")
|
10 |
|
11 |
######################################
|
@@ -17,74 +16,52 @@ class ModelPredictor:
|
|
17 |
self.model_filenames = model_filenames
|
18 |
self.models = self.load_models()
|
19 |
# Mapping from label column to human-readable strings for 0/1
|
20 |
-
# (Adjust as needed for the columns you actually have.)
|
21 |
self.prediction_map = {
|
22 |
-
"YOWRCONC": ["
|
23 |
-
"YOSEEDOC": ["
|
24 |
-
"YOWRHRS": ["
|
25 |
-
"YO_MDEA5": ["Others
|
26 |
-
"YOWRCHR": ["
|
27 |
-
"YOWRLSIN": ["
|
28 |
-
|
29 |
-
"
|
30 |
-
"
|
31 |
-
"
|
32 |
-
|
33 |
-
"
|
34 |
-
"
|
35 |
-
|
36 |
-
"
|
37 |
-
|
38 |
-
"
|
39 |
-
"
|
40 |
-
|
41 |
-
"
|
42 |
-
"Experienced changes in appetite or weight"],
|
43 |
-
"YODPLSIN": ["Never lost interest and felt bored", "Lost interest and felt bored"],
|
44 |
-
"YOWRELES": ["Did not eat less than usual", "Ate less than usual"],
|
45 |
-
"YODSCEV": ["Had fewer severe symptoms of depression", "Had more severe symptoms of depression"],
|
46 |
-
"YOPB2WK": ["Did not experience uneasy feelings lasting every day for 2+ weeks or longer",
|
47 |
-
"Experienced uneasy feelings lasting every day for 2+ weeks or longer"],
|
48 |
-
"YO_MDEA2": ["Did not have issues with physical and mental well-being every day for 2 weeks or longer",
|
49 |
-
"Had issues with physical and mental well-being every day for 2 weeks or longer"]
|
50 |
}
|
51 |
|
52 |
def load_models(self):
|
53 |
models = []
|
54 |
-
for
|
55 |
-
filepath = self.model_path +
|
56 |
-
with open(filepath,
|
57 |
-
|
58 |
-
models.append(model)
|
59 |
return models
|
60 |
|
61 |
def make_predictions(self, user_input):
|
62 |
-
"""
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
pred = model.predict(user_input)
|
69 |
-
pred = np.array(pred).flatten()
|
70 |
-
predictions.append(pred)
|
71 |
-
return predictions
|
72 |
|
73 |
def get_majority_vote(self, predictions):
|
74 |
-
"""
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
combined_predictions = np.concatenate(predictions)
|
79 |
-
majority_vote = np.bincount(combined_predictions).argmax()
|
80 |
-
return majority_vote
|
81 |
-
|
82 |
-
# Based on Equal Interval and Percentage-Based Method
|
83 |
-
# Severe: 13 to 16 votes (upper 25%)
|
84 |
-
# Moderate: 9 to 12 votes (upper-middle 25%)
|
85 |
-
# Low: 5 to 8 votes (lower-middle 25%)
|
86 |
-
# Very Low: 0 to 4 votes (lower 25%)
|
87 |
def evaluate_severity(self, majority_vote_count):
|
|
|
88 |
if majority_vote_count >= 13:
|
89 |
return "Mental health severity: Severe"
|
90 |
elif majority_vote_count >= 9:
|
@@ -95,7 +72,7 @@ class ModelPredictor:
|
|
95 |
return "Mental health severity: Very Low"
|
96 |
|
97 |
######################################
|
98 |
-
# 2)
|
99 |
######################################
|
100 |
model_filenames = [
|
101 |
"YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl",
|
@@ -110,22 +87,36 @@ predictor = ModelPredictor(model_path, model_filenames)
|
|
110 |
# 3) INPUT VALIDATION
|
111 |
######################################
|
112 |
def validate_inputs(*args):
|
|
|
113 |
for arg in args:
|
114 |
-
if arg == '' or arg is None:
|
115 |
return False
|
116 |
return True
|
117 |
|
118 |
######################################
|
119 |
-
# 4)
|
120 |
######################################
|
121 |
def predict(
|
|
|
122 |
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
|
123 |
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
|
124 |
YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE,
|
125 |
YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
|
126 |
-
YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
|
|
|
|
|
127 |
):
|
128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
user_input_data = {
|
130 |
'YNURSMDE': [int(YNURSMDE)],
|
131 |
'YMDEYR': [int(YMDEYR)],
|
@@ -159,29 +150,21 @@ def predict(
|
|
159 |
}
|
160 |
user_input = pd.DataFrame(user_input_data)
|
161 |
|
162 |
-
#
|
163 |
predictions = predictor.make_predictions(user_input)
|
164 |
-
|
165 |
-
# 2) Calculate majority vote (0 or 1) across all models
|
166 |
majority_vote = predictor.get_majority_vote(predictions)
|
167 |
-
|
168 |
-
# 3) Count how many 1's in all predictions combined
|
169 |
-
majority_vote_count = sum([1 for pred in np.concatenate(predictions) if pred == 1])
|
170 |
-
|
171 |
-
# 4) Evaluate severity
|
172 |
severity = predictor.evaluate_severity(majority_vote_count)
|
173 |
|
174 |
-
#
|
175 |
-
|
176 |
-
results = {
|
177 |
"Concentration_and_Decision_Making": [],
|
178 |
"Sleep_and_Energy_Levels": [],
|
179 |
"Mood_and_Emotional_State": [],
|
180 |
"Appetite_and_Weight_Changes": [],
|
181 |
"Duration_and_Severity_of_Depression_Symptoms": []
|
182 |
}
|
183 |
-
|
184 |
-
prediction_groups = {
|
185 |
"Concentration_and_Decision_Making": ["YOWRCONC", "YOWRDCSN"],
|
186 |
"Sleep_and_Energy_Levels": ["YOWRHRS", "YO_MDEA5", "YOWRELES", "YO_MDEA2"],
|
187 |
"Mood_and_Emotional_State": ["YOWRCHR", "YOWRLSIN", "YOWRDEPR", "YODPDISC",
|
@@ -192,313 +175,198 @@ def predict(
|
|
192 |
"YOPB2WK"]
|
193 |
}
|
194 |
|
195 |
-
#
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
if
|
201 |
-
|
|
|
202 |
else:
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
found_group = True
|
212 |
break
|
213 |
-
if not
|
214 |
-
# If it
|
215 |
pass
|
216 |
|
217 |
-
# Format
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
severity += " (Unknown prediction count is high. Please consult with a human.)"
|
232 |
-
|
233 |
-
# =============== ADDITIONAL FEATURES ===============
|
234 |
-
|
235 |
-
# A) Total Patient Count
|
236 |
total_patients = len(df)
|
237 |
-
|
238 |
"### Total Patient Count\n"
|
239 |
-
f"
|
240 |
-
"All subsequent analyses refer to these patients."
|
241 |
)
|
242 |
|
243 |
-
# B) Bar
|
244 |
input_counts = {}
|
245 |
-
for
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
"
|
252 |
-
|
253 |
-
|
254 |
-
bar_input_data,
|
255 |
-
x="Feature",
|
256 |
-
y="Count",
|
257 |
-
title="Number of Patients with the Same Value for Each Input Feature",
|
258 |
-
labels={"Feature": "Input Feature", "Count": "Number of Patients"}
|
259 |
)
|
260 |
-
|
261 |
|
262 |
-
# C) Bar
|
263 |
label_counts = {}
|
264 |
-
for i,
|
265 |
-
|
266 |
-
|
267 |
-
if
|
268 |
-
label_counts[
|
|
|
269 |
if len(label_counts) > 0:
|
270 |
-
|
271 |
-
"
|
272 |
"Count": list(label_counts.values())
|
273 |
})
|
274 |
-
|
275 |
-
|
276 |
-
x="
|
277 |
y="Count",
|
278 |
-
title="Number of Patients with the Same Predicted Label"
|
279 |
-
labels={"Model": "Predicted Column", "Count": "Patient Count"}
|
280 |
)
|
281 |
else:
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
#
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
demonstration_labels = [fn.split('.')[0] for fn in model_filenames[:3]] # first 3 labels as a sample
|
291 |
-
|
292 |
-
# We'll build a single figure with "facet_col" = label and "facet_row" = feature (small sample)
|
293 |
-
# The approach: for each (feature, label), group by (feature_value, label_value) -> count.
|
294 |
-
# Then we combine them into one big DataFrame with "feature" & "label" columns for Plotly facets.
|
295 |
-
dist_rows = []
|
296 |
-
for feat in demonstration_features:
|
297 |
if feat not in df.columns:
|
298 |
continue
|
299 |
-
for lbl in
|
300 |
if lbl not in df.columns:
|
301 |
continue
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
if len(
|
307 |
-
big_dist_df = pd.concat(
|
308 |
-
# We can re-map numeric to user-friendly text for "feat" if desired, but each feature might have a different mapping.
|
309 |
-
# For now, we just show numeric codes. Real usage would do a reverse mapping if feasible.
|
310 |
-
|
311 |
-
# For the label (0,1), we can map to short strings if we want (like "Label0" / "Label1"), or a direct numeric.
|
312 |
fig_dist = px.bar(
|
313 |
big_dist_df,
|
314 |
-
x=big_dist_df.columns[0],
|
315 |
y="count",
|
316 |
-
color=big_dist_df.columns[1],
|
317 |
facet_row="feature",
|
318 |
facet_col="label",
|
319 |
-
title="Distribution
|
320 |
-
labels={
|
321 |
-
big_dist_df.columns[0]: "Feature Value",
|
322 |
-
big_dist_df.columns[1]: "Label Value"
|
323 |
-
}
|
324 |
)
|
325 |
-
fig_dist.update_layout(height=
|
326 |
else:
|
327 |
-
fig_dist = px.bar(title="No distribution plot
|
328 |
-
|
329 |
-
# E) Nearest Neighbors
|
330 |
-
#
|
331 |
-
#
|
332 |
-
#
|
333 |
-
#
|
334 |
-
#
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
# We'll build it after the code block below for each column.
|
340 |
-
|
341 |
-
# 2. Invert label mappings from predictor.prediction_map if needed
|
342 |
-
# For each label column, 0 => first string, 1 => second string
|
343 |
-
# We'll store them in a dict: reverse_label_mapping[label_col][0 or 1] => string
|
344 |
-
reverse_label_mapping = {}
|
345 |
-
for lbl, str_list in predictor.prediction_map.items():
|
346 |
-
# str_list[0] => for 0, str_list[1] => for 1
|
347 |
-
reverse_label_mapping[lbl] = {
|
348 |
-
0: str_list[0],
|
349 |
-
1: str_list[1]
|
350 |
-
}
|
351 |
-
|
352 |
-
# Build the reverse input mapping from the provided dictionary
|
353 |
-
# We'll define that dictionary below to ensure we can invert it:
|
354 |
-
input_mapping = {
|
355 |
-
'YNURSMDE': {"Yes": 1, "No": 0},
|
356 |
-
'YMDEYR': {"Yes": 1, "No": 2},
|
357 |
-
'YSOCMDE': {"Yes": 1, "No": 0},
|
358 |
-
'YMDESUD5ANYO': {"SUD only, no MDE": 1, "MDE only, no SUD": 2, "SUD and MDE": 3, "Neither SUD or MDE": 4},
|
359 |
-
'YMSUD5YANY': {"Yes": 1, "No": 0},
|
360 |
-
'YUSUITHK': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
361 |
-
'YMDETXRX': {"Yes": 1, "No": 0},
|
362 |
-
'YUSUITHKYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
363 |
-
'YMDERSUD5ANY': {"Yes": 1, "No": 0},
|
364 |
-
'YUSUIPLNYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
365 |
-
'YCOUNMDE': {"Yes": 1, "No": 0},
|
366 |
-
'YPSY1MDE': {"Yes": 1, "No": 0},
|
367 |
-
'YHLTMDE': {"Yes": 1, "No": 0},
|
368 |
-
'YDOCMDE': {"Yes": 1, "No": 0},
|
369 |
-
'YPSY2MDE': {"Yes": 1, "No": 0},
|
370 |
-
'YMDEHARX': {"Yes": 1, "No": 0},
|
371 |
-
'LVLDIFMEM2': {"No Difficulty": 1, "Some difficulty": 2, "A lot of difficulty or cannot do at all": 3},
|
372 |
-
'MDEIMPY': {"Yes": 1, "No": 2},
|
373 |
-
'YMDEHPO': {"Yes": 1, "No": 0},
|
374 |
-
'YMIMS5YANY': {"Yes": 1, "No": 0},
|
375 |
-
'YMDEIMAD5YR': {"Yes": 1, "No": 0},
|
376 |
-
'YMIUD5YANY': {"Yes": 1, "No": 0},
|
377 |
-
'YMDEHPRX': {"Yes": 1, "No": 0},
|
378 |
-
'YMIMI5YANY': {"Yes": 1, "No": 0},
|
379 |
-
'YUSUIPLN': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
380 |
-
'YTXMDEYR': {"Yes": 1, "No": 0},
|
381 |
-
'YMDEAUD5YR': {"Yes": 1, "No": 0},
|
382 |
-
'YRXMDEYR': {"Yes": 1, "No": 0},
|
383 |
-
'YMDELT': {"Yes": 1, "No": 2}
|
384 |
-
}
|
385 |
-
|
386 |
-
# Build the reverse mapping for each column
|
387 |
-
for col, fwd_map in input_mapping.items():
|
388 |
-
reverse_input_mapping[col] = {v: k for k, v in fwd_map.items()}
|
389 |
-
|
390 |
-
# 3. Calculate Hamming distance for each row
|
391 |
-
# We'll consider the columns in user_input for comparison
|
392 |
-
features_to_compare = list(user_input.columns)
|
393 |
-
subset_df = df[features_to_compare].copy()
|
394 |
-
user_series = user_input.iloc[0]
|
395 |
-
|
396 |
distances = []
|
397 |
-
for idx,
|
398 |
-
|
399 |
-
distances.append(
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
for
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
"### Nearest Neighbors (Simple Hamming Distance)\n"
|
435 |
-
f"We searched for the top **{K}** patients whose features most closely match your input.\n\n"
|
436 |
-
"> **Note**: “Nearest neighbor” methods for high-dimensional or purely categorical data can be non-trivial. "
|
437 |
-
"This demo simply uses a Hamming distance over all input features and picks K=5 neighbors. "
|
438 |
-
"In a real application, you would refine which features are most relevant, how to encode them, "
|
439 |
-
"and how many neighbors to select.\n\n"
|
440 |
-
"Below is a brief overview of each neighbor's input-feature values and one example label (`YOWRCONC`).\n\n"
|
441 |
-
+ "\n".join(nn_rows)
|
442 |
-
)
|
443 |
-
|
444 |
-
# F) Co-occurrence Plot from the previous example (kept for completeness)
|
445 |
-
if all(col in df.columns for col in ["YMDEYR", "YMDERSUD5ANY", "YOWRCONC"]):
|
446 |
-
co_occ_data = df.groupby(["YMDEYR", "YMDERSUD5ANY", "YOWRCONC"]).size().reset_index(name="count")
|
447 |
-
fig_co_occ = px.bar(
|
448 |
-
co_occ_data,
|
449 |
-
x="YMDEYR",
|
450 |
-
y="count",
|
451 |
-
color="YOWRCONC",
|
452 |
-
facet_col="YMDERSUD5ANY",
|
453 |
-
title="Co-Occurrence Plot: YMDEYR and YMDERSUD5ANY vs YOWRCONC"
|
454 |
-
)
|
455 |
else:
|
456 |
-
fig_co_occ = px.bar(title="
|
457 |
-
|
458 |
-
#
|
459 |
-
# RETURN EVERYTHING
|
460 |
-
# We have 8 outputs:
|
461 |
-
# 1) Prediction Results (Textbox)
|
462 |
-
# 2) Mental Health Severity (Textbox)
|
463 |
-
# 3) Total Patient Count (Markdown)
|
464 |
-
# 4) Distribution Plot (for multiple input features vs. multiple labels)
|
465 |
-
# 5) Nearest Neighbors Summary (Markdown)
|
466 |
-
# 6) Co-Occurrence Plot
|
467 |
-
# 7) Bar Chart for input features
|
468 |
-
# 8) Bar Chart for predicted labels
|
469 |
-
# =======================
|
470 |
return (
|
471 |
-
|
472 |
-
severity,
|
473 |
-
|
474 |
-
fig_dist,
|
475 |
-
|
476 |
-
fig_co_occ,
|
477 |
-
|
478 |
-
|
479 |
)
|
480 |
|
481 |
######################################
|
482 |
-
# 5) MAPPING user
|
483 |
######################################
|
484 |
input_mapping = {
|
485 |
'YNURSMDE': {"Yes": 1, "No": 0},
|
486 |
'YMDEYR': {"Yes": 1, "No": 2},
|
487 |
'YSOCMDE': {"Yes": 1, "No": 0},
|
488 |
-
'YMDESUD5ANYO': {"SUD only
|
489 |
'YMSUD5YANY': {"Yes": 1, "No": 0},
|
490 |
-
'YUSUITHK': {"Yes": 1, "No": 2, "
|
491 |
'YMDETXRX': {"Yes": 1, "No": 0},
|
492 |
-
'YUSUITHKYR': {"Yes": 1, "No": 2, "
|
493 |
'YMDERSUD5ANY': {"Yes": 1, "No": 0},
|
494 |
-
'YUSUIPLNYR': {"Yes": 1, "No": 2, "
|
495 |
'YCOUNMDE': {"Yes": 1, "No": 0},
|
496 |
'YPSY1MDE': {"Yes": 1, "No": 0},
|
497 |
'YHLTMDE': {"Yes": 1, "No": 0},
|
498 |
'YDOCMDE': {"Yes": 1, "No": 0},
|
499 |
'YPSY2MDE': {"Yes": 1, "No": 0},
|
500 |
'YMDEHARX': {"Yes": 1, "No": 0},
|
501 |
-
'LVLDIFMEM2': {"No Difficulty": 1, "Some
|
502 |
'MDEIMPY': {"Yes": 1, "No": 2},
|
503 |
'YMDEHPO': {"Yes": 1, "No": 0},
|
504 |
'YMIMS5YANY': {"Yes": 1, "No": 0},
|
@@ -506,7 +374,7 @@ input_mapping = {
|
|
506 |
'YMIUD5YANY': {"Yes": 1, "No": 0},
|
507 |
'YMDEHPRX': {"Yes": 1, "No": 0},
|
508 |
'YMIMI5YANY': {"Yes": 1, "No": 0},
|
509 |
-
'YUSUIPLN': {"Yes": 1, "No": 2, "
|
510 |
'YTXMDEYR': {"Yes": 1, "No": 0},
|
511 |
'YMDEAUD5YR': {"Yes": 1, "No": 0},
|
512 |
'YRXMDEYR': {"Yes": 1, "No": 0},
|
@@ -514,89 +382,93 @@ input_mapping = {
|
|
514 |
}
|
515 |
|
516 |
######################################
|
517 |
-
# 6) GRADIO INTERFACE
|
518 |
######################################
|
519 |
-
# We have 8 outputs in total:
|
520 |
-
# 1) Prediction Results
|
521 |
-
# 2) Mental Health Severity
|
522 |
-
# 3) Total Patient Count
|
523 |
-
# 4) Distribution Plot
|
524 |
-
# 5) Nearest Neighbors
|
525 |
-
# 6) Co-Occurrence Plot
|
526 |
-
# 7) Bar Chart for input features
|
527 |
-
# 8) Bar Chart for predicted labels
|
528 |
-
|
529 |
import gradio as gr
|
530 |
|
531 |
-
#
|
532 |
-
|
533 |
-
gr.Dropdown(list(input_mapping['YMDEYR'].keys()), label="YMDEYR:
|
534 |
-
gr.Dropdown(list(input_mapping['YMDERSUD5ANY'].keys()), label="YMDERSUD5ANY: MDE
|
535 |
-
gr.Dropdown(list(input_mapping['YMDEIMAD5YR'].keys()), label="YMDEIMAD5YR: MDE
|
536 |
-
gr.Dropdown(list(input_mapping['YMIMS5YANY'].keys()), label="YMIMS5YANY: MDE
|
537 |
-
gr.Dropdown(list(input_mapping['YMDELT'].keys()), label="YMDELT:
|
538 |
-
gr.Dropdown(list(input_mapping['YMDEHARX'].keys()), label="YMDEHARX:
|
539 |
-
gr.Dropdown(list(input_mapping['YMDEHPRX'].keys()), label="YMDEHPRX:
|
540 |
-
gr.Dropdown(list(input_mapping['YMDETXRX'].keys()), label="YMDETXRX:
|
541 |
-
gr.Dropdown(list(input_mapping['YMDEHPO'].keys()), label="YMDEHPO:
|
542 |
-
gr.Dropdown(list(input_mapping['YMDEAUD5YR'].keys()), label="YMDEAUD5YR: MDE +
|
543 |
-
gr.Dropdown(list(input_mapping['YMIMI5YANY'].keys()), label="YMIMI5YANY: MDE
|
544 |
-
gr.Dropdown(list(input_mapping['YMIUD5YANY'].keys()), label="YMIUD5YANY: MDE + ILL
|
545 |
-
gr.Dropdown(list(input_mapping['YMDESUD5ANYO'].keys()), label="YMDESUD5ANYO: MDE vs
|
546 |
|
547 |
# Consultations
|
548 |
-
gr.Dropdown(list(input_mapping['YNURSMDE'].keys()), label="YNURSMDE:
|
549 |
-
gr.Dropdown(list(input_mapping['YSOCMDE'].keys()), label="YSOCMDE:
|
550 |
-
gr.Dropdown(list(input_mapping['YCOUNMDE'].keys()), label="YCOUNMDE:
|
551 |
-
gr.Dropdown(list(input_mapping['YPSY1MDE'].keys()), label="YPSY1MDE:
|
552 |
-
gr.Dropdown(list(input_mapping['YPSY2MDE'].keys()), label="YPSY2MDE:
|
553 |
-
gr.Dropdown(list(input_mapping['YHLTMDE'].keys()), label="YHLTMDE:
|
554 |
-
gr.Dropdown(list(input_mapping['YDOCMDE'].keys()), label="YDOCMDE:
|
555 |
-
gr.Dropdown(list(input_mapping['YTXMDEYR'].keys()), label="YTXMDEYR:
|
556 |
-
|
557 |
-
# Suicidal
|
558 |
-
gr.Dropdown(list(input_mapping['YUSUITHKYR'].keys()), label="YUSUITHKYR:
|
559 |
-
gr.Dropdown(list(input_mapping['YUSUIPLNYR'].keys()), label="YUSUIPLNYR:
|
560 |
-
gr.Dropdown(list(input_mapping['YUSUITHK'].keys()), label="YUSUITHK:
|
561 |
-
gr.Dropdown(list(input_mapping['YUSUIPLN'].keys()), label="YUSUIPLN:
|
562 |
|
563 |
# Impairments
|
564 |
-
gr.Dropdown(list(input_mapping['MDEIMPY'].keys()), label="MDEIMPY:
|
565 |
-
gr.Dropdown(list(input_mapping['LVLDIFMEM2'].keys()), label="LVLDIFMEM2:
|
566 |
-
gr.Dropdown(list(input_mapping['YMSUD5YANY'].keys()), label="YMSUD5YANY: MDE +
|
567 |
-
gr.Dropdown(list(input_mapping['YRXMDEYR'].keys()), label="YRXMDEYR:
|
568 |
]
|
569 |
|
570 |
-
# The
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
571 |
outputs = [
|
572 |
-
gr.Textbox(label="Prediction Results", lines=
|
573 |
-
gr.Textbox(label="Mental Health Severity", lines=
|
574 |
gr.Markdown(label="Total Patient Count"),
|
575 |
-
gr.Plot(label="Distribution Plot (Sample
|
576 |
-
gr.Markdown(label="Nearest Neighbors
|
577 |
-
gr.Plot(label="Co-
|
578 |
-
gr.Plot(label="
|
579 |
-
gr.Plot(label="
|
580 |
]
|
581 |
|
582 |
######################################
|
583 |
-
# 7) WRAPPER
|
584 |
######################################
|
585 |
def predict_with_text(
|
|
|
586 |
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
|
587 |
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
|
588 |
YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE,
|
589 |
YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
|
590 |
-
YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
|
|
|
591 |
):
|
592 |
-
# Validate
|
593 |
-
|
594 |
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
|
595 |
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
|
596 |
YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE,
|
597 |
YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
|
598 |
YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
|
599 |
-
)
|
|
|
600 |
return (
|
601 |
"Please select all required fields.",
|
602 |
"Validation Error",
|
@@ -608,7 +480,7 @@ def predict_with_text(
|
|
608 |
None
|
609 |
)
|
610 |
|
611 |
-
# Map
|
612 |
user_inputs = {
|
613 |
'YNURSMDE': input_mapping['YNURSMDE'][YNURSMDE],
|
614 |
'YMDEYR': input_mapping['YMDEYR'][YMDEYR],
|
@@ -641,36 +513,34 @@ def predict_with_text(
|
|
641 |
'YMDELT': input_mapping['YMDELT'][YMDELT]
|
642 |
}
|
643 |
|
644 |
-
#
|
645 |
-
return predict(
|
|
|
|
|
|
|
|
|
|
|
|
|
646 |
|
647 |
-
# Optional custom CSS
|
648 |
custom_css = """
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
.gradio-container .form .form-group label {
|
653 |
-
color: #1B1212 !important;
|
654 |
-
}
|
655 |
-
.gradio-container .output-textbox,
|
656 |
-
.gradio-container .output-textbox textarea {
|
657 |
-
color: #1B1212 !important;
|
658 |
-
}
|
659 |
-
.gradio-container .label,
|
660 |
-
.gradio-container .input-label {
|
661 |
-
color: #1B1212 !important;
|
662 |
-
}
|
663 |
"""
|
664 |
|
665 |
-
######################################
|
666 |
-
# 8) LAUNCH
|
667 |
-
######################################
|
668 |
interface = gr.Interface(
|
669 |
-
fn=predict_with_text,
|
670 |
inputs=inputs,
|
671 |
-
outputs=outputs,
|
672 |
-
title="
|
673 |
-
css=custom_css
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
674 |
)
|
675 |
|
676 |
if __name__ == "__main__":
|
|
|
1 |
import pickle
|
|
|
2 |
import gradio as gr
|
3 |
import numpy as np
|
4 |
import pandas as pd
|
5 |
import plotly.express as px
|
6 |
|
7 |
+
# Load the training CSV once.
|
8 |
df = pd.read_csv("X_train_Y_Train_merged_train.csv")
|
9 |
|
10 |
######################################
|
|
|
16 |
self.model_filenames = model_filenames
|
17 |
self.models = self.load_models()
|
18 |
# Mapping from label column to human-readable strings for 0/1
|
|
|
19 |
self.prediction_map = {
|
20 |
+
"YOWRCONC": ["No difficulty concentrating", "Had difficulty concentrating"],
|
21 |
+
"YOSEEDOC": ["No need to see doctor", "Needed to see doctor"],
|
22 |
+
"YOWRHRS": ["No trouble sleeping", "Had trouble sleeping"],
|
23 |
+
"YO_MDEA5": ["Others didn't notice restlessness", "Others noticed restlessness"],
|
24 |
+
"YOWRCHR": ["Not sad beyond cheering", "Felt so sad no one could cheer up"],
|
25 |
+
"YOWRLSIN": ["Never felt bored/lost interest", "Felt bored/lost interest"],
|
26 |
+
"YODPPROB": ["No other problems for 2+ weeks", "Had other problems for 2+ weeks"],
|
27 |
+
"YOWRPROB": ["No worst time feeling", "Felt worst time ever"],
|
28 |
+
"YODPR2WK": ["No depressed feelings for 2+ wks", "Depressed feelings for 2+ wks"],
|
29 |
+
"YOWRDEPR": ["Not sad or depressed most days", "Sad or depressed most days"],
|
30 |
+
"YODPDISC": ["Mood not depressed overall", "Mood depressed overall (discrepancy)"],
|
31 |
+
"YOLOSEV": ["Did not lose interest in activities", "Lost interest in activities"],
|
32 |
+
"YOWRDCSN": ["Could make decisions", "Could not make decisions"],
|
33 |
+
"YODSMMDE": ["No 2+ week depression episodes", "Had 2+ week depression episodes"],
|
34 |
+
"YO_MDEA3": ["No appetite/weight changes", "Yes appetite/weight changes"],
|
35 |
+
"YODPLSIN": ["Never bored/lost interest", "Often bored/lost interest"],
|
36 |
+
"YOWRELES": ["Did not eat less", "Ate less than usual"],
|
37 |
+
"YODSCEV": ["Fewer severe symptoms", "More severe symptoms"],
|
38 |
+
"YOPB2WK": ["No uneasy feelings daily 2+ wks", "Uneasy feelings daily 2+ wks"],
|
39 |
+
"YO_MDEA2": ["No issues physical/mental daily", "Issues physical/mental daily 2+ wks"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
}
|
41 |
|
42 |
def load_models(self):
|
43 |
models = []
|
44 |
+
for fn in self.model_filenames:
|
45 |
+
filepath = self.model_path + fn
|
46 |
+
with open(filepath, "rb") as file:
|
47 |
+
models.append(pickle.load(file))
|
|
|
48 |
return models
|
49 |
|
50 |
def make_predictions(self, user_input):
|
51 |
+
"""Return list of numpy arrays, each array either [0] or [1]."""
|
52 |
+
preds = []
|
53 |
+
for m in self.models:
|
54 |
+
out = m.predict(user_input)
|
55 |
+
preds.append(np.array(out).flatten())
|
56 |
+
return preds
|
|
|
|
|
|
|
|
|
57 |
|
58 |
def get_majority_vote(self, predictions):
|
59 |
+
"""Flatten all predictions and find 0 or 1 with majority."""
|
60 |
+
combined = np.concatenate(predictions)
|
61 |
+
return np.bincount(combined).argmax()
|
62 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
def evaluate_severity(self, majority_vote_count):
|
64 |
+
"""Heuristic: Based on 16 total models, 0-4=Very Low, 5-8=Low, 9-12=Moderate, 13-16=Severe."""
|
65 |
if majority_vote_count >= 13:
|
66 |
return "Mental health severity: Severe"
|
67 |
elif majority_vote_count >= 9:
|
|
|
72 |
return "Mental health severity: Very Low"
|
73 |
|
74 |
######################################
|
75 |
+
# 2) CONFIGURATIONS
|
76 |
######################################
|
77 |
model_filenames = [
|
78 |
"YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl",
|
|
|
87 |
# 3) INPUT VALIDATION
|
88 |
######################################
|
89 |
def validate_inputs(*args):
|
90 |
+
# Just ensure all required (non-co-occurrence) fields are picked
|
91 |
for arg in args:
|
92 |
+
if arg == '' or arg is None:
|
93 |
return False
|
94 |
return True
|
95 |
|
96 |
######################################
|
97 |
+
# 4) PREDICTION FUNCTION
|
98 |
######################################
|
99 |
def predict(
|
100 |
+
# Original required features
|
101 |
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
|
102 |
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
|
103 |
YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE,
|
104 |
YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
|
105 |
+
YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR,
|
106 |
+
# **New** optional picks for co-occurrence
|
107 |
+
co_occ_feature1, co_occ_feature2, co_occ_label
|
108 |
):
|
109 |
+
"""
|
110 |
+
Main function that:
|
111 |
+
- Predicts with the 16 models
|
112 |
+
- Aggregates results
|
113 |
+
- Produces severity
|
114 |
+
- Returns distribution & bar charts
|
115 |
+
- Finds K=2 Nearest Neighbors
|
116 |
+
- Produces *one* co-occurrence plot based on user-chosen columns
|
117 |
+
"""
|
118 |
+
|
119 |
+
# 1) Build user_input for models
|
120 |
user_input_data = {
|
121 |
'YNURSMDE': [int(YNURSMDE)],
|
122 |
'YMDEYR': [int(YMDEYR)],
|
|
|
150 |
}
|
151 |
user_input = pd.DataFrame(user_input_data)
|
152 |
|
153 |
+
# 2) Model Predictions
|
154 |
predictions = predictor.make_predictions(user_input)
|
|
|
|
|
155 |
majority_vote = predictor.get_majority_vote(predictions)
|
156 |
+
majority_vote_count = np.sum(np.concatenate(predictions) == 1)
|
|
|
|
|
|
|
|
|
157 |
severity = predictor.evaluate_severity(majority_vote_count)
|
158 |
|
159 |
+
# 3) Summarize textual results
|
160 |
+
results_by_group = {
|
|
|
161 |
"Concentration_and_Decision_Making": [],
|
162 |
"Sleep_and_Energy_Levels": [],
|
163 |
"Mood_and_Emotional_State": [],
|
164 |
"Appetite_and_Weight_Changes": [],
|
165 |
"Duration_and_Severity_of_Depression_Symptoms": []
|
166 |
}
|
167 |
+
group_map = {
|
|
|
168 |
"Concentration_and_Decision_Making": ["YOWRCONC", "YOWRDCSN"],
|
169 |
"Sleep_and_Energy_Levels": ["YOWRHRS", "YO_MDEA5", "YOWRELES", "YO_MDEA2"],
|
170 |
"Mood_and_Emotional_State": ["YOWRCHR", "YOWRLSIN", "YOWRDEPR", "YODPDISC",
|
|
|
175 |
"YOPB2WK"]
|
176 |
}
|
177 |
|
178 |
+
# Convert each model's 0/1 to text
|
179 |
+
grouped_output_lines = []
|
180 |
+
for i, pred_array in enumerate(predictions):
|
181 |
+
col_name = model_filenames[i].split(".")[0] # e.g., "YOWRCONC"
|
182 |
+
val = pred_array[0]
|
183 |
+
if col_name in predictor.prediction_map and val in [0, 1]:
|
184 |
+
text = predictor.prediction_map[col_name][val]
|
185 |
+
out_line = f"{col_name}: {text}"
|
186 |
else:
|
187 |
+
out_line = f"{col_name}: Prediction={val}"
|
188 |
+
|
189 |
+
# Find group
|
190 |
+
placed = False
|
191 |
+
for g_key, g_cols in group_map.items():
|
192 |
+
if col_name in g_cols:
|
193 |
+
results_by_group[g_key].append(out_line)
|
194 |
+
placed = True
|
|
|
195 |
break
|
196 |
+
if not placed:
|
197 |
+
# If it didn't fall into any known group, skip or handle
|
198 |
pass
|
199 |
|
200 |
+
# Format into a single string
|
201 |
+
for group_label, pred_lines in results_by_group.items():
|
202 |
+
if pred_lines:
|
203 |
+
grouped_output_lines.append(f"Group {group_label}:")
|
204 |
+
grouped_output_lines.append("\n".join(pred_lines))
|
205 |
+
grouped_output_lines.append("")
|
206 |
+
|
207 |
+
if len(grouped_output_lines) == 0:
|
208 |
+
final_result_text = "No predictions made. Check inputs."
|
209 |
+
else:
|
210 |
+
final_result_text = "\n".join(grouped_output_lines).strip()
|
211 |
+
|
212 |
+
# 4) Additional Features
|
213 |
+
# A) Total patient count
|
|
|
|
|
|
|
|
|
|
|
214 |
total_patients = len(df)
|
215 |
+
total_count_md = (
|
216 |
"### Total Patient Count\n"
|
217 |
+
f"**{total_patients}** total patients in the dataset."
|
|
|
218 |
)
|
219 |
|
220 |
+
# B) Bar chart of how many have same inputs
|
221 |
input_counts = {}
|
222 |
+
for c in user_input_data.keys():
|
223 |
+
v = user_input_data[c][0]
|
224 |
+
input_counts[c] = len(df[df[c] == v])
|
225 |
+
df_input_counts = pd.DataFrame({"Feature": list(input_counts.keys()), "Count": list(input_counts.values())})
|
226 |
+
fig_input_bar = px.bar(
|
227 |
+
df_input_counts,
|
228 |
+
x="Feature",
|
229 |
+
y="Count",
|
230 |
+
title="Number of Patients with the Same Value for Each Input Feature"
|
|
|
|
|
|
|
|
|
|
|
231 |
)
|
232 |
+
fig_input_bar.update_layout(xaxis={"categoryorder": "total descending"})
|
233 |
|
234 |
+
# C) Bar chart for predicted labels
|
235 |
label_counts = {}
|
236 |
+
for i, pred_array in enumerate(predictions):
|
237 |
+
col_name = model_filenames[i].split(".")[0]
|
238 |
+
val = pred_array[0]
|
239 |
+
if val in [0,1]:
|
240 |
+
label_counts[col_name] = len(df[df[col_name] == val])
|
241 |
+
|
242 |
if len(label_counts) > 0:
|
243 |
+
df_label_counts = pd.DataFrame({
|
244 |
+
"Label Column": list(label_counts.keys()),
|
245 |
"Count": list(label_counts.values())
|
246 |
})
|
247 |
+
fig_label_bar = px.bar(
|
248 |
+
df_label_counts,
|
249 |
+
x="Label Column",
|
250 |
y="Count",
|
251 |
+
title="Number of Patients with the Same Predicted Label"
|
|
|
252 |
)
|
253 |
else:
|
254 |
+
fig_label_bar = px.bar(title="No valid predicted labels to display")
|
255 |
+
|
256 |
+
# D) Simple Distribution Plot (demo for first 3 labels & 4 inputs)
|
257 |
+
# (Unchanged from prior approach; you can remove if you prefer.)
|
258 |
+
sample_feats = list(user_input_data.keys())[:4]
|
259 |
+
sample_labels = [fn.split(".")[0] for fn in model_filenames[:3]]
|
260 |
+
dist_segments = []
|
261 |
+
for feat in sample_feats:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
if feat not in df.columns:
|
263 |
continue
|
264 |
+
for lbl in sample_labels:
|
265 |
if lbl not in df.columns:
|
266 |
continue
|
267 |
+
temp_g = df.groupby([feat,lbl]).size().reset_index(name="count")
|
268 |
+
temp_g["feature"] = feat
|
269 |
+
temp_g["label"] = lbl
|
270 |
+
dist_segments.append(temp_g)
|
271 |
+
if len(dist_segments) > 0:
|
272 |
+
big_dist_df = pd.concat(dist_segments, ignore_index=True)
|
|
|
|
|
|
|
|
|
273 |
fig_dist = px.bar(
|
274 |
big_dist_df,
|
275 |
+
x=big_dist_df.columns[0],
|
276 |
y="count",
|
277 |
+
color=big_dist_df.columns[1],
|
278 |
facet_row="feature",
|
279 |
facet_col="label",
|
280 |
+
title="Sample Distribution Plot (first 4 features vs first 3 labels)"
|
|
|
|
|
|
|
|
|
281 |
)
|
282 |
+
fig_dist.update_layout(height=700)
|
283 |
else:
|
284 |
+
fig_dist = px.bar(title="No distribution plot generated (columns not found).")
|
285 |
+
|
286 |
+
# E) Nearest Neighbors with K=2
|
287 |
+
# We keep K=2, but for *all* label columns, we show their actual 0/1 or mapped text
|
288 |
+
# (same approach as before).
|
289 |
+
# ... [omitted here for brevity, or replicate your existing code for K=2 nearest neighbors] ...
|
290 |
+
# We'll do a short version to keep focus on co-occ:
|
291 |
+
# ---------------------------------------------------------------------
|
292 |
+
# Build Hamming distance across user_input columns
|
293 |
+
columns_for_distance = list(user_input.columns)
|
294 |
+
sub_df = df[columns_for_distance].copy()
|
295 |
+
user_row = user_input.iloc[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
distances = []
|
297 |
+
for idx, row_ in sub_df.iterrows():
|
298 |
+
dist_ = sum(row_[col] != user_row[col] for col in columns_for_distance)
|
299 |
+
distances.append(dist_)
|
300 |
+
df_dist = df.copy()
|
301 |
+
df_dist["distance"] = distances
|
302 |
+
# Sort ascending, pick K=2
|
303 |
+
K = 2
|
304 |
+
nearest_neighbors = df_dist.sort_values("distance", ascending=True).head(K)
|
305 |
+
|
306 |
+
# Summarize in Markdown
|
307 |
+
nn_md = ["### Nearest Neighbors (K=2)"]
|
308 |
+
nn_md.append("(In a real application, you'd refine which features matter, how to encode them, etc.)\n")
|
309 |
+
for irow in nearest_neighbors.itertuples():
|
310 |
+
nn_md.append(f"- **Neighbor ID {irow.Index}**: distance={irow.distance}")
|
311 |
+
nn_md_str = "\n".join(nn_md)
|
312 |
+
|
313 |
+
# F) Co-occurrence Plot for user-chosen feature1, feature2, label
|
314 |
+
# If the user picks "None" or doesn't pick valid columns, skip or fallback.
|
315 |
+
if (co_occ_feature1 is not None and co_occ_feature1 != "None" and
|
316 |
+
co_occ_feature2 is not None and co_occ_feature2 != "None" and
|
317 |
+
co_occ_label is not None and co_occ_label != "None"):
|
318 |
+
# Check if these columns are in df
|
319 |
+
if (co_occ_feature1 in df.columns and
|
320 |
+
co_occ_feature2 in df.columns and
|
321 |
+
co_occ_label in df.columns):
|
322 |
+
# Group by [co_occ_feature1, co_occ_feature2, co_occ_label]
|
323 |
+
co_data = df.groupby([co_occ_feature1, co_occ_feature2, co_occ_label]).size().reset_index(name="count")
|
324 |
+
fig_co_occ = px.bar(
|
325 |
+
co_data,
|
326 |
+
x=co_occ_feature1,
|
327 |
+
y="count",
|
328 |
+
color=co_occ_label,
|
329 |
+
facet_col=co_occ_feature2,
|
330 |
+
title=f"Co-occurrence: {co_occ_feature1} & {co_occ_feature2} vs {co_occ_label}"
|
331 |
+
)
|
332 |
+
else:
|
333 |
+
fig_co_occ = px.bar(title="One or more selected columns not found in dataframe.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
else:
|
335 |
+
fig_co_occ = px.bar(title="No co-occurrence plot (choose two features + one label).")
|
336 |
+
|
337 |
+
# Return all 8 outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
return (
|
339 |
+
final_result_text, # (1) Predictions
|
340 |
+
severity, # (2) Severity
|
341 |
+
total_count_md, # (3) Total patient count
|
342 |
+
fig_dist, # (4) Distribution Plot
|
343 |
+
nn_md_str, # (5) Nearest Neighbors
|
344 |
+
fig_co_occ, # (6) Co-occurrence
|
345 |
+
fig_input_bar, # (7) Bar Chart (input features)
|
346 |
+
fig_label_bar # (8) Bar Chart (labels)
|
347 |
)
|
348 |
|
349 |
######################################
|
350 |
+
# 5) MAPPING (user -> int)
|
351 |
######################################
|
352 |
input_mapping = {
|
353 |
'YNURSMDE': {"Yes": 1, "No": 0},
|
354 |
'YMDEYR': {"Yes": 1, "No": 2},
|
355 |
'YSOCMDE': {"Yes": 1, "No": 0},
|
356 |
+
'YMDESUD5ANYO': {"SUD only": 1, "MDE only": 2, "SUD & MDE": 3, "Neither": 4},
|
357 |
'YMSUD5YANY': {"Yes": 1, "No": 0},
|
358 |
+
'YUSUITHK': {"Yes": 1, "No": 2, "Unsure": 3, "Don't want to answer": 4},
|
359 |
'YMDETXRX': {"Yes": 1, "No": 0},
|
360 |
+
'YUSUITHKYR': {"Yes": 1, "No": 2, "Unsure": 3, "Don't want to answer": 4},
|
361 |
'YMDERSUD5ANY': {"Yes": 1, "No": 0},
|
362 |
+
'YUSUIPLNYR': {"Yes": 1, "No": 2, "Unsure": 3, "Don't want to answer": 4},
|
363 |
'YCOUNMDE': {"Yes": 1, "No": 0},
|
364 |
'YPSY1MDE': {"Yes": 1, "No": 0},
|
365 |
'YHLTMDE': {"Yes": 1, "No": 0},
|
366 |
'YDOCMDE': {"Yes": 1, "No": 0},
|
367 |
'YPSY2MDE': {"Yes": 1, "No": 0},
|
368 |
'YMDEHARX': {"Yes": 1, "No": 0},
|
369 |
+
'LVLDIFMEM2': {"No Difficulty": 1, "Some Difficulty": 2, "A lot or cannot do": 3},
|
370 |
'MDEIMPY': {"Yes": 1, "No": 2},
|
371 |
'YMDEHPO': {"Yes": 1, "No": 0},
|
372 |
'YMIMS5YANY': {"Yes": 1, "No": 0},
|
|
|
374 |
'YMIUD5YANY': {"Yes": 1, "No": 0},
|
375 |
'YMDEHPRX': {"Yes": 1, "No": 0},
|
376 |
'YMIMI5YANY': {"Yes": 1, "No": 0},
|
377 |
+
'YUSUIPLN': {"Yes": 1, "No": 2, "Unsure": 3, "Don't want to answer": 4},
|
378 |
'YTXMDEYR': {"Yes": 1, "No": 0},
|
379 |
'YMDEAUD5YR': {"Yes": 1, "No": 0},
|
380 |
'YRXMDEYR': {"Yes": 1, "No": 0},
|
|
|
382 |
}
|
383 |
|
384 |
######################################
|
385 |
+
# 6) THE GRADIO INTERFACE
|
386 |
######################################
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
387 |
import gradio as gr
|
388 |
|
389 |
+
# (A) The original required inputs
|
390 |
+
original_inputs = [
|
391 |
+
gr.Dropdown(list(input_mapping['YMDEYR'].keys()), label="YMDEYR: Past Year MDE?"),
|
392 |
+
gr.Dropdown(list(input_mapping['YMDERSUD5ANY'].keys()), label="YMDERSUD5ANY: MDE or SUD - ANY?"),
|
393 |
+
gr.Dropdown(list(input_mapping['YMDEIMAD5YR'].keys()), label="YMDEIMAD5YR: MDE + ALCOHOL?"),
|
394 |
+
gr.Dropdown(list(input_mapping['YMIMS5YANY'].keys()), label="YMIMS5YANY: MDE + SUBSTANCE?"),
|
395 |
+
gr.Dropdown(list(input_mapping['YMDELT'].keys()), label="YMDELT: MDE in Lifetime?"),
|
396 |
+
gr.Dropdown(list(input_mapping['YMDEHARX'].keys()), label="YMDEHARX: Saw Health Prof + Meds?"),
|
397 |
+
gr.Dropdown(list(input_mapping['YMDEHPRX'].keys()), label="YMDEHPRX: Saw Health Prof or Meds?"),
|
398 |
+
gr.Dropdown(list(input_mapping['YMDETXRX'].keys()), label="YMDETXRX: Received Treatment?"),
|
399 |
+
gr.Dropdown(list(input_mapping['YMDEHPO'].keys()), label="YMDEHPO: Saw Health Prof Only?"),
|
400 |
+
gr.Dropdown(list(input_mapping['YMDEAUD5YR'].keys()), label="YMDEAUD5YR: MDE + Alcohol Use?"),
|
401 |
+
gr.Dropdown(list(input_mapping['YMIMI5YANY'].keys()), label="YMIMI5YANY: MDE + ILL Drug Use?"),
|
402 |
+
gr.Dropdown(list(input_mapping['YMIUD5YANY'].keys()), label="YMIUD5YANY: MDE + ILL Drug Use?"),
|
403 |
+
gr.Dropdown(list(input_mapping['YMDESUD5ANYO'].keys()), label="YMDESUD5ANYO: MDE vs SUD vs BOTH vs NEITHER"),
|
404 |
|
405 |
# Consultations
|
406 |
+
gr.Dropdown(list(input_mapping['YNURSMDE'].keys()), label="YNURSMDE: Nurse/OT about MDE?"),
|
407 |
+
gr.Dropdown(list(input_mapping['YSOCMDE'].keys()), label="YSOCMDE: Social Worker?"),
|
408 |
+
gr.Dropdown(list(input_mapping['YCOUNMDE'].keys()), label="YCOUNMDE: Counselor?"),
|
409 |
+
gr.Dropdown(list(input_mapping['YPSY1MDE'].keys()), label="YPSY1MDE: Psychologist?"),
|
410 |
+
gr.Dropdown(list(input_mapping['YPSY2MDE'].keys()), label="YPSY2MDE: Psychiatrist?"),
|
411 |
+
gr.Dropdown(list(input_mapping['YHLTMDE'].keys()), label="YHLTMDE: Health Prof?"),
|
412 |
+
gr.Dropdown(list(input_mapping['YDOCMDE'].keys()), label="YDOCMDE: GP/Family MD?"),
|
413 |
+
gr.Dropdown(list(input_mapping['YTXMDEYR'].keys()), label="YTXMDEYR: Doctor/Health Prof?"),
|
414 |
+
|
415 |
+
# Suicidal
|
416 |
+
gr.Dropdown(list(input_mapping['YUSUITHKYR'].keys()), label="YUSUITHKYR: Serious Suicide Thoughts?"),
|
417 |
+
gr.Dropdown(list(input_mapping['YUSUIPLNYR'].keys()), label="YUSUIPLNYR: Made Plans?"),
|
418 |
+
gr.Dropdown(list(input_mapping['YUSUITHK'].keys()), label="YUSUITHK: Suicide Thoughts (12 mo)?"),
|
419 |
+
gr.Dropdown(list(input_mapping['YUSUIPLN'].keys()), label="YUSUIPLN: Made Plans (12 mo)?"),
|
420 |
|
421 |
# Impairments
|
422 |
+
gr.Dropdown(list(input_mapping['MDEIMPY'].keys()), label="MDEIMPY: Severe Role Impairment?"),
|
423 |
+
gr.Dropdown(list(input_mapping['LVLDIFMEM2'].keys()), label="LVLDIFMEM2: Difficulty Remembering/Concentrating?"),
|
424 |
+
gr.Dropdown(list(input_mapping['YMSUD5YANY'].keys()), label="YMSUD5YANY: MDE + Substance?"),
|
425 |
+
gr.Dropdown(list(input_mapping['YRXMDEYR'].keys()), label="YRXMDEYR: Used Meds for MDE (12 mo)?"),
|
426 |
]
|
427 |
|
428 |
+
# (B) The new co-occurrence inputs
|
429 |
+
# We'll give them defaults of "None" to indicate no selection.
|
430 |
+
all_cols = ["None"] + df.columns.tolist() # 'None' plus the actual columns from your df
|
431 |
+
co_occ_feature1 = gr.Dropdown(all_cols, label="Co-Occ Feature 1", value="None")
|
432 |
+
co_occ_feature2 = gr.Dropdown(all_cols, label="Co-Occ Feature 2", value="None")
|
433 |
+
all_label_cols = ["None"] + list(predictor.prediction_map.keys()) # e.g., "YOWRCONC", "YOWRHRS", ...
|
434 |
+
co_occ_label = gr.Dropdown(all_label_cols, label="Co-Occ Label", value="None")
|
435 |
+
|
436 |
+
# Combine them into a single input list
|
437 |
+
inputs = original_inputs + [co_occ_feature1, co_occ_feature2, co_occ_label]
|
438 |
+
|
439 |
+
# 8 outputs as before
|
440 |
outputs = [
|
441 |
+
gr.Textbox(label="Prediction Results", lines=15),
|
442 |
+
gr.Textbox(label="Mental Health Severity", lines=2),
|
443 |
gr.Markdown(label="Total Patient Count"),
|
444 |
+
gr.Plot(label="Distribution Plot (Sample)"),
|
445 |
+
gr.Markdown(label="Nearest Neighbors (K=2)"),
|
446 |
+
gr.Plot(label="Co-occurrence Plot"),
|
447 |
+
gr.Plot(label="Same Value Bar (Inputs)"),
|
448 |
+
gr.Plot(label="Predicted Label Bar")
|
449 |
]
|
450 |
|
451 |
######################################
|
452 |
+
# 7) WRAPPER
|
453 |
######################################
|
454 |
def predict_with_text(
|
455 |
+
# match the function signature exactly (29 required + 3 for co-occ)
|
456 |
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
|
457 |
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
|
458 |
YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE,
|
459 |
YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
|
460 |
+
YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR,
|
461 |
+
co_occ_feature1, co_occ_feature2, co_occ_label
|
462 |
):
|
463 |
+
# Validate the original 29 fields
|
464 |
+
valid = validate_inputs(
|
465 |
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
|
466 |
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
|
467 |
YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE,
|
468 |
YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
|
469 |
YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
|
470 |
+
)
|
471 |
+
if not valid:
|
472 |
return (
|
473 |
"Please select all required fields.",
|
474 |
"Validation Error",
|
|
|
480 |
None
|
481 |
)
|
482 |
|
483 |
+
# Map to numeric
|
484 |
user_inputs = {
|
485 |
'YNURSMDE': input_mapping['YNURSMDE'][YNURSMDE],
|
486 |
'YMDEYR': input_mapping['YMDEYR'][YMDEYR],
|
|
|
513 |
'YMDELT': input_mapping['YMDELT'][YMDELT]
|
514 |
}
|
515 |
|
516 |
+
# Call the core predict function with the co-occ choices as well
|
517 |
+
return predict(
|
518 |
+
**user_inputs,
|
519 |
+
co_occ_feature1=co_occ_feature1,
|
520 |
+
co_occ_feature2=co_occ_feature2,
|
521 |
+
co_occ_label=co_occ_label
|
522 |
+
)
|
523 |
+
|
524 |
|
|
|
525 |
custom_css = """
|
526 |
+
.gradio-container * {
|
527 |
+
color: #1B1212 !important;
|
528 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
529 |
"""
|
530 |
|
|
|
|
|
|
|
531 |
interface = gr.Interface(
|
532 |
+
fn=predict_with_text,
|
533 |
inputs=inputs,
|
534 |
+
outputs=outputs,
|
535 |
+
title="Mental Health Screening (NSDUH) with Selective Co-Occurrence",
|
536 |
+
css=custom_css,
|
537 |
+
description="""
|
538 |
+
**Instructions**:
|
539 |
+
1. Fill out all required fields regarding MDE/Substance Use/Consultations/Suicidal/Impairments.
|
540 |
+
2. (Optional) Choose 2 features and 1 label for the *Co-occurrence* plot.
|
541 |
+
- If you do not select them (or leave them as "None"), that plot will be skipped.
|
542 |
+
3. Click "Submit" to get predictions, severity, distribution plots, nearest neighbors, and your custom co-occurrence chart.
|
543 |
+
"""
|
544 |
)
|
545 |
|
546 |
if __name__ == "__main__":
|