pantdipendra commited on
Commit
e9e83fc
·
verified ·
1 Parent(s): c458985

v2_seperate tabs categories in UI

Browse files
Files changed (1) hide show
  1. app.py +236 -308
app.py CHANGED
@@ -27,53 +27,27 @@ class ModelPredictor:
27
  self.model_filenames = model_filenames
28
  self.models = self.load_models()
29
 
30
- # The map from each label column to the textual meaning for 0 or 1
31
- # (Some columns also mention '2' as positive, so adapt as needed).
32
  self.prediction_map = {
33
  "YOWRCONC": ["Did not have difficulty concentrating", "Had difficulty concentrating"],
34
  "YOSEEDOC": ["Did not feel the need to see a doctor", "Felt the need to see a doctor"],
35
  "YOWRHRS": ["Did not have trouble sleeping", "Had trouble sleeping"],
36
  "YO_MDEA5": ["Others did not notice restlessness/lethargy", "Others noticed restlessness/lethargy"],
37
  "YOWRCHR": ["Did not feel so sad nothing could cheer up", "Felt so sad that nothing could cheer up"],
38
- "YOWRLSIN": [
39
- "Did not feel bored / lose interest",
40
- "Felt bored / lost interest in enjoyable things"
41
- ],
42
  "YODPPROB": ["No other problems for 2+ weeks", "Had other problems for 2+ weeks"],
43
  "YOWRPROB": ["No 'worst time ever' feeling", "Had 'worst time ever' feeling"],
44
- "YODPR2WK": [
45
- "No periods with depressed feelings lasting 2+ weeks",
46
- "Had depressed feelings for 2+ weeks"
47
- ],
48
  "YOWRDEPR": ["Did not feel sad/depressed mostly everyday", "Felt sad/depressed mostly everyday"],
49
- "YODPDISC": [
50
- "Overall mood duration was not sad/depressed",
51
- "Overall mood duration was sad/depressed"
52
- ],
53
- "YOLOSEV": [
54
- "Did not lose interest in activities",
55
- "Lost interest in enjoyable things"
56
- ],
57
  "YOWRDCSN": ["Was able to make decisions", "Was unable to make decisions"],
58
- "YODSMMDE": [
59
- "Never had 2 weeks of depression symptoms",
60
- "Had 2+ weeks of depression symptoms"
61
- ],
62
- "YO_MDEA3": [
63
- "No changes in appetite/weight",
64
- "Had changes in appetite or weight"
65
- ],
66
- "YODPLSIN": ["Never lost interest / felt bored", "Lost interest / felt bored"],
67
  "YOWRELES": ["Did not eat less than usual", "Ate less than usual"],
68
  "YODSCEV": ["Fewer severe depression symptoms", "More severe depression symptoms"],
69
- "YOPB2WK": [
70
- "No uneasy feelings lasting every day for 2+ weeks",
71
- "Uneasy feelings lasting 2+ weeks"
72
- ],
73
- "YO_MDEA2": [
74
- "No physical/mental issues for 2+ weeks",
75
- "Had physical/mental issues for 2+ weeks"
76
- ]
77
  }
78
 
79
  def load_models(self):
@@ -85,10 +59,6 @@ class ModelPredictor:
85
  return loaded
86
 
87
  def make_predictions(self, user_input: pd.DataFrame):
88
- """
89
- Return a list of np.ndarrays, each of shape (1,) or (n_samples,),
90
- one for each model in self.models, in the same order as model_filenames.
91
- """
92
  predictions = []
93
  for model in self.models:
94
  out = model.predict(user_input)
@@ -96,19 +66,10 @@ class ModelPredictor:
96
  return predictions
97
 
98
  def get_majority_vote(self, predictions):
99
- """
100
- Flatten all predictions from each model into a single array
101
- and compute the most common value (mode).
102
- """
103
  combined = np.concatenate(predictions)
104
  return np.bincount(combined).argmax()
105
 
106
  def evaluate_severity(self, count_ones: int) -> str:
107
- """
108
- The user wanted a logic: if >=13 => Severe, >=9 => Moderate, >=5 => Low, else Very Low.
109
- Here 'count_ones' is how many '1's (or '2's) across all model predictions.
110
- Adjust logic if needed.
111
- """
112
  if count_ones >= 13:
113
  return "Mental Health Severity: Severe"
114
  elif count_ones >= 9:
@@ -123,174 +84,175 @@ predictor = ModelPredictor(model_path, model_filenames)
123
 
124
 
125
  ######################################
126
- # 3) HELPER: NEAREST NEIGHBORS
127
  ######################################
128
- def get_nearest_neighbors_info(user_input_df: pd.DataFrame, k=5):
129
- """
130
- Given a single-row user_input_df (the 25 numeric features),
131
- find the top-k nearest neighbors in df (using those same 25 columns).
132
- Then build a textual summary for clinicians.
133
-
134
- We assume df has the same numeric coding for these 25 features.
135
- """
136
- # 1) Ensure these columns exist in df
137
- user_cols = user_input_df.columns
138
- if not all(col in df.columns for col in user_cols):
139
- return "Cannot compute nearest neighbors. Some columns not found in df."
 
 
 
 
140
 
141
- # 2) We'll do a simple Euclidean distance
142
- # Subset df to these 25 columns
143
- sub_df = df[list(user_cols)].copy()
 
 
 
 
 
 
 
144
 
145
- # 3) Compute distance to the user input row
146
- # user_input_df has shape (1, 25). We'll broadcast to sub_df's shape
147
- # row by row. For performance, you might prefer scikit's NearestNeighbors,
148
- # but let's do a manual approach for clarity.
149
- diffs = sub_df - user_input_df.iloc[0] # shape (N,25)
150
- dists = (diffs**2).sum(axis=1)**0.5 # Euclidean
 
 
 
 
151
 
152
- # 4) Sort by distance, pick top k
153
- nn_indices = dists.nsmallest(k).index
154
- neighbors = df.loc[nn_indices]
155
 
156
- # 5) Build a textual summary
157
- # We will look at each label in predictor.prediction_map,
158
- # see if it is a column in df. If so, see how many are 1 vs 0 (or 2) among neighbors.
159
- # Then map numeric -> text from prediction_map if possible.
160
- summary_lines = [f"**Nearest Neighbors (k={k})**",
161
- f"Distances Range: {dists[nn_indices].min():.2f} to {dists[nn_indices].max():.2f}",
162
- ""]
163
- for label_col, label_map in predictor.prediction_map.items():
164
- if label_col not in neighbors.columns:
165
- continue # Not present in df
166
- # Values among neighbors
167
- vals = neighbors[label_col].value_counts().to_dict()
168
- # Example: {0: 3, 1: 2}, or {2: 4, 1: 1}, etc.
169
- line = f"{label_col} => "
170
- parts = []
171
- for val, count_ in vals.items():
172
- # If we have a mapping, use it
173
- if val in range(len(label_map)):
174
- meaning = label_map[val]
175
- parts.append(f"{count_} had {meaning}")
176
- else:
177
- parts.append(f"{count_} had numeric={val}")
178
- line += "; ".join(parts)
179
- summary_lines.append(line)
180
- summary_lines.append("")
181
- summary_text = "\n".join(summary_lines)
182
- return summary_text
183
 
184
 
185
- ######################################
186
- # 4) INPUT MAPPING
187
- ######################################
188
  def validate_inputs(*args):
189
  for arg in args:
190
  if not arg: # empty or None
191
  return False
192
  return True
193
 
194
- # Only keep the 25 features requested.
195
- input_mapping = {
196
- 'YMDESUD5ANYO': {"SUD only, no MDE": 1, "MDE only, no SUD": 2, "SUD and MDE": 3, "Neither SUD or MDE": 4},
197
- 'YMDEHPO': {"Yes": 1, "No": 0},
198
- 'YMDETXRX': {"Yes": 1, "No": 0},
199
- 'LVLDIFMEM2': {"No Difficulty": 1, "Some difficulty": 2, "A lot of difficulty or cannot do at all": 3},
200
- 'YMSUD5YANY': {"Yes": 1, "No": 0},
201
- 'YPSY2MDE': {"Yes": 1, "No": 0},
202
- 'YMDELT': {"Yes": 1, "No": 2},
203
- 'YDOCMDE': {"Yes": 1, "No": 0},
204
- 'YMIMI5YANY': {"Yes": 1, "No": 0},
205
- 'YMDEHARX': {"Yes": 1, "No": 0},
206
- 'MDEIMPY': {"Yes": 1, "No": 2},
207
- 'YRXMDEYR': {"Yes": 1, "No": 0},
208
- 'YMDERSUD5ANY': {"Yes": 1, "No": 0},
209
- 'YMIMS5YANY': {"Yes": 1, "No": 0},
210
- 'YMDEYR': {"Yes": 1, "No": 2},
211
- 'YHLTMDE': {"Yes": 1, "No": 0},
212
- 'YUSUIPLNYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
213
- 'YMDEHPRX': {"Yes": 1, "No": 0},
214
- 'YUSUIPLN': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
215
- 'YPSY1MDE': {"Yes": 1, "No": 0},
216
- 'YMIUD5YANY': {"Yes": 1, "No": 0},
217
- 'YUSUITHK': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
218
- 'YTXMDEYR': {"Yes": 1, "No": 0},
219
- 'YCOUNMDE': {"Yes": 1, "No": 0},
220
- 'YUSUITHKYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4}
221
- }
 
 
 
 
 
 
 
 
 
 
222
 
223
 
224
  ######################################
225
- # 5) PREDICT FUNCTION (Prediction Tab)
226
  ######################################
227
  def predict(
228
- # EXACT 25 features in this order:
229
- YMDESUD5ANYO, YMDEHPO, YMDETXRX, LVLDIFMEM2, YMSUD5YANY,
230
- YPSY2MDE, YMDELT, YDOCMDE, YMIMI5YANY, YMDEHARX,
231
- MDEIMPY, YRXMDEYR, YMDERSUD5ANY, YMIMS5YANY, YMDEYR,
232
- YHLTMDE, YUSUIPLNYR, YMDEHPRX, YUSUIPLN, YPSY1MDE,
233
- YMIUD5YANY, YUSUITHK, YTXMDEYR, YCOUNMDE, YUSUITHKYR
 
 
 
 
 
234
  ):
235
- # 1) Validate
236
  if not validate_inputs(
237
- YMDESUD5ANYO, YMDEHPO, YMDETXRX, LVLDIFMEM2, YMSUD5YANY,
238
- YPSY2MDE, YMDELT, YDOCMDE, YMIMI5YANY, YMDEHARX,
239
- MDEIMPY, YRXMDEYR, YMDERSUD5ANY, YMIMS5YANY, YMDEYR,
240
- YHLTMDE, YUSUIPLNYR, YMDEHPRX, YUSUIPLN, YPSY1MDE,
241
- YMIUD5YANY, YUSUITHK, YTXMDEYR, YCOUNMDE, YUSUITHKYR
 
242
  ):
243
  return (
244
  "Please select all required fields.", # 1) Prediction Results
245
- "Validation Error", # 2) Mental Health Severity
246
- "No data", # 3) Total Patient Count
247
- "No nearest neighbors info", # 4) Nearest Neighbors Summary
248
- None, # 5) Bar Chart (Input Feature)
249
- None # 6) Bar Chart (Predicted Labels)
250
  )
251
 
252
- # 2) Map user-friendly -> numeric
253
  user_input_dict = {
254
  'YMDESUD5ANYO': input_mapping['YMDESUD5ANYO'][YMDESUD5ANYO],
255
- 'YMDEHPO': input_mapping['YMDEHPO'][YMDEHPO],
256
- 'YMDETXRX': input_mapping['YMDETXRX'][YMDETXRX],
257
- 'LVLDIFMEM2': input_mapping['LVLDIFMEM2'][LVLDIFMEM2],
258
- 'YMSUD5YANY': input_mapping['YMSUD5YANY'][YMSUD5YANY],
259
- 'YPSY2MDE': input_mapping['YPSY2MDE'][YPSY2MDE],
260
  'YMDELT': input_mapping['YMDELT'][YMDELT],
261
- 'YDOCMDE': input_mapping['YDOCMDE'][YDOCMDE],
 
 
 
 
262
  'YMIMI5YANY': input_mapping['YMIMI5YANY'][YMIMI5YANY],
 
 
 
263
  'YMDEHARX': input_mapping['YMDEHARX'][YMDEHARX],
264
- 'MDEIMPY': input_mapping['MDEIMPY'][MDEIMPY],
265
  'YRXMDEYR': input_mapping['YRXMDEYR'][YRXMDEYR],
266
- 'YMDERSUD5ANY': input_mapping['YMDERSUD5ANY'][YMDERSUD5ANY],
267
- 'YMIMS5YANY': input_mapping['YMIMS5YANY'][YMIMS5YANY],
268
- 'YMDEYR': input_mapping['YMDEYR'][YMDEYR],
269
  'YHLTMDE': input_mapping['YHLTMDE'][YHLTMDE],
270
- 'YUSUIPLNYR': input_mapping['YUSUIPLNYR'][YUSUIPLNYR],
271
- 'YMDEHPRX': input_mapping['YMDEHPRX'][YMDEHPRX],
272
- 'YUSUIPLN': input_mapping['YUSUIPLN'][YUSUIPLN],
273
- 'YPSY1MDE': input_mapping['YPSY1MDE'][YPSY1MDE],
274
- 'YMIUD5YANY': input_mapping['YMIUD5YANY'][YMIUD5YANY],
275
- 'YUSUITHK': input_mapping['YUSUITHK'][YUSUITHK],
276
  'YTXMDEYR': input_mapping['YTXMDEYR'][YTXMDEYR],
 
 
 
277
  'YCOUNMDE': input_mapping['YCOUNMDE'][YCOUNMDE],
278
- 'YUSUITHKYR': input_mapping['YUSUITHKYR'][YUSUITHKYR]
 
 
 
 
 
 
 
279
  }
280
  user_df = pd.DataFrame(user_input_dict, index=[0])
281
 
282
- # 3) Make predictions
283
- predictions = predictor.make_predictions(user_df) # list of arrays
284
- # e.g. predictions[i][0] is the predicted label for model i
285
- # Flatten them for counting ones
286
  all_preds = np.concatenate(predictions)
287
- # In your logic, "1" might be a positive class, or "2" might be. Adapt if needed:
288
- # For now, let's assume "1" is the relevant "positive" count:
289
  count_ones = sum(all_preds == 1)
290
-
291
  severity_msg = predictor.evaluate_severity(count_ones)
292
 
293
- # 4) Format textual results grouped by domain
294
  groups = {
295
  "Concentration_and_Decision_Making": ["YOWRCONC", "YOWRDCSN"],
296
  "Sleep_and_Energy_Levels": ["YOWRHRS", "YO_MDEA5", "YOWRELES", "YO_MDEA2"],
@@ -302,20 +264,20 @@ def predict(
302
  "YODPPROB", "YOWRPROB", "YODPR2WK", "YODSMMDE", "YOPB2WK"
303
  ]
304
  }
305
- # Build text for each label in the order they appear in model_filenames
306
  group_text = {g: [] for g in groups}
 
307
  for i, arr in enumerate(predictions):
308
  label_col = model_filenames[i].split('.')[0] # e.g. "YOWRCONC"
309
  val = arr[0]
310
- # Map prediction to text if possible
311
  if label_col in predictor.prediction_map and val in range(len(predictor.prediction_map[label_col])):
312
  text_label = predictor.prediction_map[label_col][val]
313
  else:
314
  text_label = f"Prediction={val}"
315
 
316
- # Place it in whichever group
317
- for group_name, gcols in groups.items():
318
- if label_col in gcols:
319
  group_text[group_name].append(f"{label_col} => {text_label}")
320
  break
321
 
@@ -325,20 +287,20 @@ def predict(
325
  gtitle = gname.replace("_", " ")
326
  final_str_parts.append(f"**{gtitle}**")
327
  final_str_parts.append("\n".join(lines))
328
- final_str_parts.append("") # blank line
329
  if not final_str_parts:
330
  final_str = "No predictions made or no matching group columns."
331
  else:
332
  final_str = "\n".join(final_str_parts)
333
 
334
- # 5) Overall patient count
335
  total_count = len(df)
336
  total_count_md = f"We have **{total_count}** patients in the dataset."
337
 
338
- # 6) Nearest Neighbors summary
339
  nn_md = get_nearest_neighbors_info(user_df, k=5)
340
 
341
- # 7) Bar chart for input features
342
  input_counts = {}
343
  for col, val_ in user_input_dict.items():
344
  matched = len(df[df[col] == val_])
@@ -351,14 +313,12 @@ def predict(
351
  )
352
  fig_in.update_layout(width=1200, height=400)
353
 
354
- # 8) Bar chart for predicted labels
355
- # For each model’s label_col, see how many in df have the same predicted value
356
  label_counts = {}
357
  for i, arr in enumerate(predictions):
358
  lbl = model_filenames[i].split('.')[0]
359
  pred_val = arr[0]
360
  if lbl in df.columns:
361
- # How many in df have this same value
362
  label_counts[lbl] = len(df[df[lbl] == pred_val])
363
  if label_counts:
364
  bar_lbl_df = pd.DataFrame({
@@ -379,8 +339,8 @@ def predict(
379
  severity_msg, # 2) Mental Health Severity
380
  total_count_md, # 3) Total Patient Count
381
  nn_md, # 4) Nearest Neighbors Summary
382
- fig_in, # 5) Bar Chart for input features
383
- fig_lbl # 6) Bar Chart for predicted labels
384
  )
385
 
386
 
@@ -388,9 +348,6 @@ def predict(
388
  # 6) EXTRA TABS / FUNCTIONS
389
  ######################################
390
  def distribution_plot(feature_col, label_col):
391
- """
392
- Creates a bar chart grouping by [feature_col, label_col], showing counts.
393
- """
394
  if not feature_col or not label_col:
395
  return px.bar(title="Please select both Feature and Label.")
396
  if (feature_col not in df.columns) or (label_col not in df.columns):
@@ -409,9 +366,6 @@ def distribution_plot(feature_col, label_col):
409
 
410
 
411
  def co_occurrence_plot(feature1, feature2, label_col):
412
- """
413
- Similar approach but grouping by [feature1, feature2, label_col].
414
- """
415
  if not feature1 or not feature2 or not label_col:
416
  return px.bar(title="Please select all three fields.")
417
  if feature1 not in df.columns or feature2 not in df.columns or label_col not in df.columns:
@@ -437,127 +391,104 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
437
 
438
  # ======== TAB 1: PREDICTION ========
439
  with gr.Tab("Prediction"):
440
- inputs = [
441
- gr.Dropdown(
442
- list(input_mapping['YMDESUD5ANYO'].keys()),
443
- label="YMDESUD5ANYO: ONLY MDE, ONLY SUD, BOTH, OR NEITHER-ANY"
444
- ),
445
- gr.Dropdown(
446
- list(input_mapping['YMDEHPO'].keys()),
447
- label="YMDEHPO: Saw health prof only for MDE in past years?"
448
- ),
449
- gr.Dropdown(
450
- list(input_mapping['YMDETXRX'].keys()),
451
- label="YMDETXRX: Received treatment/counseling if saw doc/prof for MDE?"
452
- ),
453
- gr.Dropdown(
454
- list(input_mapping['LVLDIFMEM2'].keys()),
455
- label="LVLDIFMEM2: Difficulty remembering / concentrating?"
456
- ),
457
- gr.Dropdown(
458
- list(input_mapping['YMSUD5YANY'].keys()),
459
- label="YMSUD5YANY: Past-year MDE & substance use disorder?"
460
- ),
461
-
462
- gr.Dropdown(
463
- list(input_mapping['YPSY2MDE'].keys()),
464
- label="YPSY2MDE: Saw/talked to psychiatrist about MDE?"
465
- ),
466
- gr.Dropdown(
467
- list(input_mapping['YMDELT'].keys()),
468
- label="YMDELT: Had major depressive episode in lifetime?"
469
- ),
470
- gr.Dropdown(
471
- list(input_mapping['YDOCMDE'].keys()),
472
- label="YDOCMDE: Saw/talked to general practitioner/family MD about MDE?"
473
- ),
474
- gr.Dropdown(
475
- list(input_mapping['YMIMI5YANY'].keys()),
476
- label="YMIMI5YANY: Past-year MDE with severe impairment & illicit drug use?"
477
- ),
478
- gr.Dropdown(
479
- list(input_mapping['YMDEHARX'].keys()),
480
- label="YMDEHARX: Saw health professional & received medication for MDE?"
481
- ),
482
-
483
- gr.Dropdown(
484
- list(input_mapping['MDEIMPY'].keys()),
485
- label="MDEIMPY: MDE with severe role impairment?"
486
- ),
487
- gr.Dropdown(
488
- list(input_mapping['YRXMDEYR'].keys()),
489
- label="YRXMDEYR: Used received medication for MDE in past years?"
490
- ),
491
- gr.Dropdown(
492
- list(input_mapping['YMDERSUD5ANY'].keys()),
493
- label="YMDERSUD5ANY: MDE or substance use disorder - past year?"
494
- ),
495
- gr.Dropdown(
496
- list(input_mapping['YMIMS5YANY'].keys()),
497
- label="YMIMS5YANY: Past-year MDE + severe impairment + substance use?"
498
- ),
499
- gr.Dropdown(
500
- list(input_mapping['YMDEYR'].keys()),
501
- label="YMDEYR: Past-year major depressive episode?"
502
- ),
503
-
504
- gr.Dropdown(
505
- list(input_mapping['YHLTMDE'].keys()),
506
- label="YHLTMDE: Saw/talk to health professional about MDE in past year?"
507
- ),
508
- gr.Dropdown(
509
- list(input_mapping['YUSUIPLNYR'].keys()),
510
- label="YUSUIPLNYR: Made plans to kill self in past year?"
511
- ),
512
- gr.Dropdown(
513
- list(input_mapping['YMDEHPRX'].keys()),
514
- label="YMDEHPRX: Saw health prof or received med for MDE in past year?"
515
- ),
516
- gr.Dropdown(
517
- list(input_mapping['YUSUIPLN'].keys()),
518
- label="YUSUIPLN: Make plans to kill yourself in past 12 months?"
519
- ),
520
- gr.Dropdown(
521
- list(input_mapping['YPSY1MDE'].keys()),
522
- label="YPSY1MDE: Saw/talked to psychologist about MDE in past year?"
523
- ),
524
-
525
- gr.Dropdown(
526
- list(input_mapping['YMIUD5YANY'].keys()),
527
- label="YMIUD5YANY: Past-year MDE & illicit drug use disorder?"
528
- ),
529
- gr.Dropdown(
530
- list(input_mapping['YUSUITHK'].keys()),
531
- label="YUSUITHK: Youth seriously think about killing self in past 12 months?"
532
- ),
533
- gr.Dropdown(
534
- list(input_mapping['YTXMDEYR'].keys()),
535
- label="YTXMDEYR: Saw or talk to doc/health prof for MDE in past year?"
536
- ),
537
- gr.Dropdown(
538
- list(input_mapping['YCOUNMDE'].keys()),
539
- label="YCOUNMDE: Saw/talk to counselor about MDE in past year?"
540
- ),
541
- gr.Dropdown(
542
- list(input_mapping['YUSUITHKYR'].keys()),
543
- label="YUSUITHKYR: Seriously thought about killing self?"
544
- )
545
  ]
546
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
547
  predict_btn = gr.Button("Predict")
548
 
549
- # 6 outputs now
550
  out_pred_res = gr.Textbox(label="Prediction Results", lines=8)
551
  out_sev = gr.Textbox(label="Mental Health Severity", lines=2)
552
  out_count = gr.Markdown(label="Total Patient Count")
553
- out_nn = gr.Markdown(label="Nearest Neighbors Summary")
554
  out_bar_input= gr.Plot(label="Input Feature Counts")
555
  out_bar_label= gr.Plot(label="Predicted Label Counts")
556
 
557
- # Wire up the button
558
  predict_btn.click(
559
  fn=predict,
560
- inputs=inputs,
561
  outputs=[
562
  out_pred_res, # 1
563
  out_sev, # 2
@@ -571,10 +502,7 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
571
  # ======== TAB 2: Distribution Analysis ========
572
  with gr.Tab("Distribution Analysis"):
573
  gr.Markdown("## Distribution Plot\nSelect one feature and one label column to see bar counts.")
574
- # 1) We gather the 'input features' from input_mapping keys:
575
  list_of_features = sorted(input_mapping.keys())
576
-
577
- # 2) We gather the 'label columns' from predictor.prediction_map keys:
578
  list_of_labels = sorted(predictor.prediction_map.keys())
579
 
580
  feat_dd = gr.Dropdown(choices=list_of_features, label="Feature Column")
@@ -606,5 +534,5 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
606
  outputs=co_occ_output
607
  )
608
 
609
- # Finally, launch the Gradio interface
610
  demo.launch()
 
27
  self.model_filenames = model_filenames
28
  self.models = self.load_models()
29
 
 
 
30
  self.prediction_map = {
31
  "YOWRCONC": ["Did not have difficulty concentrating", "Had difficulty concentrating"],
32
  "YOSEEDOC": ["Did not feel the need to see a doctor", "Felt the need to see a doctor"],
33
  "YOWRHRS": ["Did not have trouble sleeping", "Had trouble sleeping"],
34
  "YO_MDEA5": ["Others did not notice restlessness/lethargy", "Others noticed restlessness/lethargy"],
35
  "YOWRCHR": ["Did not feel so sad nothing could cheer up", "Felt so sad that nothing could cheer up"],
36
+ "YOWRLSIN": ["Did not feel bored / lose interest", "Felt bored / lost interest"],
 
 
 
37
  "YODPPROB": ["No other problems for 2+ weeks", "Had other problems for 2+ weeks"],
38
  "YOWRPROB": ["No 'worst time ever' feeling", "Had 'worst time ever' feeling"],
39
+ "YODPR2WK": ["No periods with depressed feelings lasting 2+ weeks", "Had depressed feelings 2+ weeks"],
 
 
 
40
  "YOWRDEPR": ["Did not feel sad/depressed mostly everyday", "Felt sad/depressed mostly everyday"],
41
+ "YODPDISC": ["Overall mood not sad/depressed", "Overall mood was sad/depressed"],
42
+ "YOLOSEV": ["Did not lose interest", "Lost interest in enjoyable things"],
 
 
 
 
 
 
43
  "YOWRDCSN": ["Was able to make decisions", "Was unable to make decisions"],
44
+ "YODSMMDE": ["Never had 2 weeks depression symptoms", "Had 2+ weeks of depression symptoms"],
45
+ "YO_MDEA3": ["No changes in appetite/weight", "Had changes in appetite/weight"],
46
+ "YODPLSIN": ["Never lost interest / felt bored", "Lost interest/felt bored"],
 
 
 
 
 
 
47
  "YOWRELES": ["Did not eat less than usual", "Ate less than usual"],
48
  "YODSCEV": ["Fewer severe depression symptoms", "More severe depression symptoms"],
49
+ "YOPB2WK": ["No uneasy feelings lasting 2+ weeks", "Uneasy feelings lasting 2+ weeks"],
50
+ "YO_MDEA2": ["No physical/mental issues (2+ weeks)", "Had physical/mental issues (2+ weeks)"]
 
 
 
 
 
 
51
  }
52
 
53
  def load_models(self):
 
59
  return loaded
60
 
61
  def make_predictions(self, user_input: pd.DataFrame):
 
 
 
 
62
  predictions = []
63
  for model in self.models:
64
  out = model.predict(user_input)
 
66
  return predictions
67
 
68
  def get_majority_vote(self, predictions):
 
 
 
 
69
  combined = np.concatenate(predictions)
70
  return np.bincount(combined).argmax()
71
 
72
  def evaluate_severity(self, count_ones: int) -> str:
 
 
 
 
 
73
  if count_ones >= 13:
74
  return "Mental Health Severity: Severe"
75
  elif count_ones >= 9:
 
84
 
85
 
86
  ######################################
87
+ # 3) FEATURE CATEGORIES + MAPPING
88
  ######################################
89
+ categories_dict = {
90
+ "1. Depression & Substance Use Diagnosis": [
91
+ "YMDESUD5ANYO", "YMDELT", "YMDEYR", "YMDERSUD5ANY",
92
+ "YMSUD5YANY", "YMIUD5YANY", "YMIMS5YANY", "YMIMI5YANY"
93
+ ],
94
+ "2. Mental Health Treatment & Prof Consultation": [
95
+ "YMDEHPO", "YMDETXRX", "YMDEHARX", "YRXMDEYR", "YHLTMDE",
96
+ "YTXMDEYR", "YDOCMDE", "YPSY2MDE", "YPSY1MDE", "YCOUNMDE"
97
+ ],
98
+ "3. Functional & Cognitive Impairment": [
99
+ "MDEIMPY", "LVLDIFMEM2"
100
+ ],
101
+ "4. Suicidal Thoughts & Behaviors": [
102
+ "YUSUITHK", "YUSUITHKYR", "YUSUIPLNYR", "YUSUIPLN"
103
+ ]
104
+ }
105
 
106
+ # The numeric mappings for each of the 25 features
107
+ input_mapping = {
108
+ 'YMDESUD5ANYO': {"SUD only, no MDE": 1, "MDE only, no SUD": 2, "SUD and MDE": 3, "Neither SUD or MDE": 4},
109
+ 'YMDELT': {"Yes": 1, "No": 2},
110
+ 'YMDEYR': {"Yes": 1, "No": 2},
111
+ 'YMDERSUD5ANY': {"Yes": 1, "No": 0},
112
+ 'YMSUD5YANY': {"Yes": 1, "No": 0},
113
+ 'YMIUD5YANY': {"Yes": 1, "No": 0},
114
+ 'YMIMS5YANY': {"Yes": 1, "No": 0},
115
+ 'YMIMI5YANY': {"Yes": 1, "No": 0},
116
 
117
+ 'YMDEHPO': {"Yes": 1, "No": 0},
118
+ 'YMDETXRX': {"Yes": 1, "No": 0},
119
+ 'YMDEHARX': {"Yes": 1, "No": 0},
120
+ 'YRXMDEYR': {"Yes": 1, "No": 0},
121
+ 'YHLTMDE': {"Yes": 1, "No": 0},
122
+ 'YTXMDEYR': {"Yes": 1, "No": 0},
123
+ 'YDOCMDE': {"Yes": 1, "No": 0},
124
+ 'YPSY2MDE': {"Yes": 1, "No": 0},
125
+ 'YPSY1MDE': {"Yes": 1, "No": 0},
126
+ 'YCOUNMDE': {"Yes": 1, "No": 0},
127
 
128
+ 'MDEIMPY': {"Yes": 1, "No": 2},
129
+ 'LVLDIFMEM2': {"No Difficulty": 1, "Some difficulty": 2, "A lot of difficulty or cannot do at all": 3},
 
130
 
131
+ 'YUSUITHK': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
132
+ 'YUSUITHKYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
133
+ 'YUSUIPLNYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
134
+ 'YUSUIPLN': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
135
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
 
 
 
 
138
  def validate_inputs(*args):
139
  for arg in args:
140
  if not arg: # empty or None
141
  return False
142
  return True
143
 
144
+
145
+ ######################################
146
+ # 4) NEAREST NEIGHBORS (Grouped)
147
+ ######################################
148
+ def get_nearest_neighbors_info(user_input_df: pd.DataFrame, k=5):
149
+ # Ensure columns exist in df
150
+ user_cols = user_input_df.columns
151
+ if not all(col in df.columns for col in user_cols):
152
+ return "Cannot compute nearest neighbors. Some columns not found in df."
153
+
154
+ # Subset df
155
+ sub_df = df[list(user_cols)].copy()
156
+ diffs = sub_df - user_input_df.iloc[0]
157
+ dists = (diffs**2).sum(axis=1)**0.5
158
+ nn_indices = dists.nsmallest(k).index
159
+ neighbors = df.loc[nn_indices]
160
+
161
+ lines = [f"**Nearest Neighbors (k={k})**",
162
+ f"Distances Range: {dists[nn_indices].min():.2f} to {dists[nn_indices].max():.2f}",
163
+ ""]
164
+
165
+ # Group the features by our categories_dict
166
+ for cat_name, cat_feats in categories_dict.items():
167
+ lines.append(f"### {cat_name}")
168
+ for feat in cat_feats:
169
+ if feat not in neighbors.columns:
170
+ continue
171
+ # Count how many neighbors had each numeric value
172
+ val_counts = neighbors[feat].value_counts().to_dict()
173
+ # Build string like: "YMDESUD5ANYO => 3 had 1, 2 had 2..."
174
+ parts = []
175
+ for val_, count_ in val_counts.items():
176
+ parts.append(f"{count_} had '{val_}'")
177
+ joined = "; ".join(parts)
178
+ lines.append(f"**{feat}** => {joined}")
179
+ lines.append("") # blank line
180
+
181
+ return "\n".join(lines)
182
 
183
 
184
  ######################################
185
+ # 5) PREDICT FUNCTION
186
  ######################################
187
  def predict(
188
+ # EXACTLY 25 features, matching categories_dict ordering.
189
+ # We'll just list them in the dictionary order we want to show them:
190
+ YMDESUD5ANYO, YMDELT, YMDEYR, YMDERSUD5ANY,
191
+ YMSUD5YANY, YMIUD5YANY, YMIMS5YANY, YMIMI5YANY,
192
+
193
+ YMDEHPO, YMDETXRX, YMDEHARX, YRXMDEYR, YHLTMDE,
194
+ YTXMDEYR, YDOCMDE, YPSY2MDE, YPSY1MDE, YCOUNMDE,
195
+
196
+ MDEIMPY, LVLDIFMEM2,
197
+
198
+ YUSUITHK, YUSUITHKYR, YUSUIPLNYR, YUSUIPLN
199
  ):
 
200
  if not validate_inputs(
201
+ YMDESUD5ANYO, YMDELT, YMDEYR, YMDERSUD5ANY,
202
+ YMSUD5YANY, YMIUD5YANY, YMIMS5YANY, YMIMI5YANY,
203
+ YMDEHPO, YMDETXRX, YMDEHARX, YRXMDEYR, YHLTMDE,
204
+ YTXMDEYR, YDOCMDE, YPSY2MDE, YPSY1MDE, YCOUNMDE,
205
+ MDEIMPY, LVLDIFMEM2,
206
+ YUSUITHK, YUSUITHKYR, YUSUIPLNYR, YUSUIPLN
207
  ):
208
  return (
209
  "Please select all required fields.", # 1) Prediction Results
210
+ "Validation Error", # 2) Severity
211
+ "No data", # 3) Total Count
212
+ "No nearest neighbors info", # 4) NN Summary
213
+ None, # 5) Bar chart (Input)
214
+ None # 6) Bar chart (Labels)
215
  )
216
 
217
+ # 1) Map user-friendly -> numeric
218
  user_input_dict = {
219
  'YMDESUD5ANYO': input_mapping['YMDESUD5ANYO'][YMDESUD5ANYO],
 
 
 
 
 
220
  'YMDELT': input_mapping['YMDELT'][YMDELT],
221
+ 'YMDEYR': input_mapping['YMDEYR'][YMDEYR],
222
+ 'YMDERSUD5ANY': input_mapping['YMDERSUD5ANY'][YMDERSUD5ANY],
223
+ 'YMSUD5YANY': input_mapping['YMSUD5YANY'][YMSUD5YANY],
224
+ 'YMIUD5YANY': input_mapping['YMIUD5YANY'][YMIUD5YANY],
225
+ 'YMIMS5YANY': input_mapping['YMIMS5YANY'][YMIMS5YANY],
226
  'YMIMI5YANY': input_mapping['YMIMI5YANY'][YMIMI5YANY],
227
+
228
+ 'YMDEHPO': input_mapping['YMDEHPO'][YMDEHPO],
229
+ 'YMDETXRX': input_mapping['YMDETXRX'][YMDETXRX],
230
  'YMDEHARX': input_mapping['YMDEHARX'][YMDEHARX],
 
231
  'YRXMDEYR': input_mapping['YRXMDEYR'][YRXMDEYR],
 
 
 
232
  'YHLTMDE': input_mapping['YHLTMDE'][YHLTMDE],
 
 
 
 
 
 
233
  'YTXMDEYR': input_mapping['YTXMDEYR'][YTXMDEYR],
234
+ 'YDOCMDE': input_mapping['YDOCMDE'][YDOCMDE],
235
+ 'YPSY2MDE': input_mapping['YPSY2MDE'][YPSY2MDE],
236
+ 'YPSY1MDE': input_mapping['YPSY1MDE'][YPSY1MDE],
237
  'YCOUNMDE': input_mapping['YCOUNMDE'][YCOUNMDE],
238
+
239
+ 'MDEIMPY': input_mapping['MDEIMPY'][MDEIMPY],
240
+ 'LVLDIFMEM2': input_mapping['LVLDIFMEM2'][LVLDIFMEM2],
241
+
242
+ 'YUSUITHK': input_mapping['YUSUITHK'][YUSUITHK],
243
+ 'YUSUITHKYR': input_mapping['YUSUITHKYR'][YUSUITHKYR],
244
+ 'YUSUIPLNYR': input_mapping['YUSUIPLNYR'][YUSUIPLNYR],
245
+ 'YUSUIPLN': input_mapping['YUSUIPLN'][YUSUIPLN]
246
  }
247
  user_df = pd.DataFrame(user_input_dict, index=[0])
248
 
249
+ # 2) Predict
250
+ predictions = predictor.make_predictions(user_df)
 
 
251
  all_preds = np.concatenate(predictions)
 
 
252
  count_ones = sum(all_preds == 1)
 
253
  severity_msg = predictor.evaluate_severity(count_ones)
254
 
255
+ # 3) Grouped textual results
256
  groups = {
257
  "Concentration_and_Decision_Making": ["YOWRCONC", "YOWRDCSN"],
258
  "Sleep_and_Energy_Levels": ["YOWRHRS", "YO_MDEA5", "YOWRELES", "YO_MDEA2"],
 
264
  "YODPPROB", "YOWRPROB", "YODPR2WK", "YODSMMDE", "YOPB2WK"
265
  ]
266
  }
 
267
  group_text = {g: [] for g in groups}
268
+ # The model_filenames order determines which label is i
269
  for i, arr in enumerate(predictions):
270
  label_col = model_filenames[i].split('.')[0] # e.g. "YOWRCONC"
271
  val = arr[0]
272
+ # If we have a textual map, use it
273
  if label_col in predictor.prediction_map and val in range(len(predictor.prediction_map[label_col])):
274
  text_label = predictor.prediction_map[label_col][val]
275
  else:
276
  text_label = f"Prediction={val}"
277
 
278
+ # Put in whichever group
279
+ for group_name, cols_ in groups.items():
280
+ if label_col in cols_:
281
  group_text[group_name].append(f"{label_col} => {text_label}")
282
  break
283
 
 
287
  gtitle = gname.replace("_", " ")
288
  final_str_parts.append(f"**{gtitle}**")
289
  final_str_parts.append("\n".join(lines))
290
+ final_str_parts.append("")
291
  if not final_str_parts:
292
  final_str = "No predictions made or no matching group columns."
293
  else:
294
  final_str = "\n".join(final_str_parts)
295
 
296
+ # 4) Additional info
297
  total_count = len(df)
298
  total_count_md = f"We have **{total_count}** patients in the dataset."
299
 
300
+ # 5) Nearest Neighbors
301
  nn_md = get_nearest_neighbors_info(user_df, k=5)
302
 
303
+ # 6) Bar chart for input features
304
  input_counts = {}
305
  for col, val_ in user_input_dict.items():
306
  matched = len(df[df[col] == val_])
 
313
  )
314
  fig_in.update_layout(width=1200, height=400)
315
 
316
+ # 7) Bar chart for predicted labels
 
317
  label_counts = {}
318
  for i, arr in enumerate(predictions):
319
  lbl = model_filenames[i].split('.')[0]
320
  pred_val = arr[0]
321
  if lbl in df.columns:
 
322
  label_counts[lbl] = len(df[df[lbl] == pred_val])
323
  if label_counts:
324
  bar_lbl_df = pd.DataFrame({
 
339
  severity_msg, # 2) Mental Health Severity
340
  total_count_md, # 3) Total Patient Count
341
  nn_md, # 4) Nearest Neighbors Summary
342
+ fig_in, # 5) Bar Chart (input features)
343
+ fig_lbl # 6) Bar Chart (labels)
344
  )
345
 
346
 
 
348
  # 6) EXTRA TABS / FUNCTIONS
349
  ######################################
350
  def distribution_plot(feature_col, label_col):
 
 
 
351
  if not feature_col or not label_col:
352
  return px.bar(title="Please select both Feature and Label.")
353
  if (feature_col not in df.columns) or (label_col not in df.columns):
 
366
 
367
 
368
  def co_occurrence_plot(feature1, feature2, label_col):
 
 
 
369
  if not feature1 or not feature2 or not label_col:
370
  return px.bar(title="Please select all three fields.")
371
  if feature1 not in df.columns or feature2 not in df.columns or label_col not in df.columns:
 
391
 
392
  # ======== TAB 1: PREDICTION ========
393
  with gr.Tab("Prediction"):
394
+ gr.Markdown(
395
+ """
396
+ ### Please provide inputs in each of the four categories below.
397
+ *All fields are required.*
398
+ """
399
+ )
400
+
401
+ # For clarity, we define an ordered list of the features in the exact sequence
402
+ # matching our predict() function. We’ll group them under the same headings.
403
+ cat1_col_labels = [
404
+ ("YMDESUD5ANYO", "YMDESUD5ANYO: ONLY MDE, ONLY SUD, BOTH, OR NEITHER"),
405
+ ("YMDELT", "YMDELT: Had major depressive episode in lifetime"),
406
+ ("YMDEYR", "YMDEYR: Past-year major depressive episode"),
407
+ ("YMDERSUD5ANY", "YMDERSUD5ANY: MDE or substance use disorder - past year"),
408
+ ("YMSUD5YANY", "YMSUD5YANY: Past-year MDE & substance use disorder"),
409
+ ("YMIUD5YANY", "YMIUD5YANY: Past-year MDE & illicit drug use disorder"),
410
+ ("YMIMS5YANY", "YMIMS5YANY: Past-year MDE + severe impairment + substance use"),
411
+ ("YMIMI5YANY", "YMIMI5YANY: Past-year MDE with severe impairment & illicit drug use")
412
+ ]
413
+ cat2_col_labels = [
414
+ ("YMDEHPO", "YMDEHPO: Saw health prof only for MDE in past year"),
415
+ ("YMDETXRX", "YMDETXRX: Received treatment/counseling if saw doc/prof for MDE"),
416
+ ("YMDEHARX", "YMDEHARX: Saw health professional & received medication for MDE"),
417
+ ("YRXMDEYR", "YRXMDEYR: Used received medication for MDE in past years"),
418
+ ("YHLTMDE", "YHLTMDE: Saw/talked to health professional about MDE in past year"),
419
+ ("YTXMDEYR", "YTXMDEYR: Saw or talked to doc/health prof for MDE in past year"),
420
+ ("YDOCMDE", "YDOCMDE: Saw/talked to general practitioner/family MD about MDE"),
421
+ ("YPSY2MDE", "YPSY2MDE: Saw/talked to psychiatrist about MDE"),
422
+ ("YPSY1MDE", "YPSY1MDE: Saw/talked to psychologist about MDE"),
423
+ ("YCOUNMDE", "YCOUNMDE: Saw/talked to counselor about MDE")
424
+ ]
425
+ cat3_col_labels = [
426
+ ("MDEIMPY", "MDEIMPY: MDE with severe role impairment"),
427
+ ("LVLDIFMEM2", "LVLDIFMEM2: Difficulty remembering/concentrating")
428
+ ]
429
+ cat4_col_labels = [
430
+ ("YUSUITHK", "YUSUITHK: Youth seriously think about killing self in past 12 months"),
431
+ ("YUSUITHKYR", "YUSUITHKYR: Seriously thought about killing self"),
432
+ ("YUSUIPLNYR", "YUSUIPLNYR: Made plans to kill self in past year"),
433
+ ("YUSUIPLN", "YUSUIPLN: Made plans to kill yourself in past 12 months")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
  ]
435
 
436
+ # Category 1 block
437
+ gr.Markdown("#### 1. Depression & Substance Use Diagnosis")
438
+ cat1_inputs = []
439
+ for col, label_text in cat1_col_labels:
440
+ dd = gr.Dropdown(
441
+ choices=list(input_mapping[col].keys()),
442
+ label=label_text
443
+ )
444
+ cat1_inputs.append(dd)
445
+
446
+ # Category 2 block
447
+ gr.Markdown("#### 2. Mental Health Treatment & Professional Consultation")
448
+ cat2_inputs = []
449
+ for col, label_text in cat2_col_labels:
450
+ dd = gr.Dropdown(
451
+ choices=list(input_mapping[col].keys()),
452
+ label=label_text
453
+ )
454
+ cat2_inputs.append(dd)
455
+
456
+ # Category 3 block
457
+ gr.Markdown("#### 3. Functional & Cognitive Impairment")
458
+ cat3_inputs = []
459
+ for col, label_text in cat3_col_labels:
460
+ dd = gr.Dropdown(
461
+ choices=list(input_mapping[col].keys()),
462
+ label=label_text
463
+ )
464
+ cat3_inputs.append(dd)
465
+
466
+ # Category 4 block
467
+ gr.Markdown("#### 4. Suicidal Thoughts & Behaviors")
468
+ cat4_inputs = []
469
+ for col, label_text in cat4_col_labels:
470
+ dd = gr.Dropdown(
471
+ choices=list(input_mapping[col].keys()),
472
+ label=label_text
473
+ )
474
+ cat4_inputs.append(dd)
475
+
476
+ # The overall input list must match the order in `predict()`
477
+ all_inputs = cat1_inputs + cat2_inputs + cat3_inputs + cat4_inputs
478
+
479
  predict_btn = gr.Button("Predict")
480
 
481
+ # 6 outputs
482
  out_pred_res = gr.Textbox(label="Prediction Results", lines=8)
483
  out_sev = gr.Textbox(label="Mental Health Severity", lines=2)
484
  out_count = gr.Markdown(label="Total Patient Count")
485
+ out_nn = gr.Markdown(label="Nearest Neighbors Summary (Grouped by Category)")
486
  out_bar_input= gr.Plot(label="Input Feature Counts")
487
  out_bar_label= gr.Plot(label="Predicted Label Counts")
488
 
 
489
  predict_btn.click(
490
  fn=predict,
491
+ inputs=all_inputs,
492
  outputs=[
493
  out_pred_res, # 1
494
  out_sev, # 2
 
502
  # ======== TAB 2: Distribution Analysis ========
503
  with gr.Tab("Distribution Analysis"):
504
  gr.Markdown("## Distribution Plot\nSelect one feature and one label column to see bar counts.")
 
505
  list_of_features = sorted(input_mapping.keys())
 
 
506
  list_of_labels = sorted(predictor.prediction_map.keys())
507
 
508
  feat_dd = gr.Dropdown(choices=list_of_features, label="Feature Column")
 
534
  outputs=co_occ_output
535
  )
536
 
537
+ # Finally, launch
538
  demo.launch()