pantdipendra commited on
Commit
cf4c3a5
·
verified ·
1 Parent(s): 67c356a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +372 -295
app.py CHANGED
@@ -8,17 +8,16 @@ import plotly.express as px
8
  # Load the training CSV once (outside the functions so it is read only once).
9
  df = pd.read_csv("X_train_Y_Train_merged_train.csv")
10
 
11
- ###############################################################################
12
- # 1) Model Predictor class
13
- ###############################################################################
14
  class ModelPredictor:
15
  def __init__(self, model_path, model_filenames):
16
  self.model_path = model_path
17
  self.model_filenames = model_filenames
18
  self.models = self.load_models()
19
-
20
- # For each model name, define the mapping from 0->..., 1->...
21
- # If you have more labels, expand this dictionary accordingly.
22
  self.prediction_map = {
23
  "YOWRCONC": ["Did not have difficulty concentrating", "Had difficulty concentrating"],
24
  "YOSEEDOC": ["Did not feel the need to see a doctor", "Felt the need to see a doctor"],
@@ -95,9 +94,9 @@ class ModelPredictor:
95
  else:
96
  return "Mental health severity: Very Low"
97
 
98
- ###############################################################################
99
- # 2) Model Filenames & Predictor
100
- ###############################################################################
101
  model_filenames = [
102
  "YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl",
103
  "YODPPROB.pkl", "YOWRPROB.pkl", "YODPR2WK.pkl", "YOWRDEPR.pkl",
@@ -107,60 +106,18 @@ model_filenames = [
107
  model_path = "models/"
108
  predictor = ModelPredictor(model_path, model_filenames)
109
 
110
- ###############################################################################
111
- # 3) Validate Inputs
112
- ###############################################################################
113
  def validate_inputs(*args):
114
  for arg in args:
115
  if arg == '' or arg is None: # Assuming empty string or None as unselected
116
  return False
117
  return True
118
 
119
- ###############################################################################
120
- # 4) Reverse Lookup (numeric -> user-friendly text) for input columns
121
- ###############################################################################
122
- # We'll define the forward mapping here. The reverse mapping is constructed below.
123
- input_mapping = {
124
- 'YNURSMDE': {"Yes": 1, "No": 0},
125
- 'YMDEYR': {"Yes": 1, "No": 2},
126
- 'YSOCMDE': {"Yes": 1, "No": 0},
127
- 'YMDESUD5ANYO': {"SUD only, no MDE": 1, "MDE only, no SUD": 2, "SUD and MDE": 3, "Neither SUD or MDE": 4},
128
- 'YMSUD5YANY': {"Yes": 1, "No": 0},
129
- 'YUSUITHK': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
130
- 'YMDETXRX': {"Yes": 1, "No": 0},
131
- 'YUSUITHKYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
132
- 'YMDERSUD5ANY': {"Yes": 1, "No": 0},
133
- 'YUSUIPLNYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
134
- 'YCOUNMDE': {"Yes": 1, "No": 0},
135
- 'YPSY1MDE': {"Yes": 1, "No": 0},
136
- 'YHLTMDE': {"Yes": 1, "No": 0},
137
- 'YDOCMDE': {"Yes": 1, "No": 0},
138
- 'YPSY2MDE': {"Yes": 1, "No": 0},
139
- 'YMDEHARX': {"Yes": 1, "No": 0},
140
- 'LVLDIFMEM2': {"No Difficulty": 1, "Some difficulty": 2, "A lot of difficulty or cannot do at all": 3},
141
- 'MDEIMPY': {"Yes": 1, "No": 2},
142
- 'YMDEHPO': {"Yes": 1, "No": 0},
143
- 'YMIMS5YANY': {"Yes": 1, "No": 0},
144
- 'YMDEIMAD5YR': {"Yes": 1, "No": 0},
145
- 'YMIUD5YANY': {"Yes": 1, "No": 0},
146
- 'YMDEHPRX': {"Yes": 1, "No": 0},
147
- 'YMIMI5YANY': {"Yes": 1, "No": 0},
148
- 'YUSUIPLN': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
149
- 'YTXMDEYR': {"Yes": 1, "No": 0},
150
- 'YMDEAUD5YR': {"Yes": 1, "No": 0},
151
- 'YRXMDEYR': {"Yes": 1, "No": 0},
152
- 'YMDELT': {"Yes": 1, "No": 2}
153
- }
154
-
155
- # Build reverse mapping: { "YNURSMDE": {1: "Yes", 0: "No"}, ... } etc.
156
- reverse_mapping = {}
157
- for col, mapping_dict in input_mapping.items():
158
- rev = {v: k for k, v in mapping_dict.items()} # invert dict
159
- reverse_mapping[col] = rev
160
-
161
- ###############################################################################
162
- # 5) Main Predict Function
163
- ###############################################################################
164
  def predict(
165
  YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
166
  YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
@@ -168,17 +125,7 @@ def predict(
168
  YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
169
  YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
170
  ):
171
- """
172
- Core prediction function that:
173
- 1) Predicts with each model
174
- 2) Aggregates results
175
- 3) Produces an overall 'severity'
176
- 4) Returns detailed per-model predictions
177
- 5) Creates a distribution plot for ALL input features vs. a chosen label
178
- 6) Nearest neighbor logic (with disclaimers), mapping numeric -> user text
179
- """
180
-
181
- # 1) Prepare user_input dataframe
182
  user_input_data = {
183
  'YNURSMDE': [int(YNURSMDE)],
184
  'YMDEYR': [int(YMDEYR)],
@@ -212,20 +159,20 @@ def predict(
212
  }
213
  user_input = pd.DataFrame(user_input_data)
214
 
215
- # 2) Make predictions
216
  predictions = predictor.make_predictions(user_input)
217
 
218
- # 3) Calculate majority vote (0 or 1) across all models
219
  majority_vote = predictor.get_majority_vote(predictions)
220
 
221
- # 4) Count how many 1's in all predictions combined
222
  majority_vote_count = sum([1 for pred in np.concatenate(predictions) if pred == 1])
223
 
224
- # 5) Evaluate severity
225
  severity = predictor.evaluate_severity(majority_vote_count)
226
 
227
- # 6) Prepare per-model predictions
228
- # We'll group them just like before
229
  results = {
230
  "Concentration_and_Decision_Making": [],
231
  "Sleep_and_Energy_Levels": [],
@@ -245,17 +192,18 @@ def predict(
245
  "YOPB2WK"]
246
  }
247
 
248
- # We'll keep a record of which model => which predicted label
249
  for i, pred in enumerate(predictions):
250
- model_name = predictor.model_filenames[i].split('.')[0]
251
  pred_value = pred[0]
252
  # Map the prediction value to a human-readable string
253
  if model_name in predictor.prediction_map and pred_value in [0, 1]:
254
  result_text = f"Model {model_name}: {predictor.prediction_map[model_name][pred_value]}"
255
  else:
256
- result_text = f"Model {model_name}: Unknown or out-of-range"
 
257
 
258
- # Append to the appropriate group
259
  found_group = False
260
  for group_name, group_models in prediction_groups.items():
261
  if model_name in group_models:
@@ -263,10 +211,10 @@ def predict(
263
  found_group = True
264
  break
265
  if not found_group:
266
- # If no group matches, skip or store in "Other"
267
  pass
268
 
269
- # 7) Nicely format the results
270
  formatted_results = []
271
  for group, preds in results.items():
272
  if preds:
@@ -274,184 +222,366 @@ def predict(
274
  formatted_results.append("\n".join(preds))
275
  formatted_results.append("\n")
276
  formatted_results = "\n".join(formatted_results).strip()
277
- if len(formatted_results) == 0:
278
  formatted_results = "No predictions made. Please check your inputs."
279
 
280
- # 8) Additional disclaimers if there's a large fraction of unknown
281
- num_unknown = sum(1 for group, preds in results.items() if any("Unknown or out-of-range" in p for p in preds))
282
- if num_unknown > len(predictor.model_filenames) / 2:
283
  severity += " (Unknown prediction count is high. Please consult with a human.)"
284
 
285
- ############################################################################
 
286
  # A) Total Patient Count
287
- ############################################################################
288
  total_patients = len(df)
289
  total_patient_count_markdown = (
290
  "### Total Patient Count\n"
291
- f"There are **{total_patients}** total patients in the dataset.\n\n"
292
- "This number helps you understand the size of the dataset used."
293
  )
294
 
295
- ############################################################################
296
- # B) Distribution Plot: All Input Features vs. a single predicted label
297
- ############################################################################
298
- # For demonstration, let's pick "YOWRCONC" if it exists in df:
299
- # We'll melt the dataset so that each input feature is in a "FeatureName" column,
300
- # and each distinct category is in "FeatureValue". We'll group by those + label to get counts.
301
- chosen_label = "YOWRCONC"
302
- if chosen_label in df.columns:
303
- # 1) Narrow down to the columns of interest
304
- # We'll only use the input features that exist in df
305
- input_cols_in_df = [c for c in user_input_data.keys() if c in df.columns]
306
- # 2) We'll create a "melted" version of these input features
307
- # i.e., row per (patient_id, FeatureName, FeatureValue)
308
- sub_df = df[input_cols_in_df + [chosen_label]].copy()
309
- # Melt them
310
- melted = sub_df.melt(
311
- id_vars=[chosen_label],
312
- var_name="FeatureName",
313
- value_name="FeatureValue"
314
- )
315
- # 3) Group by (FeatureName, FeatureValue, chosen_label) to get size
316
- dist_data = melted.groupby(["FeatureName", "FeatureValue", chosen_label]).size().reset_index(name="count")
317
- # 4) We'll try to map FeatureValue from numeric -> user-friendly text if possible
318
- # We'll do it only if FeatureName is in reverse_mapping.
319
- def map_value(row):
320
- fn = row["FeatureName"]
321
- fv = row["FeatureValue"]
322
- if fn in reverse_mapping:
323
- if fv in reverse_mapping[fn]:
324
- return reverse_mapping[fn][fv] # e.g. 1->"Yes"
325
- return fv # fallback
326
- dist_data["FeatureValueText"] = dist_data.apply(map_value, axis=1)
327
- # 5) Similarly, map chosen_label (0 or 1) to text if in predictor.prediction_map
328
- if chosen_label in predictor.prediction_map:
329
- def map_label(val):
330
- if val in [0, 1]:
331
- return predictor.prediction_map[chosen_label][val]
332
- return f"Unknown label {val}"
333
- dist_data["LabelText"] = dist_data[chosen_label].apply(map_label)
334
- else:
335
- dist_data["LabelText"] = dist_data[chosen_label].astype(str)
336
 
337
- # 6) Now produce a bar chart with facet_col = FeatureName
338
- fig_distribution = px.bar(
339
- dist_data,
340
- x="FeatureValueText",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  y="count",
342
- color="LabelText",
343
- facet_col="FeatureName",
344
- facet_col_wrap=4, # how many facets per row
345
- title=f"Distribution of All Input Features vs. {chosen_label}",
346
- height=800
 
 
 
347
  )
348
- fig_distribution.update_layout(legend=dict(title=chosen_label))
349
- # (Optional) Adjust layout or text angle if you have many categories
350
- fig_distribution.update_xaxes(tickangle=45)
351
  else:
352
- # Fallback
353
- fig_distribution = px.bar(title=f"Label {chosen_label} not found in dataset. Distribution not available.")
354
-
355
- ############################################################################
356
- # C) Nearest Neighbors (Hamming Distance) with disclaimers & user-friendly text
357
- ############################################################################
358
- # "Nearest neighbor” methods for high-dimensional or purely categorical data can be non-trivial.
359
- # This demo uses a Hamming distance over all input features, picks K=5.
360
- # In real practice, you'd refine which features to use, how to encode them, etc.
361
-
362
- # 1) Build a DataFrame to compare with the user_input
363
- features_to_compare = [col for col in user_input_data if col in df.columns]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
  user_series = user_input.iloc[0]
365
 
366
- # 2) Compute distances
367
  distances = []
368
- for idx, row in df[features_to_compare].iterrows():
369
- d = 0
370
- for col in features_to_compare:
371
- if row[col] != user_series[col]:
372
- d += 1
373
- distances.append(d)
374
 
375
  df_with_dist = df.copy()
376
  df_with_dist["distance"] = distances
377
 
378
- # 3) Sort and pick top K=5
379
  K = 5
380
  nearest_neighbors = df_with_dist.sort_values("distance", ascending=True).head(K)
381
 
382
- # 4) Show how many had the chosen_label=0 vs 1, but also map them
383
- # We'll also demonstrate showing user-friendly text for each neighbor's feature values.
384
- # However, if you have large K or many features, this can be big.
385
- if chosen_label in nearest_neighbors.columns:
386
- nn_label_0 = len(nearest_neighbors[nearest_neighbors[chosen_label] == 0])
387
- nn_label_1 = len(nearest_neighbors[nearest_neighbors[chosen_label] == 1])
388
- if chosen_label in predictor.prediction_map:
389
- label0_text = predictor.prediction_map[chosen_label][0]
390
- label1_text = predictor.prediction_map[chosen_label][1]
391
- else:
392
- label0_text = "Label=0"
393
- label1_text = "Label=1"
394
- else:
395
- nn_label_0 = nn_label_1 = 0
396
- label0_text = "Label=0"
397
- label1_text = "Label=1"
398
-
399
- # 5) Build an example table of those neighbors in user-friendly text
400
- neighbor_text_rows = []
401
- for idx, nn_row in nearest_neighbors.iterrows():
402
- # For each feature, map numeric -> user text
403
- row_str_parts = []
404
- row_str_parts.append(f"distance={nn_row['distance']}")
405
- for fcol in features_to_compare:
406
- val = nn_row[fcol]
407
- # try to map
408
- if fcol in reverse_mapping and val in reverse_mapping[fcol]:
409
- val_str = reverse_mapping[fcol][val]
410
  else:
411
- val_str = str(val)
412
- row_str_parts.append(f"{fcol}={val_str}")
413
- # For the label
414
- if chosen_label in nn_row:
415
- lbl_val = nn_row[chosen_label]
416
- if chosen_label in predictor.prediction_map and lbl_val in [0, 1]:
417
- lbl_str = predictor.prediction_map[chosen_label][lbl_val]
418
  else:
419
- lbl_str = str(lbl_val)
420
- row_str_parts.append(f"{chosen_label}={lbl_str}")
421
- neighbor_text_rows.append(" | ".join(row_str_parts))
422
 
423
- neighbor_text_block = "\n".join(neighbor_text_rows)
424
 
425
  similar_patient_markdown = (
426
  "### Nearest Neighbors (Simple Hamming Distance)\n"
427
- "“Nearest neighbor” methods for high-dimensional or purely categorical data can be non-trivial. "
428
- "This demo simply uses a Hamming distance over all input features and picks **K=5** neighbors.\n\n"
 
429
  "In a real application, you would refine which features are most relevant, how to encode them, "
430
  "and how many neighbors to select.\n\n"
431
- f"Among these **{K}** nearest neighbors:\n"
432
- f"- **{nn_label_0}** had {label0_text}\n"
433
- f"- **{nn_label_1}** had {label1_text}\n\n"
434
- "Below is a breakdown of each neighbor's key features in user-friendly text:\n\n"
435
- f"```\n{neighbor_text_block}\n```"
436
  )
437
 
438
- ############################################################################
439
- # Return 8 outputs
440
- ############################################################################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
  return (
442
- formatted_results, # 1) Prediction results (Textbox)
443
- severity, # 2) Mental Health Severity (Textbox)
444
- total_patient_count_markdown, # 3) Total Patient Count (Markdown)
445
- fig_distribution, # 4) Distribution Plot (Plot)
446
- similar_patient_markdown, # 5) Nearest Neighbor Summary (Markdown)
447
- None, # 6) Placeholder if you need more plots
448
- None, # 7) Another placeholder
449
- None # 8) Another placeholder
450
  )
451
 
452
- ###############################################################################
453
- # 6) Gradio Interface: We'll keep 8 outputs, but only use some in this demo
454
- ###############################################################################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
455
  def predict_with_text(
456
  YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
457
  YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
@@ -459,7 +589,7 @@ def predict_with_text(
459
  YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
460
  YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
461
  ):
462
- # Validate that all required inputs are selected
463
  if not validate_inputs(
464
  YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
465
  YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
@@ -468,15 +598,17 @@ def predict_with_text(
468
  YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
469
  ):
470
  return (
471
- "Please select all required fields.", # Prediction Results
472
- "Validation Error", # Severity
473
- "No data", # Total Patient Count
474
- None, # Distribution Plot
475
- "No data", # Nearest Neighbors
476
- None, None, None # Placeholders
 
 
477
  )
478
 
479
- # Map from user-friendly text to int
480
  user_inputs = {
481
  'YNURSMDE': input_mapping['YNURSMDE'][YNURSMDE],
482
  'YMDEYR': input_mapping['YMDEYR'][YMDEYR],
@@ -508,68 +640,11 @@ def predict_with_text(
508
  'YRXMDEYR': input_mapping['YRXMDEYR'][YRXMDEYR],
509
  'YMDELT': input_mapping['YMDELT'][YMDELT]
510
  }
511
-
512
- # Pass our mapped values into the original 'predict' function
513
- return predict(**user_inputs)
514
-
515
- ###############################################################################
516
- # 7) Define and Launch Gradio Interface
517
- ###############################################################################
518
- import sys
519
-
520
- # We have 8 outputs (some are placeholders)
521
- outputs = [
522
- gr.Textbox(label="Prediction Results", lines=30),
523
- gr.Textbox(label="Mental Health Severity", lines=4),
524
- gr.Markdown(label="Total Patient Count"),
525
- gr.Plot(label="Distribution of All Input Features vs. One Label"),
526
- gr.Markdown(label="Nearest Neighbors Summary"),
527
- gr.Plot(label="Placeholder Plot"),
528
- gr.Plot(label="Placeholder Plot"),
529
- gr.Plot(label="Placeholder Plot")
530
- ]
531
 
532
- # Define the inputs
533
- inputs = [
534
- # Major Depressive Episode (MDE) questions
535
- gr.Dropdown(list(input_mapping['YMDEYR'].keys()), label="YMDEYR: PAST YEAR MDE?"),
536
- gr.Dropdown(list(input_mapping['YMDERSUD5ANY'].keys()), label="YMDERSUD5ANY: MDE OR SUBSTANCE USE DISORDER - ANY"),
537
- gr.Dropdown(list(input_mapping['YMDEIMAD5YR'].keys()), label="YMDEIMAD5YR: MDE + ALCOHOL USE DISORDER?"),
538
- gr.Dropdown(list(input_mapping['YMIMS5YANY'].keys()), label="YMIMS5YANY: MDE + SUBSTANCE USE DISORDER?"),
539
- gr.Dropdown(list(input_mapping['YMDELT'].keys()), label="YMDELT: EVER HAD MDE LIFETIME?"),
540
- gr.Dropdown(list(input_mapping['YMDEHARX'].keys()), label="YMDEHARX: SAW HEALTH PROF + MEDS FOR MDE"),
541
- gr.Dropdown(list(input_mapping['YMDEHPRX'].keys()), label="YMDEHPRX: SAW HEALTH PROF OR MEDS FOR MDE"),
542
- gr.Dropdown(list(input_mapping['YMDETXRX'].keys()), label="YMDETXRX: TREATMENT/COUNSELING FOR MDE"),
543
- gr.Dropdown(list(input_mapping['YMDEHPO'].keys()), label="YMDEHPO: HEALTH PROF ONLY FOR MDE"),
544
- gr.Dropdown(list(input_mapping['YMDEAUD5YR'].keys()), label="YMDEAUD5YR: MDE + ALCOHOL USE DISORDER"),
545
- gr.Dropdown(list(input_mapping['YMIMI5YANY'].keys()), label="YMIMI5YANY: MDE + ILL DRUG USE DISORDER"),
546
- gr.Dropdown(list(input_mapping['YMIUD5YANY'].keys()), label="YMIUD5YANY: MDE + ILL DRUG USE DISORDER"),
547
- gr.Dropdown(list(input_mapping['YMDESUD5ANYO'].keys()), label="YMDESUD5ANYO: MDE vs. SUD vs. BOTH vs. NEITHER"),
548
-
549
- # Consultations
550
- gr.Dropdown(list(input_mapping['YNURSMDE'].keys()), label="YNURSMDE: NURSE / OT FOR MDE"),
551
- gr.Dropdown(list(input_mapping['YSOCMDE'].keys()), label="YSOCMDE: SOCIAL WORKER FOR MDE"),
552
- gr.Dropdown(list(input_mapping['YCOUNMDE'].keys()), label="YCOUNMDE: COUNSELOR FOR MDE"),
553
- gr.Dropdown(list(input_mapping['YPSY1MDE'].keys()), label="YPSY1MDE: PSYCHOLOGIST FOR MDE"),
554
- gr.Dropdown(list(input_mapping['YPSY2MDE'].keys()), label="YPSY2MDE: PSYCHIATRIST FOR MDE"),
555
- gr.Dropdown(list(input_mapping['YHLTMDE'].keys()), label="YHLTMDE: HEALTH PROF FOR MDE"),
556
- gr.Dropdown(list(input_mapping['YDOCMDE'].keys()), label="YDOCMDE: GP/FAMILY MD FOR MDE"),
557
- gr.Dropdown(list(input_mapping['YTXMDEYR'].keys()), label="YTXMDEYR: DOCTOR/HEALTH PROF FOR MDE THIS YEAR"),
558
-
559
- # Suicidal thoughts / plans
560
- gr.Dropdown(list(input_mapping['YUSUITHKYR'].keys()), label="YUSUITHKYR: SERIOUSLY THOUGHT ABOUT KILLING SELF"),
561
- gr.Dropdown(list(input_mapping['YUSUIPLNYR'].keys()), label="YUSUIPLNYR: MADE PLANS TO KILL SELF"),
562
- gr.Dropdown(list(input_mapping['YUSUITHK'].keys()), label="YUSUITHK: THINK ABOUT KILLING SELF (12 MONTHS)"),
563
- gr.Dropdown(list(input_mapping['YUSUIPLN'].keys()), label="YUSUIPLN: MADE PLANS TO KILL SELF (12 MONTHS)"),
564
-
565
- # Impairment
566
- gr.Dropdown(list(input_mapping['MDEIMPY'].keys()), label="MDEIMPY: MDE WITH SEVERE ROLE IMPAIRMENT?"),
567
- gr.Dropdown(list(input_mapping['LVLDIFMEM2'].keys()), label="LVLDIFMEM2: DIFFICULTY REMEMBERING/CONCENTRATING"),
568
- gr.Dropdown(list(input_mapping['YMSUD5YANY'].keys()), label="YMSUD5YANY: MDE + SUBSTANCE USE DISORDER?"),
569
- gr.Dropdown(list(input_mapping['YRXMDEYR'].keys()), label="YRXMDEYR: USED MEDS FOR MDE IN PAST YEAR?")
570
- ]
571
 
572
- # Custom CSS (optional)
573
  custom_css = """
574
  .gradio-container * {
575
  color: #1B1212 !important;
@@ -587,13 +662,15 @@ custom_css = """
587
  }
588
  """
589
 
590
- # Build the interface
 
 
591
  interface = gr.Interface(
592
- fn=predict_with_text,
593
- inputs=inputs,
594
- outputs=outputs,
595
- title="Adolescents with Substance Use Mental Health Screening (NSDUH Data)",
596
- css=custom_css,
597
  )
598
 
599
  if __name__ == "__main__":
 
8
  # Load the training CSV once (outside the functions so it is read only once).
9
  df = pd.read_csv("X_train_Y_Train_merged_train.csv")
10
 
11
+ ######################################
12
+ # 1) MODEL PREDICTOR CLASS
13
+ ######################################
14
  class ModelPredictor:
15
  def __init__(self, model_path, model_filenames):
16
  self.model_path = model_path
17
  self.model_filenames = model_filenames
18
  self.models = self.load_models()
19
+ # Mapping from label column to human-readable strings for 0/1
20
+ # (Adjust as needed for the columns you actually have.)
 
21
  self.prediction_map = {
22
  "YOWRCONC": ["Did not have difficulty concentrating", "Had difficulty concentrating"],
23
  "YOSEEDOC": ["Did not feel the need to see a doctor", "Felt the need to see a doctor"],
 
94
  else:
95
  return "Mental health severity: Very Low"
96
 
97
+ ######################################
98
+ # 2) MODEL & DATA
99
+ ######################################
100
  model_filenames = [
101
  "YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl",
102
  "YODPPROB.pkl", "YOWRPROB.pkl", "YODPR2WK.pkl", "YOWRDEPR.pkl",
 
106
  model_path = "models/"
107
  predictor = ModelPredictor(model_path, model_filenames)
108
 
109
+ ######################################
110
+ # 3) INPUT VALIDATION
111
+ ######################################
112
  def validate_inputs(*args):
113
  for arg in args:
114
  if arg == '' or arg is None: # Assuming empty string or None as unselected
115
  return False
116
  return True
117
 
118
+ ######################################
119
+ # 4) MAIN PREDICTION FUNCTION
120
+ ######################################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  def predict(
122
  YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
123
  YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
 
125
  YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
126
  YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
127
  ):
128
+ # Prepare user_input dataframe for prediction
 
 
 
 
 
 
 
 
 
 
129
  user_input_data = {
130
  'YNURSMDE': [int(YNURSMDE)],
131
  'YMDEYR': [int(YMDEYR)],
 
159
  }
160
  user_input = pd.DataFrame(user_input_data)
161
 
162
+ # 1) Make predictions with each model
163
  predictions = predictor.make_predictions(user_input)
164
 
165
+ # 2) Calculate majority vote (0 or 1) across all models
166
  majority_vote = predictor.get_majority_vote(predictions)
167
 
168
+ # 3) Count how many 1's in all predictions combined
169
  majority_vote_count = sum([1 for pred in np.concatenate(predictions) if pred == 1])
170
 
171
+ # 4) Evaluate severity
172
  severity = predictor.evaluate_severity(majority_vote_count)
173
 
174
+ # 5) Prepare detailed results (group them)
175
+ # We keep the old grouping as an example, but you can adapt as needed.
176
  results = {
177
  "Concentration_and_Decision_Making": [],
178
  "Sleep_and_Energy_Levels": [],
 
192
  "YOPB2WK"]
193
  }
194
 
195
+ # For textual results
196
  for i, pred in enumerate(predictions):
197
+ model_name = model_filenames[i].split('.')[0]
198
  pred_value = pred[0]
199
  # Map the prediction value to a human-readable string
200
  if model_name in predictor.prediction_map and pred_value in [0, 1]:
201
  result_text = f"Model {model_name}: {predictor.prediction_map[model_name][pred_value]}"
202
  else:
203
+ # Fallback
204
+ result_text = f"Model {model_name}: Prediction = {pred_value} (unmapped)"
205
 
206
+ # Append to the appropriate group if matched
207
  found_group = False
208
  for group_name, group_models in prediction_groups.items():
209
  if model_name in group_models:
 
211
  found_group = True
212
  break
213
  if not found_group:
214
+ # If it doesn't match any group, skip or handle differently
215
  pass
216
 
217
+ # Format the grouped results
218
  formatted_results = []
219
  for group, preds in results.items():
220
  if preds:
 
222
  formatted_results.append("\n".join(preds))
223
  formatted_results.append("\n")
224
  formatted_results = "\n".join(formatted_results).strip()
225
+ if not formatted_results:
226
  formatted_results = "No predictions made. Please check your inputs."
227
 
228
+ # If too many unknown predictions, add a note
229
+ num_unknown = len([p for group_preds in results.values() for p in group_preds if "(unmapped)" in p])
230
+ if num_unknown > len(model_filenames) / 2:
231
  severity += " (Unknown prediction count is high. Please consult with a human.)"
232
 
233
+ # =============== ADDITIONAL FEATURES ===============
234
+
235
  # A) Total Patient Count
 
236
  total_patients = len(df)
237
  total_patient_count_markdown = (
238
  "### Total Patient Count\n"
239
+ f"There are **{total_patients}** total patients in the dataset.\n"
240
+ "All subsequent analyses refer to these patients."
241
  )
242
 
243
+ # B) Bar Chart for input features (how many share same value as user_input)
244
+ input_counts = {}
245
+ for col in user_input_data.keys():
246
+ val = user_input_data[col][0]
247
+ same_val_count = len(df[df[col] == val])
248
+ input_counts[col] = same_val_count
249
+ bar_input_data = pd.DataFrame({
250
+ "Feature": list(input_counts.keys()),
251
+ "Count": list(input_counts.values())
252
+ })
253
+ fig_bar_input = px.bar(
254
+ bar_input_data,
255
+ x="Feature",
256
+ y="Count",
257
+ title="Number of Patients with the Same Value for Each Input Feature",
258
+ labels={"Feature": "Input Feature", "Count": "Number of Patients"}
259
+ )
260
+ fig_bar_input.update_layout(xaxis={'categoryorder':'total descending'})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
+ # C) Bar Chart for predicted labels (distribution in df)
263
+ label_counts = {}
264
+ for i, pred in enumerate(predictions):
265
+ model_name = model_filenames[i].split('.')[0]
266
+ pred_value = pred[0]
267
+ if pred_value in [0, 1]:
268
+ label_counts[model_name] = len(df[df[model_name] == pred_value])
269
+ if len(label_counts) > 0:
270
+ bar_label_data = pd.DataFrame({
271
+ "Model": list(label_counts.keys()),
272
+ "Count": list(label_counts.values())
273
+ })
274
+ fig_bar_labels = px.bar(
275
+ bar_label_data,
276
+ x="Model",
277
+ y="Count",
278
+ title="Number of Patients with the Same Predicted Label",
279
+ labels={"Model": "Predicted Column", "Count": "Patient Count"}
280
+ )
281
+ else:
282
+ # Fallback if no valid predictions
283
+ fig_bar_labels = px.bar(title="No valid predicted labels to display")
284
+
285
+ # D) Distribution Plot: All Input Features vs. All Predicted Labels
286
+ # This can create MANY subplots if you have many features & labels.
287
+ # We'll do a small demonstration with a subset of input features & model columns
288
+ # to avoid overwhelming the UI.
289
+ demonstration_features = list(user_input_data.keys())[:4] # first 4 features as a sample
290
+ demonstration_labels = [fn.split('.')[0] for fn in model_filenames[:3]] # first 3 labels as a sample
291
+
292
+ # We'll build a single figure with "facet_col" = label and "facet_row" = feature (small sample)
293
+ # The approach: for each (feature, label), group by (feature_value, label_value) -> count.
294
+ # Then we combine them into one big DataFrame with "feature" & "label" columns for Plotly facets.
295
+ dist_rows = []
296
+ for feat in demonstration_features:
297
+ if feat not in df.columns:
298
+ continue
299
+ for lbl in demonstration_labels:
300
+ if lbl not in df.columns:
301
+ continue
302
+ tmp_df = df.groupby([feat, lbl]).size().reset_index(name="count")
303
+ tmp_df["feature"] = feat
304
+ tmp_df["label"] = lbl
305
+ dist_rows.append(tmp_df)
306
+ if len(dist_rows) > 0:
307
+ big_dist_df = pd.concat(dist_rows, ignore_index=True)
308
+ # We can re-map numeric to user-friendly text for "feat" if desired, but each feature might have a different mapping.
309
+ # For now, we just show numeric codes. Real usage would do a reverse mapping if feasible.
310
+
311
+ # For the label (0,1), we can map to short strings if we want (like "Label0" / "Label1"), or a direct numeric.
312
+ fig_dist = px.bar(
313
+ big_dist_df,
314
+ x=big_dist_df.columns[0], # the feature's value is the 0-th col in groupby
315
  y="count",
316
+ color=big_dist_df.columns[1], # the label's value is the 1st col in groupby
317
+ facet_row="feature",
318
+ facet_col="label",
319
+ title="Distribution of Sample Input Features vs. Sample Predicted Labels (Demo)",
320
+ labels={
321
+ big_dist_df.columns[0]: "Feature Value",
322
+ big_dist_df.columns[1]: "Label Value"
323
+ }
324
  )
325
+ fig_dist.update_layout(height=800)
 
 
326
  else:
327
+ fig_dist = px.bar(title="No distribution plot could be generated (check feature/label columns).")
328
+
329
+ # E) Nearest Neighbors: Hamming Distance, K=5, with disclaimers & user-friendly text
330
+ # "Nearest neighbor” methods for high-dimensional or purely categorical data can be non-trivial.
331
+ # This demo simply uses a Hamming distance over all input features and picks K=5 neighbors.
332
+ # In a real application, you would refine which features are most relevant, how to encode them,
333
+ # and how many neighbors to select.
334
+ # We also show how to revert numeric codes -> user-friendly text.
335
+
336
+ # 1. Invert the user-friendly text mapping (for inputs).
337
+ # We'll assume input_mapping is consistent. We build a reverse mapping for each column.
338
+ reverse_input_mapping = {}
339
+ # We'll build it after the code block below for each column.
340
+
341
+ # 2. Invert label mappings from predictor.prediction_map if needed
342
+ # For each label column, 0 => first string, 1 => second string
343
+ # We'll store them in a dict: reverse_label_mapping[label_col][0 or 1] => string
344
+ reverse_label_mapping = {}
345
+ for lbl, str_list in predictor.prediction_map.items():
346
+ # str_list[0] => for 0, str_list[1] => for 1
347
+ reverse_label_mapping[lbl] = {
348
+ 0: str_list[0],
349
+ 1: str_list[1]
350
+ }
351
+
352
+ # Build the reverse input mapping from the provided dictionary
353
+ # We'll define that dictionary below to ensure we can invert it:
354
+ input_mapping = {
355
+ 'YNURSMDE': {"Yes": 1, "No": 0},
356
+ 'YMDEYR': {"Yes": 1, "No": 2},
357
+ 'YSOCMDE': {"Yes": 1, "No": 0},
358
+ 'YMDESUD5ANYO': {"SUD only, no MDE": 1, "MDE only, no SUD": 2, "SUD and MDE": 3, "Neither SUD or MDE": 4},
359
+ 'YMSUD5YANY': {"Yes": 1, "No": 0},
360
+ 'YUSUITHK': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
361
+ 'YMDETXRX': {"Yes": 1, "No": 0},
362
+ 'YUSUITHKYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
363
+ 'YMDERSUD5ANY': {"Yes": 1, "No": 0},
364
+ 'YUSUIPLNYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
365
+ 'YCOUNMDE': {"Yes": 1, "No": 0},
366
+ 'YPSY1MDE': {"Yes": 1, "No": 0},
367
+ 'YHLTMDE': {"Yes": 1, "No": 0},
368
+ 'YDOCMDE': {"Yes": 1, "No": 0},
369
+ 'YPSY2MDE': {"Yes": 1, "No": 0},
370
+ 'YMDEHARX': {"Yes": 1, "No": 0},
371
+ 'LVLDIFMEM2': {"No Difficulty": 1, "Some difficulty": 2, "A lot of difficulty or cannot do at all": 3},
372
+ 'MDEIMPY': {"Yes": 1, "No": 2},
373
+ 'YMDEHPO': {"Yes": 1, "No": 0},
374
+ 'YMIMS5YANY': {"Yes": 1, "No": 0},
375
+ 'YMDEIMAD5YR': {"Yes": 1, "No": 0},
376
+ 'YMIUD5YANY': {"Yes": 1, "No": 0},
377
+ 'YMDEHPRX': {"Yes": 1, "No": 0},
378
+ 'YMIMI5YANY': {"Yes": 1, "No": 0},
379
+ 'YUSUIPLN': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
380
+ 'YTXMDEYR': {"Yes": 1, "No": 0},
381
+ 'YMDEAUD5YR': {"Yes": 1, "No": 0},
382
+ 'YRXMDEYR': {"Yes": 1, "No": 0},
383
+ 'YMDELT': {"Yes": 1, "No": 2}
384
+ }
385
+
386
+ # Build the reverse mapping for each column
387
+ for col, fwd_map in input_mapping.items():
388
+ reverse_input_mapping[col] = {v: k for k, v in fwd_map.items()}
389
+
390
+ # 3. Calculate Hamming distance for each row
391
+ # We'll consider the columns in user_input for comparison
392
+ features_to_compare = list(user_input.columns)
393
+ subset_df = df[features_to_compare].copy()
394
  user_series = user_input.iloc[0]
395
 
 
396
  distances = []
397
+ for idx, row in subset_df.iterrows():
398
+ dist = sum(row[col] != user_series[col] for col in features_to_compare)
399
+ distances.append(dist)
 
 
 
400
 
401
  df_with_dist = df.copy()
402
  df_with_dist["distance"] = distances
403
 
404
+ # 4. Sort by distance ascending, pick top K=5
405
  K = 5
406
  nearest_neighbors = df_with_dist.sort_values("distance", ascending=True).head(K)
407
 
408
+ # 5. Summarize neighbor info in user-friendly text
409
+ # For demonstration, let's show a small table with each neighbor's values
410
+ # for the same features. We'll also show a label or two.
411
+ # We'll do this in Markdown format.
412
+ nn_rows = []
413
+ for idx, nr in nearest_neighbors.iterrows():
414
+ # Convert each feature to text if possible
415
+ row_text = []
416
+ for col in features_to_compare:
417
+ val_numeric = nr[col]
418
+ if col in reverse_input_mapping:
419
+ row_text.append(f"{col}={reverse_input_mapping[col].get(val_numeric, val_numeric)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  else:
421
+ row_text.append(f"{col}={val_numeric}")
422
+ # Let's also show YOWRCONC as an example label (if present)
423
+ if "YOWRCONC" in nearest_neighbors.columns:
424
+ label_val = nr["YOWRCONC"]
425
+ if "YOWRCONC" in reverse_label_mapping:
426
+ label_str = reverse_label_mapping["YOWRCONC"].get(label_val, label_val)
427
+ row_text.append(f"YOWRCONC={label_str}")
428
  else:
429
+ row_text.append(f"YOWRCONC={label_val}")
 
 
430
 
431
+ nn_rows.append(f"- **Neighbor ID {idx}** (distance={nr['distance']}): " + ", ".join(row_text))
432
 
433
  similar_patient_markdown = (
434
  "### Nearest Neighbors (Simple Hamming Distance)\n"
435
+ f"We searched for the top **{K}** patients whose features most closely match your input.\n\n"
436
+ "> **Note**: “Nearest neighbor” methods for high-dimensional or purely categorical data can be non-trivial. "
437
+ "This demo simply uses a Hamming distance over all input features and picks K=5 neighbors. "
438
  "In a real application, you would refine which features are most relevant, how to encode them, "
439
  "and how many neighbors to select.\n\n"
440
+ "Below is a brief overview of each neighbor's input-feature values and one example label (`YOWRCONC`).\n\n"
441
+ + "\n".join(nn_rows)
 
 
 
442
  )
443
 
444
+ # F) Co-occurrence Plot from the previous example (kept for completeness)
445
+ if all(col in df.columns for col in ["YMDEYR", "YMDERSUD5ANY", "YOWRCONC"]):
446
+ co_occ_data = df.groupby(["YMDEYR", "YMDERSUD5ANY", "YOWRCONC"]).size().reset_index(name="count")
447
+ fig_co_occ = px.bar(
448
+ co_occ_data,
449
+ x="YMDEYR",
450
+ y="count",
451
+ color="YOWRCONC",
452
+ facet_col="YMDERSUD5ANY",
453
+ title="Co-Occurrence Plot: YMDEYR and YMDERSUD5ANY vs YOWRCONC"
454
+ )
455
+ else:
456
+ fig_co_occ = px.bar(title="Co-occurrence plot not available (check columns).")
457
+
458
+ # =======================
459
+ # RETURN EVERYTHING
460
+ # We have 8 outputs:
461
+ # 1) Prediction Results (Textbox)
462
+ # 2) Mental Health Severity (Textbox)
463
+ # 3) Total Patient Count (Markdown)
464
+ # 4) Distribution Plot (for multiple input features vs. multiple labels)
465
+ # 5) Nearest Neighbors Summary (Markdown)
466
+ # 6) Co-Occurrence Plot
467
+ # 7) Bar Chart for input features
468
+ # 8) Bar Chart for predicted labels
469
+ # =======================
470
  return (
471
+ formatted_results,
472
+ severity,
473
+ total_patient_count_markdown,
474
+ fig_dist,
475
+ similar_patient_markdown,
476
+ fig_co_occ,
477
+ fig_bar_input,
478
+ fig_bar_labels
479
  )
480
 
481
+ ######################################
482
+ # 5) MAPPING user-friendly text => numeric
483
+ ######################################
484
+ input_mapping = {
485
+ 'YNURSMDE': {"Yes": 1, "No": 0},
486
+ 'YMDEYR': {"Yes": 1, "No": 2},
487
+ 'YSOCMDE': {"Yes": 1, "No": 0},
488
+ 'YMDESUD5ANYO': {"SUD only, no MDE": 1, "MDE only, no SUD": 2, "SUD and MDE": 3, "Neither SUD or MDE": 4},
489
+ 'YMSUD5YANY': {"Yes": 1, "No": 0},
490
+ 'YUSUITHK': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
491
+ 'YMDETXRX': {"Yes": 1, "No": 0},
492
+ 'YUSUITHKYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
493
+ 'YMDERSUD5ANY': {"Yes": 1, "No": 0},
494
+ 'YUSUIPLNYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
495
+ 'YCOUNMDE': {"Yes": 1, "No": 0},
496
+ 'YPSY1MDE': {"Yes": 1, "No": 0},
497
+ 'YHLTMDE': {"Yes": 1, "No": 0},
498
+ 'YDOCMDE': {"Yes": 1, "No": 0},
499
+ 'YPSY2MDE': {"Yes": 1, "No": 0},
500
+ 'YMDEHARX': {"Yes": 1, "No": 0},
501
+ 'LVLDIFMEM2': {"No Difficulty": 1, "Some difficulty": 2, "A lot of difficulty or cannot do at all": 3},
502
+ 'MDEIMPY': {"Yes": 1, "No": 2},
503
+ 'YMDEHPO': {"Yes": 1, "No": 0},
504
+ 'YMIMS5YANY': {"Yes": 1, "No": 0},
505
+ 'YMDEIMAD5YR': {"Yes": 1, "No": 0},
506
+ 'YMIUD5YANY': {"Yes": 1, "No": 0},
507
+ 'YMDEHPRX': {"Yes": 1, "No": 0},
508
+ 'YMIMI5YANY': {"Yes": 1, "No": 0},
509
+ 'YUSUIPLN': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
510
+ 'YTXMDEYR': {"Yes": 1, "No": 0},
511
+ 'YMDEAUD5YR': {"Yes": 1, "No": 0},
512
+ 'YRXMDEYR': {"Yes": 1, "No": 0},
513
+ 'YMDELT': {"Yes": 1, "No": 2}
514
+ }
515
+
516
+ ######################################
517
+ # 6) GRADIO INTERFACE
518
+ ######################################
519
+ # We have 8 outputs in total:
520
+ # 1) Prediction Results
521
+ # 2) Mental Health Severity
522
+ # 3) Total Patient Count
523
+ # 4) Distribution Plot
524
+ # 5) Nearest Neighbors
525
+ # 6) Co-Occurrence Plot
526
+ # 7) Bar Chart for input features
527
+ # 8) Bar Chart for predicted labels
528
+
529
+ import gradio as gr
530
+
531
+ # Define the inputs in the same order as function signature
532
+ inputs = [
533
+ gr.Dropdown(list(input_mapping['YMDEYR'].keys()), label="YMDEYR: PAST YEARS MAJOR DEPRESSIVE EPISODE"),
534
+ gr.Dropdown(list(input_mapping['YMDERSUD5ANY'].keys()), label="YMDERSUD5ANY: MDE OR SUBSTANCE USE DISORDER - ANY"),
535
+ gr.Dropdown(list(input_mapping['YMDEIMAD5YR'].keys()), label="YMDEIMAD5YR: MDE WITH SEV. IMP + ALCOHOL USE DISORDER"),
536
+ gr.Dropdown(list(input_mapping['YMIMS5YANY'].keys()), label="YMIMS5YANY: MDE W/ SEV. IMP + SUBSTANCE USE DISORDER"),
537
+ gr.Dropdown(list(input_mapping['YMDELT'].keys()), label="YMDELT: HAD MAJOR DEPRESSIVE EPISODE IN LIFETIME"),
538
+ gr.Dropdown(list(input_mapping['YMDEHARX'].keys()), label="YMDEHARX: SAW HEALTH PROF + MEDS FOR MDE"),
539
+ gr.Dropdown(list(input_mapping['YMDEHPRX'].keys()), label="YMDEHPRX: SAW HEALTH PROF OR MEDS FOR MDE"),
540
+ gr.Dropdown(list(input_mapping['YMDETXRX'].keys()), label="YMDETXRX: RECEIVED TREATMENT/COUNSELING FOR MDE"),
541
+ gr.Dropdown(list(input_mapping['YMDEHPO'].keys()), label="YMDEHPO: SAW HEALTH PROF ONLY FOR MDE"),
542
+ gr.Dropdown(list(input_mapping['YMDEAUD5YR'].keys()), label="YMDEAUD5YR: MDE + ALCOHOL USE DISORDER"),
543
+ gr.Dropdown(list(input_mapping['YMIMI5YANY'].keys()), label="YMIMI5YANY: MDE W/ ILL DRUG USE DISORDER"),
544
+ gr.Dropdown(list(input_mapping['YMIUD5YANY'].keys()), label="YMIUD5YANY: MDE + ILL DRUG USE DISORDER"),
545
+ gr.Dropdown(list(input_mapping['YMDESUD5ANYO'].keys()), label="YMDESUD5ANYO: MDE vs. SUD vs. BOTH vs. NEITHER"),
546
+
547
+ # Consultations
548
+ gr.Dropdown(list(input_mapping['YNURSMDE'].keys()), label="YNURSMDE: SAW/TALK TO NURSE/OT ABOUT MDE"),
549
+ gr.Dropdown(list(input_mapping['YSOCMDE'].keys()), label="YSOCMDE: SAW/TALK TO SOCIAL WORKER ABOUT MDE"),
550
+ gr.Dropdown(list(input_mapping['YCOUNMDE'].keys()), label="YCOUNMDE: SAW/TALK TO COUNSELOR ABOUT MDE"),
551
+ gr.Dropdown(list(input_mapping['YPSY1MDE'].keys()), label="YPSY1MDE: SAW/TALK TO PSYCHOLOGIST ABOUT MDE"),
552
+ gr.Dropdown(list(input_mapping['YPSY2MDE'].keys()), label="YPSY2MDE: SAW/TALK TO PSYCHIATRIST ABOUT MDE"),
553
+ gr.Dropdown(list(input_mapping['YHLTMDE'].keys()), label="YHLTMDE: SAW/TALK TO HEALTH PROFESSIONAL ABOUT MDE"),
554
+ gr.Dropdown(list(input_mapping['YDOCMDE'].keys()), label="YDOCMDE: SAW/TALK TO GP/FAMILY MD ABOUT MDE"),
555
+ gr.Dropdown(list(input_mapping['YTXMDEYR'].keys()), label="YTXMDEYR: SAW/TALK DOCTOR/HEALTH PROF FOR MDE"),
556
+
557
+ # Suicidal thoughts/plans
558
+ gr.Dropdown(list(input_mapping['YUSUITHKYR'].keys()), label="YUSUITHKYR: SERIOUSLY THOUGHT ABOUT KILLING SELF"),
559
+ gr.Dropdown(list(input_mapping['YUSUIPLNYR'].keys()), label="YUSUIPLNYR: MADE PLANS TO KILL SELF"),
560
+ gr.Dropdown(list(input_mapping['YUSUITHK'].keys()), label="YUSUITHK: THINK ABOUT KILLING SELF (12 MONTHS)"),
561
+ gr.Dropdown(list(input_mapping['YUSUIPLN'].keys()), label="YUSUIPLN: MADE PLANS TO KILL SELF (12 MONTHS)"),
562
+
563
+ # Impairments
564
+ gr.Dropdown(list(input_mapping['MDEIMPY'].keys()), label="MDEIMPY: MDE W/ SEVERE ROLE IMPAIRMENT"),
565
+ gr.Dropdown(list(input_mapping['LVLDIFMEM2'].keys()), label="LVLDIFMEM2: LEVEL OF DIFFICULTY REMEMBERING/CONCENTRATING"),
566
+ gr.Dropdown(list(input_mapping['YMSUD5YANY'].keys()), label="YMSUD5YANY: MDE + SUBSTANCE USE DISORDER - ANY"),
567
+ gr.Dropdown(list(input_mapping['YRXMDEYR'].keys()), label="YRXMDEYR: USED MEDS FOR MDE IN PAST YEAR"),
568
+ ]
569
+
570
+ # The 8 outputs
571
+ outputs = [
572
+ gr.Textbox(label="Prediction Results", lines=30),
573
+ gr.Textbox(label="Mental Health Severity", lines=4),
574
+ gr.Markdown(label="Total Patient Count"),
575
+ gr.Plot(label="Distribution Plot (Sample of Features & Labels)"),
576
+ gr.Markdown(label="Nearest Neighbors Summary"),
577
+ gr.Plot(label="Co-Occurrence Plot"),
578
+ gr.Plot(label="Number of Patients per Input Feature"),
579
+ gr.Plot(label="Number of Patients with Predicted Labels")
580
+ ]
581
+
582
+ ######################################
583
+ # 7) WRAPPER FOR PREDICT
584
+ ######################################
585
  def predict_with_text(
586
  YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
587
  YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
 
589
  YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
590
  YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
591
  ):
592
+ # Validate user inputs
593
  if not validate_inputs(
594
  YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
595
  YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
 
598
  YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
599
  ):
600
  return (
601
+ "Please select all required fields.",
602
+ "Validation Error",
603
+ "No data",
604
+ None,
605
+ "No data",
606
+ None,
607
+ None,
608
+ None
609
  )
610
 
611
+ # Map user-friendly text to numeric
612
  user_inputs = {
613
  'YNURSMDE': input_mapping['YNURSMDE'][YNURSMDE],
614
  'YMDEYR': input_mapping['YMDEYR'][YMDEYR],
 
640
  'YRXMDEYR': input_mapping['YRXMDEYR'][YRXMDEYR],
641
  'YMDELT': input_mapping['YMDELT'][YMDELT]
642
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
643
 
644
+ # Pass these mapped values into the core predict function
645
+ return predict(**user_inputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
646
 
647
+ # Optional custom CSS
648
  custom_css = """
649
  .gradio-container * {
650
  color: #1B1212 !important;
 
662
  }
663
  """
664
 
665
+ ######################################
666
+ # 8) LAUNCH
667
+ ######################################
668
  interface = gr.Interface(
669
+ fn=predict_with_text,
670
+ inputs=inputs,
671
+ outputs=outputs,
672
+ title="Adolescents with Substance Use Mental Health Screening (NSDUH Data)",
673
+ css=custom_css
674
  )
675
 
676
  if __name__ == "__main__":