pantdipendra
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -8,17 +8,16 @@ import plotly.express as px
|
|
8 |
# Load the training CSV once (outside the functions so it is read only once).
|
9 |
df = pd.read_csv("X_train_Y_Train_merged_train.csv")
|
10 |
|
11 |
-
|
12 |
-
# 1)
|
13 |
-
|
14 |
class ModelPredictor:
|
15 |
def __init__(self, model_path, model_filenames):
|
16 |
self.model_path = model_path
|
17 |
self.model_filenames = model_filenames
|
18 |
self.models = self.load_models()
|
19 |
-
|
20 |
-
#
|
21 |
-
# If you have more labels, expand this dictionary accordingly.
|
22 |
self.prediction_map = {
|
23 |
"YOWRCONC": ["Did not have difficulty concentrating", "Had difficulty concentrating"],
|
24 |
"YOSEEDOC": ["Did not feel the need to see a doctor", "Felt the need to see a doctor"],
|
@@ -95,9 +94,9 @@ class ModelPredictor:
|
|
95 |
else:
|
96 |
return "Mental health severity: Very Low"
|
97 |
|
98 |
-
|
99 |
-
# 2)
|
100 |
-
|
101 |
model_filenames = [
|
102 |
"YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl",
|
103 |
"YODPPROB.pkl", "YOWRPROB.pkl", "YODPR2WK.pkl", "YOWRDEPR.pkl",
|
@@ -107,60 +106,18 @@ model_filenames = [
|
|
107 |
model_path = "models/"
|
108 |
predictor = ModelPredictor(model_path, model_filenames)
|
109 |
|
110 |
-
|
111 |
-
# 3)
|
112 |
-
|
113 |
def validate_inputs(*args):
|
114 |
for arg in args:
|
115 |
if arg == '' or arg is None: # Assuming empty string or None as unselected
|
116 |
return False
|
117 |
return True
|
118 |
|
119 |
-
|
120 |
-
# 4)
|
121 |
-
|
122 |
-
# We'll define the forward mapping here. The reverse mapping is constructed below.
|
123 |
-
input_mapping = {
|
124 |
-
'YNURSMDE': {"Yes": 1, "No": 0},
|
125 |
-
'YMDEYR': {"Yes": 1, "No": 2},
|
126 |
-
'YSOCMDE': {"Yes": 1, "No": 0},
|
127 |
-
'YMDESUD5ANYO': {"SUD only, no MDE": 1, "MDE only, no SUD": 2, "SUD and MDE": 3, "Neither SUD or MDE": 4},
|
128 |
-
'YMSUD5YANY': {"Yes": 1, "No": 0},
|
129 |
-
'YUSUITHK': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
130 |
-
'YMDETXRX': {"Yes": 1, "No": 0},
|
131 |
-
'YUSUITHKYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
132 |
-
'YMDERSUD5ANY': {"Yes": 1, "No": 0},
|
133 |
-
'YUSUIPLNYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
134 |
-
'YCOUNMDE': {"Yes": 1, "No": 0},
|
135 |
-
'YPSY1MDE': {"Yes": 1, "No": 0},
|
136 |
-
'YHLTMDE': {"Yes": 1, "No": 0},
|
137 |
-
'YDOCMDE': {"Yes": 1, "No": 0},
|
138 |
-
'YPSY2MDE': {"Yes": 1, "No": 0},
|
139 |
-
'YMDEHARX': {"Yes": 1, "No": 0},
|
140 |
-
'LVLDIFMEM2': {"No Difficulty": 1, "Some difficulty": 2, "A lot of difficulty or cannot do at all": 3},
|
141 |
-
'MDEIMPY': {"Yes": 1, "No": 2},
|
142 |
-
'YMDEHPO': {"Yes": 1, "No": 0},
|
143 |
-
'YMIMS5YANY': {"Yes": 1, "No": 0},
|
144 |
-
'YMDEIMAD5YR': {"Yes": 1, "No": 0},
|
145 |
-
'YMIUD5YANY': {"Yes": 1, "No": 0},
|
146 |
-
'YMDEHPRX': {"Yes": 1, "No": 0},
|
147 |
-
'YMIMI5YANY': {"Yes": 1, "No": 0},
|
148 |
-
'YUSUIPLN': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
149 |
-
'YTXMDEYR': {"Yes": 1, "No": 0},
|
150 |
-
'YMDEAUD5YR': {"Yes": 1, "No": 0},
|
151 |
-
'YRXMDEYR': {"Yes": 1, "No": 0},
|
152 |
-
'YMDELT': {"Yes": 1, "No": 2}
|
153 |
-
}
|
154 |
-
|
155 |
-
# Build reverse mapping: { "YNURSMDE": {1: "Yes", 0: "No"}, ... } etc.
|
156 |
-
reverse_mapping = {}
|
157 |
-
for col, mapping_dict in input_mapping.items():
|
158 |
-
rev = {v: k for k, v in mapping_dict.items()} # invert dict
|
159 |
-
reverse_mapping[col] = rev
|
160 |
-
|
161 |
-
###############################################################################
|
162 |
-
# 5) Main Predict Function
|
163 |
-
###############################################################################
|
164 |
def predict(
|
165 |
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
|
166 |
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
|
@@ -168,17 +125,7 @@ def predict(
|
|
168 |
YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
|
169 |
YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
|
170 |
):
|
171 |
-
|
172 |
-
Core prediction function that:
|
173 |
-
1) Predicts with each model
|
174 |
-
2) Aggregates results
|
175 |
-
3) Produces an overall 'severity'
|
176 |
-
4) Returns detailed per-model predictions
|
177 |
-
5) Creates a distribution plot for ALL input features vs. a chosen label
|
178 |
-
6) Nearest neighbor logic (with disclaimers), mapping numeric -> user text
|
179 |
-
"""
|
180 |
-
|
181 |
-
# 1) Prepare user_input dataframe
|
182 |
user_input_data = {
|
183 |
'YNURSMDE': [int(YNURSMDE)],
|
184 |
'YMDEYR': [int(YMDEYR)],
|
@@ -212,20 +159,20 @@ def predict(
|
|
212 |
}
|
213 |
user_input = pd.DataFrame(user_input_data)
|
214 |
|
215 |
-
#
|
216 |
predictions = predictor.make_predictions(user_input)
|
217 |
|
218 |
-
#
|
219 |
majority_vote = predictor.get_majority_vote(predictions)
|
220 |
|
221 |
-
#
|
222 |
majority_vote_count = sum([1 for pred in np.concatenate(predictions) if pred == 1])
|
223 |
|
224 |
-
#
|
225 |
severity = predictor.evaluate_severity(majority_vote_count)
|
226 |
|
227 |
-
#
|
228 |
-
# We
|
229 |
results = {
|
230 |
"Concentration_and_Decision_Making": [],
|
231 |
"Sleep_and_Energy_Levels": [],
|
@@ -245,17 +192,18 @@ def predict(
|
|
245 |
"YOPB2WK"]
|
246 |
}
|
247 |
|
248 |
-
#
|
249 |
for i, pred in enumerate(predictions):
|
250 |
-
model_name =
|
251 |
pred_value = pred[0]
|
252 |
# Map the prediction value to a human-readable string
|
253 |
if model_name in predictor.prediction_map and pred_value in [0, 1]:
|
254 |
result_text = f"Model {model_name}: {predictor.prediction_map[model_name][pred_value]}"
|
255 |
else:
|
256 |
-
|
|
|
257 |
|
258 |
-
# Append to the appropriate group
|
259 |
found_group = False
|
260 |
for group_name, group_models in prediction_groups.items():
|
261 |
if model_name in group_models:
|
@@ -263,10 +211,10 @@ def predict(
|
|
263 |
found_group = True
|
264 |
break
|
265 |
if not found_group:
|
266 |
-
# If
|
267 |
pass
|
268 |
|
269 |
-
#
|
270 |
formatted_results = []
|
271 |
for group, preds in results.items():
|
272 |
if preds:
|
@@ -274,184 +222,366 @@ def predict(
|
|
274 |
formatted_results.append("\n".join(preds))
|
275 |
formatted_results.append("\n")
|
276 |
formatted_results = "\n".join(formatted_results).strip()
|
277 |
-
if
|
278 |
formatted_results = "No predictions made. Please check your inputs."
|
279 |
|
280 |
-
#
|
281 |
-
num_unknown =
|
282 |
-
if num_unknown > len(
|
283 |
severity += " (Unknown prediction count is high. Please consult with a human.)"
|
284 |
|
285 |
-
|
|
|
286 |
# A) Total Patient Count
|
287 |
-
############################################################################
|
288 |
total_patients = len(df)
|
289 |
total_patient_count_markdown = (
|
290 |
"### Total Patient Count\n"
|
291 |
-
f"There are **{total_patients}** total patients in the dataset.\n
|
292 |
-
"
|
293 |
)
|
294 |
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
value_name="FeatureValue"
|
314 |
-
)
|
315 |
-
# 3) Group by (FeatureName, FeatureValue, chosen_label) to get size
|
316 |
-
dist_data = melted.groupby(["FeatureName", "FeatureValue", chosen_label]).size().reset_index(name="count")
|
317 |
-
# 4) We'll try to map FeatureValue from numeric -> user-friendly text if possible
|
318 |
-
# We'll do it only if FeatureName is in reverse_mapping.
|
319 |
-
def map_value(row):
|
320 |
-
fn = row["FeatureName"]
|
321 |
-
fv = row["FeatureValue"]
|
322 |
-
if fn in reverse_mapping:
|
323 |
-
if fv in reverse_mapping[fn]:
|
324 |
-
return reverse_mapping[fn][fv] # e.g. 1->"Yes"
|
325 |
-
return fv # fallback
|
326 |
-
dist_data["FeatureValueText"] = dist_data.apply(map_value, axis=1)
|
327 |
-
# 5) Similarly, map chosen_label (0 or 1) to text if in predictor.prediction_map
|
328 |
-
if chosen_label in predictor.prediction_map:
|
329 |
-
def map_label(val):
|
330 |
-
if val in [0, 1]:
|
331 |
-
return predictor.prediction_map[chosen_label][val]
|
332 |
-
return f"Unknown label {val}"
|
333 |
-
dist_data["LabelText"] = dist_data[chosen_label].apply(map_label)
|
334 |
-
else:
|
335 |
-
dist_data["LabelText"] = dist_data[chosen_label].astype(str)
|
336 |
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
341 |
y="count",
|
342 |
-
color=
|
343 |
-
|
344 |
-
|
345 |
-
title=
|
346 |
-
|
|
|
|
|
|
|
347 |
)
|
348 |
-
|
349 |
-
# (Optional) Adjust layout or text angle if you have many categories
|
350 |
-
fig_distribution.update_xaxes(tickangle=45)
|
351 |
else:
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
#
|
357 |
-
|
358 |
-
#
|
359 |
-
#
|
360 |
-
|
361 |
-
|
362 |
-
#
|
363 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
364 |
user_series = user_input.iloc[0]
|
365 |
|
366 |
-
# 2) Compute distances
|
367 |
distances = []
|
368 |
-
for idx, row in
|
369 |
-
|
370 |
-
|
371 |
-
if row[col] != user_series[col]:
|
372 |
-
d += 1
|
373 |
-
distances.append(d)
|
374 |
|
375 |
df_with_dist = df.copy()
|
376 |
df_with_dist["distance"] = distances
|
377 |
|
378 |
-
#
|
379 |
K = 5
|
380 |
nearest_neighbors = df_with_dist.sort_values("distance", ascending=True).head(K)
|
381 |
|
382 |
-
#
|
383 |
-
#
|
384 |
-
#
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
else:
|
395 |
-
nn_label_0 = nn_label_1 = 0
|
396 |
-
label0_text = "Label=0"
|
397 |
-
label1_text = "Label=1"
|
398 |
-
|
399 |
-
# 5) Build an example table of those neighbors in user-friendly text
|
400 |
-
neighbor_text_rows = []
|
401 |
-
for idx, nn_row in nearest_neighbors.iterrows():
|
402 |
-
# For each feature, map numeric -> user text
|
403 |
-
row_str_parts = []
|
404 |
-
row_str_parts.append(f"distance={nn_row['distance']}")
|
405 |
-
for fcol in features_to_compare:
|
406 |
-
val = nn_row[fcol]
|
407 |
-
# try to map
|
408 |
-
if fcol in reverse_mapping and val in reverse_mapping[fcol]:
|
409 |
-
val_str = reverse_mapping[fcol][val]
|
410 |
else:
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
else:
|
419 |
-
|
420 |
-
row_str_parts.append(f"{chosen_label}={lbl_str}")
|
421 |
-
neighbor_text_rows.append(" | ".join(row_str_parts))
|
422 |
|
423 |
-
|
424 |
|
425 |
similar_patient_markdown = (
|
426 |
"### Nearest Neighbors (Simple Hamming Distance)\n"
|
427 |
-
"
|
428 |
-
"
|
|
|
429 |
"In a real application, you would refine which features are most relevant, how to encode them, "
|
430 |
"and how many neighbors to select.\n\n"
|
431 |
-
|
432 |
-
|
433 |
-
f"- **{nn_label_1}** had {label1_text}\n\n"
|
434 |
-
"Below is a breakdown of each neighbor's key features in user-friendly text:\n\n"
|
435 |
-
f"```\n{neighbor_text_block}\n```"
|
436 |
)
|
437 |
|
438 |
-
|
439 |
-
|
440 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
441 |
return (
|
442 |
-
formatted_results,
|
443 |
-
severity,
|
444 |
-
total_patient_count_markdown,
|
445 |
-
|
446 |
-
similar_patient_markdown,
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
)
|
451 |
|
452 |
-
|
453 |
-
#
|
454 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
455 |
def predict_with_text(
|
456 |
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
|
457 |
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
|
@@ -459,7 +589,7 @@ def predict_with_text(
|
|
459 |
YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
|
460 |
YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
|
461 |
):
|
462 |
-
# Validate
|
463 |
if not validate_inputs(
|
464 |
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
|
465 |
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
|
@@ -468,15 +598,17 @@ def predict_with_text(
|
|
468 |
YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
|
469 |
):
|
470 |
return (
|
471 |
-
"Please select all required fields.",
|
472 |
-
"Validation Error",
|
473 |
-
"No data",
|
474 |
-
None,
|
475 |
-
"No data",
|
476 |
-
None,
|
|
|
|
|
477 |
)
|
478 |
|
479 |
-
# Map
|
480 |
user_inputs = {
|
481 |
'YNURSMDE': input_mapping['YNURSMDE'][YNURSMDE],
|
482 |
'YMDEYR': input_mapping['YMDEYR'][YMDEYR],
|
@@ -508,68 +640,11 @@ def predict_with_text(
|
|
508 |
'YRXMDEYR': input_mapping['YRXMDEYR'][YRXMDEYR],
|
509 |
'YMDELT': input_mapping['YMDELT'][YMDELT]
|
510 |
}
|
511 |
-
|
512 |
-
# Pass our mapped values into the original 'predict' function
|
513 |
-
return predict(**user_inputs)
|
514 |
-
|
515 |
-
###############################################################################
|
516 |
-
# 7) Define and Launch Gradio Interface
|
517 |
-
###############################################################################
|
518 |
-
import sys
|
519 |
-
|
520 |
-
# We have 8 outputs (some are placeholders)
|
521 |
-
outputs = [
|
522 |
-
gr.Textbox(label="Prediction Results", lines=30),
|
523 |
-
gr.Textbox(label="Mental Health Severity", lines=4),
|
524 |
-
gr.Markdown(label="Total Patient Count"),
|
525 |
-
gr.Plot(label="Distribution of All Input Features vs. One Label"),
|
526 |
-
gr.Markdown(label="Nearest Neighbors Summary"),
|
527 |
-
gr.Plot(label="Placeholder Plot"),
|
528 |
-
gr.Plot(label="Placeholder Plot"),
|
529 |
-
gr.Plot(label="Placeholder Plot")
|
530 |
-
]
|
531 |
|
532 |
-
#
|
533 |
-
|
534 |
-
# Major Depressive Episode (MDE) questions
|
535 |
-
gr.Dropdown(list(input_mapping['YMDEYR'].keys()), label="YMDEYR: PAST YEAR MDE?"),
|
536 |
-
gr.Dropdown(list(input_mapping['YMDERSUD5ANY'].keys()), label="YMDERSUD5ANY: MDE OR SUBSTANCE USE DISORDER - ANY"),
|
537 |
-
gr.Dropdown(list(input_mapping['YMDEIMAD5YR'].keys()), label="YMDEIMAD5YR: MDE + ALCOHOL USE DISORDER?"),
|
538 |
-
gr.Dropdown(list(input_mapping['YMIMS5YANY'].keys()), label="YMIMS5YANY: MDE + SUBSTANCE USE DISORDER?"),
|
539 |
-
gr.Dropdown(list(input_mapping['YMDELT'].keys()), label="YMDELT: EVER HAD MDE LIFETIME?"),
|
540 |
-
gr.Dropdown(list(input_mapping['YMDEHARX'].keys()), label="YMDEHARX: SAW HEALTH PROF + MEDS FOR MDE"),
|
541 |
-
gr.Dropdown(list(input_mapping['YMDEHPRX'].keys()), label="YMDEHPRX: SAW HEALTH PROF OR MEDS FOR MDE"),
|
542 |
-
gr.Dropdown(list(input_mapping['YMDETXRX'].keys()), label="YMDETXRX: TREATMENT/COUNSELING FOR MDE"),
|
543 |
-
gr.Dropdown(list(input_mapping['YMDEHPO'].keys()), label="YMDEHPO: HEALTH PROF ONLY FOR MDE"),
|
544 |
-
gr.Dropdown(list(input_mapping['YMDEAUD5YR'].keys()), label="YMDEAUD5YR: MDE + ALCOHOL USE DISORDER"),
|
545 |
-
gr.Dropdown(list(input_mapping['YMIMI5YANY'].keys()), label="YMIMI5YANY: MDE + ILL DRUG USE DISORDER"),
|
546 |
-
gr.Dropdown(list(input_mapping['YMIUD5YANY'].keys()), label="YMIUD5YANY: MDE + ILL DRUG USE DISORDER"),
|
547 |
-
gr.Dropdown(list(input_mapping['YMDESUD5ANYO'].keys()), label="YMDESUD5ANYO: MDE vs. SUD vs. BOTH vs. NEITHER"),
|
548 |
-
|
549 |
-
# Consultations
|
550 |
-
gr.Dropdown(list(input_mapping['YNURSMDE'].keys()), label="YNURSMDE: NURSE / OT FOR MDE"),
|
551 |
-
gr.Dropdown(list(input_mapping['YSOCMDE'].keys()), label="YSOCMDE: SOCIAL WORKER FOR MDE"),
|
552 |
-
gr.Dropdown(list(input_mapping['YCOUNMDE'].keys()), label="YCOUNMDE: COUNSELOR FOR MDE"),
|
553 |
-
gr.Dropdown(list(input_mapping['YPSY1MDE'].keys()), label="YPSY1MDE: PSYCHOLOGIST FOR MDE"),
|
554 |
-
gr.Dropdown(list(input_mapping['YPSY2MDE'].keys()), label="YPSY2MDE: PSYCHIATRIST FOR MDE"),
|
555 |
-
gr.Dropdown(list(input_mapping['YHLTMDE'].keys()), label="YHLTMDE: HEALTH PROF FOR MDE"),
|
556 |
-
gr.Dropdown(list(input_mapping['YDOCMDE'].keys()), label="YDOCMDE: GP/FAMILY MD FOR MDE"),
|
557 |
-
gr.Dropdown(list(input_mapping['YTXMDEYR'].keys()), label="YTXMDEYR: DOCTOR/HEALTH PROF FOR MDE THIS YEAR"),
|
558 |
-
|
559 |
-
# Suicidal thoughts / plans
|
560 |
-
gr.Dropdown(list(input_mapping['YUSUITHKYR'].keys()), label="YUSUITHKYR: SERIOUSLY THOUGHT ABOUT KILLING SELF"),
|
561 |
-
gr.Dropdown(list(input_mapping['YUSUIPLNYR'].keys()), label="YUSUIPLNYR: MADE PLANS TO KILL SELF"),
|
562 |
-
gr.Dropdown(list(input_mapping['YUSUITHK'].keys()), label="YUSUITHK: THINK ABOUT KILLING SELF (12 MONTHS)"),
|
563 |
-
gr.Dropdown(list(input_mapping['YUSUIPLN'].keys()), label="YUSUIPLN: MADE PLANS TO KILL SELF (12 MONTHS)"),
|
564 |
-
|
565 |
-
# Impairment
|
566 |
-
gr.Dropdown(list(input_mapping['MDEIMPY'].keys()), label="MDEIMPY: MDE WITH SEVERE ROLE IMPAIRMENT?"),
|
567 |
-
gr.Dropdown(list(input_mapping['LVLDIFMEM2'].keys()), label="LVLDIFMEM2: DIFFICULTY REMEMBERING/CONCENTRATING"),
|
568 |
-
gr.Dropdown(list(input_mapping['YMSUD5YANY'].keys()), label="YMSUD5YANY: MDE + SUBSTANCE USE DISORDER?"),
|
569 |
-
gr.Dropdown(list(input_mapping['YRXMDEYR'].keys()), label="YRXMDEYR: USED MEDS FOR MDE IN PAST YEAR?")
|
570 |
-
]
|
571 |
|
572 |
-
#
|
573 |
custom_css = """
|
574 |
.gradio-container * {
|
575 |
color: #1B1212 !important;
|
@@ -587,13 +662,15 @@ custom_css = """
|
|
587 |
}
|
588 |
"""
|
589 |
|
590 |
-
|
|
|
|
|
591 |
interface = gr.Interface(
|
592 |
-
fn=predict_with_text,
|
593 |
-
inputs=inputs,
|
594 |
-
outputs=outputs,
|
595 |
-
title="Adolescents with Substance Use Mental Health Screening (NSDUH Data)",
|
596 |
-
css=custom_css
|
597 |
)
|
598 |
|
599 |
if __name__ == "__main__":
|
|
|
8 |
# Load the training CSV once (outside the functions so it is read only once).
|
9 |
df = pd.read_csv("X_train_Y_Train_merged_train.csv")
|
10 |
|
11 |
+
######################################
|
12 |
+
# 1) MODEL PREDICTOR CLASS
|
13 |
+
######################################
|
14 |
class ModelPredictor:
|
15 |
def __init__(self, model_path, model_filenames):
|
16 |
self.model_path = model_path
|
17 |
self.model_filenames = model_filenames
|
18 |
self.models = self.load_models()
|
19 |
+
# Mapping from label column to human-readable strings for 0/1
|
20 |
+
# (Adjust as needed for the columns you actually have.)
|
|
|
21 |
self.prediction_map = {
|
22 |
"YOWRCONC": ["Did not have difficulty concentrating", "Had difficulty concentrating"],
|
23 |
"YOSEEDOC": ["Did not feel the need to see a doctor", "Felt the need to see a doctor"],
|
|
|
94 |
else:
|
95 |
return "Mental health severity: Very Low"
|
96 |
|
97 |
+
######################################
|
98 |
+
# 2) MODEL & DATA
|
99 |
+
######################################
|
100 |
model_filenames = [
|
101 |
"YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl",
|
102 |
"YODPPROB.pkl", "YOWRPROB.pkl", "YODPR2WK.pkl", "YOWRDEPR.pkl",
|
|
|
106 |
model_path = "models/"
|
107 |
predictor = ModelPredictor(model_path, model_filenames)
|
108 |
|
109 |
+
######################################
|
110 |
+
# 3) INPUT VALIDATION
|
111 |
+
######################################
|
112 |
def validate_inputs(*args):
|
113 |
for arg in args:
|
114 |
if arg == '' or arg is None: # Assuming empty string or None as unselected
|
115 |
return False
|
116 |
return True
|
117 |
|
118 |
+
######################################
|
119 |
+
# 4) MAIN PREDICTION FUNCTION
|
120 |
+
######################################
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
def predict(
|
122 |
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
|
123 |
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
|
|
|
125 |
YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
|
126 |
YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
|
127 |
):
|
128 |
+
# Prepare user_input dataframe for prediction
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
user_input_data = {
|
130 |
'YNURSMDE': [int(YNURSMDE)],
|
131 |
'YMDEYR': [int(YMDEYR)],
|
|
|
159 |
}
|
160 |
user_input = pd.DataFrame(user_input_data)
|
161 |
|
162 |
+
# 1) Make predictions with each model
|
163 |
predictions = predictor.make_predictions(user_input)
|
164 |
|
165 |
+
# 2) Calculate majority vote (0 or 1) across all models
|
166 |
majority_vote = predictor.get_majority_vote(predictions)
|
167 |
|
168 |
+
# 3) Count how many 1's in all predictions combined
|
169 |
majority_vote_count = sum([1 for pred in np.concatenate(predictions) if pred == 1])
|
170 |
|
171 |
+
# 4) Evaluate severity
|
172 |
severity = predictor.evaluate_severity(majority_vote_count)
|
173 |
|
174 |
+
# 5) Prepare detailed results (group them)
|
175 |
+
# We keep the old grouping as an example, but you can adapt as needed.
|
176 |
results = {
|
177 |
"Concentration_and_Decision_Making": [],
|
178 |
"Sleep_and_Energy_Levels": [],
|
|
|
192 |
"YOPB2WK"]
|
193 |
}
|
194 |
|
195 |
+
# For textual results
|
196 |
for i, pred in enumerate(predictions):
|
197 |
+
model_name = model_filenames[i].split('.')[0]
|
198 |
pred_value = pred[0]
|
199 |
# Map the prediction value to a human-readable string
|
200 |
if model_name in predictor.prediction_map and pred_value in [0, 1]:
|
201 |
result_text = f"Model {model_name}: {predictor.prediction_map[model_name][pred_value]}"
|
202 |
else:
|
203 |
+
# Fallback
|
204 |
+
result_text = f"Model {model_name}: Prediction = {pred_value} (unmapped)"
|
205 |
|
206 |
+
# Append to the appropriate group if matched
|
207 |
found_group = False
|
208 |
for group_name, group_models in prediction_groups.items():
|
209 |
if model_name in group_models:
|
|
|
211 |
found_group = True
|
212 |
break
|
213 |
if not found_group:
|
214 |
+
# If it doesn't match any group, skip or handle differently
|
215 |
pass
|
216 |
|
217 |
+
# Format the grouped results
|
218 |
formatted_results = []
|
219 |
for group, preds in results.items():
|
220 |
if preds:
|
|
|
222 |
formatted_results.append("\n".join(preds))
|
223 |
formatted_results.append("\n")
|
224 |
formatted_results = "\n".join(formatted_results).strip()
|
225 |
+
if not formatted_results:
|
226 |
formatted_results = "No predictions made. Please check your inputs."
|
227 |
|
228 |
+
# If too many unknown predictions, add a note
|
229 |
+
num_unknown = len([p for group_preds in results.values() for p in group_preds if "(unmapped)" in p])
|
230 |
+
if num_unknown > len(model_filenames) / 2:
|
231 |
severity += " (Unknown prediction count is high. Please consult with a human.)"
|
232 |
|
233 |
+
# =============== ADDITIONAL FEATURES ===============
|
234 |
+
|
235 |
# A) Total Patient Count
|
|
|
236 |
total_patients = len(df)
|
237 |
total_patient_count_markdown = (
|
238 |
"### Total Patient Count\n"
|
239 |
+
f"There are **{total_patients}** total patients in the dataset.\n"
|
240 |
+
"All subsequent analyses refer to these patients."
|
241 |
)
|
242 |
|
243 |
+
# B) Bar Chart for input features (how many share same value as user_input)
|
244 |
+
input_counts = {}
|
245 |
+
for col in user_input_data.keys():
|
246 |
+
val = user_input_data[col][0]
|
247 |
+
same_val_count = len(df[df[col] == val])
|
248 |
+
input_counts[col] = same_val_count
|
249 |
+
bar_input_data = pd.DataFrame({
|
250 |
+
"Feature": list(input_counts.keys()),
|
251 |
+
"Count": list(input_counts.values())
|
252 |
+
})
|
253 |
+
fig_bar_input = px.bar(
|
254 |
+
bar_input_data,
|
255 |
+
x="Feature",
|
256 |
+
y="Count",
|
257 |
+
title="Number of Patients with the Same Value for Each Input Feature",
|
258 |
+
labels={"Feature": "Input Feature", "Count": "Number of Patients"}
|
259 |
+
)
|
260 |
+
fig_bar_input.update_layout(xaxis={'categoryorder':'total descending'})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
|
262 |
+
# C) Bar Chart for predicted labels (distribution in df)
|
263 |
+
label_counts = {}
|
264 |
+
for i, pred in enumerate(predictions):
|
265 |
+
model_name = model_filenames[i].split('.')[0]
|
266 |
+
pred_value = pred[0]
|
267 |
+
if pred_value in [0, 1]:
|
268 |
+
label_counts[model_name] = len(df[df[model_name] == pred_value])
|
269 |
+
if len(label_counts) > 0:
|
270 |
+
bar_label_data = pd.DataFrame({
|
271 |
+
"Model": list(label_counts.keys()),
|
272 |
+
"Count": list(label_counts.values())
|
273 |
+
})
|
274 |
+
fig_bar_labels = px.bar(
|
275 |
+
bar_label_data,
|
276 |
+
x="Model",
|
277 |
+
y="Count",
|
278 |
+
title="Number of Patients with the Same Predicted Label",
|
279 |
+
labels={"Model": "Predicted Column", "Count": "Patient Count"}
|
280 |
+
)
|
281 |
+
else:
|
282 |
+
# Fallback if no valid predictions
|
283 |
+
fig_bar_labels = px.bar(title="No valid predicted labels to display")
|
284 |
+
|
285 |
+
# D) Distribution Plot: All Input Features vs. All Predicted Labels
|
286 |
+
# This can create MANY subplots if you have many features & labels.
|
287 |
+
# We'll do a small demonstration with a subset of input features & model columns
|
288 |
+
# to avoid overwhelming the UI.
|
289 |
+
demonstration_features = list(user_input_data.keys())[:4] # first 4 features as a sample
|
290 |
+
demonstration_labels = [fn.split('.')[0] for fn in model_filenames[:3]] # first 3 labels as a sample
|
291 |
+
|
292 |
+
# We'll build a single figure with "facet_col" = label and "facet_row" = feature (small sample)
|
293 |
+
# The approach: for each (feature, label), group by (feature_value, label_value) -> count.
|
294 |
+
# Then we combine them into one big DataFrame with "feature" & "label" columns for Plotly facets.
|
295 |
+
dist_rows = []
|
296 |
+
for feat in demonstration_features:
|
297 |
+
if feat not in df.columns:
|
298 |
+
continue
|
299 |
+
for lbl in demonstration_labels:
|
300 |
+
if lbl not in df.columns:
|
301 |
+
continue
|
302 |
+
tmp_df = df.groupby([feat, lbl]).size().reset_index(name="count")
|
303 |
+
tmp_df["feature"] = feat
|
304 |
+
tmp_df["label"] = lbl
|
305 |
+
dist_rows.append(tmp_df)
|
306 |
+
if len(dist_rows) > 0:
|
307 |
+
big_dist_df = pd.concat(dist_rows, ignore_index=True)
|
308 |
+
# We can re-map numeric to user-friendly text for "feat" if desired, but each feature might have a different mapping.
|
309 |
+
# For now, we just show numeric codes. Real usage would do a reverse mapping if feasible.
|
310 |
+
|
311 |
+
# For the label (0,1), we can map to short strings if we want (like "Label0" / "Label1"), or a direct numeric.
|
312 |
+
fig_dist = px.bar(
|
313 |
+
big_dist_df,
|
314 |
+
x=big_dist_df.columns[0], # the feature's value is the 0-th col in groupby
|
315 |
y="count",
|
316 |
+
color=big_dist_df.columns[1], # the label's value is the 1st col in groupby
|
317 |
+
facet_row="feature",
|
318 |
+
facet_col="label",
|
319 |
+
title="Distribution of Sample Input Features vs. Sample Predicted Labels (Demo)",
|
320 |
+
labels={
|
321 |
+
big_dist_df.columns[0]: "Feature Value",
|
322 |
+
big_dist_df.columns[1]: "Label Value"
|
323 |
+
}
|
324 |
)
|
325 |
+
fig_dist.update_layout(height=800)
|
|
|
|
|
326 |
else:
|
327 |
+
fig_dist = px.bar(title="No distribution plot could be generated (check feature/label columns).")
|
328 |
+
|
329 |
+
# E) Nearest Neighbors: Hamming Distance, K=5, with disclaimers & user-friendly text
|
330 |
+
# "Nearest neighbor” methods for high-dimensional or purely categorical data can be non-trivial.
|
331 |
+
# This demo simply uses a Hamming distance over all input features and picks K=5 neighbors.
|
332 |
+
# In a real application, you would refine which features are most relevant, how to encode them,
|
333 |
+
# and how many neighbors to select.
|
334 |
+
# We also show how to revert numeric codes -> user-friendly text.
|
335 |
+
|
336 |
+
# 1. Invert the user-friendly text mapping (for inputs).
|
337 |
+
# We'll assume input_mapping is consistent. We build a reverse mapping for each column.
|
338 |
+
reverse_input_mapping = {}
|
339 |
+
# We'll build it after the code block below for each column.
|
340 |
+
|
341 |
+
# 2. Invert label mappings from predictor.prediction_map if needed
|
342 |
+
# For each label column, 0 => first string, 1 => second string
|
343 |
+
# We'll store them in a dict: reverse_label_mapping[label_col][0 or 1] => string
|
344 |
+
reverse_label_mapping = {}
|
345 |
+
for lbl, str_list in predictor.prediction_map.items():
|
346 |
+
# str_list[0] => for 0, str_list[1] => for 1
|
347 |
+
reverse_label_mapping[lbl] = {
|
348 |
+
0: str_list[0],
|
349 |
+
1: str_list[1]
|
350 |
+
}
|
351 |
+
|
352 |
+
# Build the reverse input mapping from the provided dictionary
|
353 |
+
# We'll define that dictionary below to ensure we can invert it:
|
354 |
+
input_mapping = {
|
355 |
+
'YNURSMDE': {"Yes": 1, "No": 0},
|
356 |
+
'YMDEYR': {"Yes": 1, "No": 2},
|
357 |
+
'YSOCMDE': {"Yes": 1, "No": 0},
|
358 |
+
'YMDESUD5ANYO': {"SUD only, no MDE": 1, "MDE only, no SUD": 2, "SUD and MDE": 3, "Neither SUD or MDE": 4},
|
359 |
+
'YMSUD5YANY': {"Yes": 1, "No": 0},
|
360 |
+
'YUSUITHK': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
361 |
+
'YMDETXRX': {"Yes": 1, "No": 0},
|
362 |
+
'YUSUITHKYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
363 |
+
'YMDERSUD5ANY': {"Yes": 1, "No": 0},
|
364 |
+
'YUSUIPLNYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
365 |
+
'YCOUNMDE': {"Yes": 1, "No": 0},
|
366 |
+
'YPSY1MDE': {"Yes": 1, "No": 0},
|
367 |
+
'YHLTMDE': {"Yes": 1, "No": 0},
|
368 |
+
'YDOCMDE': {"Yes": 1, "No": 0},
|
369 |
+
'YPSY2MDE': {"Yes": 1, "No": 0},
|
370 |
+
'YMDEHARX': {"Yes": 1, "No": 0},
|
371 |
+
'LVLDIFMEM2': {"No Difficulty": 1, "Some difficulty": 2, "A lot of difficulty or cannot do at all": 3},
|
372 |
+
'MDEIMPY': {"Yes": 1, "No": 2},
|
373 |
+
'YMDEHPO': {"Yes": 1, "No": 0},
|
374 |
+
'YMIMS5YANY': {"Yes": 1, "No": 0},
|
375 |
+
'YMDEIMAD5YR': {"Yes": 1, "No": 0},
|
376 |
+
'YMIUD5YANY': {"Yes": 1, "No": 0},
|
377 |
+
'YMDEHPRX': {"Yes": 1, "No": 0},
|
378 |
+
'YMIMI5YANY': {"Yes": 1, "No": 0},
|
379 |
+
'YUSUIPLN': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
380 |
+
'YTXMDEYR': {"Yes": 1, "No": 0},
|
381 |
+
'YMDEAUD5YR': {"Yes": 1, "No": 0},
|
382 |
+
'YRXMDEYR': {"Yes": 1, "No": 0},
|
383 |
+
'YMDELT': {"Yes": 1, "No": 2}
|
384 |
+
}
|
385 |
+
|
386 |
+
# Build the reverse mapping for each column
|
387 |
+
for col, fwd_map in input_mapping.items():
|
388 |
+
reverse_input_mapping[col] = {v: k for k, v in fwd_map.items()}
|
389 |
+
|
390 |
+
# 3. Calculate Hamming distance for each row
|
391 |
+
# We'll consider the columns in user_input for comparison
|
392 |
+
features_to_compare = list(user_input.columns)
|
393 |
+
subset_df = df[features_to_compare].copy()
|
394 |
user_series = user_input.iloc[0]
|
395 |
|
|
|
396 |
distances = []
|
397 |
+
for idx, row in subset_df.iterrows():
|
398 |
+
dist = sum(row[col] != user_series[col] for col in features_to_compare)
|
399 |
+
distances.append(dist)
|
|
|
|
|
|
|
400 |
|
401 |
df_with_dist = df.copy()
|
402 |
df_with_dist["distance"] = distances
|
403 |
|
404 |
+
# 4. Sort by distance ascending, pick top K=5
|
405 |
K = 5
|
406 |
nearest_neighbors = df_with_dist.sort_values("distance", ascending=True).head(K)
|
407 |
|
408 |
+
# 5. Summarize neighbor info in user-friendly text
|
409 |
+
# For demonstration, let's show a small table with each neighbor's values
|
410 |
+
# for the same features. We'll also show a label or two.
|
411 |
+
# We'll do this in Markdown format.
|
412 |
+
nn_rows = []
|
413 |
+
for idx, nr in nearest_neighbors.iterrows():
|
414 |
+
# Convert each feature to text if possible
|
415 |
+
row_text = []
|
416 |
+
for col in features_to_compare:
|
417 |
+
val_numeric = nr[col]
|
418 |
+
if col in reverse_input_mapping:
|
419 |
+
row_text.append(f"{col}={reverse_input_mapping[col].get(val_numeric, val_numeric)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
420 |
else:
|
421 |
+
row_text.append(f"{col}={val_numeric}")
|
422 |
+
# Let's also show YOWRCONC as an example label (if present)
|
423 |
+
if "YOWRCONC" in nearest_neighbors.columns:
|
424 |
+
label_val = nr["YOWRCONC"]
|
425 |
+
if "YOWRCONC" in reverse_label_mapping:
|
426 |
+
label_str = reverse_label_mapping["YOWRCONC"].get(label_val, label_val)
|
427 |
+
row_text.append(f"YOWRCONC={label_str}")
|
428 |
else:
|
429 |
+
row_text.append(f"YOWRCONC={label_val}")
|
|
|
|
|
430 |
|
431 |
+
nn_rows.append(f"- **Neighbor ID {idx}** (distance={nr['distance']}): " + ", ".join(row_text))
|
432 |
|
433 |
similar_patient_markdown = (
|
434 |
"### Nearest Neighbors (Simple Hamming Distance)\n"
|
435 |
+
f"We searched for the top **{K}** patients whose features most closely match your input.\n\n"
|
436 |
+
"> **Note**: “Nearest neighbor” methods for high-dimensional or purely categorical data can be non-trivial. "
|
437 |
+
"This demo simply uses a Hamming distance over all input features and picks K=5 neighbors. "
|
438 |
"In a real application, you would refine which features are most relevant, how to encode them, "
|
439 |
"and how many neighbors to select.\n\n"
|
440 |
+
"Below is a brief overview of each neighbor's input-feature values and one example label (`YOWRCONC`).\n\n"
|
441 |
+
+ "\n".join(nn_rows)
|
|
|
|
|
|
|
442 |
)
|
443 |
|
444 |
+
# F) Co-occurrence Plot from the previous example (kept for completeness)
|
445 |
+
if all(col in df.columns for col in ["YMDEYR", "YMDERSUD5ANY", "YOWRCONC"]):
|
446 |
+
co_occ_data = df.groupby(["YMDEYR", "YMDERSUD5ANY", "YOWRCONC"]).size().reset_index(name="count")
|
447 |
+
fig_co_occ = px.bar(
|
448 |
+
co_occ_data,
|
449 |
+
x="YMDEYR",
|
450 |
+
y="count",
|
451 |
+
color="YOWRCONC",
|
452 |
+
facet_col="YMDERSUD5ANY",
|
453 |
+
title="Co-Occurrence Plot: YMDEYR and YMDERSUD5ANY vs YOWRCONC"
|
454 |
+
)
|
455 |
+
else:
|
456 |
+
fig_co_occ = px.bar(title="Co-occurrence plot not available (check columns).")
|
457 |
+
|
458 |
+
# =======================
|
459 |
+
# RETURN EVERYTHING
|
460 |
+
# We have 8 outputs:
|
461 |
+
# 1) Prediction Results (Textbox)
|
462 |
+
# 2) Mental Health Severity (Textbox)
|
463 |
+
# 3) Total Patient Count (Markdown)
|
464 |
+
# 4) Distribution Plot (for multiple input features vs. multiple labels)
|
465 |
+
# 5) Nearest Neighbors Summary (Markdown)
|
466 |
+
# 6) Co-Occurrence Plot
|
467 |
+
# 7) Bar Chart for input features
|
468 |
+
# 8) Bar Chart for predicted labels
|
469 |
+
# =======================
|
470 |
return (
|
471 |
+
formatted_results,
|
472 |
+
severity,
|
473 |
+
total_patient_count_markdown,
|
474 |
+
fig_dist,
|
475 |
+
similar_patient_markdown,
|
476 |
+
fig_co_occ,
|
477 |
+
fig_bar_input,
|
478 |
+
fig_bar_labels
|
479 |
)
|
480 |
|
481 |
+
######################################
|
482 |
+
# 5) MAPPING user-friendly text => numeric
|
483 |
+
######################################
|
484 |
+
input_mapping = {
|
485 |
+
'YNURSMDE': {"Yes": 1, "No": 0},
|
486 |
+
'YMDEYR': {"Yes": 1, "No": 2},
|
487 |
+
'YSOCMDE': {"Yes": 1, "No": 0},
|
488 |
+
'YMDESUD5ANYO': {"SUD only, no MDE": 1, "MDE only, no SUD": 2, "SUD and MDE": 3, "Neither SUD or MDE": 4},
|
489 |
+
'YMSUD5YANY': {"Yes": 1, "No": 0},
|
490 |
+
'YUSUITHK': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
491 |
+
'YMDETXRX': {"Yes": 1, "No": 0},
|
492 |
+
'YUSUITHKYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
493 |
+
'YMDERSUD5ANY': {"Yes": 1, "No": 0},
|
494 |
+
'YUSUIPLNYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
495 |
+
'YCOUNMDE': {"Yes": 1, "No": 0},
|
496 |
+
'YPSY1MDE': {"Yes": 1, "No": 0},
|
497 |
+
'YHLTMDE': {"Yes": 1, "No": 0},
|
498 |
+
'YDOCMDE': {"Yes": 1, "No": 0},
|
499 |
+
'YPSY2MDE': {"Yes": 1, "No": 0},
|
500 |
+
'YMDEHARX': {"Yes": 1, "No": 0},
|
501 |
+
'LVLDIFMEM2': {"No Difficulty": 1, "Some difficulty": 2, "A lot of difficulty or cannot do at all": 3},
|
502 |
+
'MDEIMPY': {"Yes": 1, "No": 2},
|
503 |
+
'YMDEHPO': {"Yes": 1, "No": 0},
|
504 |
+
'YMIMS5YANY': {"Yes": 1, "No": 0},
|
505 |
+
'YMDEIMAD5YR': {"Yes": 1, "No": 0},
|
506 |
+
'YMIUD5YANY': {"Yes": 1, "No": 0},
|
507 |
+
'YMDEHPRX': {"Yes": 1, "No": 0},
|
508 |
+
'YMIMI5YANY': {"Yes": 1, "No": 0},
|
509 |
+
'YUSUIPLN': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
510 |
+
'YTXMDEYR': {"Yes": 1, "No": 0},
|
511 |
+
'YMDEAUD5YR': {"Yes": 1, "No": 0},
|
512 |
+
'YRXMDEYR': {"Yes": 1, "No": 0},
|
513 |
+
'YMDELT': {"Yes": 1, "No": 2}
|
514 |
+
}
|
515 |
+
|
516 |
+
######################################
|
517 |
+
# 6) GRADIO INTERFACE
|
518 |
+
######################################
|
519 |
+
# We have 8 outputs in total:
|
520 |
+
# 1) Prediction Results
|
521 |
+
# 2) Mental Health Severity
|
522 |
+
# 3) Total Patient Count
|
523 |
+
# 4) Distribution Plot
|
524 |
+
# 5) Nearest Neighbors
|
525 |
+
# 6) Co-Occurrence Plot
|
526 |
+
# 7) Bar Chart for input features
|
527 |
+
# 8) Bar Chart for predicted labels
|
528 |
+
|
529 |
+
import gradio as gr
|
530 |
+
|
531 |
+
# Define the inputs in the same order as function signature
|
532 |
+
inputs = [
|
533 |
+
gr.Dropdown(list(input_mapping['YMDEYR'].keys()), label="YMDEYR: PAST YEARS MAJOR DEPRESSIVE EPISODE"),
|
534 |
+
gr.Dropdown(list(input_mapping['YMDERSUD5ANY'].keys()), label="YMDERSUD5ANY: MDE OR SUBSTANCE USE DISORDER - ANY"),
|
535 |
+
gr.Dropdown(list(input_mapping['YMDEIMAD5YR'].keys()), label="YMDEIMAD5YR: MDE WITH SEV. IMP + ALCOHOL USE DISORDER"),
|
536 |
+
gr.Dropdown(list(input_mapping['YMIMS5YANY'].keys()), label="YMIMS5YANY: MDE W/ SEV. IMP + SUBSTANCE USE DISORDER"),
|
537 |
+
gr.Dropdown(list(input_mapping['YMDELT'].keys()), label="YMDELT: HAD MAJOR DEPRESSIVE EPISODE IN LIFETIME"),
|
538 |
+
gr.Dropdown(list(input_mapping['YMDEHARX'].keys()), label="YMDEHARX: SAW HEALTH PROF + MEDS FOR MDE"),
|
539 |
+
gr.Dropdown(list(input_mapping['YMDEHPRX'].keys()), label="YMDEHPRX: SAW HEALTH PROF OR MEDS FOR MDE"),
|
540 |
+
gr.Dropdown(list(input_mapping['YMDETXRX'].keys()), label="YMDETXRX: RECEIVED TREATMENT/COUNSELING FOR MDE"),
|
541 |
+
gr.Dropdown(list(input_mapping['YMDEHPO'].keys()), label="YMDEHPO: SAW HEALTH PROF ONLY FOR MDE"),
|
542 |
+
gr.Dropdown(list(input_mapping['YMDEAUD5YR'].keys()), label="YMDEAUD5YR: MDE + ALCOHOL USE DISORDER"),
|
543 |
+
gr.Dropdown(list(input_mapping['YMIMI5YANY'].keys()), label="YMIMI5YANY: MDE W/ ILL DRUG USE DISORDER"),
|
544 |
+
gr.Dropdown(list(input_mapping['YMIUD5YANY'].keys()), label="YMIUD5YANY: MDE + ILL DRUG USE DISORDER"),
|
545 |
+
gr.Dropdown(list(input_mapping['YMDESUD5ANYO'].keys()), label="YMDESUD5ANYO: MDE vs. SUD vs. BOTH vs. NEITHER"),
|
546 |
+
|
547 |
+
# Consultations
|
548 |
+
gr.Dropdown(list(input_mapping['YNURSMDE'].keys()), label="YNURSMDE: SAW/TALK TO NURSE/OT ABOUT MDE"),
|
549 |
+
gr.Dropdown(list(input_mapping['YSOCMDE'].keys()), label="YSOCMDE: SAW/TALK TO SOCIAL WORKER ABOUT MDE"),
|
550 |
+
gr.Dropdown(list(input_mapping['YCOUNMDE'].keys()), label="YCOUNMDE: SAW/TALK TO COUNSELOR ABOUT MDE"),
|
551 |
+
gr.Dropdown(list(input_mapping['YPSY1MDE'].keys()), label="YPSY1MDE: SAW/TALK TO PSYCHOLOGIST ABOUT MDE"),
|
552 |
+
gr.Dropdown(list(input_mapping['YPSY2MDE'].keys()), label="YPSY2MDE: SAW/TALK TO PSYCHIATRIST ABOUT MDE"),
|
553 |
+
gr.Dropdown(list(input_mapping['YHLTMDE'].keys()), label="YHLTMDE: SAW/TALK TO HEALTH PROFESSIONAL ABOUT MDE"),
|
554 |
+
gr.Dropdown(list(input_mapping['YDOCMDE'].keys()), label="YDOCMDE: SAW/TALK TO GP/FAMILY MD ABOUT MDE"),
|
555 |
+
gr.Dropdown(list(input_mapping['YTXMDEYR'].keys()), label="YTXMDEYR: SAW/TALK DOCTOR/HEALTH PROF FOR MDE"),
|
556 |
+
|
557 |
+
# Suicidal thoughts/plans
|
558 |
+
gr.Dropdown(list(input_mapping['YUSUITHKYR'].keys()), label="YUSUITHKYR: SERIOUSLY THOUGHT ABOUT KILLING SELF"),
|
559 |
+
gr.Dropdown(list(input_mapping['YUSUIPLNYR'].keys()), label="YUSUIPLNYR: MADE PLANS TO KILL SELF"),
|
560 |
+
gr.Dropdown(list(input_mapping['YUSUITHK'].keys()), label="YUSUITHK: THINK ABOUT KILLING SELF (12 MONTHS)"),
|
561 |
+
gr.Dropdown(list(input_mapping['YUSUIPLN'].keys()), label="YUSUIPLN: MADE PLANS TO KILL SELF (12 MONTHS)"),
|
562 |
+
|
563 |
+
# Impairments
|
564 |
+
gr.Dropdown(list(input_mapping['MDEIMPY'].keys()), label="MDEIMPY: MDE W/ SEVERE ROLE IMPAIRMENT"),
|
565 |
+
gr.Dropdown(list(input_mapping['LVLDIFMEM2'].keys()), label="LVLDIFMEM2: LEVEL OF DIFFICULTY REMEMBERING/CONCENTRATING"),
|
566 |
+
gr.Dropdown(list(input_mapping['YMSUD5YANY'].keys()), label="YMSUD5YANY: MDE + SUBSTANCE USE DISORDER - ANY"),
|
567 |
+
gr.Dropdown(list(input_mapping['YRXMDEYR'].keys()), label="YRXMDEYR: USED MEDS FOR MDE IN PAST YEAR"),
|
568 |
+
]
|
569 |
+
|
570 |
+
# The 8 outputs
|
571 |
+
outputs = [
|
572 |
+
gr.Textbox(label="Prediction Results", lines=30),
|
573 |
+
gr.Textbox(label="Mental Health Severity", lines=4),
|
574 |
+
gr.Markdown(label="Total Patient Count"),
|
575 |
+
gr.Plot(label="Distribution Plot (Sample of Features & Labels)"),
|
576 |
+
gr.Markdown(label="Nearest Neighbors Summary"),
|
577 |
+
gr.Plot(label="Co-Occurrence Plot"),
|
578 |
+
gr.Plot(label="Number of Patients per Input Feature"),
|
579 |
+
gr.Plot(label="Number of Patients with Predicted Labels")
|
580 |
+
]
|
581 |
+
|
582 |
+
######################################
|
583 |
+
# 7) WRAPPER FOR PREDICT
|
584 |
+
######################################
|
585 |
def predict_with_text(
|
586 |
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
|
587 |
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
|
|
|
589 |
YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
|
590 |
YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
|
591 |
):
|
592 |
+
# Validate user inputs
|
593 |
if not validate_inputs(
|
594 |
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
|
595 |
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
|
|
|
598 |
YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
|
599 |
):
|
600 |
return (
|
601 |
+
"Please select all required fields.",
|
602 |
+
"Validation Error",
|
603 |
+
"No data",
|
604 |
+
None,
|
605 |
+
"No data",
|
606 |
+
None,
|
607 |
+
None,
|
608 |
+
None
|
609 |
)
|
610 |
|
611 |
+
# Map user-friendly text to numeric
|
612 |
user_inputs = {
|
613 |
'YNURSMDE': input_mapping['YNURSMDE'][YNURSMDE],
|
614 |
'YMDEYR': input_mapping['YMDEYR'][YMDEYR],
|
|
|
640 |
'YRXMDEYR': input_mapping['YRXMDEYR'][YRXMDEYR],
|
641 |
'YMDELT': input_mapping['YMDELT'][YMDELT]
|
642 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
643 |
|
644 |
+
# Pass these mapped values into the core predict function
|
645 |
+
return predict(**user_inputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
646 |
|
647 |
+
# Optional custom CSS
|
648 |
custom_css = """
|
649 |
.gradio-container * {
|
650 |
color: #1B1212 !important;
|
|
|
662 |
}
|
663 |
"""
|
664 |
|
665 |
+
######################################
|
666 |
+
# 8) LAUNCH
|
667 |
+
######################################
|
668 |
interface = gr.Interface(
|
669 |
+
fn=predict_with_text,
|
670 |
+
inputs=inputs,
|
671 |
+
outputs=outputs,
|
672 |
+
title="Adolescents with Substance Use Mental Health Screening (NSDUH Data)",
|
673 |
+
css=custom_css
|
674 |
)
|
675 |
|
676 |
if __name__ == "__main__":
|