pantdipendra commited on
Commit
3b96ce2
·
verified ·
1 Parent(s): 69090fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +283 -395
app.py CHANGED
@@ -1,14 +1,24 @@
1
  import pickle
2
- import gradio as gr
3
  import numpy as np
4
  import pandas as pd
5
  import plotly.express as px
 
6
 
7
- # Load the training CSV once.
 
 
8
  df = pd.read_csv("X_train_Y_Train_merged_train.csv")
9
 
 
 
 
 
 
 
 
 
10
  ######################################
11
- # 1) MODEL PREDICTOR CLASS
12
  ######################################
13
  class ModelPredictor:
14
  def __init__(self, model_path, model_filenames):
@@ -17,106 +27,83 @@ class ModelPredictor:
17
  self.models = self.load_models()
18
  # Mapping from label column to human-readable strings for 0/1
19
  self.prediction_map = {
20
- "YOWRCONC": ["No difficulty concentrating", "Had difficulty concentrating"],
21
- "YOSEEDOC": ["No need to see doctor", "Needed to see doctor"],
22
- "YOWRHRS": ["No trouble sleeping", "Had trouble sleeping"],
23
- "YO_MDEA5": ["Others didn't notice restlessness", "Others noticed restlessness"],
24
- "YOWRCHR": ["Not sad beyond cheering", "Felt so sad no one could cheer up"],
25
- "YOWRLSIN": ["Never felt bored/lost interest", "Felt bored/lost interest"],
26
  "YODPPROB": ["No other problems for 2+ weeks", "Had other problems for 2+ weeks"],
27
- "YOWRPROB": ["No worst time feeling", "Felt worst time ever"],
28
- "YODPR2WK": ["No depressed feelings for 2+ wks", "Depressed feelings for 2+ wks"],
29
- "YOWRDEPR": ["Not sad or depressed most days", "Sad or depressed most days"],
30
  "YODPDISC": ["Mood not depressed overall", "Mood depressed overall (discrepancy)"],
31
- "YOLOSEV": ["Did not lose interest in activities", "Lost interest in activities"],
32
- "YOWRDCSN": ["Could make decisions", "Could not make decisions"],
33
- "YODSMMDE": ["No 2+ week depression episodes", "Had 2+ week depression episodes"],
34
- "YO_MDEA3": ["No appetite/weight changes", "Yes appetite/weight changes"],
35
- "YODPLSIN": ["Never bored/lost interest", "Often bored/lost interest"],
36
- "YOWRELES": ["Did not eat less", "Ate less than usual"],
37
  "YODSCEV": ["Fewer severe symptoms", "More severe symptoms"],
38
- "YOPB2WK": ["No uneasy feelings daily 2+ wks", "Uneasy feelings daily 2+ wks"],
39
- "YO_MDEA2": ["No issues physical/mental daily", "Issues physical/mental daily 2+ wks"]
40
  }
41
 
42
  def load_models(self):
43
  models = []
44
- for fn in self.model_filenames:
45
- filepath = self.model_path + fn
46
- with open(filepath, "rb") as file:
47
- models.append(pickle.load(file))
 
48
  return models
49
 
50
  def make_predictions(self, user_input):
51
- """Return list of numpy arrays, each array either [0] or [1]."""
52
- preds = []
53
- for m in self.models:
54
- out = m.predict(user_input)
55
- preds.append(np.array(out).flatten())
56
- return preds
57
 
58
  def get_majority_vote(self, predictions):
59
- """Flatten all predictions and find 0 or 1 with majority."""
60
  combined = np.concatenate(predictions)
61
- return np.bincount(combined).argmax()
 
 
62
 
63
  def evaluate_severity(self, majority_vote_count):
64
- """Heuristic: Based on 16 total models, 0-4=Very Low, 5-8=Low, 9-12=Moderate, 13-16=Severe."""
65
  if majority_vote_count >= 13:
66
- return "Mental health severity: Severe"
67
  elif majority_vote_count >= 9:
68
- return "Mental health severity: Moderate"
69
  elif majority_vote_count >= 5:
70
- return "Mental health severity: Low"
71
  else:
72
- return "Mental health severity: Very Low"
73
-
74
- ######################################
75
- # 2) CONFIGURATIONS
76
- ######################################
77
- model_filenames = [
78
- "YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl",
79
- "YODPPROB.pkl", "YOWRPROB.pkl", "YODPR2WK.pkl", "YOWRDEPR.pkl",
80
- "YODPDISC.pkl", "YOLOSEV.pkl", "YOWRDCSN.pkl", "YODSMMDE.pkl",
81
- "YO_MDEA3.pkl", "YODPLSIN.pkl", "YOWRELES.pkl", "YOPB2WK.pkl"
82
- ]
83
- model_path = "models/"
84
- predictor = ModelPredictor(model_path, model_filenames)
85
 
86
  ######################################
87
- # 3) INPUT VALIDATION
88
  ######################################
89
  def validate_inputs(*args):
90
- # Just ensure all required (non-co-occurrence) fields are picked
91
  for arg in args:
92
  if arg == '' or arg is None:
93
  return False
94
  return True
95
 
96
  ######################################
97
- # 4) PREDICTION FUNCTION
98
  ######################################
 
 
99
  def predict(
100
- # Original required features
101
  YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
102
  YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
103
  YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE,
104
  YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
105
- YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR,
106
- # **New** optional picks for co-occurrence
107
- co_occ_feature1, co_occ_feature2, co_occ_label
108
  ):
109
- """
110
- Main function that:
111
- - Predicts with the 16 models
112
- - Aggregates results
113
- - Produces severity
114
- - Returns distribution & bar charts
115
- - Finds K=2 Nearest Neighbors
116
- - Produces *one* co-occurrence plot based on user-chosen columns
117
- """
118
-
119
- # 1) Build user_input for models
120
  user_input_data = {
121
  'YNURSMDE': [int(YNURSMDE)],
122
  'YMDEYR': [int(YMDEYR)],
@@ -150,21 +137,21 @@ def predict(
150
  }
151
  user_input = pd.DataFrame(user_input_data)
152
 
153
- # 2) Model Predictions
154
  predictions = predictor.make_predictions(user_input)
 
 
155
  majority_vote = predictor.get_majority_vote(predictions)
156
- majority_vote_count = np.sum(np.concatenate(predictions) == 1)
157
- severity = predictor.evaluate_severity(majority_vote_count)
158
-
159
- # 3) Summarize textual results
160
- results_by_group = {
161
- "Concentration_and_Decision_Making": [],
162
- "Sleep_and_Energy_Levels": [],
163
- "Mood_and_Emotional_State": [],
164
- "Appetite_and_Weight_Changes": [],
165
- "Duration_and_Severity_of_Depression_Symptoms": []
166
- }
167
- group_map = {
168
  "Concentration_and_Decision_Making": ["YOWRCONC", "YOWRDCSN"],
169
  "Sleep_and_Energy_Levels": ["YOWRHRS", "YO_MDEA5", "YOWRELES", "YO_MDEA2"],
170
  "Mood_and_Emotional_State": ["YOWRCHR", "YOWRLSIN", "YOWRDEPR", "YODPDISC",
@@ -174,199 +161,139 @@ def predict(
174
  "YODPR2WK", "YODSMMDE",
175
  "YOPB2WK"]
176
  }
177
-
178
- # Convert each model's 0/1 to text
179
- grouped_output_lines = []
180
- for i, pred_array in enumerate(predictions):
181
- col_name = model_filenames[i].split(".")[0] # e.g., "YOWRCONC"
182
- val = pred_array[0]
183
- if col_name in predictor.prediction_map and val in [0, 1]:
184
- text = predictor.prediction_map[col_name][val]
185
- out_line = f"{col_name}: {text}"
186
  else:
187
- out_line = f"{col_name}: Prediction={val}"
188
-
189
- # Find group
190
- placed = False
191
- for g_key, g_cols in group_map.items():
192
- if col_name in g_cols:
193
- results_by_group[g_key].append(out_line)
194
- placed = True
195
  break
196
- if not placed:
197
- # If it didn't fall into any known group, skip or handle
198
  pass
199
 
200
- # Format into a single string
201
- for group_label, pred_lines in results_by_group.items():
202
- if pred_lines:
203
- grouped_output_lines.append(f"Group {group_label}:")
204
- grouped_output_lines.append("\n".join(pred_lines))
205
- grouped_output_lines.append("")
206
-
207
- if len(grouped_output_lines) == 0:
208
- final_result_text = "No predictions made. Check inputs."
209
- else:
210
- final_result_text = "\n".join(grouped_output_lines).strip()
211
-
212
- # 4) Additional Features
213
- # A) Total patient count
214
  total_patients = len(df)
215
- total_count_md = (
216
- "### Total Patient Count\n"
217
- f"**{total_patients}** total patients in the dataset."
218
- )
219
-
220
- # B) Bar chart of how many have same inputs
221
- input_counts = {}
222
- for c in user_input_data.keys():
223
- v = user_input_data[c][0]
224
- input_counts[c] = len(df[df[c] == v])
225
- df_input_counts = pd.DataFrame({"Feature": list(input_counts.keys()), "Count": list(input_counts.values())})
226
- fig_input_bar = px.bar(
227
- df_input_counts,
228
- x="Feature",
229
- y="Count",
230
- title="Number of Patients with the Same Value for Each Input Feature"
231
  )
232
- fig_input_bar.update_layout(xaxis={"categoryorder": "total descending"})
233
 
234
- # C) Bar chart for predicted labels
 
 
 
 
 
 
 
 
 
 
 
235
  label_counts = {}
236
- for i, pred_array in enumerate(predictions):
237
- col_name = model_filenames[i].split(".")[0]
238
- val = pred_array[0]
239
- if val in [0,1]:
240
- label_counts[col_name] = len(df[df[col_name] == val])
241
-
242
- if len(label_counts) > 0:
243
- df_label_counts = pd.DataFrame({
244
- "Label Column": list(label_counts.keys()),
245
- "Count": list(label_counts.values())
246
- })
247
- fig_label_bar = px.bar(
248
- df_label_counts,
249
- x="Label Column",
250
- y="Count",
251
- title="Number of Patients with the Same Predicted Label"
252
- )
253
  else:
254
- fig_label_bar = px.bar(title="No valid predicted labels to display")
255
-
256
- # D) Simple Distribution Plot (demo for first 3 labels & 4 inputs)
257
- # (Unchanged from prior approach; you can remove if you prefer.)
258
- sample_feats = list(user_input_data.keys())[:31]
259
- sample_labels = [fn.split(".")[0] for fn in model_filenames[:15]]
260
- dist_segments = []
261
- for feat in sample_feats:
 
262
  if feat not in df.columns:
263
  continue
264
- for lbl in sample_labels:
265
- if lbl not in df.columns:
266
  continue
267
- temp_g = df.groupby([feat,lbl]).size().reset_index(name="count")
268
- temp_g["feature"] = feat
269
- temp_g["label"] = lbl
270
- dist_segments.append(temp_g)
271
- if len(dist_segments) > 0:
272
- big_dist_df = pd.concat(dist_segments, ignore_index=True)
273
- fig_dist = px.bar(
274
- big_dist_df,
275
- x=big_dist_df.columns[0],
276
- y="count",
277
- color=big_dist_df.columns[1],
278
- facet_row="feature",
279
- facet_col="label",
280
- title="Sample Distribution Plot (first 4 features vs first 3 labels)"
281
- )
282
- fig_dist.update_layout(height=700)
283
- else:
284
- fig_dist = px.bar(title="No distribution plot generated (columns not found).")
285
-
286
- # E) Nearest Neighbors with K=2
287
- # We keep K=2, but for *all* label columns, we show their actual 0/1 or mapped text
288
- # (same approach as before).
289
- # ... [omitted here for brevity, or replicate your existing code for K=2 nearest neighbors] ...
290
- # We'll do a short version to keep focus on co-occ:
291
- # ---------------------------------------------------------------------
292
- # Build Hamming distance across user_input columns
293
- columns_for_distance = list(user_input.columns)
294
- sub_df = df[columns_for_distance].copy()
295
- user_row = user_input.iloc[0]
296
- distances = []
297
- for idx, row_ in sub_df.iterrows():
298
- dist_ = sum(row_[col] != user_row[col] for col in columns_for_distance)
299
- distances.append(dist_)
300
- df_dist = df.copy()
301
- df_dist["distance"] = distances
302
- # Sort ascending, pick K=2
303
- K = 2
304
- nearest_neighbors = df_dist.sort_values("distance", ascending=True).head(K)
305
-
306
- # Summarize in Markdown
307
- nn_md = ["### Nearest Neighbors (K=2)"]
308
- nn_md.append("(In a real application, you'd refine which features matter, how to encode them, etc.)\n")
309
- for irow in nearest_neighbors.itertuples():
310
- nn_md.append(f"- **Neighbor ID {irow.Index}**: distance={irow.distance}")
311
- nn_md_str = "\n".join(nn_md)
312
-
313
- # F) Co-occurrence Plot for user-chosen feature1, feature2, label
314
- # If the user picks "None" or doesn't pick valid columns, skip or fallback.
315
- if (co_occ_feature1 is not None and co_occ_feature1 != "None" and
316
- co_occ_feature2 is not None and co_occ_feature2 != "None" and
317
- co_occ_label is not None and co_occ_label != "None"):
318
- # Check if these columns are in df
319
- if (co_occ_feature1 in df.columns and
320
- co_occ_feature2 in df.columns and
321
- co_occ_label in df.columns):
322
- # Group by [co_occ_feature1, co_occ_feature2, co_occ_label]
323
- co_data = df.groupby([co_occ_feature1, co_occ_feature2, co_occ_label]).size().reset_index(name="count")
324
- fig_co_occ = px.bar(
325
- co_data,
326
- x=co_occ_feature1,
327
- y="count",
328
- color=co_occ_label,
329
- facet_col=co_occ_feature2,
330
- title=f"Co-occurrence: {co_occ_feature1} & {co_occ_feature2} vs {co_occ_label}"
331
- )
332
- else:
333
- fig_co_occ = px.bar(title="One or more selected columns not found in dataframe.")
334
  else:
335
- fig_co_occ = px.bar(title="No co-occurrence plot (choose two features + one label).")
336
 
337
- # Return all 8 outputs
 
 
 
 
 
 
 
 
338
  return (
339
- final_result_text, # (1) Predictions
340
- severity, # (2) Severity
341
- total_count_md, # (3) Total patient count
342
- fig_dist, # (4) Distribution Plot
343
- nn_md_str, # (5) Nearest Neighbors
344
- fig_co_occ, # (6) Co-occurrence
345
- fig_input_bar, # (7) Bar Chart (input features)
346
- fig_label_bar # (8) Bar Chart (labels)
347
  )
348
 
349
  ######################################
350
- # 5) MAPPING (user -> int)
351
  ######################################
352
  input_mapping = {
353
  'YNURSMDE': {"Yes": 1, "No": 0},
354
  'YMDEYR': {"Yes": 1, "No": 2},
355
  'YSOCMDE': {"Yes": 1, "No": 0},
356
- 'YMDESUD5ANYO': {"SUD only": 1, "MDE only": 2, "SUD & MDE": 3, "Neither": 4},
357
  'YMSUD5YANY': {"Yes": 1, "No": 0},
358
- 'YUSUITHK': {"Yes": 1, "No": 2, "Unsure": 3, "Don't want to answer": 4},
359
  'YMDETXRX': {"Yes": 1, "No": 0},
360
- 'YUSUITHKYR': {"Yes": 1, "No": 2, "Unsure": 3, "Don't want to answer": 4},
361
  'YMDERSUD5ANY': {"Yes": 1, "No": 0},
362
- 'YUSUIPLNYR': {"Yes": 1, "No": 2, "Unsure": 3, "Don't want to answer": 4},
363
  'YCOUNMDE': {"Yes": 1, "No": 0},
364
  'YPSY1MDE': {"Yes": 1, "No": 0},
365
  'YHLTMDE': {"Yes": 1, "No": 0},
366
  'YDOCMDE': {"Yes": 1, "No": 0},
367
  'YPSY2MDE': {"Yes": 1, "No": 0},
368
  'YMDEHARX': {"Yes": 1, "No": 0},
369
- 'LVLDIFMEM2': {"No Difficulty": 1, "Some Difficulty": 2, "A lot or cannot do": 3},
370
  'MDEIMPY': {"Yes": 1, "No": 2},
371
  'YMDEHPO': {"Yes": 1, "No": 0},
372
  'YMIMS5YANY': {"Yes": 1, "No": 0},
@@ -374,7 +301,7 @@ input_mapping = {
374
  'YMIUD5YANY': {"Yes": 1, "No": 0},
375
  'YMDEHPRX': {"Yes": 1, "No": 0},
376
  'YMIMI5YANY': {"Yes": 1, "No": 0},
377
- 'YUSUIPLN': {"Yes": 1, "No": 2, "Unsure": 3, "Don't want to answer": 4},
378
  'YTXMDEYR': {"Yes": 1, "No": 0},
379
  'YMDEAUD5YR': {"Yes": 1, "No": 0},
380
  'YRXMDEYR': {"Yes": 1, "No": 0},
@@ -382,166 +309,127 @@ input_mapping = {
382
  }
383
 
384
  ######################################
385
- # 6) THE GRADIO INTERFACE
386
  ######################################
387
- import gradio as gr
388
-
389
- # (A) The original required inputs
390
- original_inputs = [
391
- gr.Dropdown(list(input_mapping['YMDEYR'].keys()), label="YMDEYR: Past Year MDE?"),
392
- gr.Dropdown(list(input_mapping['YMDERSUD5ANY'].keys()), label="YMDERSUD5ANY: MDE or SUD - ANY?"),
393
- gr.Dropdown(list(input_mapping['YMDEIMAD5YR'].keys()), label="YMDEIMAD5YR: MDE + ALCOHOL?"),
394
- gr.Dropdown(list(input_mapping['YMIMS5YANY'].keys()), label="YMIMS5YANY: MDE + SUBSTANCE?"),
395
- gr.Dropdown(list(input_mapping['YMDELT'].keys()), label="YMDELT: MDE in Lifetime?"),
396
- gr.Dropdown(list(input_mapping['YMDEHARX'].keys()), label="YMDEHARX: Saw Health Prof + Meds?"),
397
- gr.Dropdown(list(input_mapping['YMDEHPRX'].keys()), label="YMDEHPRX: Saw Health Prof or Meds?"),
398
- gr.Dropdown(list(input_mapping['YMDETXRX'].keys()), label="YMDETXRX: Received Treatment?"),
399
- gr.Dropdown(list(input_mapping['YMDEHPO'].keys()), label="YMDEHPO: Saw Health Prof Only?"),
400
- gr.Dropdown(list(input_mapping['YMDEAUD5YR'].keys()), label="YMDEAUD5YR: MDE + Alcohol Use?"),
401
- gr.Dropdown(list(input_mapping['YMIMI5YANY'].keys()), label="YMIMI5YANY: MDE + ILL Drug Use?"),
402
- gr.Dropdown(list(input_mapping['YMIUD5YANY'].keys()), label="YMIUD5YANY: MDE + ILL Drug Use?"),
403
- gr.Dropdown(list(input_mapping['YMDESUD5ANYO'].keys()), label="YMDESUD5ANYO: MDE vs SUD vs BOTH vs NEITHER"),
404
-
405
- # Consultations
406
- gr.Dropdown(list(input_mapping['YNURSMDE'].keys()), label="YNURSMDE: Nurse/OT about MDE?"),
407
- gr.Dropdown(list(input_mapping['YSOCMDE'].keys()), label="YSOCMDE: Social Worker?"),
408
- gr.Dropdown(list(input_mapping['YCOUNMDE'].keys()), label="YCOUNMDE: Counselor?"),
409
- gr.Dropdown(list(input_mapping['YPSY1MDE'].keys()), label="YPSY1MDE: Psychologist?"),
410
- gr.Dropdown(list(input_mapping['YPSY2MDE'].keys()), label="YPSY2MDE: Psychiatrist?"),
411
- gr.Dropdown(list(input_mapping['YHLTMDE'].keys()), label="YHLTMDE: Health Prof?"),
412
- gr.Dropdown(list(input_mapping['YDOCMDE'].keys()), label="YDOCMDE: GP/Family MD?"),
413
- gr.Dropdown(list(input_mapping['YTXMDEYR'].keys()), label="YTXMDEYR: Doctor/Health Prof?"),
414
-
415
- # Suicidal
416
- gr.Dropdown(list(input_mapping['YUSUITHKYR'].keys()), label="YUSUITHKYR: Serious Suicide Thoughts?"),
417
- gr.Dropdown(list(input_mapping['YUSUIPLNYR'].keys()), label="YUSUIPLNYR: Made Plans?"),
418
- gr.Dropdown(list(input_mapping['YUSUITHK'].keys()), label="YUSUITHK: Suicide Thoughts (12 mo)?"),
419
- gr.Dropdown(list(input_mapping['YUSUIPLN'].keys()), label="YUSUIPLN: Made Plans (12 mo)?"),
420
-
421
- # Impairments
422
- gr.Dropdown(list(input_mapping['MDEIMPY'].keys()), label="MDEIMPY: Severe Role Impairment?"),
423
- gr.Dropdown(list(input_mapping['LVLDIFMEM2'].keys()), label="LVLDIFMEM2: Difficulty Remembering/Concentrating?"),
424
- gr.Dropdown(list(input_mapping['YMSUD5YANY'].keys()), label="YMSUD5YANY: MDE + Substance?"),
425
- gr.Dropdown(list(input_mapping['YRXMDEYR'].keys()), label="YRXMDEYR: Used Meds for MDE (12 mo)?"),
426
- ]
427
-
428
- # (B) The new co-occurrence inputs
429
- # We'll give them defaults of "None" to indicate no selection.
430
- all_cols = ["None"] + df.columns.tolist() # 'None' plus the actual columns from your df
431
- co_occ_feature1 = gr.Dropdown(all_cols, label="Co-Occ Feature 1", value="None")
432
- co_occ_feature2 = gr.Dropdown(all_cols, label="Co-Occ Feature 2", value="None")
433
- all_label_cols = ["None"] + list(predictor.prediction_map.keys()) # e.g., "YOWRCONC", "YOWRHRS", ...
434
- co_occ_label = gr.Dropdown(all_label_cols, label="Co-Occ Label", value="None")
435
-
436
- # Combine them into a single input list
437
- inputs = original_inputs + [co_occ_feature1, co_occ_feature2, co_occ_label]
438
-
439
- # 8 outputs as before
440
- outputs = [
441
- gr.Textbox(label="Prediction Results", lines=15),
442
- gr.Textbox(label="Mental Health Severity", lines=2),
443
- gr.Markdown(label="Total Patient Count"),
444
- gr.Plot(label="Distribution Plot (Sample)"),
445
- gr.Markdown(label="Nearest Neighbors (K=2)"),
446
- gr.Plot(label="Co-occurrence Plot"),
447
- gr.Plot(label="Same Value Bar (Inputs)"),
448
- gr.Plot(label="Predicted Label Bar")
449
- ]
450
 
451
  ######################################
452
- # 7) WRAPPER
453
  ######################################
454
- def predict_with_text(
455
- # match the function signature exactly (29 required + 3 for co-occ)
456
- YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
457
- YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
458
- YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE,
459
- YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
460
- YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR,
461
- co_occ_feature1, co_occ_feature2, co_occ_label
462
- ):
463
- # Validate the original 29 fields
464
- valid = validate_inputs(
465
- YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
466
- YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
467
- YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE,
468
- YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
469
- YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
470
- )
471
- if not valid:
472
- return (
473
- "Please select all required fields.",
474
- "Validation Error",
475
- "No data",
476
- None,
477
- "No data",
478
- None,
479
- None,
480
- None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481
  )
482
-
483
- # Map to numeric
484
- user_inputs = {
485
- 'YNURSMDE': input_mapping['YNURSMDE'][YNURSMDE],
486
- 'YMDEYR': input_mapping['YMDEYR'][YMDEYR],
487
- 'YSOCMDE': input_mapping['YSOCMDE'][YSOCMDE],
488
- 'YMDESUD5ANYO': input_mapping['YMDESUD5ANYO'][YMDESUD5ANYO],
489
- 'YMSUD5YANY': input_mapping['YMSUD5YANY'][YMSUD5YANY],
490
- 'YUSUITHK': input_mapping['YUSUITHK'][YUSUITHK],
491
- 'YMDETXRX': input_mapping['YMDETXRX'][YMDETXRX],
492
- 'YUSUITHKYR': input_mapping['YUSUITHKYR'][YUSUITHKYR],
493
- 'YMDERSUD5ANY': input_mapping['YMDERSUD5ANY'][YMDERSUD5ANY],
494
- 'YUSUIPLNYR': input_mapping['YUSUIPLNYR'][YUSUIPLNYR],
495
- 'YCOUNMDE': input_mapping['YCOUNMDE'][YCOUNMDE],
496
- 'YPSY1MDE': input_mapping['YPSY1MDE'][YPSY1MDE],
497
- 'YHLTMDE': input_mapping['YHLTMDE'][YHLTMDE],
498
- 'YDOCMDE': input_mapping['YDOCMDE'][YDOCMDE],
499
- 'YPSY2MDE': input_mapping['YPSY2MDE'][YPSY2MDE],
500
- 'YMDEHARX': input_mapping['YMDEHARX'][YMDEHARX],
501
- 'LVLDIFMEM2': input_mapping['LVLDIFMEM2'][LVLDIFMEM2],
502
- 'MDEIMPY': input_mapping['MDEIMPY'][MDEIMPY],
503
- 'YMDEHPO': input_mapping['YMDEHPO'][YMDEHPO],
504
- 'YMIMS5YANY': input_mapping['YMIMS5YANY'][YMIMS5YANY],
505
- 'YMDEIMAD5YR': input_mapping['YMDEIMAD5YR'][YMDEIMAD5YR],
506
- 'YMIUD5YANY': input_mapping['YMIUD5YANY'][YMIUD5YANY],
507
- 'YMDEHPRX': input_mapping['YMDEHPRX'][YMDEHPRX],
508
- 'YMIMI5YANY': input_mapping['YMIMI5YANY'][YMIMI5YANY],
509
- 'YUSUIPLN': input_mapping['YUSUIPLN'][YUSUIPLN],
510
- 'YTXMDEYR': input_mapping['YTXMDEYR'][YTXMDEYR],
511
- 'YMDEAUD5YR': input_mapping['YMDEAUD5YR'][YMDEAUD5YR],
512
- 'YRXMDEYR': input_mapping['YRXMDEYR'][YRXMDEYR],
513
- 'YMDELT': input_mapping['YMDELT'][YMDELT]
514
- }
515
-
516
- # Call the core predict function with the co-occ choices as well
517
- return predict(
518
- **user_inputs,
519
- co_occ_feature1=co_occ_feature1,
520
- co_occ_feature2=co_occ_feature2,
521
- co_occ_label=co_occ_label
522
- )
523
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
 
 
525
  custom_css = """
526
- .gradio-container * {
527
- color: #1B1212 !important;
528
- }
 
 
529
  """
530
 
531
- interface = gr.Interface(
532
- fn=predict_with_text,
533
- inputs=inputs,
534
- outputs=outputs,
535
- title="Mental Health Screening (NSDUH) with Selective Co-Occurrence",
536
- css=custom_css,
537
- description="""
538
- **Instructions**:
539
- 1. Fill out all required fields regarding MDE/Substance Use/Consultations/Suicidal/Impairments.
540
- 2. (Optional) Choose 2 features and 1 label for the *Co-occurrence* plot.
541
- - If you do not select them (or leave them as "None"), that plot will be skipped.
542
- 3. Click "Submit" to get predictions, severity, distribution plots, nearest neighbors, and your custom co-occurrence chart.
543
- """
544
- )
545
-
546
- if __name__ == "__main__":
547
- interface.launch()
 
1
  import pickle
 
2
  import numpy as np
3
  import pandas as pd
4
  import plotly.express as px
5
+ import gradio as gr
6
 
7
+ ######################################
8
+ # 1) Load Data & Prepare
9
+ ######################################
10
  df = pd.read_csv("X_train_Y_Train_merged_train.csv")
11
 
12
+ model_filenames = [
13
+ "YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl",
14
+ "YODPPROB.pkl", "YOWRPROB.pkl", "YODPR2WK.pkl", "YOWRDEPR.pkl",
15
+ "YODPDISC.pkl", "YOLOSEV.pkl", "YOWRDCSN.pkl", "YODSMMDE.pkl",
16
+ "YO_MDEA3.pkl", "YODPLSIN.pkl", "YOWRELES.pkl", "YOPB2WK.pkl"
17
+ ]
18
+ model_path = "models/"
19
+
20
  ######################################
21
+ # 2) Model Predictor
22
  ######################################
23
  class ModelPredictor:
24
  def __init__(self, model_path, model_filenames):
 
27
  self.models = self.load_models()
28
  # Mapping from label column to human-readable strings for 0/1
29
  self.prediction_map = {
30
+ "YOWRCONC": ["Did not have difficulty concentrating", "Had difficulty concentrating"],
31
+ "YOSEEDOC": ["Did not feel the need to see a doctor", "Felt the need to see a doctor"],
32
+ "YOWRHRS": ["Did not have trouble sleeping", "Had trouble sleeping"],
33
+ "YO_MDEA5": ["Others did not notice restlessness/lethargy", "Others noticed restlessness/lethargy"],
34
+ "YOWRCHR": ["Did not feel so sad", "Felt so sad nothing could cheer up"],
35
+ "YOWRLSIN": ["Did not feel bored and lose interest", "Felt bored and lost interest"],
36
  "YODPPROB": ["No other problems for 2+ weeks", "Had other problems for 2+ weeks"],
37
+ "YOWRPROB": ["Did not have the worst time ever feeling", "Had the worst time ever feeling"],
38
+ "YODPR2WK": ["No periods of 2+ weeks feelings", "Had periods of 2+ weeks feelings"],
39
+ "YOWRDEPR": ["Did not feel depressed mostly everyday", "Felt depressed mostly everyday"],
40
  "YODPDISC": ["Mood not depressed overall", "Mood depressed overall (discrepancy)"],
41
+ "YOLOSEV": ["Did not lose interest in enjoyable things", "Lost interest in enjoyable things"],
42
+ "YOWRDCSN": ["Was able to make decisions", "Was unable to make decisions"],
43
+ "YODSMMDE": ["Never had depression for 2+ weeks", "Had depression for 2+ weeks"],
44
+ "YO_MDEA3": ["No appetite/weight changes", "Had appetite/weight changes"],
45
+ "YODPLSIN": ["Never bored/lost interest", "Felt bored/lost interest"],
46
+ "YOWRELES": ["Did not eat less than usual", "Ate less than usual"],
47
  "YODSCEV": ["Fewer severe symptoms", "More severe symptoms"],
48
+ "YOPB2WK": ["No uneasy feelings 2+ weeks", "Had uneasy feelings 2+ weeks"],
49
+ "YO_MDEA2": ["No issues w/ physical/mental well-being", "Issues w/ physical/mental well-being"]
50
  }
51
 
52
  def load_models(self):
53
  models = []
54
+ for filename in model_filenames:
55
+ filepath = self.model_path + filename
56
+ with open(filepath, 'rb') as file:
57
+ model = pickle.load(file)
58
+ models.append(model)
59
  return models
60
 
61
  def make_predictions(self, user_input):
62
+ # Each model => returns array of [0] or [1]
63
+ predictions = []
64
+ for model in self.models:
65
+ pred = model.predict(user_input)
66
+ predictions.append(pred.flatten())
67
+ return predictions
68
 
69
  def get_majority_vote(self, predictions):
 
70
  combined = np.concatenate(predictions)
71
+ # 0 or 1 with highest frequency
72
+ majority_vote = np.bincount(combined).argmax()
73
+ return majority_vote
74
 
75
  def evaluate_severity(self, majority_vote_count):
76
+ # Simple threshold approach
77
  if majority_vote_count >= 13:
78
+ return "Mental Health Severity: Severe"
79
  elif majority_vote_count >= 9:
80
+ return "Mental Health Severity: Moderate"
81
  elif majority_vote_count >= 5:
82
+ return "Mental Health Severity: Low"
83
  else:
84
+ return "Mental Health Severity: Very Low"
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  ######################################
87
+ # 3) Validate Inputs
88
  ######################################
89
  def validate_inputs(*args):
 
90
  for arg in args:
91
  if arg == '' or arg is None:
92
  return False
93
  return True
94
 
95
  ######################################
96
+ # 4) Core Prediction
97
  ######################################
98
+ predictor = ModelPredictor(model_path, model_filenames)
99
+
100
  def predict(
 
101
  YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
102
  YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
103
  YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE,
104
  YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
105
+ YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
 
 
106
  ):
 
 
 
 
 
 
 
 
 
 
 
107
  user_input_data = {
108
  'YNURSMDE': [int(YNURSMDE)],
109
  'YMDEYR': [int(YMDEYR)],
 
137
  }
138
  user_input = pd.DataFrame(user_input_data)
139
 
140
+ # 1) Predict
141
  predictions = predictor.make_predictions(user_input)
142
+
143
+ # 2) Majority vote
144
  majority_vote = predictor.get_majority_vote(predictions)
145
+
146
+ # 3) Count how many are '1'
147
+ num_ones = sum(np.concatenate(predictions) == 1)
148
+
149
+ # 4) Severity
150
+ severity = predictor.evaluate_severity(num_ones)
151
+
152
+ # 5) Grouped textual results
153
+ # [Same grouping logic as before, or adapt as needed]
154
+ groups = {
 
 
155
  "Concentration_and_Decision_Making": ["YOWRCONC", "YOWRDCSN"],
156
  "Sleep_and_Energy_Levels": ["YOWRHRS", "YO_MDEA5", "YOWRELES", "YO_MDEA2"],
157
  "Mood_and_Emotional_State": ["YOWRCHR", "YOWRLSIN", "YOWRDEPR", "YODPDISC",
 
161
  "YODPR2WK", "YODSMMDE",
162
  "YOPB2WK"]
163
  }
164
+ grouped_text = {k: [] for k in groups}
165
+ for i, pred in enumerate(predictions):
166
+ col_name = model_filenames[i].split('.')[0]
167
+ pred_val = pred[0]
168
+ if col_name in predictor.prediction_map and pred_val in [0,1]:
169
+ text_val = predictor.prediction_map[col_name][pred_val]
 
 
 
170
  else:
171
+ text_val = f"Prediction={pred_val}"
172
+ # Find which group
173
+ assigned = False
174
+ for gname, gcols in groups.items():
175
+ if col_name in gcols:
176
+ grouped_text[gname].append(f"{col_name} => {text_val}")
177
+ assigned = True
 
178
  break
179
+ if not assigned:
180
+ # Or skip
181
  pass
182
 
183
+ final_str = []
184
+ for gname, items in grouped_text.items():
185
+ if items:
186
+ final_str.append(f"**{gname.replace('_',' ')}**")
187
+ final_str.append("\n".join(items))
188
+ final_str.append("\n")
189
+ final_str = "\n".join(final_str).strip()
190
+ if not final_str:
191
+ final_str = "No predictions made. Please check inputs."
192
+
193
+ # 6) Additional charts: total patients, distribution for input features, etc.
 
 
 
194
  total_patients = len(df)
195
+ total_patient_markdown = (
196
+ f"### Total Patient Count\nThere are **{total_patients}** patients in the dataset."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  )
 
198
 
199
+ # A) Bar chart for input features
200
+ same_val_counts = {}
201
+ for col, val_list in user_input_data.items():
202
+ val_ = val_list[0]
203
+ same_val_counts[col] = len(df[df[col] == val_])
204
+ bar_input_df = pd.DataFrame({"Feature": list(same_val_counts.keys()),
205
+ "Count": list(same_val_counts.values())})
206
+ fig_bar_input = px.bar(bar_input_df, x="Feature", y="Count",
207
+ title="Number of Patients with Same Input Feature Values")
208
+ fig_bar_input.update_layout(width=800, height=500)
209
+
210
+ # B) Bar chart for predicted labels
211
  label_counts = {}
212
+ all_preds_flat = np.concatenate(predictions)
213
+ for i, arr in enumerate(predictions):
214
+ lbl_col = model_filenames[i].split('.')[0]
215
+ pred_val = arr[0]
216
+ if pred_val in [0,1]:
217
+ label_counts[lbl_col] = len(df[df[lbl_col] == pred_val])
218
+ if label_counts:
219
+ bar_label_df = pd.DataFrame({"Label": list(label_counts.keys()),
220
+ "Count": list(label_counts.values())})
221
+ fig_bar_labels = px.bar(bar_label_df, x="Label", y="Count",
222
+ title="Number of Patients with the Same Predicted Label")
223
+ fig_bar_labels.update_layout(width=800, height=500)
 
 
 
 
 
224
  else:
225
+ fig_bar_labels = px.bar(title="No valid predicted labels to display.")
226
+ fig_bar_labels.update_layout(width=800, height=500)
227
+
228
+ # C) Distribution Plot (small sample)
229
+ # We'll pick the first 4 user_input columns & first 3 labels
230
+ subset_input_cols = list(user_input_data.keys())[:4]
231
+ subset_labels = [fn.split('.')[0] for fn in model_filenames[:3]]
232
+ dist_rows = []
233
+ for feat in subset_input_cols:
234
  if feat not in df.columns:
235
  continue
236
+ for label_col in subset_labels:
237
+ if label_col not in df.columns:
238
  continue
239
+ tmp = df.groupby([feat, label_col]).size().reset_index(name="count")
240
+ tmp["feature"] = feat
241
+ tmp["label"] = label_col
242
+ dist_rows.append(tmp)
243
+ if dist_rows:
244
+ big_dist_df = pd.concat(dist_rows, ignore_index=True)
245
+ fig_dist = px.bar(big_dist_df,
246
+ x=big_dist_df.columns[0],
247
+ y="count",
248
+ color=big_dist_df.columns[1],
249
+ facet_row="feature",
250
+ facet_col="label",
251
+ title="Distribution of Sample Input Features vs. Sample Predicted Labels")
252
+ fig_dist.update_layout(width=1000, height=700)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  else:
254
+ fig_dist = px.bar(title="Distribution plot not generated.")
255
 
256
+ # D) Nearest Neighbors (K=2) [Optional as before]
257
+ # ... omitted for brevity if you want to keep from prior code ...
258
+ # or keep it.
259
+ # For now, let's produce an empty markdown
260
+ nearest_neighbors_markdown = "Nearest neighbors omitted here for brevity..."
261
+
262
+ # We won't produce a default co-occurrence plot here, since we do it in a separate tab.
263
+
264
+ # Return 8 items
265
  return (
266
+ final_str,
267
+ severity,
268
+ total_patient_markdown,
269
+ fig_dist,
270
+ nearest_neighbors_markdown,
271
+ None, # placeholder for a single co-occurrence plot
272
+ fig_bar_input,
273
+ fig_bar_labels
274
  )
275
 
276
  ######################################
277
+ # 5) Input Mapping
278
  ######################################
279
  input_mapping = {
280
  'YNURSMDE': {"Yes": 1, "No": 0},
281
  'YMDEYR': {"Yes": 1, "No": 2},
282
  'YSOCMDE': {"Yes": 1, "No": 0},
283
+ 'YMDESUD5ANYO': {"SUD only, no MDE": 1, "MDE only, no SUD": 2, "SUD and MDE": 3, "Neither SUD or MDE": 4},
284
  'YMSUD5YANY': {"Yes": 1, "No": 0},
285
+ 'YUSUITHK': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
286
  'YMDETXRX': {"Yes": 1, "No": 0},
287
+ 'YUSUITHKYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
288
  'YMDERSUD5ANY': {"Yes": 1, "No": 0},
289
+ 'YUSUIPLNYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
290
  'YCOUNMDE': {"Yes": 1, "No": 0},
291
  'YPSY1MDE': {"Yes": 1, "No": 0},
292
  'YHLTMDE': {"Yes": 1, "No": 0},
293
  'YDOCMDE': {"Yes": 1, "No": 0},
294
  'YPSY2MDE': {"Yes": 1, "No": 0},
295
  'YMDEHARX': {"Yes": 1, "No": 0},
296
+ 'LVLDIFMEM2': {"No Difficulty": 1, "Some difficulty": 2, "A lot of difficulty or cannot do at all": 3},
297
  'MDEIMPY': {"Yes": 1, "No": 2},
298
  'YMDEHPO': {"Yes": 1, "No": 0},
299
  'YMIMS5YANY': {"Yes": 1, "No": 0},
 
301
  'YMIUD5YANY': {"Yes": 1, "No": 0},
302
  'YMDEHPRX': {"Yes": 1, "No": 0},
303
  'YMIMI5YANY': {"Yes": 1, "No": 0},
304
+ 'YUSUIPLN': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
305
  'YTXMDEYR': {"Yes": 1, "No": 0},
306
  'YMDEAUD5YR': {"Yes": 1, "No": 0},
307
  'YRXMDEYR': {"Yes": 1, "No": 0},
 
309
  }
310
 
311
  ######################################
312
+ # 6) Co-Occurrence Function (Separate)
313
  ######################################
314
+ def co_occurrence_plot(feature1, feature2, label_col):
315
+ """
316
+ Generate a single co-occurrence bar chart grouping by [feature1, feature2, label_col].
317
+ We set a custom width/height so it's clearly visible.
318
+ """
319
+ if not feature1 or not feature2 or not label_col:
320
+ return px.bar(title="Please select all three fields.")
321
+ if feature1 not in df.columns or feature2 not in df.columns or label_col not in df.columns:
322
+ return px.bar(title="Selected columns not found in the dataset.")
323
+
324
+ # Group
325
+ grouped_df = df.groupby([feature1, feature2, label_col]).size().reset_index(name="count")
326
+ fig = px.bar(
327
+ grouped_df,
328
+ x=feature1,
329
+ y="count",
330
+ color=label_col,
331
+ facet_col=feature2,
332
+ title=f"Co-Occurrence Plot: {feature1} & {feature2} vs. {label_col}"
333
+ )
334
+ fig.update_layout(width=1000, height=600)
335
+ return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
 
337
  ######################################
338
+ # 7) Gradio with Tabs
339
  ######################################
340
+ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
341
+
342
+ with gr.Tab("Prediction"):
343
+ # Inputs (same order as function signature)
344
+ YMDEYR_dd = gr.Dropdown(list(input_mapping['YMDEYR'].keys()), label="YMDEYR")
345
+ YMDERSUD5ANY_dd = gr.Dropdown(list(input_mapping['YMDERSUD5ANY'].keys()), label="YMDERSUD5ANY")
346
+ YMDEIMAD5YR_dd = gr.Dropdown(list(input_mapping['YMDEIMAD5YR'].keys()), label="YMDEIMAD5YR")
347
+ YMIMS5YANY_dd = gr.Dropdown(list(input_mapping['YMIMS5YANY'].keys()), label="YMIMS5YANY")
348
+ YMDELT_dd = gr.Dropdown(list(input_mapping['YMDELT'].keys()), label="YMDELT")
349
+ YMDEHARX_dd = gr.Dropdown(list(input_mapping['YMDEHARX'].keys()), label="YMDEHARX")
350
+ YMDEHPRX_dd = gr.Dropdown(list(input_mapping['YMDEHPRX'].keys()), label="YMDEHPRX")
351
+ YMDETXRX_dd = gr.Dropdown(list(input_mapping['YMDETXRX'].keys()), label="YMDETXRX")
352
+ YMDEHPO_dd = gr.Dropdown(list(input_mapping['YMDEHPO'].keys()), label="YMDEHPO")
353
+ YMDEAUD5YR_dd = gr.Dropdown(list(input_mapping['YMDEAUD5YR'].keys()), label="YMDEAUD5YR")
354
+ YMIMI5YANY_dd = gr.Dropdown(list(input_mapping['YMIMI5YANY'].keys()), label="YMIMI5YANY")
355
+ YMIUD5YANY_dd = gr.Dropdown(list(input_mapping['YMIUD5YANY'].keys()), label="YMIUD5YANY")
356
+ YMDESUD5ANYO_dd = gr.Dropdown(list(input_mapping['YMDESUD5ANYO'].keys()), label="YMDESUD5ANYO")
357
+
358
+ # Consultations
359
+ YNURSMDE_dd = gr.Dropdown(list(input_mapping['YNURSMDE'].keys()), label="YNURSMDE")
360
+ YSOCMDE_dd = gr.Dropdown(list(input_mapping['YSOCMDE'].keys()), label="YSOCMDE")
361
+ YCOUNMDE_dd = gr.Dropdown(list(input_mapping['YCOUNMDE'].keys()), label="YCOUNMDE")
362
+ YPSY1MDE_dd = gr.Dropdown(list(input_mapping['YPSY1MDE'].keys()), label="YPSY1MDE")
363
+ YPSY2MDE_dd = gr.Dropdown(list(input_mapping['YPSY2MDE'].keys()), label="YPSY2MDE")
364
+ YHLTMDE_dd = gr.Dropdown(list(input_mapping['YHLTMDE'].keys()), label="YHLTMDE")
365
+ YDOCMDE_dd = gr.Dropdown(list(input_mapping['YDOCMDE'].keys()), label="YDOCMDE")
366
+ YTXMDEYR_dd = gr.Dropdown(list(input_mapping['YTXMDEYR'].keys()), label="YTXMDEYR")
367
+
368
+ # Suicidal thoughts/plans
369
+ YUSUITHKYR_dd = gr.Dropdown(list(input_mapping['YUSUITHKYR'].keys()), label="YUSUITHKYR")
370
+ YUSUIPLNYR_dd = gr.Dropdown(list(input_mapping['YUSUIPLNYR'].keys()), label="YUSUIPLNYR")
371
+ YUSUITHK_dd = gr.Dropdown(list(input_mapping['YUSUITHK'].keys()), label="YUSUITHK")
372
+ YUSUIPLN_dd = gr.Dropdown(list(input_mapping['YUSUIPLN'].keys()), label="YUSUIPLN")
373
+
374
+ # Impairments
375
+ MDEIMPY_dd = gr.Dropdown(list(input_mapping['MDEIMPY'].keys()), label="MDEIMPY")
376
+ LVLDIFMEM2_dd = gr.Dropdown(list(input_mapping['LVLDIFMEM2'].keys()), label="LVLDIFMEM2")
377
+ YMSUD5YANY_dd = gr.Dropdown(list(input_mapping['YMSUD5YANY'].keys()), label="YMSUD5YANY")
378
+ YRXMDEYR_dd = gr.Dropdown(list(input_mapping['YRXMDEYR'].keys()), label="YRXMDEYR")
379
+
380
+ # 8 outputs
381
+ out_pred_res = gr.Textbox(label="Prediction Results", lines=8)
382
+ out_sev = gr.Textbox(label="Mental Health Severity", lines=2)
383
+ out_count = gr.Markdown(label="Total Patient Count")
384
+ out_distplot = gr.Plot(label="Distribution Plot")
385
+ out_nn = gr.Markdown(label="Nearest Neighbors Summary")
386
+ out_cooc = gr.Plot(label="Co-occurrence Plot Placeholder")
387
+ out_bar_input = gr.Plot(label="Input Feature Counts")
388
+ out_bar_labels = gr.Plot(label="Predicted Label Counts")
389
+
390
+ # Button
391
+ predict_btn = gr.Button("Predict")
392
+
393
+ # Link button to the function
394
+ predict_btn.click(
395
+ fn=predict,
396
+ inputs=[
397
+ YMDEYR_dd, YMDERSUD5ANY_dd, YMDEIMAD5YR_dd, YMIMS5YANY_dd, YMDELT_dd, YMDEHARX_dd,
398
+ YMDEHPRX_dd, YMDETXRX_dd, YMDEHPO_dd, YMDEAUD5YR_dd, YMIMI5YANY_dd, YMIUD5YANY_dd,
399
+ YMDESUD5ANYO_dd, YNURSMDE_dd, YSOCMDE_dd, YCOUNMDE_dd, YPSY1MDE_dd, YPSY2MDE_dd,
400
+ YHLTMDE_dd, YDOCMDE_dd, YTXMDEYR_dd, YUSUITHKYR_dd, YUSUIPLNYR_dd, YUSUITHK_dd,
401
+ YUSUIPLN_dd, MDEIMPY_dd, LVLDIFMEM2_dd, YMSUD5YANY_dd, YRXMDEYR_dd
402
+ ],
403
+ outputs=[
404
+ out_pred_res, out_sev, out_count, out_distplot, out_nn, out_cooc, out_bar_input, out_bar_labels
405
+ ]
406
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
 
408
+ with gr.Tab("Co-occurrence"):
409
+ gr.Markdown("## Generate a Co-Occurrence Plot on Demand\nSelect two features and one label:")
410
+ with gr.Row():
411
+ feature1_dd = gr.Dropdown(sorted(df.columns), label="Feature 1")
412
+ feature2_dd = gr.Dropdown(sorted(df.columns), label="Feature 2")
413
+ label_dd = gr.Dropdown(sorted(df.columns), label="Label Column")
414
+ out_co_occ_plot = gr.Plot(label="Co-occurrence Plot")
415
+
416
+ co_occ_btn = gr.Button("Generate Plot")
417
+
418
+ # Link to co_occurrence_plot function
419
+ co_occ_btn.click(
420
+ fn=co_occurrence_plot,
421
+ inputs=[feature1_dd, feature2_dd, label_dd],
422
+ outputs=out_co_occ_plot
423
+ )
424
 
425
+ # Optional custom CSS for bigger container
426
  custom_css = """
427
+ .gradio-container {
428
+ max-width: 1200px;
429
+ margin-left: auto;
430
+ margin-right: auto;
431
+ }
432
  """
433
 
434
+ # Launch
435
+ demo.launch(server_name="0.0.0.0", server_port=7860)