pantdipendra commited on
Commit
1fd21ae
·
verified ·
1 Parent(s): cf4c3a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +288 -418
app.py CHANGED
@@ -1,11 +1,10 @@
1
  import pickle
2
-
3
  import gradio as gr
4
  import numpy as np
5
  import pandas as pd
6
  import plotly.express as px
7
 
8
- # Load the training CSV once (outside the functions so it is read only once).
9
  df = pd.read_csv("X_train_Y_Train_merged_train.csv")
10
 
11
  ######################################
@@ -17,74 +16,52 @@ class ModelPredictor:
17
  self.model_filenames = model_filenames
18
  self.models = self.load_models()
19
  # Mapping from label column to human-readable strings for 0/1
20
- # (Adjust as needed for the columns you actually have.)
21
  self.prediction_map = {
22
- "YOWRCONC": ["Did not have difficulty concentrating", "Had difficulty concentrating"],
23
- "YOSEEDOC": ["Did not feel the need to see a doctor", "Felt the need to see a doctor"],
24
- "YOWRHRS": ["Did not have trouble sleeping", "Had trouble sleeping"],
25
- "YO_MDEA5": ["Others did not notice restlessness/lethargy", "Others noticed restlessness/lethargy"],
26
- "YOWRCHR": ["Did not feel so sad that nothing could cheer up", "Felt so sad that nothing could cheer up"],
27
- "YOWRLSIN": ["Did not feel bored and lose interest in all enjoyable things",
28
- "Felt bored and lost interest in all enjoyable things"],
29
- "YODPPROB": ["Did not have other problems for 2+ weeks", "Had other problems for 2+ weeks"],
30
- "YOWRPROB": ["Did not have the worst time ever feeling", "Had the worst time ever feeling"],
31
- "YODPR2WK": ["Did not have periods where feelings lasted 2+ weeks",
32
- "Had periods where feelings lasted 2+ weeks"],
33
- "YOWRDEPR": ["Did not feel sad/depressed mostly everyday", "Felt sad/depressed mostly everyday"],
34
- "YODPDISC": ["Overall mood duration was not sad/depressed",
35
- "Overall mood duration was sad/depressed (discrepancy)"],
36
- "YOLOSEV": ["Did not lose interest in enjoyable things and activities",
37
- "Lost interest in enjoyable things and activities"],
38
- "YOWRDCSN": ["Was able to make decisions", "Was unable to make decisions"],
39
- "YODSMMDE": ["Never had depression symptoms lasting 2 weeks or longer",
40
- "Had depression symptoms lasting 2 weeks or longer"],
41
- "YO_MDEA3": ["Did not experience changes in appetite or weight",
42
- "Experienced changes in appetite or weight"],
43
- "YODPLSIN": ["Never lost interest and felt bored", "Lost interest and felt bored"],
44
- "YOWRELES": ["Did not eat less than usual", "Ate less than usual"],
45
- "YODSCEV": ["Had fewer severe symptoms of depression", "Had more severe symptoms of depression"],
46
- "YOPB2WK": ["Did not experience uneasy feelings lasting every day for 2+ weeks or longer",
47
- "Experienced uneasy feelings lasting every day for 2+ weeks or longer"],
48
- "YO_MDEA2": ["Did not have issues with physical and mental well-being every day for 2 weeks or longer",
49
- "Had issues with physical and mental well-being every day for 2 weeks or longer"]
50
  }
51
 
52
  def load_models(self):
53
  models = []
54
- for filename in self.model_filenames:
55
- filepath = self.model_path + filename
56
- with open(filepath, 'rb') as file:
57
- model = pickle.load(file)
58
- models.append(model)
59
  return models
60
 
61
  def make_predictions(self, user_input):
62
- """
63
- Returns a list of numpy arrays, each array is [0] or [1].
64
- The i-th array corresponds to the i-th model in self.models.
65
- """
66
- predictions = []
67
- for model in self.models:
68
- pred = model.predict(user_input)
69
- pred = np.array(pred).flatten()
70
- predictions.append(pred)
71
- return predictions
72
 
73
  def get_majority_vote(self, predictions):
74
- """
75
- Flatten all predictions from all models, combine them into a single array,
76
- then find the majority class (0 or 1) across all of them.
77
- """
78
- combined_predictions = np.concatenate(predictions)
79
- majority_vote = np.bincount(combined_predictions).argmax()
80
- return majority_vote
81
-
82
- # Based on Equal Interval and Percentage-Based Method
83
- # Severe: 13 to 16 votes (upper 25%)
84
- # Moderate: 9 to 12 votes (upper-middle 25%)
85
- # Low: 5 to 8 votes (lower-middle 25%)
86
- # Very Low: 0 to 4 votes (lower 25%)
87
  def evaluate_severity(self, majority_vote_count):
 
88
  if majority_vote_count >= 13:
89
  return "Mental health severity: Severe"
90
  elif majority_vote_count >= 9:
@@ -95,7 +72,7 @@ class ModelPredictor:
95
  return "Mental health severity: Very Low"
96
 
97
  ######################################
98
- # 2) MODEL & DATA
99
  ######################################
100
  model_filenames = [
101
  "YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl",
@@ -110,22 +87,36 @@ predictor = ModelPredictor(model_path, model_filenames)
110
  # 3) INPUT VALIDATION
111
  ######################################
112
  def validate_inputs(*args):
 
113
  for arg in args:
114
- if arg == '' or arg is None: # Assuming empty string or None as unselected
115
  return False
116
  return True
117
 
118
  ######################################
119
- # 4) MAIN PREDICTION FUNCTION
120
  ######################################
121
  def predict(
 
122
  YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
123
  YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
124
  YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE,
125
  YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
126
- YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
 
 
127
  ):
128
- # Prepare user_input dataframe for prediction
 
 
 
 
 
 
 
 
 
 
129
  user_input_data = {
130
  'YNURSMDE': [int(YNURSMDE)],
131
  'YMDEYR': [int(YMDEYR)],
@@ -159,29 +150,21 @@ def predict(
159
  }
160
  user_input = pd.DataFrame(user_input_data)
161
 
162
- # 1) Make predictions with each model
163
  predictions = predictor.make_predictions(user_input)
164
-
165
- # 2) Calculate majority vote (0 or 1) across all models
166
  majority_vote = predictor.get_majority_vote(predictions)
167
-
168
- # 3) Count how many 1's in all predictions combined
169
- majority_vote_count = sum([1 for pred in np.concatenate(predictions) if pred == 1])
170
-
171
- # 4) Evaluate severity
172
  severity = predictor.evaluate_severity(majority_vote_count)
173
 
174
- # 5) Prepare detailed results (group them)
175
- # We keep the old grouping as an example, but you can adapt as needed.
176
- results = {
177
  "Concentration_and_Decision_Making": [],
178
  "Sleep_and_Energy_Levels": [],
179
  "Mood_and_Emotional_State": [],
180
  "Appetite_and_Weight_Changes": [],
181
  "Duration_and_Severity_of_Depression_Symptoms": []
182
  }
183
-
184
- prediction_groups = {
185
  "Concentration_and_Decision_Making": ["YOWRCONC", "YOWRDCSN"],
186
  "Sleep_and_Energy_Levels": ["YOWRHRS", "YO_MDEA5", "YOWRELES", "YO_MDEA2"],
187
  "Mood_and_Emotional_State": ["YOWRCHR", "YOWRLSIN", "YOWRDEPR", "YODPDISC",
@@ -192,313 +175,198 @@ def predict(
192
  "YOPB2WK"]
193
  }
194
 
195
- # For textual results
196
- for i, pred in enumerate(predictions):
197
- model_name = model_filenames[i].split('.')[0]
198
- pred_value = pred[0]
199
- # Map the prediction value to a human-readable string
200
- if model_name in predictor.prediction_map and pred_value in [0, 1]:
201
- result_text = f"Model {model_name}: {predictor.prediction_map[model_name][pred_value]}"
 
202
  else:
203
- # Fallback
204
- result_text = f"Model {model_name}: Prediction = {pred_value} (unmapped)"
205
-
206
- # Append to the appropriate group if matched
207
- found_group = False
208
- for group_name, group_models in prediction_groups.items():
209
- if model_name in group_models:
210
- results[group_name].append(result_text)
211
- found_group = True
212
  break
213
- if not found_group:
214
- # If it doesn't match any group, skip or handle differently
215
  pass
216
 
217
- # Format the grouped results
218
- formatted_results = []
219
- for group, preds in results.items():
220
- if preds:
221
- formatted_results.append(f"Group {group.replace('_', ' ')}:")
222
- formatted_results.append("\n".join(preds))
223
- formatted_results.append("\n")
224
- formatted_results = "\n".join(formatted_results).strip()
225
- if not formatted_results:
226
- formatted_results = "No predictions made. Please check your inputs."
227
-
228
- # If too many unknown predictions, add a note
229
- num_unknown = len([p for group_preds in results.values() for p in group_preds if "(unmapped)" in p])
230
- if num_unknown > len(model_filenames) / 2:
231
- severity += " (Unknown prediction count is high. Please consult with a human.)"
232
-
233
- # =============== ADDITIONAL FEATURES ===============
234
-
235
- # A) Total Patient Count
236
  total_patients = len(df)
237
- total_patient_count_markdown = (
238
  "### Total Patient Count\n"
239
- f"There are **{total_patients}** total patients in the dataset.\n"
240
- "All subsequent analyses refer to these patients."
241
  )
242
 
243
- # B) Bar Chart for input features (how many share same value as user_input)
244
  input_counts = {}
245
- for col in user_input_data.keys():
246
- val = user_input_data[col][0]
247
- same_val_count = len(df[df[col] == val])
248
- input_counts[col] = same_val_count
249
- bar_input_data = pd.DataFrame({
250
- "Feature": list(input_counts.keys()),
251
- "Count": list(input_counts.values())
252
- })
253
- fig_bar_input = px.bar(
254
- bar_input_data,
255
- x="Feature",
256
- y="Count",
257
- title="Number of Patients with the Same Value for Each Input Feature",
258
- labels={"Feature": "Input Feature", "Count": "Number of Patients"}
259
  )
260
- fig_bar_input.update_layout(xaxis={'categoryorder':'total descending'})
261
 
262
- # C) Bar Chart for predicted labels (distribution in df)
263
  label_counts = {}
264
- for i, pred in enumerate(predictions):
265
- model_name = model_filenames[i].split('.')[0]
266
- pred_value = pred[0]
267
- if pred_value in [0, 1]:
268
- label_counts[model_name] = len(df[df[model_name] == pred_value])
 
269
  if len(label_counts) > 0:
270
- bar_label_data = pd.DataFrame({
271
- "Model": list(label_counts.keys()),
272
  "Count": list(label_counts.values())
273
  })
274
- fig_bar_labels = px.bar(
275
- bar_label_data,
276
- x="Model",
277
  y="Count",
278
- title="Number of Patients with the Same Predicted Label",
279
- labels={"Model": "Predicted Column", "Count": "Patient Count"}
280
  )
281
  else:
282
- # Fallback if no valid predictions
283
- fig_bar_labels = px.bar(title="No valid predicted labels to display")
284
-
285
- # D) Distribution Plot: All Input Features vs. All Predicted Labels
286
- # This can create MANY subplots if you have many features & labels.
287
- # We'll do a small demonstration with a subset of input features & model columns
288
- # to avoid overwhelming the UI.
289
- demonstration_features = list(user_input_data.keys())[:4] # first 4 features as a sample
290
- demonstration_labels = [fn.split('.')[0] for fn in model_filenames[:3]] # first 3 labels as a sample
291
-
292
- # We'll build a single figure with "facet_col" = label and "facet_row" = feature (small sample)
293
- # The approach: for each (feature, label), group by (feature_value, label_value) -> count.
294
- # Then we combine them into one big DataFrame with "feature" & "label" columns for Plotly facets.
295
- dist_rows = []
296
- for feat in demonstration_features:
297
  if feat not in df.columns:
298
  continue
299
- for lbl in demonstration_labels:
300
  if lbl not in df.columns:
301
  continue
302
- tmp_df = df.groupby([feat, lbl]).size().reset_index(name="count")
303
- tmp_df["feature"] = feat
304
- tmp_df["label"] = lbl
305
- dist_rows.append(tmp_df)
306
- if len(dist_rows) > 0:
307
- big_dist_df = pd.concat(dist_rows, ignore_index=True)
308
- # We can re-map numeric to user-friendly text for "feat" if desired, but each feature might have a different mapping.
309
- # For now, we just show numeric codes. Real usage would do a reverse mapping if feasible.
310
-
311
- # For the label (0,1), we can map to short strings if we want (like "Label0" / "Label1"), or a direct numeric.
312
  fig_dist = px.bar(
313
  big_dist_df,
314
- x=big_dist_df.columns[0], # the feature's value is the 0-th col in groupby
315
  y="count",
316
- color=big_dist_df.columns[1], # the label's value is the 1st col in groupby
317
  facet_row="feature",
318
  facet_col="label",
319
- title="Distribution of Sample Input Features vs. Sample Predicted Labels (Demo)",
320
- labels={
321
- big_dist_df.columns[0]: "Feature Value",
322
- big_dist_df.columns[1]: "Label Value"
323
- }
324
  )
325
- fig_dist.update_layout(height=800)
326
  else:
327
- fig_dist = px.bar(title="No distribution plot could be generated (check feature/label columns).")
328
-
329
- # E) Nearest Neighbors: Hamming Distance, K=5, with disclaimers & user-friendly text
330
- # "Nearest neighbor” methods for high-dimensional or purely categorical data can be non-trivial.
331
- # This demo simply uses a Hamming distance over all input features and picks K=5 neighbors.
332
- # In a real application, you would refine which features are most relevant, how to encode them,
333
- # and how many neighbors to select.
334
- # We also show how to revert numeric codes -> user-friendly text.
335
-
336
- # 1. Invert the user-friendly text mapping (for inputs).
337
- # We'll assume input_mapping is consistent. We build a reverse mapping for each column.
338
- reverse_input_mapping = {}
339
- # We'll build it after the code block below for each column.
340
-
341
- # 2. Invert label mappings from predictor.prediction_map if needed
342
- # For each label column, 0 => first string, 1 => second string
343
- # We'll store them in a dict: reverse_label_mapping[label_col][0 or 1] => string
344
- reverse_label_mapping = {}
345
- for lbl, str_list in predictor.prediction_map.items():
346
- # str_list[0] => for 0, str_list[1] => for 1
347
- reverse_label_mapping[lbl] = {
348
- 0: str_list[0],
349
- 1: str_list[1]
350
- }
351
-
352
- # Build the reverse input mapping from the provided dictionary
353
- # We'll define that dictionary below to ensure we can invert it:
354
- input_mapping = {
355
- 'YNURSMDE': {"Yes": 1, "No": 0},
356
- 'YMDEYR': {"Yes": 1, "No": 2},
357
- 'YSOCMDE': {"Yes": 1, "No": 0},
358
- 'YMDESUD5ANYO': {"SUD only, no MDE": 1, "MDE only, no SUD": 2, "SUD and MDE": 3, "Neither SUD or MDE": 4},
359
- 'YMSUD5YANY': {"Yes": 1, "No": 0},
360
- 'YUSUITHK': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
361
- 'YMDETXRX': {"Yes": 1, "No": 0},
362
- 'YUSUITHKYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
363
- 'YMDERSUD5ANY': {"Yes": 1, "No": 0},
364
- 'YUSUIPLNYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
365
- 'YCOUNMDE': {"Yes": 1, "No": 0},
366
- 'YPSY1MDE': {"Yes": 1, "No": 0},
367
- 'YHLTMDE': {"Yes": 1, "No": 0},
368
- 'YDOCMDE': {"Yes": 1, "No": 0},
369
- 'YPSY2MDE': {"Yes": 1, "No": 0},
370
- 'YMDEHARX': {"Yes": 1, "No": 0},
371
- 'LVLDIFMEM2': {"No Difficulty": 1, "Some difficulty": 2, "A lot of difficulty or cannot do at all": 3},
372
- 'MDEIMPY': {"Yes": 1, "No": 2},
373
- 'YMDEHPO': {"Yes": 1, "No": 0},
374
- 'YMIMS5YANY': {"Yes": 1, "No": 0},
375
- 'YMDEIMAD5YR': {"Yes": 1, "No": 0},
376
- 'YMIUD5YANY': {"Yes": 1, "No": 0},
377
- 'YMDEHPRX': {"Yes": 1, "No": 0},
378
- 'YMIMI5YANY': {"Yes": 1, "No": 0},
379
- 'YUSUIPLN': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
380
- 'YTXMDEYR': {"Yes": 1, "No": 0},
381
- 'YMDEAUD5YR': {"Yes": 1, "No": 0},
382
- 'YRXMDEYR': {"Yes": 1, "No": 0},
383
- 'YMDELT': {"Yes": 1, "No": 2}
384
- }
385
-
386
- # Build the reverse mapping for each column
387
- for col, fwd_map in input_mapping.items():
388
- reverse_input_mapping[col] = {v: k for k, v in fwd_map.items()}
389
-
390
- # 3. Calculate Hamming distance for each row
391
- # We'll consider the columns in user_input for comparison
392
- features_to_compare = list(user_input.columns)
393
- subset_df = df[features_to_compare].copy()
394
- user_series = user_input.iloc[0]
395
-
396
  distances = []
397
- for idx, row in subset_df.iterrows():
398
- dist = sum(row[col] != user_series[col] for col in features_to_compare)
399
- distances.append(dist)
400
-
401
- df_with_dist = df.copy()
402
- df_with_dist["distance"] = distances
403
-
404
- # 4. Sort by distance ascending, pick top K=5
405
- K = 5
406
- nearest_neighbors = df_with_dist.sort_values("distance", ascending=True).head(K)
407
-
408
- # 5. Summarize neighbor info in user-friendly text
409
- # For demonstration, let's show a small table with each neighbor's values
410
- # for the same features. We'll also show a label or two.
411
- # We'll do this in Markdown format.
412
- nn_rows = []
413
- for idx, nr in nearest_neighbors.iterrows():
414
- # Convert each feature to text if possible
415
- row_text = []
416
- for col in features_to_compare:
417
- val_numeric = nr[col]
418
- if col in reverse_input_mapping:
419
- row_text.append(f"{col}={reverse_input_mapping[col].get(val_numeric, val_numeric)}")
420
- else:
421
- row_text.append(f"{col}={val_numeric}")
422
- # Let's also show YOWRCONC as an example label (if present)
423
- if "YOWRCONC" in nearest_neighbors.columns:
424
- label_val = nr["YOWRCONC"]
425
- if "YOWRCONC" in reverse_label_mapping:
426
- label_str = reverse_label_mapping["YOWRCONC"].get(label_val, label_val)
427
- row_text.append(f"YOWRCONC={label_str}")
428
- else:
429
- row_text.append(f"YOWRCONC={label_val}")
430
-
431
- nn_rows.append(f"- **Neighbor ID {idx}** (distance={nr['distance']}): " + ", ".join(row_text))
432
-
433
- similar_patient_markdown = (
434
- "### Nearest Neighbors (Simple Hamming Distance)\n"
435
- f"We searched for the top **{K}** patients whose features most closely match your input.\n\n"
436
- "> **Note**: “Nearest neighbor” methods for high-dimensional or purely categorical data can be non-trivial. "
437
- "This demo simply uses a Hamming distance over all input features and picks K=5 neighbors. "
438
- "In a real application, you would refine which features are most relevant, how to encode them, "
439
- "and how many neighbors to select.\n\n"
440
- "Below is a brief overview of each neighbor's input-feature values and one example label (`YOWRCONC`).\n\n"
441
- + "\n".join(nn_rows)
442
- )
443
-
444
- # F) Co-occurrence Plot from the previous example (kept for completeness)
445
- if all(col in df.columns for col in ["YMDEYR", "YMDERSUD5ANY", "YOWRCONC"]):
446
- co_occ_data = df.groupby(["YMDEYR", "YMDERSUD5ANY", "YOWRCONC"]).size().reset_index(name="count")
447
- fig_co_occ = px.bar(
448
- co_occ_data,
449
- x="YMDEYR",
450
- y="count",
451
- color="YOWRCONC",
452
- facet_col="YMDERSUD5ANY",
453
- title="Co-Occurrence Plot: YMDEYR and YMDERSUD5ANY vs YOWRCONC"
454
- )
455
  else:
456
- fig_co_occ = px.bar(title="Co-occurrence plot not available (check columns).")
457
-
458
- # =======================
459
- # RETURN EVERYTHING
460
- # We have 8 outputs:
461
- # 1) Prediction Results (Textbox)
462
- # 2) Mental Health Severity (Textbox)
463
- # 3) Total Patient Count (Markdown)
464
- # 4) Distribution Plot (for multiple input features vs. multiple labels)
465
- # 5) Nearest Neighbors Summary (Markdown)
466
- # 6) Co-Occurrence Plot
467
- # 7) Bar Chart for input features
468
- # 8) Bar Chart for predicted labels
469
- # =======================
470
  return (
471
- formatted_results,
472
- severity,
473
- total_patient_count_markdown,
474
- fig_dist,
475
- similar_patient_markdown,
476
- fig_co_occ,
477
- fig_bar_input,
478
- fig_bar_labels
479
  )
480
 
481
  ######################################
482
- # 5) MAPPING user-friendly text => numeric
483
  ######################################
484
  input_mapping = {
485
  'YNURSMDE': {"Yes": 1, "No": 0},
486
  'YMDEYR': {"Yes": 1, "No": 2},
487
  'YSOCMDE': {"Yes": 1, "No": 0},
488
- 'YMDESUD5ANYO': {"SUD only, no MDE": 1, "MDE only, no SUD": 2, "SUD and MDE": 3, "Neither SUD or MDE": 4},
489
  'YMSUD5YANY': {"Yes": 1, "No": 0},
490
- 'YUSUITHK': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
491
  'YMDETXRX': {"Yes": 1, "No": 0},
492
- 'YUSUITHKYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
493
  'YMDERSUD5ANY': {"Yes": 1, "No": 0},
494
- 'YUSUIPLNYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
495
  'YCOUNMDE': {"Yes": 1, "No": 0},
496
  'YPSY1MDE': {"Yes": 1, "No": 0},
497
  'YHLTMDE': {"Yes": 1, "No": 0},
498
  'YDOCMDE': {"Yes": 1, "No": 0},
499
  'YPSY2MDE': {"Yes": 1, "No": 0},
500
  'YMDEHARX': {"Yes": 1, "No": 0},
501
- 'LVLDIFMEM2': {"No Difficulty": 1, "Some difficulty": 2, "A lot of difficulty or cannot do at all": 3},
502
  'MDEIMPY': {"Yes": 1, "No": 2},
503
  'YMDEHPO': {"Yes": 1, "No": 0},
504
  'YMIMS5YANY': {"Yes": 1, "No": 0},
@@ -506,7 +374,7 @@ input_mapping = {
506
  'YMIUD5YANY': {"Yes": 1, "No": 0},
507
  'YMDEHPRX': {"Yes": 1, "No": 0},
508
  'YMIMI5YANY': {"Yes": 1, "No": 0},
509
- 'YUSUIPLN': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
510
  'YTXMDEYR': {"Yes": 1, "No": 0},
511
  'YMDEAUD5YR': {"Yes": 1, "No": 0},
512
  'YRXMDEYR': {"Yes": 1, "No": 0},
@@ -514,89 +382,93 @@ input_mapping = {
514
  }
515
 
516
  ######################################
517
- # 6) GRADIO INTERFACE
518
  ######################################
519
- # We have 8 outputs in total:
520
- # 1) Prediction Results
521
- # 2) Mental Health Severity
522
- # 3) Total Patient Count
523
- # 4) Distribution Plot
524
- # 5) Nearest Neighbors
525
- # 6) Co-Occurrence Plot
526
- # 7) Bar Chart for input features
527
- # 8) Bar Chart for predicted labels
528
-
529
  import gradio as gr
530
 
531
- # Define the inputs in the same order as function signature
532
- inputs = [
533
- gr.Dropdown(list(input_mapping['YMDEYR'].keys()), label="YMDEYR: PAST YEARS MAJOR DEPRESSIVE EPISODE"),
534
- gr.Dropdown(list(input_mapping['YMDERSUD5ANY'].keys()), label="YMDERSUD5ANY: MDE OR SUBSTANCE USE DISORDER - ANY"),
535
- gr.Dropdown(list(input_mapping['YMDEIMAD5YR'].keys()), label="YMDEIMAD5YR: MDE WITH SEV. IMP + ALCOHOL USE DISORDER"),
536
- gr.Dropdown(list(input_mapping['YMIMS5YANY'].keys()), label="YMIMS5YANY: MDE W/ SEV. IMP + SUBSTANCE USE DISORDER"),
537
- gr.Dropdown(list(input_mapping['YMDELT'].keys()), label="YMDELT: HAD MAJOR DEPRESSIVE EPISODE IN LIFETIME"),
538
- gr.Dropdown(list(input_mapping['YMDEHARX'].keys()), label="YMDEHARX: SAW HEALTH PROF + MEDS FOR MDE"),
539
- gr.Dropdown(list(input_mapping['YMDEHPRX'].keys()), label="YMDEHPRX: SAW HEALTH PROF OR MEDS FOR MDE"),
540
- gr.Dropdown(list(input_mapping['YMDETXRX'].keys()), label="YMDETXRX: RECEIVED TREATMENT/COUNSELING FOR MDE"),
541
- gr.Dropdown(list(input_mapping['YMDEHPO'].keys()), label="YMDEHPO: SAW HEALTH PROF ONLY FOR MDE"),
542
- gr.Dropdown(list(input_mapping['YMDEAUD5YR'].keys()), label="YMDEAUD5YR: MDE + ALCOHOL USE DISORDER"),
543
- gr.Dropdown(list(input_mapping['YMIMI5YANY'].keys()), label="YMIMI5YANY: MDE W/ ILL DRUG USE DISORDER"),
544
- gr.Dropdown(list(input_mapping['YMIUD5YANY'].keys()), label="YMIUD5YANY: MDE + ILL DRUG USE DISORDER"),
545
- gr.Dropdown(list(input_mapping['YMDESUD5ANYO'].keys()), label="YMDESUD5ANYO: MDE vs. SUD vs. BOTH vs. NEITHER"),
546
 
547
  # Consultations
548
- gr.Dropdown(list(input_mapping['YNURSMDE'].keys()), label="YNURSMDE: SAW/TALK TO NURSE/OT ABOUT MDE"),
549
- gr.Dropdown(list(input_mapping['YSOCMDE'].keys()), label="YSOCMDE: SAW/TALK TO SOCIAL WORKER ABOUT MDE"),
550
- gr.Dropdown(list(input_mapping['YCOUNMDE'].keys()), label="YCOUNMDE: SAW/TALK TO COUNSELOR ABOUT MDE"),
551
- gr.Dropdown(list(input_mapping['YPSY1MDE'].keys()), label="YPSY1MDE: SAW/TALK TO PSYCHOLOGIST ABOUT MDE"),
552
- gr.Dropdown(list(input_mapping['YPSY2MDE'].keys()), label="YPSY2MDE: SAW/TALK TO PSYCHIATRIST ABOUT MDE"),
553
- gr.Dropdown(list(input_mapping['YHLTMDE'].keys()), label="YHLTMDE: SAW/TALK TO HEALTH PROFESSIONAL ABOUT MDE"),
554
- gr.Dropdown(list(input_mapping['YDOCMDE'].keys()), label="YDOCMDE: SAW/TALK TO GP/FAMILY MD ABOUT MDE"),
555
- gr.Dropdown(list(input_mapping['YTXMDEYR'].keys()), label="YTXMDEYR: SAW/TALK DOCTOR/HEALTH PROF FOR MDE"),
556
-
557
- # Suicidal thoughts/plans
558
- gr.Dropdown(list(input_mapping['YUSUITHKYR'].keys()), label="YUSUITHKYR: SERIOUSLY THOUGHT ABOUT KILLING SELF"),
559
- gr.Dropdown(list(input_mapping['YUSUIPLNYR'].keys()), label="YUSUIPLNYR: MADE PLANS TO KILL SELF"),
560
- gr.Dropdown(list(input_mapping['YUSUITHK'].keys()), label="YUSUITHK: THINK ABOUT KILLING SELF (12 MONTHS)"),
561
- gr.Dropdown(list(input_mapping['YUSUIPLN'].keys()), label="YUSUIPLN: MADE PLANS TO KILL SELF (12 MONTHS)"),
562
 
563
  # Impairments
564
- gr.Dropdown(list(input_mapping['MDEIMPY'].keys()), label="MDEIMPY: MDE W/ SEVERE ROLE IMPAIRMENT"),
565
- gr.Dropdown(list(input_mapping['LVLDIFMEM2'].keys()), label="LVLDIFMEM2: LEVEL OF DIFFICULTY REMEMBERING/CONCENTRATING"),
566
- gr.Dropdown(list(input_mapping['YMSUD5YANY'].keys()), label="YMSUD5YANY: MDE + SUBSTANCE USE DISORDER - ANY"),
567
- gr.Dropdown(list(input_mapping['YRXMDEYR'].keys()), label="YRXMDEYR: USED MEDS FOR MDE IN PAST YEAR"),
568
  ]
569
 
570
- # The 8 outputs
 
 
 
 
 
 
 
 
 
 
 
571
  outputs = [
572
- gr.Textbox(label="Prediction Results", lines=30),
573
- gr.Textbox(label="Mental Health Severity", lines=4),
574
  gr.Markdown(label="Total Patient Count"),
575
- gr.Plot(label="Distribution Plot (Sample of Features & Labels)"),
576
- gr.Markdown(label="Nearest Neighbors Summary"),
577
- gr.Plot(label="Co-Occurrence Plot"),
578
- gr.Plot(label="Number of Patients per Input Feature"),
579
- gr.Plot(label="Number of Patients with Predicted Labels")
580
  ]
581
 
582
  ######################################
583
- # 7) WRAPPER FOR PREDICT
584
  ######################################
585
  def predict_with_text(
 
586
  YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
587
  YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
588
  YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE,
589
  YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
590
- YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
 
591
  ):
592
- # Validate user inputs
593
- if not validate_inputs(
594
  YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
595
  YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
596
  YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE,
597
  YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
598
  YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
599
- ):
 
600
  return (
601
  "Please select all required fields.",
602
  "Validation Error",
@@ -608,7 +480,7 @@ def predict_with_text(
608
  None
609
  )
610
 
611
- # Map user-friendly text to numeric
612
  user_inputs = {
613
  'YNURSMDE': input_mapping['YNURSMDE'][YNURSMDE],
614
  'YMDEYR': input_mapping['YMDEYR'][YMDEYR],
@@ -641,36 +513,34 @@ def predict_with_text(
641
  'YMDELT': input_mapping['YMDELT'][YMDELT]
642
  }
643
 
644
- # Pass these mapped values into the core predict function
645
- return predict(**user_inputs)
 
 
 
 
 
 
646
 
647
- # Optional custom CSS
648
  custom_css = """
649
- .gradio-container * {
650
- color: #1B1212 !important;
651
- }
652
- .gradio-container .form .form-group label {
653
- color: #1B1212 !important;
654
- }
655
- .gradio-container .output-textbox,
656
- .gradio-container .output-textbox textarea {
657
- color: #1B1212 !important;
658
- }
659
- .gradio-container .label,
660
- .gradio-container .input-label {
661
- color: #1B1212 !important;
662
- }
663
  """
664
 
665
- ######################################
666
- # 8) LAUNCH
667
- ######################################
668
  interface = gr.Interface(
669
- fn=predict_with_text,
670
  inputs=inputs,
671
- outputs=outputs,
672
- title="Adolescents with Substance Use Mental Health Screening (NSDUH Data)",
673
- css=custom_css
 
 
 
 
 
 
 
674
  )
675
 
676
  if __name__ == "__main__":
 
1
  import pickle
 
2
  import gradio as gr
3
  import numpy as np
4
  import pandas as pd
5
  import plotly.express as px
6
 
7
+ # Load the training CSV once.
8
  df = pd.read_csv("X_train_Y_Train_merged_train.csv")
9
 
10
  ######################################
 
16
  self.model_filenames = model_filenames
17
  self.models = self.load_models()
18
  # Mapping from label column to human-readable strings for 0/1
 
19
  self.prediction_map = {
20
+ "YOWRCONC": ["No difficulty concentrating", "Had difficulty concentrating"],
21
+ "YOSEEDOC": ["No need to see doctor", "Needed to see doctor"],
22
+ "YOWRHRS": ["No trouble sleeping", "Had trouble sleeping"],
23
+ "YO_MDEA5": ["Others didn't notice restlessness", "Others noticed restlessness"],
24
+ "YOWRCHR": ["Not sad beyond cheering", "Felt so sad no one could cheer up"],
25
+ "YOWRLSIN": ["Never felt bored/lost interest", "Felt bored/lost interest"],
26
+ "YODPPROB": ["No other problems for 2+ weeks", "Had other problems for 2+ weeks"],
27
+ "YOWRPROB": ["No worst time feeling", "Felt worst time ever"],
28
+ "YODPR2WK": ["No depressed feelings for 2+ wks", "Depressed feelings for 2+ wks"],
29
+ "YOWRDEPR": ["Not sad or depressed most days", "Sad or depressed most days"],
30
+ "YODPDISC": ["Mood not depressed overall", "Mood depressed overall (discrepancy)"],
31
+ "YOLOSEV": ["Did not lose interest in activities", "Lost interest in activities"],
32
+ "YOWRDCSN": ["Could make decisions", "Could not make decisions"],
33
+ "YODSMMDE": ["No 2+ week depression episodes", "Had 2+ week depression episodes"],
34
+ "YO_MDEA3": ["No appetite/weight changes", "Yes appetite/weight changes"],
35
+ "YODPLSIN": ["Never bored/lost interest", "Often bored/lost interest"],
36
+ "YOWRELES": ["Did not eat less", "Ate less than usual"],
37
+ "YODSCEV": ["Fewer severe symptoms", "More severe symptoms"],
38
+ "YOPB2WK": ["No uneasy feelings daily 2+ wks", "Uneasy feelings daily 2+ wks"],
39
+ "YO_MDEA2": ["No issues physical/mental daily", "Issues physical/mental daily 2+ wks"]
 
 
 
 
 
 
 
 
40
  }
41
 
42
  def load_models(self):
43
  models = []
44
+ for fn in self.model_filenames:
45
+ filepath = self.model_path + fn
46
+ with open(filepath, "rb") as file:
47
+ models.append(pickle.load(file))
 
48
  return models
49
 
50
  def make_predictions(self, user_input):
51
+ """Return list of numpy arrays, each array either [0] or [1]."""
52
+ preds = []
53
+ for m in self.models:
54
+ out = m.predict(user_input)
55
+ preds.append(np.array(out).flatten())
56
+ return preds
 
 
 
 
57
 
58
  def get_majority_vote(self, predictions):
59
+ """Flatten all predictions and find 0 or 1 with majority."""
60
+ combined = np.concatenate(predictions)
61
+ return np.bincount(combined).argmax()
62
+
 
 
 
 
 
 
 
 
 
63
  def evaluate_severity(self, majority_vote_count):
64
+ """Heuristic: Based on 16 total models, 0-4=Very Low, 5-8=Low, 9-12=Moderate, 13-16=Severe."""
65
  if majority_vote_count >= 13:
66
  return "Mental health severity: Severe"
67
  elif majority_vote_count >= 9:
 
72
  return "Mental health severity: Very Low"
73
 
74
  ######################################
75
+ # 2) CONFIGURATIONS
76
  ######################################
77
  model_filenames = [
78
  "YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl",
 
87
  # 3) INPUT VALIDATION
88
  ######################################
89
  def validate_inputs(*args):
90
+ # Just ensure all required (non-co-occurrence) fields are picked
91
  for arg in args:
92
+ if arg == '' or arg is None:
93
  return False
94
  return True
95
 
96
  ######################################
97
+ # 4) PREDICTION FUNCTION
98
  ######################################
99
  def predict(
100
+ # Original required features
101
  YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
102
  YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
103
  YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE,
104
  YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
105
+ YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR,
106
+ # **New** optional picks for co-occurrence
107
+ co_occ_feature1, co_occ_feature2, co_occ_label
108
  ):
109
+ """
110
+ Main function that:
111
+ - Predicts with the 16 models
112
+ - Aggregates results
113
+ - Produces severity
114
+ - Returns distribution & bar charts
115
+ - Finds K=2 Nearest Neighbors
116
+ - Produces *one* co-occurrence plot based on user-chosen columns
117
+ """
118
+
119
+ # 1) Build user_input for models
120
  user_input_data = {
121
  'YNURSMDE': [int(YNURSMDE)],
122
  'YMDEYR': [int(YMDEYR)],
 
150
  }
151
  user_input = pd.DataFrame(user_input_data)
152
 
153
+ # 2) Model Predictions
154
  predictions = predictor.make_predictions(user_input)
 
 
155
  majority_vote = predictor.get_majority_vote(predictions)
156
+ majority_vote_count = np.sum(np.concatenate(predictions) == 1)
 
 
 
 
157
  severity = predictor.evaluate_severity(majority_vote_count)
158
 
159
+ # 3) Summarize textual results
160
+ results_by_group = {
 
161
  "Concentration_and_Decision_Making": [],
162
  "Sleep_and_Energy_Levels": [],
163
  "Mood_and_Emotional_State": [],
164
  "Appetite_and_Weight_Changes": [],
165
  "Duration_and_Severity_of_Depression_Symptoms": []
166
  }
167
+ group_map = {
 
168
  "Concentration_and_Decision_Making": ["YOWRCONC", "YOWRDCSN"],
169
  "Sleep_and_Energy_Levels": ["YOWRHRS", "YO_MDEA5", "YOWRELES", "YO_MDEA2"],
170
  "Mood_and_Emotional_State": ["YOWRCHR", "YOWRLSIN", "YOWRDEPR", "YODPDISC",
 
175
  "YOPB2WK"]
176
  }
177
 
178
+ # Convert each model's 0/1 to text
179
+ grouped_output_lines = []
180
+ for i, pred_array in enumerate(predictions):
181
+ col_name = model_filenames[i].split(".")[0] # e.g., "YOWRCONC"
182
+ val = pred_array[0]
183
+ if col_name in predictor.prediction_map and val in [0, 1]:
184
+ text = predictor.prediction_map[col_name][val]
185
+ out_line = f"{col_name}: {text}"
186
  else:
187
+ out_line = f"{col_name}: Prediction={val}"
188
+
189
+ # Find group
190
+ placed = False
191
+ for g_key, g_cols in group_map.items():
192
+ if col_name in g_cols:
193
+ results_by_group[g_key].append(out_line)
194
+ placed = True
 
195
  break
196
+ if not placed:
197
+ # If it didn't fall into any known group, skip or handle
198
  pass
199
 
200
+ # Format into a single string
201
+ for group_label, pred_lines in results_by_group.items():
202
+ if pred_lines:
203
+ grouped_output_lines.append(f"Group {group_label}:")
204
+ grouped_output_lines.append("\n".join(pred_lines))
205
+ grouped_output_lines.append("")
206
+
207
+ if len(grouped_output_lines) == 0:
208
+ final_result_text = "No predictions made. Check inputs."
209
+ else:
210
+ final_result_text = "\n".join(grouped_output_lines).strip()
211
+
212
+ # 4) Additional Features
213
+ # A) Total patient count
 
 
 
 
 
214
  total_patients = len(df)
215
+ total_count_md = (
216
  "### Total Patient Count\n"
217
+ f"**{total_patients}** total patients in the dataset."
 
218
  )
219
 
220
+ # B) Bar chart of how many have same inputs
221
  input_counts = {}
222
+ for c in user_input_data.keys():
223
+ v = user_input_data[c][0]
224
+ input_counts[c] = len(df[df[c] == v])
225
+ df_input_counts = pd.DataFrame({"Feature": list(input_counts.keys()), "Count": list(input_counts.values())})
226
+ fig_input_bar = px.bar(
227
+ df_input_counts,
228
+ x="Feature",
229
+ y="Count",
230
+ title="Number of Patients with the Same Value for Each Input Feature"
 
 
 
 
 
231
  )
232
+ fig_input_bar.update_layout(xaxis={"categoryorder": "total descending"})
233
 
234
+ # C) Bar chart for predicted labels
235
  label_counts = {}
236
+ for i, pred_array in enumerate(predictions):
237
+ col_name = model_filenames[i].split(".")[0]
238
+ val = pred_array[0]
239
+ if val in [0,1]:
240
+ label_counts[col_name] = len(df[df[col_name] == val])
241
+
242
  if len(label_counts) > 0:
243
+ df_label_counts = pd.DataFrame({
244
+ "Label Column": list(label_counts.keys()),
245
  "Count": list(label_counts.values())
246
  })
247
+ fig_label_bar = px.bar(
248
+ df_label_counts,
249
+ x="Label Column",
250
  y="Count",
251
+ title="Number of Patients with the Same Predicted Label"
 
252
  )
253
  else:
254
+ fig_label_bar = px.bar(title="No valid predicted labels to display")
255
+
256
+ # D) Simple Distribution Plot (demo for first 3 labels & 4 inputs)
257
+ # (Unchanged from prior approach; you can remove if you prefer.)
258
+ sample_feats = list(user_input_data.keys())[:4]
259
+ sample_labels = [fn.split(".")[0] for fn in model_filenames[:3]]
260
+ dist_segments = []
261
+ for feat in sample_feats:
 
 
 
 
 
 
 
262
  if feat not in df.columns:
263
  continue
264
+ for lbl in sample_labels:
265
  if lbl not in df.columns:
266
  continue
267
+ temp_g = df.groupby([feat,lbl]).size().reset_index(name="count")
268
+ temp_g["feature"] = feat
269
+ temp_g["label"] = lbl
270
+ dist_segments.append(temp_g)
271
+ if len(dist_segments) > 0:
272
+ big_dist_df = pd.concat(dist_segments, ignore_index=True)
 
 
 
 
273
  fig_dist = px.bar(
274
  big_dist_df,
275
+ x=big_dist_df.columns[0],
276
  y="count",
277
+ color=big_dist_df.columns[1],
278
  facet_row="feature",
279
  facet_col="label",
280
+ title="Sample Distribution Plot (first 4 features vs first 3 labels)"
 
 
 
 
281
  )
282
+ fig_dist.update_layout(height=700)
283
  else:
284
+ fig_dist = px.bar(title="No distribution plot generated (columns not found).")
285
+
286
+ # E) Nearest Neighbors with K=2
287
+ # We keep K=2, but for *all* label columns, we show their actual 0/1 or mapped text
288
+ # (same approach as before).
289
+ # ... [omitted here for brevity, or replicate your existing code for K=2 nearest neighbors] ...
290
+ # We'll do a short version to keep focus on co-occ:
291
+ # ---------------------------------------------------------------------
292
+ # Build Hamming distance across user_input columns
293
+ columns_for_distance = list(user_input.columns)
294
+ sub_df = df[columns_for_distance].copy()
295
+ user_row = user_input.iloc[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  distances = []
297
+ for idx, row_ in sub_df.iterrows():
298
+ dist_ = sum(row_[col] != user_row[col] for col in columns_for_distance)
299
+ distances.append(dist_)
300
+ df_dist = df.copy()
301
+ df_dist["distance"] = distances
302
+ # Sort ascending, pick K=2
303
+ K = 2
304
+ nearest_neighbors = df_dist.sort_values("distance", ascending=True).head(K)
305
+
306
+ # Summarize in Markdown
307
+ nn_md = ["### Nearest Neighbors (K=2)"]
308
+ nn_md.append("(In a real application, you'd refine which features matter, how to encode them, etc.)\n")
309
+ for irow in nearest_neighbors.itertuples():
310
+ nn_md.append(f"- **Neighbor ID {irow.Index}**: distance={irow.distance}")
311
+ nn_md_str = "\n".join(nn_md)
312
+
313
+ # F) Co-occurrence Plot for user-chosen feature1, feature2, label
314
+ # If the user picks "None" or doesn't pick valid columns, skip or fallback.
315
+ if (co_occ_feature1 is not None and co_occ_feature1 != "None" and
316
+ co_occ_feature2 is not None and co_occ_feature2 != "None" and
317
+ co_occ_label is not None and co_occ_label != "None"):
318
+ # Check if these columns are in df
319
+ if (co_occ_feature1 in df.columns and
320
+ co_occ_feature2 in df.columns and
321
+ co_occ_label in df.columns):
322
+ # Group by [co_occ_feature1, co_occ_feature2, co_occ_label]
323
+ co_data = df.groupby([co_occ_feature1, co_occ_feature2, co_occ_label]).size().reset_index(name="count")
324
+ fig_co_occ = px.bar(
325
+ co_data,
326
+ x=co_occ_feature1,
327
+ y="count",
328
+ color=co_occ_label,
329
+ facet_col=co_occ_feature2,
330
+ title=f"Co-occurrence: {co_occ_feature1} & {co_occ_feature2} vs {co_occ_label}"
331
+ )
332
+ else:
333
+ fig_co_occ = px.bar(title="One or more selected columns not found in dataframe.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  else:
335
+ fig_co_occ = px.bar(title="No co-occurrence plot (choose two features + one label).")
336
+
337
+ # Return all 8 outputs
 
 
 
 
 
 
 
 
 
 
 
338
  return (
339
+ final_result_text, # (1) Predictions
340
+ severity, # (2) Severity
341
+ total_count_md, # (3) Total patient count
342
+ fig_dist, # (4) Distribution Plot
343
+ nn_md_str, # (5) Nearest Neighbors
344
+ fig_co_occ, # (6) Co-occurrence
345
+ fig_input_bar, # (7) Bar Chart (input features)
346
+ fig_label_bar # (8) Bar Chart (labels)
347
  )
348
 
349
  ######################################
350
+ # 5) MAPPING (user -> int)
351
  ######################################
352
  input_mapping = {
353
  'YNURSMDE': {"Yes": 1, "No": 0},
354
  'YMDEYR': {"Yes": 1, "No": 2},
355
  'YSOCMDE': {"Yes": 1, "No": 0},
356
+ 'YMDESUD5ANYO': {"SUD only": 1, "MDE only": 2, "SUD & MDE": 3, "Neither": 4},
357
  'YMSUD5YANY': {"Yes": 1, "No": 0},
358
+ 'YUSUITHK': {"Yes": 1, "No": 2, "Unsure": 3, "Don't want to answer": 4},
359
  'YMDETXRX': {"Yes": 1, "No": 0},
360
+ 'YUSUITHKYR': {"Yes": 1, "No": 2, "Unsure": 3, "Don't want to answer": 4},
361
  'YMDERSUD5ANY': {"Yes": 1, "No": 0},
362
+ 'YUSUIPLNYR': {"Yes": 1, "No": 2, "Unsure": 3, "Don't want to answer": 4},
363
  'YCOUNMDE': {"Yes": 1, "No": 0},
364
  'YPSY1MDE': {"Yes": 1, "No": 0},
365
  'YHLTMDE': {"Yes": 1, "No": 0},
366
  'YDOCMDE': {"Yes": 1, "No": 0},
367
  'YPSY2MDE': {"Yes": 1, "No": 0},
368
  'YMDEHARX': {"Yes": 1, "No": 0},
369
+ 'LVLDIFMEM2': {"No Difficulty": 1, "Some Difficulty": 2, "A lot or cannot do": 3},
370
  'MDEIMPY': {"Yes": 1, "No": 2},
371
  'YMDEHPO': {"Yes": 1, "No": 0},
372
  'YMIMS5YANY': {"Yes": 1, "No": 0},
 
374
  'YMIUD5YANY': {"Yes": 1, "No": 0},
375
  'YMDEHPRX': {"Yes": 1, "No": 0},
376
  'YMIMI5YANY': {"Yes": 1, "No": 0},
377
+ 'YUSUIPLN': {"Yes": 1, "No": 2, "Unsure": 3, "Don't want to answer": 4},
378
  'YTXMDEYR': {"Yes": 1, "No": 0},
379
  'YMDEAUD5YR': {"Yes": 1, "No": 0},
380
  'YRXMDEYR': {"Yes": 1, "No": 0},
 
382
  }
383
 
384
  ######################################
385
+ # 6) THE GRADIO INTERFACE
386
  ######################################
 
 
 
 
 
 
 
 
 
 
387
  import gradio as gr
388
 
389
+ # (A) The original required inputs
390
+ original_inputs = [
391
+ gr.Dropdown(list(input_mapping['YMDEYR'].keys()), label="YMDEYR: Past Year MDE?"),
392
+ gr.Dropdown(list(input_mapping['YMDERSUD5ANY'].keys()), label="YMDERSUD5ANY: MDE or SUD - ANY?"),
393
+ gr.Dropdown(list(input_mapping['YMDEIMAD5YR'].keys()), label="YMDEIMAD5YR: MDE + ALCOHOL?"),
394
+ gr.Dropdown(list(input_mapping['YMIMS5YANY'].keys()), label="YMIMS5YANY: MDE + SUBSTANCE?"),
395
+ gr.Dropdown(list(input_mapping['YMDELT'].keys()), label="YMDELT: MDE in Lifetime?"),
396
+ gr.Dropdown(list(input_mapping['YMDEHARX'].keys()), label="YMDEHARX: Saw Health Prof + Meds?"),
397
+ gr.Dropdown(list(input_mapping['YMDEHPRX'].keys()), label="YMDEHPRX: Saw Health Prof or Meds?"),
398
+ gr.Dropdown(list(input_mapping['YMDETXRX'].keys()), label="YMDETXRX: Received Treatment?"),
399
+ gr.Dropdown(list(input_mapping['YMDEHPO'].keys()), label="YMDEHPO: Saw Health Prof Only?"),
400
+ gr.Dropdown(list(input_mapping['YMDEAUD5YR'].keys()), label="YMDEAUD5YR: MDE + Alcohol Use?"),
401
+ gr.Dropdown(list(input_mapping['YMIMI5YANY'].keys()), label="YMIMI5YANY: MDE + ILL Drug Use?"),
402
+ gr.Dropdown(list(input_mapping['YMIUD5YANY'].keys()), label="YMIUD5YANY: MDE + ILL Drug Use?"),
403
+ gr.Dropdown(list(input_mapping['YMDESUD5ANYO'].keys()), label="YMDESUD5ANYO: MDE vs SUD vs BOTH vs NEITHER"),
404
 
405
  # Consultations
406
+ gr.Dropdown(list(input_mapping['YNURSMDE'].keys()), label="YNURSMDE: Nurse/OT about MDE?"),
407
+ gr.Dropdown(list(input_mapping['YSOCMDE'].keys()), label="YSOCMDE: Social Worker?"),
408
+ gr.Dropdown(list(input_mapping['YCOUNMDE'].keys()), label="YCOUNMDE: Counselor?"),
409
+ gr.Dropdown(list(input_mapping['YPSY1MDE'].keys()), label="YPSY1MDE: Psychologist?"),
410
+ gr.Dropdown(list(input_mapping['YPSY2MDE'].keys()), label="YPSY2MDE: Psychiatrist?"),
411
+ gr.Dropdown(list(input_mapping['YHLTMDE'].keys()), label="YHLTMDE: Health Prof?"),
412
+ gr.Dropdown(list(input_mapping['YDOCMDE'].keys()), label="YDOCMDE: GP/Family MD?"),
413
+ gr.Dropdown(list(input_mapping['YTXMDEYR'].keys()), label="YTXMDEYR: Doctor/Health Prof?"),
414
+
415
+ # Suicidal
416
+ gr.Dropdown(list(input_mapping['YUSUITHKYR'].keys()), label="YUSUITHKYR: Serious Suicide Thoughts?"),
417
+ gr.Dropdown(list(input_mapping['YUSUIPLNYR'].keys()), label="YUSUIPLNYR: Made Plans?"),
418
+ gr.Dropdown(list(input_mapping['YUSUITHK'].keys()), label="YUSUITHK: Suicide Thoughts (12 mo)?"),
419
+ gr.Dropdown(list(input_mapping['YUSUIPLN'].keys()), label="YUSUIPLN: Made Plans (12 mo)?"),
420
 
421
  # Impairments
422
+ gr.Dropdown(list(input_mapping['MDEIMPY'].keys()), label="MDEIMPY: Severe Role Impairment?"),
423
+ gr.Dropdown(list(input_mapping['LVLDIFMEM2'].keys()), label="LVLDIFMEM2: Difficulty Remembering/Concentrating?"),
424
+ gr.Dropdown(list(input_mapping['YMSUD5YANY'].keys()), label="YMSUD5YANY: MDE + Substance?"),
425
+ gr.Dropdown(list(input_mapping['YRXMDEYR'].keys()), label="YRXMDEYR: Used Meds for MDE (12 mo)?"),
426
  ]
427
 
428
+ # (B) The new co-occurrence inputs
429
+ # We'll give them defaults of "None" to indicate no selection.
430
+ all_cols = ["None"] + df.columns.tolist() # 'None' plus the actual columns from your df
431
+ co_occ_feature1 = gr.Dropdown(all_cols, label="Co-Occ Feature 1", value="None")
432
+ co_occ_feature2 = gr.Dropdown(all_cols, label="Co-Occ Feature 2", value="None")
433
+ all_label_cols = ["None"] + list(predictor.prediction_map.keys()) # e.g., "YOWRCONC", "YOWRHRS", ...
434
+ co_occ_label = gr.Dropdown(all_label_cols, label="Co-Occ Label", value="None")
435
+
436
+ # Combine them into a single input list
437
+ inputs = original_inputs + [co_occ_feature1, co_occ_feature2, co_occ_label]
438
+
439
+ # 8 outputs as before
440
  outputs = [
441
+ gr.Textbox(label="Prediction Results", lines=15),
442
+ gr.Textbox(label="Mental Health Severity", lines=2),
443
  gr.Markdown(label="Total Patient Count"),
444
+ gr.Plot(label="Distribution Plot (Sample)"),
445
+ gr.Markdown(label="Nearest Neighbors (K=2)"),
446
+ gr.Plot(label="Co-occurrence Plot"),
447
+ gr.Plot(label="Same Value Bar (Inputs)"),
448
+ gr.Plot(label="Predicted Label Bar")
449
  ]
450
 
451
  ######################################
452
+ # 7) WRAPPER
453
  ######################################
454
  def predict_with_text(
455
+ # match the function signature exactly (29 required + 3 for co-occ)
456
  YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
457
  YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
458
  YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE,
459
  YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
460
+ YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR,
461
+ co_occ_feature1, co_occ_feature2, co_occ_label
462
  ):
463
+ # Validate the original 29 fields
464
+ valid = validate_inputs(
465
  YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
466
  YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
467
  YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE,
468
  YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
469
  YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
470
+ )
471
+ if not valid:
472
  return (
473
  "Please select all required fields.",
474
  "Validation Error",
 
480
  None
481
  )
482
 
483
+ # Map to numeric
484
  user_inputs = {
485
  'YNURSMDE': input_mapping['YNURSMDE'][YNURSMDE],
486
  'YMDEYR': input_mapping['YMDEYR'][YMDEYR],
 
513
  'YMDELT': input_mapping['YMDELT'][YMDELT]
514
  }
515
 
516
+ # Call the core predict function with the co-occ choices as well
517
+ return predict(
518
+ **user_inputs,
519
+ co_occ_feature1=co_occ_feature1,
520
+ co_occ_feature2=co_occ_feature2,
521
+ co_occ_label=co_occ_label
522
+ )
523
+
524
 
 
525
  custom_css = """
526
+ .gradio-container * {
527
+ color: #1B1212 !important;
528
+ }
 
 
 
 
 
 
 
 
 
 
 
529
  """
530
 
 
 
 
531
  interface = gr.Interface(
532
+ fn=predict_with_text,
533
  inputs=inputs,
534
+ outputs=outputs,
535
+ title="Mental Health Screening (NSDUH) with Selective Co-Occurrence",
536
+ css=custom_css,
537
+ description="""
538
+ **Instructions**:
539
+ 1. Fill out all required fields regarding MDE/Substance Use/Consultations/Suicidal/Impairments.
540
+ 2. (Optional) Choose 2 features and 1 label for the *Co-occurrence* plot.
541
+ - If you do not select them (or leave them as "None"), that plot will be skipped.
542
+ 3. Click "Submit" to get predictions, severity, distribution plots, nearest neighbors, and your custom co-occurrence chart.
543
+ """
544
  )
545
 
546
  if __name__ == "__main__":