Update app.py
Browse files
app.py
CHANGED
@@ -8,13 +8,17 @@ import plotly.express as px
|
|
8 |
# Load the training CSV once (outside the functions so it is read only once).
|
9 |
df = pd.read_csv("X_train_Y_Train_merged_train.csv")
|
10 |
|
|
|
|
|
|
|
11 |
class ModelPredictor:
|
12 |
def __init__(self, model_path, model_filenames):
|
13 |
self.model_path = model_path
|
14 |
self.model_filenames = model_filenames
|
15 |
self.models = self.load_models()
|
16 |
-
|
17 |
-
#
|
|
|
18 |
self.prediction_map = {
|
19 |
"YOWRCONC": ["Did not have difficulty concentrating", "Had difficulty concentrating"],
|
20 |
"YOSEEDOC": ["Did not feel the need to see a doctor", "Felt the need to see a doctor"],
|
@@ -91,7 +95,9 @@ class ModelPredictor:
|
|
91 |
else:
|
92 |
return "Mental health severity: Very Low"
|
93 |
|
94 |
-
|
|
|
|
|
95 |
model_filenames = [
|
96 |
"YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl",
|
97 |
"YODPPROB.pkl", "YOWRPROB.pkl", "YODPR2WK.pkl", "YOWRDEPR.pkl",
|
@@ -101,12 +107,60 @@ model_filenames = [
|
|
101 |
model_path = "models/"
|
102 |
predictor = ModelPredictor(model_path, model_filenames)
|
103 |
|
|
|
|
|
|
|
104 |
def validate_inputs(*args):
|
105 |
for arg in args:
|
106 |
if arg == '' or arg is None: # Assuming empty string or None as unselected
|
107 |
return False
|
108 |
return True
|
109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
def predict(
|
111 |
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
|
112 |
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
|
@@ -120,15 +174,11 @@ def predict(
|
|
120 |
2) Aggregates results
|
121 |
3) Produces an overall 'severity'
|
122 |
4) Returns detailed per-model predictions
|
123 |
-
5)
|
124 |
-
6)
|
125 |
-
- Total patient count (markdown)
|
126 |
-
- Cross-tab & grouped bar chart
|
127 |
-
- Similar Patient (Nearest Neighbors)
|
128 |
-
- Co-occurrence plot
|
129 |
"""
|
130 |
|
131 |
-
# Prepare user_input dataframe
|
132 |
user_input_data = {
|
133 |
'YNURSMDE': [int(YNURSMDE)],
|
134 |
'YMDEYR': [int(YMDEYR)],
|
@@ -162,21 +212,20 @@ def predict(
|
|
162 |
}
|
163 |
user_input = pd.DataFrame(user_input_data)
|
164 |
|
165 |
-
#
|
166 |
-
# 1) Make predictions
|
167 |
-
# -----------------------
|
168 |
predictions = predictor.make_predictions(user_input)
|
169 |
|
170 |
-
#
|
171 |
majority_vote = predictor.get_majority_vote(predictions)
|
172 |
|
173 |
-
#
|
174 |
majority_vote_count = sum([1 for pred in np.concatenate(predictions) if pred == 1])
|
175 |
|
176 |
-
#
|
177 |
severity = predictor.evaluate_severity(majority_vote_count)
|
178 |
|
179 |
-
#
|
|
|
180 |
results = {
|
181 |
"Concentration_and_Decision_Making": [],
|
182 |
"Sleep_and_Energy_Levels": [],
|
@@ -196,17 +245,15 @@ def predict(
|
|
196 |
"YOPB2WK"]
|
197 |
}
|
198 |
|
|
|
199 |
for i, pred in enumerate(predictions):
|
200 |
-
model_name = model_filenames[i].split('.')[0]
|
201 |
pred_value = pred[0]
|
202 |
# Map the prediction value to a human-readable string
|
203 |
if model_name in predictor.prediction_map and pred_value in [0, 1]:
|
204 |
result_text = f"Model {model_name}: {predictor.prediction_map[model_name][pred_value]}"
|
205 |
-
elif model_name in predictor.prediction_map:
|
206 |
-
# Out of known range => "Unknown"
|
207 |
-
result_text = f"Model {model_name}: Unknown prediction value {pred_value}"
|
208 |
else:
|
209 |
-
result_text = f"Model {model_name}: Unknown
|
210 |
|
211 |
# Append to the appropriate group
|
212 |
found_group = False
|
@@ -216,318 +263,195 @@ def predict(
|
|
216 |
found_group = True
|
217 |
break
|
218 |
if not found_group:
|
219 |
-
# If
|
220 |
pass
|
221 |
|
222 |
-
#
|
223 |
formatted_results = []
|
224 |
for group, preds in results.items():
|
225 |
if preds:
|
226 |
formatted_results.append(f"Group {group.replace('_', ' ')}:")
|
227 |
formatted_results.append("\n".join(preds))
|
228 |
formatted_results.append("\n")
|
229 |
-
|
230 |
formatted_results = "\n".join(formatted_results).strip()
|
231 |
-
|
232 |
if len(formatted_results) == 0:
|
233 |
formatted_results = "No predictions made. Please check your inputs."
|
234 |
-
|
235 |
-
# Heuristic: if too many unknown predictions, append note
|
236 |
-
num_unknown = len([
|
237 |
-
pred for group, preds in results.items()
|
238 |
-
for pred in preds if "Unknown prediction value" in pred or "Unknown model" in pred
|
239 |
-
])
|
240 |
-
if num_unknown > len(model_filenames) / 2:
|
241 |
-
severity += " (Unknown prediction count is high. Please consult with a human.)"
|
242 |
|
243 |
-
#
|
244 |
-
|
245 |
-
|
|
|
246 |
|
247 |
-
|
|
|
|
|
248 |
total_patients = len(df)
|
249 |
total_patient_count_markdown = (
|
250 |
"### Total Patient Count\n"
|
251 |
f"There are **{total_patients}** total patients in the dataset.\n\n"
|
252 |
-
"This
|
253 |
-
"All subsequent analyses are relative to these patients."
|
254 |
-
)
|
255 |
-
|
256 |
-
# B) Analyze Each Input Feature
|
257 |
-
# For each feature in user_input, compute how many patients have that same value.
|
258 |
-
input_counts = {}
|
259 |
-
for col in user_input_data.keys():
|
260 |
-
val = user_input_data[col][0]
|
261 |
-
same_val_count = len(df[df[col] == val])
|
262 |
-
input_counts[col] = same_val_count
|
263 |
-
|
264 |
-
# Plot: Bar Chart for each input feature
|
265 |
-
bar_input_data = pd.DataFrame({
|
266 |
-
"Feature": list(input_counts.keys()),
|
267 |
-
"Count": list(input_counts.values())
|
268 |
-
})
|
269 |
-
fig_bar_input = px.bar(
|
270 |
-
bar_input_data,
|
271 |
-
x="Feature",
|
272 |
-
y="Count",
|
273 |
-
title="Number of Patients with the Same Value for Each Input Feature",
|
274 |
-
labels={"Feature": "Input Feature", "Count": "Number of Patients"}
|
275 |
)
|
276 |
-
fig_bar_input.update_layout(xaxis={'categoryorder':'total descending'})
|
277 |
|
278 |
-
|
279 |
-
#
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
title="Number of Patients with the Predicted Label (0 or 1) by Model",
|
298 |
-
labels={"Model": "Predicted Column", "Count": "Number of Patients"}
|
299 |
-
)
|
300 |
-
fig_bar_labels.update_layout(xaxis={'categoryorder':'total descending'})
|
301 |
-
else:
|
302 |
-
# If everything was unknown, produce an empty figure or a fallback message
|
303 |
-
bar_label_data = pd.DataFrame({"Model": [], "Count": []})
|
304 |
-
fig_bar_labels = px.bar(
|
305 |
-
bar_label_data,
|
306 |
-
x="Model",
|
307 |
-
y="Count",
|
308 |
-
title="No valid predicted labels to display"
|
309 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
310 |
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
# you might adapt accordingly.
|
316 |
-
if "YMDEYR" in df.columns and "YOWRCONC" in df.columns:
|
317 |
-
cross_tab_data = df.groupby(["YMDEYR", "YOWRCONC"]).size().reset_index(name="count")
|
318 |
-
fig_cross_tab = px.bar(
|
319 |
-
cross_tab_data,
|
320 |
-
x="YMDEYR",
|
321 |
y="count",
|
322 |
-
color="
|
323 |
-
|
324 |
-
|
325 |
-
|
|
|
326 |
)
|
|
|
|
|
|
|
327 |
else:
|
328 |
-
#
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
#
|
333 |
-
|
334 |
-
#
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
return dist
|
341 |
-
|
342 |
-
# Create a single row for easy iteration
|
343 |
user_series = user_input.iloc[0]
|
344 |
|
345 |
-
#
|
346 |
-
# that were used in the user_input.
|
347 |
-
# NOTE: In real usage, confirm these columns exist in df.
|
348 |
-
# If df lacks them or is encoded differently, you'd adapt.
|
349 |
-
features_to_compare = list(user_input.columns)
|
350 |
-
# For Hamming, ensure we pick only the columns present in df
|
351 |
-
features_to_compare = [f for f in features_to_compare if f in df.columns]
|
352 |
-
|
353 |
-
# Build a DataFrame we can safely compare
|
354 |
-
subset_df = df[features_to_compare].copy()
|
355 |
-
|
356 |
-
# Calculate distances
|
357 |
distances = []
|
358 |
-
for idx, row in
|
359 |
d = 0
|
360 |
for col in features_to_compare:
|
361 |
if row[col] != user_series[col]:
|
362 |
d += 1
|
363 |
distances.append(d)
|
364 |
|
365 |
-
# Attach distances
|
366 |
df_with_dist = df.copy()
|
367 |
df_with_dist["distance"] = distances
|
368 |
|
369 |
-
# Sort
|
370 |
K = 5
|
371 |
nearest_neighbors = df_with_dist.sort_values("distance", ascending=True).head(K)
|
372 |
|
373 |
-
#
|
374 |
-
|
375 |
-
if
|
376 |
-
|
377 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
378 |
|
379 |
-
# Summarize in markdown
|
380 |
similar_patient_markdown = (
|
381 |
"### Nearest Neighbors (Simple Hamming Distance)\n"
|
382 |
-
|
383 |
-
"
|
384 |
-
"
|
385 |
-
|
386 |
-
f"
|
387 |
-
"
|
388 |
-
"
|
|
|
|
|
389 |
)
|
390 |
|
391 |
-
|
392 |
-
#
|
393 |
-
|
394 |
-
if all(col in df.columns for col in ["YMDEYR", "YMDERSUD5ANY", "YOWRCONC"]):
|
395 |
-
co_occ_data = df.groupby(["YMDEYR", "YMDERSUD5ANY", "YOWRCONC"]).size().reset_index(name="count")
|
396 |
-
fig_co_occ = px.bar(
|
397 |
-
co_occ_data,
|
398 |
-
x="YMDEYR",
|
399 |
-
y="count",
|
400 |
-
color="YOWRCONC",
|
401 |
-
facet_col="YMDERSUD5ANY",
|
402 |
-
title="Co-Occurrence Plot: YMDEYR and YMDERSUD5ANY vs YOWRCONC"
|
403 |
-
)
|
404 |
-
else:
|
405 |
-
fig_co_occ = px.bar(title="Co-occurrence plot not available (columns not found).")
|
406 |
-
|
407 |
-
# ------------------------
|
408 |
-
# Return everything
|
409 |
-
# ------------------------
|
410 |
-
# We now have 8 items to return:
|
411 |
-
# 1) Prediction Results (Textbox)
|
412 |
-
# 2) Mental Health Severity (Textbox)
|
413 |
-
# 3) Total Patient Count (Markdown)
|
414 |
-
# 4) Cross-Tab & Grouped Bar Chart (Plot)
|
415 |
-
# 5) Nearest Neighbors Summary (Markdown)
|
416 |
-
# 6) Co-Occurrence Plot (Plot)
|
417 |
-
# 7) Bar Chart for input features (Plot)
|
418 |
-
# 8) Bar Chart for predicted labels (Plot)
|
419 |
return (
|
420 |
-
formatted_results,
|
421 |
-
severity,
|
422 |
-
total_patient_count_markdown,
|
423 |
-
|
424 |
-
similar_patient_markdown,
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
)
|
429 |
|
430 |
-
|
431 |
-
#
|
432 |
-
|
433 |
-
input_mapping = {
|
434 |
-
'YNURSMDE': {"Yes": 1, "No": 0},
|
435 |
-
'YMDEYR': {"Yes": 1, "No": 2},
|
436 |
-
'YSOCMDE': {"Yes": 1, "No": 0},
|
437 |
-
'YMDESUD5ANYO': {"SUD only, no MDE": 1, "MDE only, no SUD": 2, "SUD and MDE": 3, "Neither SUD or MDE": 4},
|
438 |
-
'YMSUD5YANY': {"Yes": 1, "No": 0},
|
439 |
-
'YUSUITHK': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
440 |
-
'YMDETXRX': {"Yes": 1, "No": 0},
|
441 |
-
'YUSUITHKYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
442 |
-
'YMDERSUD5ANY': {"Yes": 1, "No": 0},
|
443 |
-
'YUSUIPLNYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
444 |
-
'YCOUNMDE': {"Yes": 1, "No": 0},
|
445 |
-
'YPSY1MDE': {"Yes": 1, "No": 0},
|
446 |
-
'YHLTMDE': {"Yes": 1, "No": 0},
|
447 |
-
'YDOCMDE': {"Yes": 1, "No": 0},
|
448 |
-
'YPSY2MDE': {"Yes": 1, "No": 0},
|
449 |
-
'YMDEHARX': {"Yes": 1, "No": 0},
|
450 |
-
'LVLDIFMEM2': {"No Difficulty": 1, "Some difficulty": 2, "A lot of difficulty or cannot do at all": 3},
|
451 |
-
'MDEIMPY': {"Yes": 1, "No": 2},
|
452 |
-
'YMDEHPO': {"Yes": 1, "No": 0},
|
453 |
-
'YMIMS5YANY': {"Yes": 1, "No": 0},
|
454 |
-
'YMDEIMAD5YR': {"Yes": 1, "No": 0},
|
455 |
-
'YMIUD5YANY': {"Yes": 1, "No": 0},
|
456 |
-
'YMDEHPRX': {"Yes": 1, "No": 0},
|
457 |
-
'YMIMI5YANY': {"Yes": 1, "No": 0},
|
458 |
-
'YUSUIPLN': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
459 |
-
'YTXMDEYR': {"Yes": 1, "No": 0},
|
460 |
-
'YMDEAUD5YR': {"Yes": 1, "No": 0},
|
461 |
-
'YRXMDEYR': {"Yes": 1, "No": 0},
|
462 |
-
'YMDELT': {"Yes": 1, "No": 2}
|
463 |
-
}
|
464 |
-
|
465 |
-
# -----------------------------------------------------------------------------
|
466 |
-
# Create the Gradio interface
|
467 |
-
# -----------------------------------------------------------------------------
|
468 |
-
# We have 8 outputs now:
|
469 |
-
# 1) Prediction Results (Textbox)
|
470 |
-
# 2) Mental Health Severity (Textbox)
|
471 |
-
# 3) Total Patient Count (Markdown)
|
472 |
-
# 4) Cross-Tab & Grouped Bar Chart (Plot)
|
473 |
-
# 5) Nearest Neighbors Summary (Markdown)
|
474 |
-
# 6) Co-Occurrence Plot (Plot)
|
475 |
-
# 7) Bar Chart for input features (Plot)
|
476 |
-
# 8) Bar Chart for predicted labels (Plot)
|
477 |
-
|
478 |
-
# Define the "inputs" in the same order used in the function signature
|
479 |
-
inputs = [
|
480 |
-
################# Ordered and grouped ##########################
|
481 |
-
# Questions related to Major Depressive Episode (MDE) and related impairments or disorders
|
482 |
-
gr.Dropdown(list(input_mapping['YMDEYR'].keys()), label="YMDEYR: PAST YEARS MAJOR DEPRESSIVE EPISODE"),
|
483 |
-
gr.Dropdown(list(input_mapping['YMDERSUD5ANY'].keys()), label="YMDERSUD5ANY: MDE OR SUBSTANCE USE DISORDER - ANY"),
|
484 |
-
gr.Dropdown(list(input_mapping['YMDEIMAD5YR'].keys()), label="YMDEIMAD5YR: MDE WITH SEV. IMP + ALCOHOL USE DISORDER"),
|
485 |
-
gr.Dropdown(list(input_mapping['YMIMS5YANY'].keys()), label="YMIMS5YANY: MDE W/ SEV. IMP + SUBSTANCE USE DISORDER"),
|
486 |
-
gr.Dropdown(list(input_mapping['YMDELT'].keys()), label="YMDELT: HAD MAJOR DEPRESSIVE EPISODE IN LIFETIME"),
|
487 |
-
gr.Dropdown(list(input_mapping['YMDEHARX'].keys()), label="YMDEHARX: SAW HEALTH PROF + MEDS FOR MDE"),
|
488 |
-
gr.Dropdown(list(input_mapping['YMDEHPRX'].keys()), label="YMDEHPRX: SAW HEALTH PROF OR MEDS FOR MDE"),
|
489 |
-
gr.Dropdown(list(input_mapping['YMDETXRX'].keys()), label="YMDETXRX: RECEIVED TREATMENT/COUNSELING FOR MDE"),
|
490 |
-
gr.Dropdown(list(input_mapping['YMDEHPO'].keys()), label="YMDEHPO: SAW HEALTH PROF ONLY FOR MDE"),
|
491 |
-
gr.Dropdown(list(input_mapping['YMDEAUD5YR'].keys()), label="YMDEAUD5YR: MDE + ALCOHOL USE DISORDER"),
|
492 |
-
gr.Dropdown(list(input_mapping['YMIMI5YANY'].keys()), label="YMIMI5YANY: MDE W/ ILL DRUG USE DISORDER"),
|
493 |
-
gr.Dropdown(list(input_mapping['YMIUD5YANY'].keys()), label="YMIUD5YANY: MDE + ILL DRUG USE DISORDER"),
|
494 |
-
gr.Dropdown(list(input_mapping['YMDESUD5ANYO'].keys()), label="YMDESUD5ANYO: MDE vs. SUD vs. BOTH vs. NEITHER"),
|
495 |
-
|
496 |
-
# Questions related to consultations with professionals about MDE
|
497 |
-
gr.Dropdown(list(input_mapping['YNURSMDE'].keys()), label="YNURSMDE: SAW/TALK TO NURSE/OT ABOUT MDE"),
|
498 |
-
gr.Dropdown(list(input_mapping['YSOCMDE'].keys()), label="YSOCMDE: SAW/TALK TO SOCIAL WORKER ABOUT MDE"),
|
499 |
-
gr.Dropdown(list(input_mapping['YCOUNMDE'].keys()), label="YCOUNMDE: SAW/TALK TO COUNSELOR ABOUT MDE"),
|
500 |
-
gr.Dropdown(list(input_mapping['YPSY1MDE'].keys()), label="YPSY1MDE: SAW/TALK TO PSYCHOLOGIST ABOUT MDE"),
|
501 |
-
gr.Dropdown(list(input_mapping['YPSY2MDE'].keys()), label="YPSY2MDE: SAW/TALK TO PSYCHIATRIST ABOUT MDE"),
|
502 |
-
gr.Dropdown(list(input_mapping['YHLTMDE'].keys()), label="YHLTMDE: SAW/TALK TO HEALTH PROFESSIONAL ABOUT MDE"),
|
503 |
-
gr.Dropdown(list(input_mapping['YDOCMDE'].keys()), label="YDOCMDE: SAW/TALK TO GP/FAMILY MD ABOUT MDE"),
|
504 |
-
gr.Dropdown(list(input_mapping['YTXMDEYR'].keys()), label="YTXMDEYR: SAW/TALK DOCTOR/HEALTH PROF FOR MDE"),
|
505 |
-
|
506 |
-
# Questions related to suicidal thoughts and plans
|
507 |
-
gr.Dropdown(list(input_mapping['YUSUITHKYR'].keys()), label="YUSUITHKYR: SERIOUSLY THOUGHT ABOUT KILLING SELF"),
|
508 |
-
gr.Dropdown(list(input_mapping['YUSUIPLNYR'].keys()), label="YUSUIPLNYR: MADE PLANS TO KILL SELF"),
|
509 |
-
gr.Dropdown(list(input_mapping['YUSUITHK'].keys()), label="YUSUITHK: THINK ABOUT KILLING SELF (12 MONTHS)"),
|
510 |
-
gr.Dropdown(list(input_mapping['YUSUIPLN'].keys()), label="YUSUIPLN: MADE PLANS TO KILL SELF (12 MONTHS)"),
|
511 |
-
|
512 |
-
# Questions related to impairment due to MDE
|
513 |
-
gr.Dropdown(list(input_mapping['MDEIMPY'].keys()), label="MDEIMPY: MDE W/ SEVERE ROLE IMPAIRMENT"),
|
514 |
-
gr.Dropdown(list(input_mapping['LVLDIFMEM2'].keys()), label="LVLDIFMEM2: LEVEL OF DIFFICULTY REMEMBERING/CONCENTRATING"),
|
515 |
-
gr.Dropdown(list(input_mapping['YMSUD5YANY'].keys()), label="YMSUD5YANY: MDE + SUBSTANCE USE DISORDER - ANY"),
|
516 |
-
gr.Dropdown(list(input_mapping['YRXMDEYR'].keys()), label="YRXMDEYR: USED MEDS FOR MDE IN PAST YEAR"),
|
517 |
-
]
|
518 |
-
|
519 |
-
# We now have 8 outputs in total:
|
520 |
-
outputs = [
|
521 |
-
gr.Textbox(label="Prediction Results", lines=30),
|
522 |
-
gr.Textbox(label="Mental Health Severity", lines=4),
|
523 |
-
gr.Markdown(label="Total Patient Count"),
|
524 |
-
gr.Plot(label="Cross-Tab & Grouped Bar Chart"),
|
525 |
-
gr.Markdown(label="Nearest Neighbors Summary"),
|
526 |
-
gr.Plot(label="Co-Occurrence Plot"),
|
527 |
-
gr.Plot(label="Number of Patients per Input Feature"),
|
528 |
-
gr.Plot(label="Number of Patients with Predicted Labels")
|
529 |
-
]
|
530 |
-
|
531 |
def predict_with_text(
|
532 |
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
|
533 |
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
|
@@ -547,11 +471,9 @@ def predict_with_text(
|
|
547 |
"Please select all required fields.", # Prediction Results
|
548 |
"Validation Error", # Severity
|
549 |
"No data", # Total Patient Count
|
550 |
-
None, #
|
551 |
"No data", # Nearest Neighbors
|
552 |
-
None,
|
553 |
-
None, # Input Features Bar
|
554 |
-
None # Predicted Labels Bar
|
555 |
)
|
556 |
|
557 |
# Map from user-friendly text to int
|
@@ -590,6 +512,63 @@ def predict_with_text(
|
|
590 |
# Pass our mapped values into the original 'predict' function
|
591 |
return predict(**user_inputs)
|
592 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
593 |
# Custom CSS (optional)
|
594 |
custom_css = """
|
595 |
.gradio-container * {
|
@@ -608,13 +587,13 @@ custom_css = """
|
|
608 |
}
|
609 |
"""
|
610 |
|
611 |
-
#
|
612 |
interface = gr.Interface(
|
613 |
-
fn=predict_with_text,
|
614 |
-
inputs=inputs,
|
615 |
-
outputs=outputs,
|
616 |
-
title="Adolescents with Substance Use Mental Health Screening (NSDUH Data)",
|
617 |
-
css=custom_css
|
618 |
)
|
619 |
|
620 |
if __name__ == "__main__":
|
|
|
8 |
# Load the training CSV once (outside the functions so it is read only once).
|
9 |
df = pd.read_csv("X_train_Y_Train_merged_train.csv")
|
10 |
|
11 |
+
###############################################################################
|
12 |
+
# 1) Model Predictor class
|
13 |
+
###############################################################################
|
14 |
class ModelPredictor:
|
15 |
def __init__(self, model_path, model_filenames):
|
16 |
self.model_path = model_path
|
17 |
self.model_filenames = model_filenames
|
18 |
self.models = self.load_models()
|
19 |
+
|
20 |
+
# For each model name, define the mapping from 0->..., 1->...
|
21 |
+
# If you have more labels, expand this dictionary accordingly.
|
22 |
self.prediction_map = {
|
23 |
"YOWRCONC": ["Did not have difficulty concentrating", "Had difficulty concentrating"],
|
24 |
"YOSEEDOC": ["Did not feel the need to see a doctor", "Felt the need to see a doctor"],
|
|
|
95 |
else:
|
96 |
return "Mental health severity: Very Low"
|
97 |
|
98 |
+
###############################################################################
|
99 |
+
# 2) Model Filenames & Predictor
|
100 |
+
###############################################################################
|
101 |
model_filenames = [
|
102 |
"YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl",
|
103 |
"YODPPROB.pkl", "YOWRPROB.pkl", "YODPR2WK.pkl", "YOWRDEPR.pkl",
|
|
|
107 |
model_path = "models/"
|
108 |
predictor = ModelPredictor(model_path, model_filenames)
|
109 |
|
110 |
+
###############################################################################
|
111 |
+
# 3) Validate Inputs
|
112 |
+
###############################################################################
|
113 |
def validate_inputs(*args):
|
114 |
for arg in args:
|
115 |
if arg == '' or arg is None: # Assuming empty string or None as unselected
|
116 |
return False
|
117 |
return True
|
118 |
|
119 |
+
###############################################################################
|
120 |
+
# 4) Reverse Lookup (numeric -> user-friendly text) for input columns
|
121 |
+
###############################################################################
|
122 |
+
# We'll define the forward mapping here. The reverse mapping is constructed below.
|
123 |
+
input_mapping = {
|
124 |
+
'YNURSMDE': {"Yes": 1, "No": 0},
|
125 |
+
'YMDEYR': {"Yes": 1, "No": 2},
|
126 |
+
'YSOCMDE': {"Yes": 1, "No": 0},
|
127 |
+
'YMDESUD5ANYO': {"SUD only, no MDE": 1, "MDE only, no SUD": 2, "SUD and MDE": 3, "Neither SUD or MDE": 4},
|
128 |
+
'YMSUD5YANY': {"Yes": 1, "No": 0},
|
129 |
+
'YUSUITHK': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
130 |
+
'YMDETXRX': {"Yes": 1, "No": 0},
|
131 |
+
'YUSUITHKYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
132 |
+
'YMDERSUD5ANY': {"Yes": 1, "No": 0},
|
133 |
+
'YUSUIPLNYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
134 |
+
'YCOUNMDE': {"Yes": 1, "No": 0},
|
135 |
+
'YPSY1MDE': {"Yes": 1, "No": 0},
|
136 |
+
'YHLTMDE': {"Yes": 1, "No": 0},
|
137 |
+
'YDOCMDE': {"Yes": 1, "No": 0},
|
138 |
+
'YPSY2MDE': {"Yes": 1, "No": 0},
|
139 |
+
'YMDEHARX': {"Yes": 1, "No": 0},
|
140 |
+
'LVLDIFMEM2': {"No Difficulty": 1, "Some difficulty": 2, "A lot of difficulty or cannot do at all": 3},
|
141 |
+
'MDEIMPY': {"Yes": 1, "No": 2},
|
142 |
+
'YMDEHPO': {"Yes": 1, "No": 0},
|
143 |
+
'YMIMS5YANY': {"Yes": 1, "No": 0},
|
144 |
+
'YMDEIMAD5YR': {"Yes": 1, "No": 0},
|
145 |
+
'YMIUD5YANY': {"Yes": 1, "No": 0},
|
146 |
+
'YMDEHPRX': {"Yes": 1, "No": 0},
|
147 |
+
'YMIMI5YANY': {"Yes": 1, "No": 0},
|
148 |
+
'YUSUIPLN': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
149 |
+
'YTXMDEYR': {"Yes": 1, "No": 0},
|
150 |
+
'YMDEAUD5YR': {"Yes": 1, "No": 0},
|
151 |
+
'YRXMDEYR': {"Yes": 1, "No": 0},
|
152 |
+
'YMDELT': {"Yes": 1, "No": 2}
|
153 |
+
}
|
154 |
+
|
155 |
+
# Build reverse mapping: { "YNURSMDE": {1: "Yes", 0: "No"}, ... } etc.
|
156 |
+
reverse_mapping = {}
|
157 |
+
for col, mapping_dict in input_mapping.items():
|
158 |
+
rev = {v: k for k, v in mapping_dict.items()} # invert dict
|
159 |
+
reverse_mapping[col] = rev
|
160 |
+
|
161 |
+
###############################################################################
|
162 |
+
# 5) Main Predict Function
|
163 |
+
###############################################################################
|
164 |
def predict(
|
165 |
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
|
166 |
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
|
|
|
174 |
2) Aggregates results
|
175 |
3) Produces an overall 'severity'
|
176 |
4) Returns detailed per-model predictions
|
177 |
+
5) Creates a distribution plot for ALL input features vs. a chosen label
|
178 |
+
6) Nearest neighbor logic (with disclaimers), mapping numeric -> user text
|
|
|
|
|
|
|
|
|
179 |
"""
|
180 |
|
181 |
+
# 1) Prepare user_input dataframe
|
182 |
user_input_data = {
|
183 |
'YNURSMDE': [int(YNURSMDE)],
|
184 |
'YMDEYR': [int(YMDEYR)],
|
|
|
212 |
}
|
213 |
user_input = pd.DataFrame(user_input_data)
|
214 |
|
215 |
+
# 2) Make predictions
|
|
|
|
|
216 |
predictions = predictor.make_predictions(user_input)
|
217 |
|
218 |
+
# 3) Calculate majority vote (0 or 1) across all models
|
219 |
majority_vote = predictor.get_majority_vote(predictions)
|
220 |
|
221 |
+
# 4) Count how many 1's in all predictions combined
|
222 |
majority_vote_count = sum([1 for pred in np.concatenate(predictions) if pred == 1])
|
223 |
|
224 |
+
# 5) Evaluate severity
|
225 |
severity = predictor.evaluate_severity(majority_vote_count)
|
226 |
|
227 |
+
# 6) Prepare per-model predictions
|
228 |
+
# We'll group them just like before
|
229 |
results = {
|
230 |
"Concentration_and_Decision_Making": [],
|
231 |
"Sleep_and_Energy_Levels": [],
|
|
|
245 |
"YOPB2WK"]
|
246 |
}
|
247 |
|
248 |
+
# We'll keep a record of which model => which predicted label
|
249 |
for i, pred in enumerate(predictions):
|
250 |
+
model_name = predictor.model_filenames[i].split('.')[0]
|
251 |
pred_value = pred[0]
|
252 |
# Map the prediction value to a human-readable string
|
253 |
if model_name in predictor.prediction_map and pred_value in [0, 1]:
|
254 |
result_text = f"Model {model_name}: {predictor.prediction_map[model_name][pred_value]}"
|
|
|
|
|
|
|
255 |
else:
|
256 |
+
result_text = f"Model {model_name}: Unknown or out-of-range"
|
257 |
|
258 |
# Append to the appropriate group
|
259 |
found_group = False
|
|
|
263 |
found_group = True
|
264 |
break
|
265 |
if not found_group:
|
266 |
+
# If no group matches, skip or store in "Other"
|
267 |
pass
|
268 |
|
269 |
+
# 7) Nicely format the results
|
270 |
formatted_results = []
|
271 |
for group, preds in results.items():
|
272 |
if preds:
|
273 |
formatted_results.append(f"Group {group.replace('_', ' ')}:")
|
274 |
formatted_results.append("\n".join(preds))
|
275 |
formatted_results.append("\n")
|
|
|
276 |
formatted_results = "\n".join(formatted_results).strip()
|
|
|
277 |
if len(formatted_results) == 0:
|
278 |
formatted_results = "No predictions made. Please check your inputs."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
279 |
|
280 |
+
# 8) Additional disclaimers if there's a large fraction of unknown
|
281 |
+
num_unknown = sum(1 for group, preds in results.items() if any("Unknown or out-of-range" in p for p in preds))
|
282 |
+
if num_unknown > len(predictor.model_filenames) / 2:
|
283 |
+
severity += " (Unknown prediction count is high. Please consult with a human.)"
|
284 |
|
285 |
+
############################################################################
|
286 |
+
# A) Total Patient Count
|
287 |
+
############################################################################
|
288 |
total_patients = len(df)
|
289 |
total_patient_count_markdown = (
|
290 |
"### Total Patient Count\n"
|
291 |
f"There are **{total_patients}** total patients in the dataset.\n\n"
|
292 |
+
"This number helps you understand the size of the dataset used."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
)
|
|
|
294 |
|
295 |
+
############################################################################
|
296 |
+
# B) Distribution Plot: All Input Features vs. a single predicted label
|
297 |
+
############################################################################
|
298 |
+
# For demonstration, let's pick "YOWRCONC" if it exists in df:
|
299 |
+
# We'll melt the dataset so that each input feature is in a "FeatureName" column,
|
300 |
+
# and each distinct category is in "FeatureValue". We'll group by those + label to get counts.
|
301 |
+
chosen_label = "YOWRCONC"
|
302 |
+
if chosen_label in df.columns:
|
303 |
+
# 1) Narrow down to the columns of interest
|
304 |
+
# We'll only use the input features that exist in df
|
305 |
+
input_cols_in_df = [c for c in user_input_data.keys() if c in df.columns]
|
306 |
+
# 2) We'll create a "melted" version of these input features
|
307 |
+
# i.e., row per (patient_id, FeatureName, FeatureValue)
|
308 |
+
sub_df = df[input_cols_in_df + [chosen_label]].copy()
|
309 |
+
# Melt them
|
310 |
+
melted = sub_df.melt(
|
311 |
+
id_vars=[chosen_label],
|
312 |
+
var_name="FeatureName",
|
313 |
+
value_name="FeatureValue"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
314 |
)
|
315 |
+
# 3) Group by (FeatureName, FeatureValue, chosen_label) to get size
|
316 |
+
dist_data = melted.groupby(["FeatureName", "FeatureValue", chosen_label]).size().reset_index(name="count")
|
317 |
+
# 4) We'll try to map FeatureValue from numeric -> user-friendly text if possible
|
318 |
+
# We'll do it only if FeatureName is in reverse_mapping.
|
319 |
+
def map_value(row):
|
320 |
+
fn = row["FeatureName"]
|
321 |
+
fv = row["FeatureValue"]
|
322 |
+
if fn in reverse_mapping:
|
323 |
+
if fv in reverse_mapping[fn]:
|
324 |
+
return reverse_mapping[fn][fv] # e.g. 1->"Yes"
|
325 |
+
return fv # fallback
|
326 |
+
dist_data["FeatureValueText"] = dist_data.apply(map_value, axis=1)
|
327 |
+
# 5) Similarly, map chosen_label (0 or 1) to text if in predictor.prediction_map
|
328 |
+
if chosen_label in predictor.prediction_map:
|
329 |
+
def map_label(val):
|
330 |
+
if val in [0, 1]:
|
331 |
+
return predictor.prediction_map[chosen_label][val]
|
332 |
+
return f"Unknown label {val}"
|
333 |
+
dist_data["LabelText"] = dist_data[chosen_label].apply(map_label)
|
334 |
+
else:
|
335 |
+
dist_data["LabelText"] = dist_data[chosen_label].astype(str)
|
336 |
|
337 |
+
# 6) Now produce a bar chart with facet_col = FeatureName
|
338 |
+
fig_distribution = px.bar(
|
339 |
+
dist_data,
|
340 |
+
x="FeatureValueText",
|
|
|
|
|
|
|
|
|
|
|
|
|
341 |
y="count",
|
342 |
+
color="LabelText",
|
343 |
+
facet_col="FeatureName",
|
344 |
+
facet_col_wrap=4, # how many facets per row
|
345 |
+
title=f"Distribution of All Input Features vs. {chosen_label}",
|
346 |
+
height=800
|
347 |
)
|
348 |
+
fig_distribution.update_layout(legend=dict(title=chosen_label))
|
349 |
+
# (Optional) Adjust layout or text angle if you have many categories
|
350 |
+
fig_distribution.update_xaxes(tickangle=45)
|
351 |
else:
|
352 |
+
# Fallback
|
353 |
+
fig_distribution = px.bar(title=f"Label {chosen_label} not found in dataset. Distribution not available.")
|
354 |
+
|
355 |
+
############################################################################
|
356 |
+
# C) Nearest Neighbors (Hamming Distance) with disclaimers & user-friendly text
|
357 |
+
############################################################################
|
358 |
+
# "Nearest neighbor” methods for high-dimensional or purely categorical data can be non-trivial.
|
359 |
+
# This demo uses a Hamming distance over all input features, picks K=5.
|
360 |
+
# In real practice, you'd refine which features to use, how to encode them, etc.
|
361 |
+
|
362 |
+
# 1) Build a DataFrame to compare with the user_input
|
363 |
+
features_to_compare = [col for col in user_input_data if col in df.columns]
|
|
|
|
|
|
|
364 |
user_series = user_input.iloc[0]
|
365 |
|
366 |
+
# 2) Compute distances
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
367 |
distances = []
|
368 |
+
for idx, row in df[features_to_compare].iterrows():
|
369 |
d = 0
|
370 |
for col in features_to_compare:
|
371 |
if row[col] != user_series[col]:
|
372 |
d += 1
|
373 |
distances.append(d)
|
374 |
|
|
|
375 |
df_with_dist = df.copy()
|
376 |
df_with_dist["distance"] = distances
|
377 |
|
378 |
+
# 3) Sort and pick top K=5
|
379 |
K = 5
|
380 |
nearest_neighbors = df_with_dist.sort_values("distance", ascending=True).head(K)
|
381 |
|
382 |
+
# 4) Show how many had the chosen_label=0 vs 1, but also map them
|
383 |
+
# We'll also demonstrate showing user-friendly text for each neighbor's feature values.
|
384 |
+
# However, if you have large K or many features, this can be big.
|
385 |
+
if chosen_label in nearest_neighbors.columns:
|
386 |
+
nn_label_0 = len(nearest_neighbors[nearest_neighbors[chosen_label] == 0])
|
387 |
+
nn_label_1 = len(nearest_neighbors[nearest_neighbors[chosen_label] == 1])
|
388 |
+
if chosen_label in predictor.prediction_map:
|
389 |
+
label0_text = predictor.prediction_map[chosen_label][0]
|
390 |
+
label1_text = predictor.prediction_map[chosen_label][1]
|
391 |
+
else:
|
392 |
+
label0_text = "Label=0"
|
393 |
+
label1_text = "Label=1"
|
394 |
+
else:
|
395 |
+
nn_label_0 = nn_label_1 = 0
|
396 |
+
label0_text = "Label=0"
|
397 |
+
label1_text = "Label=1"
|
398 |
+
|
399 |
+
# 5) Build an example table of those neighbors in user-friendly text
|
400 |
+
neighbor_text_rows = []
|
401 |
+
for idx, nn_row in nearest_neighbors.iterrows():
|
402 |
+
# For each feature, map numeric -> user text
|
403 |
+
row_str_parts = []
|
404 |
+
row_str_parts.append(f"distance={nn_row['distance']}")
|
405 |
+
for fcol in features_to_compare:
|
406 |
+
val = nn_row[fcol]
|
407 |
+
# try to map
|
408 |
+
if fcol in reverse_mapping and val in reverse_mapping[fcol]:
|
409 |
+
val_str = reverse_mapping[fcol][val]
|
410 |
+
else:
|
411 |
+
val_str = str(val)
|
412 |
+
row_str_parts.append(f"{fcol}={val_str}")
|
413 |
+
# For the label
|
414 |
+
if chosen_label in nn_row:
|
415 |
+
lbl_val = nn_row[chosen_label]
|
416 |
+
if chosen_label in predictor.prediction_map and lbl_val in [0, 1]:
|
417 |
+
lbl_str = predictor.prediction_map[chosen_label][lbl_val]
|
418 |
+
else:
|
419 |
+
lbl_str = str(lbl_val)
|
420 |
+
row_str_parts.append(f"{chosen_label}={lbl_str}")
|
421 |
+
neighbor_text_rows.append(" | ".join(row_str_parts))
|
422 |
+
|
423 |
+
neighbor_text_block = "\n".join(neighbor_text_rows)
|
424 |
|
|
|
425 |
similar_patient_markdown = (
|
426 |
"### Nearest Neighbors (Simple Hamming Distance)\n"
|
427 |
+
"“Nearest neighbor” methods for high-dimensional or purely categorical data can be non-trivial. "
|
428 |
+
"This demo simply uses a Hamming distance over all input features and picks **K=5** neighbors.\n\n"
|
429 |
+
"In a real application, you would refine which features are most relevant, how to encode them, "
|
430 |
+
"and how many neighbors to select.\n\n"
|
431 |
+
f"Among these **{K}** nearest neighbors:\n"
|
432 |
+
f"- **{nn_label_0}** had {label0_text}\n"
|
433 |
+
f"- **{nn_label_1}** had {label1_text}\n\n"
|
434 |
+
"Below is a breakdown of each neighbor's key features in user-friendly text:\n\n"
|
435 |
+
f"```\n{neighbor_text_block}\n```"
|
436 |
)
|
437 |
|
438 |
+
############################################################################
|
439 |
+
# Return 8 outputs
|
440 |
+
############################################################################
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
441 |
return (
|
442 |
+
formatted_results, # 1) Prediction results (Textbox)
|
443 |
+
severity, # 2) Mental Health Severity (Textbox)
|
444 |
+
total_patient_count_markdown, # 3) Total Patient Count (Markdown)
|
445 |
+
fig_distribution, # 4) Distribution Plot (Plot)
|
446 |
+
similar_patient_markdown, # 5) Nearest Neighbor Summary (Markdown)
|
447 |
+
None, # 6) Placeholder if you need more plots
|
448 |
+
None, # 7) Another placeholder
|
449 |
+
None # 8) Another placeholder
|
450 |
)
|
451 |
|
452 |
+
###############################################################################
|
453 |
+
# 6) Gradio Interface: We'll keep 8 outputs, but only use some in this demo
|
454 |
+
###############################################################################
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
455 |
def predict_with_text(
|
456 |
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
|
457 |
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
|
|
|
471 |
"Please select all required fields.", # Prediction Results
|
472 |
"Validation Error", # Severity
|
473 |
"No data", # Total Patient Count
|
474 |
+
None, # Distribution Plot
|
475 |
"No data", # Nearest Neighbors
|
476 |
+
None, None, None # Placeholders
|
|
|
|
|
477 |
)
|
478 |
|
479 |
# Map from user-friendly text to int
|
|
|
512 |
# Pass our mapped values into the original 'predict' function
|
513 |
return predict(**user_inputs)
|
514 |
|
515 |
+
###############################################################################
|
516 |
+
# 7) Define and Launch Gradio Interface
|
517 |
+
###############################################################################
|
518 |
+
import sys
|
519 |
+
|
520 |
+
# We have 8 outputs (some are placeholders)
|
521 |
+
outputs = [
|
522 |
+
gr.Textbox(label="Prediction Results", lines=30),
|
523 |
+
gr.Textbox(label="Mental Health Severity", lines=4),
|
524 |
+
gr.Markdown(label="Total Patient Count"),
|
525 |
+
gr.Plot(label="Distribution of All Input Features vs. One Label"),
|
526 |
+
gr.Markdown(label="Nearest Neighbors Summary"),
|
527 |
+
gr.Plot(label="Placeholder Plot"),
|
528 |
+
gr.Plot(label="Placeholder Plot"),
|
529 |
+
gr.Plot(label="Placeholder Plot")
|
530 |
+
]
|
531 |
+
|
532 |
+
# Define the inputs
|
533 |
+
inputs = [
|
534 |
+
# Major Depressive Episode (MDE) questions
|
535 |
+
gr.Dropdown(list(input_mapping['YMDEYR'].keys()), label="YMDEYR: PAST YEAR MDE?"),
|
536 |
+
gr.Dropdown(list(input_mapping['YMDERSUD5ANY'].keys()), label="YMDERSUD5ANY: MDE OR SUBSTANCE USE DISORDER - ANY"),
|
537 |
+
gr.Dropdown(list(input_mapping['YMDEIMAD5YR'].keys()), label="YMDEIMAD5YR: MDE + ALCOHOL USE DISORDER?"),
|
538 |
+
gr.Dropdown(list(input_mapping['YMIMS5YANY'].keys()), label="YMIMS5YANY: MDE + SUBSTANCE USE DISORDER?"),
|
539 |
+
gr.Dropdown(list(input_mapping['YMDELT'].keys()), label="YMDELT: EVER HAD MDE LIFETIME?"),
|
540 |
+
gr.Dropdown(list(input_mapping['YMDEHARX'].keys()), label="YMDEHARX: SAW HEALTH PROF + MEDS FOR MDE"),
|
541 |
+
gr.Dropdown(list(input_mapping['YMDEHPRX'].keys()), label="YMDEHPRX: SAW HEALTH PROF OR MEDS FOR MDE"),
|
542 |
+
gr.Dropdown(list(input_mapping['YMDETXRX'].keys()), label="YMDETXRX: TREATMENT/COUNSELING FOR MDE"),
|
543 |
+
gr.Dropdown(list(input_mapping['YMDEHPO'].keys()), label="YMDEHPO: HEALTH PROF ONLY FOR MDE"),
|
544 |
+
gr.Dropdown(list(input_mapping['YMDEAUD5YR'].keys()), label="YMDEAUD5YR: MDE + ALCOHOL USE DISORDER"),
|
545 |
+
gr.Dropdown(list(input_mapping['YMIMI5YANY'].keys()), label="YMIMI5YANY: MDE + ILL DRUG USE DISORDER"),
|
546 |
+
gr.Dropdown(list(input_mapping['YMIUD5YANY'].keys()), label="YMIUD5YANY: MDE + ILL DRUG USE DISORDER"),
|
547 |
+
gr.Dropdown(list(input_mapping['YMDESUD5ANYO'].keys()), label="YMDESUD5ANYO: MDE vs. SUD vs. BOTH vs. NEITHER"),
|
548 |
+
|
549 |
+
# Consultations
|
550 |
+
gr.Dropdown(list(input_mapping['YNURSMDE'].keys()), label="YNURSMDE: NURSE / OT FOR MDE"),
|
551 |
+
gr.Dropdown(list(input_mapping['YSOCMDE'].keys()), label="YSOCMDE: SOCIAL WORKER FOR MDE"),
|
552 |
+
gr.Dropdown(list(input_mapping['YCOUNMDE'].keys()), label="YCOUNMDE: COUNSELOR FOR MDE"),
|
553 |
+
gr.Dropdown(list(input_mapping['YPSY1MDE'].keys()), label="YPSY1MDE: PSYCHOLOGIST FOR MDE"),
|
554 |
+
gr.Dropdown(list(input_mapping['YPSY2MDE'].keys()), label="YPSY2MDE: PSYCHIATRIST FOR MDE"),
|
555 |
+
gr.Dropdown(list(input_mapping['YHLTMDE'].keys()), label="YHLTMDE: HEALTH PROF FOR MDE"),
|
556 |
+
gr.Dropdown(list(input_mapping['YDOCMDE'].keys()), label="YDOCMDE: GP/FAMILY MD FOR MDE"),
|
557 |
+
gr.Dropdown(list(input_mapping['YTXMDEYR'].keys()), label="YTXMDEYR: DOCTOR/HEALTH PROF FOR MDE THIS YEAR"),
|
558 |
+
|
559 |
+
# Suicidal thoughts / plans
|
560 |
+
gr.Dropdown(list(input_mapping['YUSUITHKYR'].keys()), label="YUSUITHKYR: SERIOUSLY THOUGHT ABOUT KILLING SELF"),
|
561 |
+
gr.Dropdown(list(input_mapping['YUSUIPLNYR'].keys()), label="YUSUIPLNYR: MADE PLANS TO KILL SELF"),
|
562 |
+
gr.Dropdown(list(input_mapping['YUSUITHK'].keys()), label="YUSUITHK: THINK ABOUT KILLING SELF (12 MONTHS)"),
|
563 |
+
gr.Dropdown(list(input_mapping['YUSUIPLN'].keys()), label="YUSUIPLN: MADE PLANS TO KILL SELF (12 MONTHS)"),
|
564 |
+
|
565 |
+
# Impairment
|
566 |
+
gr.Dropdown(list(input_mapping['MDEIMPY'].keys()), label="MDEIMPY: MDE WITH SEVERE ROLE IMPAIRMENT?"),
|
567 |
+
gr.Dropdown(list(input_mapping['LVLDIFMEM2'].keys()), label="LVLDIFMEM2: DIFFICULTY REMEMBERING/CONCENTRATING"),
|
568 |
+
gr.Dropdown(list(input_mapping['YMSUD5YANY'].keys()), label="YMSUD5YANY: MDE + SUBSTANCE USE DISORDER?"),
|
569 |
+
gr.Dropdown(list(input_mapping['YRXMDEYR'].keys()), label="YRXMDEYR: USED MEDS FOR MDE IN PAST YEAR?")
|
570 |
+
]
|
571 |
+
|
572 |
# Custom CSS (optional)
|
573 |
custom_css = """
|
574 |
.gradio-container * {
|
|
|
587 |
}
|
588 |
"""
|
589 |
|
590 |
+
# Build the interface
|
591 |
interface = gr.Interface(
|
592 |
+
fn=predict_with_text,
|
593 |
+
inputs=inputs,
|
594 |
+
outputs=outputs,
|
595 |
+
title="Adolescents with Substance Use Mental Health Screening (NSDUH Data)",
|
596 |
+
css=custom_css,
|
597 |
)
|
598 |
|
599 |
if __name__ == "__main__":
|