pantdipendra commited on
Commit
ebac442
·
verified ·
1 Parent(s): 6b501f6
Files changed (1) hide show
  1. app.py +140 -83
app.py CHANGED
@@ -7,8 +7,14 @@ import plotly.express as px
7
  ######################################
8
  # 1) LOAD DATA & MODELS
9
  ######################################
 
10
  df = pd.read_csv("X_train_test_combined_dataset_Filtered_dataset.csv")
11
 
 
 
 
 
 
12
  model_filenames = [
13
  "YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl",
14
  "YODPPROB.pkl", "YOWRPROB.pkl", "YODPR2WK.pkl", "YOWRDEPR.pkl",
@@ -17,7 +23,6 @@ model_filenames = [
17
  ]
18
  model_path = "models/"
19
 
20
-
21
  ######################################
22
  # 2) MODEL PREDICTOR
23
  ######################################
@@ -38,7 +43,7 @@ class ModelPredictor:
38
  "YODPPROB": ["No other problems for 2+ weeks", "Had other problems for 2+ weeks"],
39
  "YOWRPROB": ["No 'worst time ever' feeling", "Had 'worst time ever' feeling"],
40
  "YODPR2WK": ["No depressed feelings for 2+ wks", "Had depressed feelings for 2+ wks"],
41
- "YOWRDEPR": ["Did NOT feel sad/depressed daily", "Felt sad/depressed mostly everyday"],
42
  "YODPDISC": ["Overall mood not sad/depressed", "Overall mood was sad/depressed"],
43
  "YOLOSEV": ["Did NOT lose interest in things", "Lost interest in enjoyable things"],
44
  "YOWRDCSN": ["Was able to make decisions", "Was unable to make decisions"],
@@ -54,9 +59,14 @@ class ModelPredictor:
54
  def load_models(self):
55
  loaded = []
56
  for fname in self.model_filenames:
57
- with open(self.model_path + fname, "rb") as f:
58
- model = pickle.load(f)
59
- loaded.append(model)
 
 
 
 
 
60
  return loaded
61
 
62
  def make_predictions(self, user_input: pd.DataFrame):
@@ -91,17 +101,14 @@ class ModelPredictor:
91
  else:
92
  return "Mental Health Severity: Very Low"
93
 
94
-
95
  predictor = ModelPredictor(model_path, model_filenames)
96
 
97
-
98
  ######################################
99
  # 3) FEATURE CATEGORIES + MAPPING
100
  ######################################
101
- # Replaced 'YMDESUD5ANYO' with 'YMDESUD5ANY' to match your CSV
102
  categories_dict = {
103
  "1. Depression & Substance Use Diagnosis": [
104
- "YMDESUD5ANY", "YMDELT", "YMDEYR", "YMDERSUD5ANY",
105
  "YMSUD5YANY", "YMIUD5YANY", "YMIMS5YANY", "YMIMI5YANY"
106
  ],
107
  "2. Mental Health Treatment & Prof Consultation": [
@@ -116,9 +123,13 @@ categories_dict = {
116
  ]
117
  }
118
 
119
- # Again, replaced 'YMDESUD5ANYO' with 'YMDESUD5ANY'
120
  input_mapping = {
121
- 'YMDESUD5ANY': {"SUD only, no MDE": 1, "MDE only, no SUD": 2, "SUD and MDE": 3, "Neither SUD or MDE": 4},
 
 
 
 
 
122
  'YMDELT': {"Yes": 1, "No": 2},
123
  'YMDEYR': {"Yes": 1, "No": 2},
124
  'YMDERSUD5ANY': {"Yes": 1, "No": 0},
@@ -140,7 +151,11 @@ input_mapping = {
140
  'YCOUNMDE': {"Yes": 1, "No": 0},
141
 
142
  'MDEIMPY': {"Yes": 1, "No": 2},
143
- 'LVLDIFMEM2': {"No Difficulty": 1, "Some difficulty": 2, "A lot of difficulty or cannot do at all": 3},
 
 
 
 
144
 
145
  'YUSUITHK': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
146
  'YUSUITHKYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
@@ -148,10 +163,9 @@ input_mapping = {
148
  'YUSUIPLN': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4}
149
  }
150
 
151
-
152
  def validate_inputs(*args):
153
  for arg in args:
154
- if not arg: # empty or None
155
  return False
156
  return True
157
 
@@ -209,13 +223,12 @@ def get_nearest_neighbors_info(user_input_df: pd.DataFrame, k=5):
209
  lines.append("")
210
  return "\n".join(lines)
211
 
212
-
213
  ######################################
214
  # 5) PREDICT FUNCTION
215
  ######################################
216
  def predict(
217
  # Category 1 (8):
218
- YMDESUD5ANY, YMDELT, YMDEYR, YMDERSUD5ANY,
219
  YMSUD5YANY, YMIUD5YANY, YMIMS5YANY, YMIMI5YANY,
220
  # Category 2 (11):
221
  YMDEHPO, YMDETXRX, YMDEHARX, YMDEHPRX, YRXMDEYR,
@@ -227,7 +240,7 @@ def predict(
227
  ):
228
  # 1) Validate
229
  if not validate_inputs(
230
- YMDESUD5ANY, YMDELT, YMDEYR, YMDERSUD5ANY,
231
  YMSUD5YANY, YMIUD5YANY, YMIMS5YANY, YMIMI5YANY,
232
  YMDEHPO, YMDETXRX, YMDEHARX, YMDEHPRX, YRXMDEYR,
233
  YHLTMDE, YTXMDEYR, YDOCMDE, YPSY2MDE, YPSY1MDE, YCOUNMDE,
@@ -235,49 +248,71 @@ def predict(
235
  YUSUITHK, YUSUITHKYR, YUSUIPLNYR, YUSUIPLN
236
  ):
237
  return (
238
- "Please select all required fields.",
239
- "Validation Error",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  "No data",
241
  "No nearest neighbors info",
242
  None,
243
  None
244
  )
245
 
246
- # 2) Convert text -> numeric
247
- user_input_dict = {
248
- 'YMDESUD5ANY': input_mapping['YMDESUD5ANY'][YMDESUD5ANY],
249
- 'YMDELT': input_mapping['YMDELT'][YMDELT],
250
- 'YMDEYR': input_mapping['YMDEYR'][YMDEYR],
251
- 'YMDERSUD5ANY': input_mapping['YMDERSUD5ANY'][YMDERSUD5ANY],
252
- 'YMSUD5YANY': input_mapping['YMSUD5YANY'][YMSUD5YANY],
253
- 'YMIUD5YANY': input_mapping['YMIUD5YANY'][YMIUD5YANY],
254
- 'YMIMS5YANY': input_mapping['YMIMS5YANY'][YMIMS5YANY],
255
- 'YMIMI5YANY': input_mapping['YMIMI5YANY'][YMIMI5YANY],
256
-
257
- 'YMDEHPO': input_mapping['YMDEHPO'][YMDEHPO],
258
- 'YMDETXRX': input_mapping['YMDETXRX'][YMDETXRX],
259
- 'YMDEHARX': input_mapping['YMDEHARX'][YMDEHARX],
260
- 'YMDEHPRX': input_mapping['YMDEHPRX'][YMDEHPRX],
261
- 'YRXMDEYR': input_mapping['YRXMDEYR'][YRXMDEYR],
262
- 'YHLTMDE': input_mapping['YHLTMDE'][YHLTMDE],
263
- 'YTXMDEYR': input_mapping['YTXMDEYR'][YTXMDEYR],
264
- 'YDOCMDE': input_mapping['YDOCMDE'][YDOCMDE],
265
- 'YPSY2MDE': input_mapping['YPSY2MDE'][YPSY2MDE],
266
- 'YPSY1MDE': input_mapping['YPSY1MDE'][YPSY1MDE],
267
- 'YCOUNMDE': input_mapping['YCOUNMDE'][YCOUNMDE],
268
-
269
- 'MDEIMPY': input_mapping['MDEIMPY'][MDEIMPY],
270
- 'LVLDIFMEM2': input_mapping['LVLDIFMEM2'][LVLDIFMEM2],
271
-
272
- 'YUSUITHK': input_mapping['YUSUITHK'][YUSUITHK],
273
- 'YUSUITHKYR': input_mapping['YUSUITHKYR'][YUSUITHKYR],
274
- 'YUSUIPLNYR': input_mapping['YUSUIPLNYR'][YUSUIPLNYR],
275
- 'YUSUIPLN': input_mapping['YUSUIPLN'][YUSUIPLN]
276
- }
277
  user_df = pd.DataFrame(user_input_dict, index=[0])
278
 
279
  # 3) Make predictions
280
- preds, probs = predictor.make_predictions(user_df)
 
 
 
 
 
 
 
 
 
 
281
 
282
  # Flatten predictions for severity count
283
  all_preds = np.concatenate(preds)
@@ -295,13 +330,13 @@ def predict(
295
 
296
  # Group them by domain
297
  domain_groups = {
298
- "Concentration_and_Decision_Making": ["YOWRCONC", "YOWRDCSN"],
299
- "Sleep_and_Energy_Levels": ["YOWRHRS", "YO_MDEA5", "YOWRELES", "YO_MDEA2"],
300
- "Mood_and_Emotional_State": [
301
  "YOWRCHR", "YOWRLSIN", "YOWRDEPR", "YODPDISC", "YOLOSEV", "YODPLSIN", "YODSCEV"
302
  ],
303
- "Appetite_and_Weight_Changes": ["YO_MDEA3", "YOWRELES"],
304
- "Duration_and_Severity_of_Depression_Symptoms": [
305
  "YODPPROB", "YOWRPROB", "YODPR2WK", "YODSMMDE", "YOPB2WK"
306
  ]
307
  }
@@ -320,14 +355,13 @@ def predict(
320
  if not np.isnan(prob_val):
321
  text_prob = f"(Prob= {prob_val:.2f})"
322
  else:
323
- text_prob = "(No prob available)"
324
 
325
  group_lines.append(f"{lbl} => {text_pred} {text_prob}")
326
  if group_lines:
327
- gtitle = gname.replace("_", " ")
328
- final_str_parts.append(f"**{gtitle}**")
329
  final_str_parts.append("\n".join(group_lines))
330
- final_str_parts.append("")
331
 
332
  if final_str_parts:
333
  final_str = "\n".join(final_str_parts)
@@ -345,8 +379,10 @@ def predict(
345
  for col, val_ in user_input_dict.items():
346
  matched = len(df[df[col] == val_])
347
  input_counts[col] = matched
348
- bar_in_df = pd.DataFrame({"Feature": list(input_counts.keys()),
349
- "Count": list(input_counts.values())})
 
 
350
  fig_in = px.bar(
351
  bar_in_df, x="Feature", y="Count",
352
  title="Number of Patients with the Same Input Feature Values"
@@ -376,12 +412,11 @@ def predict(
376
  final_str, # 1) Prediction Results
377
  severity_msg, # 2) Mental Health Severity
378
  total_count_md, # 3) Total Patient Count
379
- nn_md, # 4) Nearest Neighbors
380
  fig_in, # 5) Bar Chart (input features)
381
  fig_lbl # 6) Bar Chart (labels)
382
  )
383
 
384
-
385
  ######################################
386
  # 6) UNIFIED DISTRIBUTION/CO-OCCURRENCE
387
  ######################################
@@ -399,8 +434,13 @@ def combined_plot(feature_list, label_col):
399
  if f_ not in df.columns or label_col not in df.columns:
400
  return px.bar(title="Selected columns not found in the dataset.")
401
  grouped = df.groupby([f_, label_col]).size().reset_index(name="count")
402
- fig = px.bar(grouped, x=f_, y="count", color=label_col,
403
- title=f"Distribution of {f_} vs {label_col}")
 
 
 
 
 
404
  fig.update_layout(width=1200, height=600)
405
  return fig
406
 
@@ -410,8 +450,12 @@ def combined_plot(feature_list, label_col):
410
  return px.bar(title="Selected columns not found in the dataset.")
411
  grouped = df.groupby([f1, f2, label_col]).size().reset_index(name="count")
412
  fig = px.bar(
413
- grouped, x=f1, y="count", color=label_col,
414
- facet_col=f2, title=f"Co-occurrence: {f1}, {f2} vs {label_col}"
 
 
 
 
415
  )
416
  fig.update_layout(width=1200, height=600)
417
  return fig
@@ -419,20 +463,19 @@ def combined_plot(feature_list, label_col):
419
  else:
420
  return px.bar(title="Please select exactly 1 or 2 features.")
421
 
422
-
423
  ######################################
424
  # 7) BUILD GRADIO UI
425
  ######################################
426
  with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
427
 
428
- # TAB 1: Prediction
429
  with gr.Tab("Prediction"):
430
  gr.Markdown("### Please provide inputs in each of the four categories below. All fields are required.")
431
 
432
- # Category 1
433
  gr.Markdown("#### 1. Depression & Substance Use Diagnosis")
434
  cat1_col_labels = [
435
- ("YMDESUD5ANY", "YMDESUD5ANY: ONLY MDE, ONLY SUD, BOTH, OR NEITHER"),
436
  ("YMDELT", "YMDELT: Had major depressive episode in lifetime"),
437
  ("YMDEYR", "YMDEYR: Past-year major depressive episode"),
438
  ("YMDERSUD5ANY", "YMDERSUD5ANY: MDE or SUD in past year?"),
@@ -444,10 +487,13 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
444
  cat1_inputs = []
445
  for col, label_text in cat1_col_labels:
446
  cat1_inputs.append(
447
- gr.Dropdown(choices=list(input_mapping[col].keys()), label=label_text)
 
 
 
448
  )
449
 
450
- # Category 2
451
  gr.Markdown("#### 2. Mental Health Treatment & Professional Consultation")
452
  cat2_col_labels = [
453
  ("YMDEHPO", "YMDEHPO: Saw health prof only for MDE"),
@@ -465,10 +511,13 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
465
  cat2_inputs = []
466
  for col, label_text in cat2_col_labels:
467
  cat2_inputs.append(
468
- gr.Dropdown(choices=list(input_mapping[col].keys()), label=label_text)
 
 
 
469
  )
470
 
471
- # Category 3
472
  gr.Markdown("#### 3. Functional & Cognitive Impairment")
473
  cat3_col_labels = [
474
  ("MDEIMPY", "MDEIMPY: MDE with severe role impairment?"),
@@ -477,10 +526,13 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
477
  cat3_inputs = []
478
  for col, label_text in cat3_col_labels:
479
  cat3_inputs.append(
480
- gr.Dropdown(choices=list(input_mapping[col].keys()), label=label_text)
 
 
 
481
  )
482
 
483
- # Category 4
484
  gr.Markdown("#### 4. Suicidal Thoughts & Behaviors")
485
  cat4_col_labels = [
486
  ("YUSUITHK", "YUSUITHK: Thought of killing self (past 12 months)?"),
@@ -491,12 +543,16 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
491
  cat4_inputs = []
492
  for col, label_text in cat4_col_labels:
493
  cat4_inputs.append(
494
- gr.Dropdown(choices=list(input_mapping[col].keys()), label=label_text)
 
 
 
495
  )
496
 
497
- # Combine in the same order
498
  all_inputs = cat1_inputs + cat2_inputs + cat3_inputs + cat4_inputs
499
 
 
500
  predict_btn = gr.Button("Predict")
501
 
502
  out_pred_res = gr.Textbox(label="Prediction Results (with Probability)", lines=8)
@@ -506,6 +562,7 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
506
  out_bar_input= gr.Plot(label="Input Feature Counts")
507
  out_bar_label= gr.Plot(label="Predicted Label Counts")
508
 
 
509
  predict_btn.click(
510
  fn=predict,
511
  inputs=all_inputs,
@@ -522,8 +579,8 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
522
  # ======== TAB 2: Unified Distribution/Co-occurrence ========
523
  with gr.Tab("Distribution/Co-occurrence"):
524
  gr.Markdown("### Select 1 or 2 features + 1 label to see a bar chart.")
525
- # Possibly you want only columns from input_mapping or from df
526
- # We'll let user pick from df.columns:
527
  list_of_features = sorted(df.columns)
528
  list_of_labels = sorted(predictor.prediction_map.keys())
529
 
@@ -545,5 +602,5 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
545
  outputs=combined_output
546
  )
547
 
548
- # Finally, launch
549
  demo.launch()
 
7
  ######################################
8
  # 1) LOAD DATA & MODELS
9
  ######################################
10
+ # Load your dataset
11
  df = pd.read_csv("X_train_test_combined_dataset_Filtered_dataset.csv")
12
 
13
+ # Ensure 'YMDESUD5ANYO' exists in your DataFrame
14
+ if 'YMDESUD5ANYO' not in df.columns:
15
+ raise ValueError("The column 'YMDESUD5ANYO' is missing from the dataset. Please check your CSV file.")
16
+
17
+ # List of model filenames
18
  model_filenames = [
19
  "YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl",
20
  "YODPPROB.pkl", "YOWRPROB.pkl", "YODPR2WK.pkl", "YOWRDEPR.pkl",
 
23
  ]
24
  model_path = "models/"
25
 
 
26
  ######################################
27
  # 2) MODEL PREDICTOR
28
  ######################################
 
43
  "YODPPROB": ["No other problems for 2+ weeks", "Had other problems for 2+ weeks"],
44
  "YOWRPROB": ["No 'worst time ever' feeling", "Had 'worst time ever' feeling"],
45
  "YODPR2WK": ["No depressed feelings for 2+ wks", "Had depressed feelings for 2+ wks"],
46
+ "YOWRDEPR": ["Did NOT feel sad/depressed daily", "Felt sad/depressed mostly everyday"],
47
  "YODPDISC": ["Overall mood not sad/depressed", "Overall mood was sad/depressed"],
48
  "YOLOSEV": ["Did NOT lose interest in things", "Lost interest in enjoyable things"],
49
  "YOWRDCSN": ["Was able to make decisions", "Was unable to make decisions"],
 
59
  def load_models(self):
60
  loaded = []
61
  for fname in self.model_filenames:
62
+ try:
63
+ with open(self.model_path + fname, "rb") as f:
64
+ model = pickle.load(f)
65
+ loaded.append(model)
66
+ except FileNotFoundError:
67
+ raise FileNotFoundError(f"Model file '{fname}' not found in path '{self.model_path}'.")
68
+ except Exception as e:
69
+ raise Exception(f"Error loading model '{fname}': {e}")
70
  return loaded
71
 
72
  def make_predictions(self, user_input: pd.DataFrame):
 
101
  else:
102
  return "Mental Health Severity: Very Low"
103
 
 
104
  predictor = ModelPredictor(model_path, model_filenames)
105
 
 
106
  ######################################
107
  # 3) FEATURE CATEGORIES + MAPPING
108
  ######################################
 
109
  categories_dict = {
110
  "1. Depression & Substance Use Diagnosis": [
111
+ "YMDESUD5ANYO", "YMDELT", "YMDEYR", "YMDERSUD5ANY",
112
  "YMSUD5YANY", "YMIUD5YANY", "YMIMS5YANY", "YMIMI5YANY"
113
  ],
114
  "2. Mental Health Treatment & Prof Consultation": [
 
123
  ]
124
  }
125
 
 
126
  input_mapping = {
127
+ 'YMDESUD5ANYO': {
128
+ "SUD only, no MDE": 1,
129
+ "MDE only, no SUD": 2,
130
+ "SUD and MDE": 3,
131
+ "Neither SUD or MDE": 4
132
+ },
133
  'YMDELT': {"Yes": 1, "No": 2},
134
  'YMDEYR': {"Yes": 1, "No": 2},
135
  'YMDERSUD5ANY': {"Yes": 1, "No": 0},
 
151
  'YCOUNMDE': {"Yes": 1, "No": 0},
152
 
153
  'MDEIMPY': {"Yes": 1, "No": 2},
154
+ 'LVLDIFMEM2': {
155
+ "No Difficulty": 1,
156
+ "Some difficulty": 2,
157
+ "A lot of difficulty or cannot do at all": 3
158
+ },
159
 
160
  'YUSUITHK': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
161
  'YUSUITHKYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
 
163
  'YUSUIPLN': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4}
164
  }
165
 
 
166
  def validate_inputs(*args):
167
  for arg in args:
168
+ if arg is None or arg == "":
169
  return False
170
  return True
171
 
 
223
  lines.append("")
224
  return "\n".join(lines)
225
 
 
226
  ######################################
227
  # 5) PREDICT FUNCTION
228
  ######################################
229
  def predict(
230
  # Category 1 (8):
231
+ YMDESUD5ANYO, YMDELT, YMDEYR, YMDERSUD5ANY,
232
  YMSUD5YANY, YMIUD5YANY, YMIMS5YANY, YMIMI5YANY,
233
  # Category 2 (11):
234
  YMDEHPO, YMDETXRX, YMDEHARX, YMDEHPRX, YRXMDEYR,
 
240
  ):
241
  # 1) Validate
242
  if not validate_inputs(
243
+ YMDESUD5ANYO, YMDELT, YMDEYR, YMDERSUD5ANY,
244
  YMSUD5YANY, YMIUD5YANY, YMIMS5YANY, YMIMI5YANY,
245
  YMDEHPO, YMDETXRX, YMDEHARX, YMDEHPRX, YRXMDEYR,
246
  YHLTMDE, YTXMDEYR, YDOCMDE, YPSY2MDE, YPSY1MDE, YCOUNMDE,
 
248
  YUSUITHK, YUSUITHKYR, YUSUIPLNYR, YUSUIPLN
249
  ):
250
  return (
251
+ "Please select all required fields.", # 1) Prediction Results
252
+ "Validation Error", # 2) Severity
253
+ "No data", # 3) Total Count
254
+ "No nearest neighbors info", # 4) NN Summary
255
+ None, # 5) Bar chart (Input)
256
+ None # 6) Bar chart (Labels)
257
+ )
258
+
259
+ # 2) Convert text -> numeric
260
+ try:
261
+ user_input_dict = {
262
+ 'YMDESUD5ANYO': input_mapping['YMDESUD5ANYO'][YMDESUD5ANYO],
263
+ 'YMDELT': input_mapping['YMDELT'][YMDELT],
264
+ 'YMDEYR': input_mapping['YMDEYR'][YMDEYR],
265
+ 'YMDERSUD5ANY': input_mapping['YMDERSUD5ANY'][YMDERSUD5ANY],
266
+ 'YMSUD5YANY': input_mapping['YMSUD5YANY'][YMSUD5YANY],
267
+ 'YMIUD5YANY': input_mapping['YMIUD5YANY'][YMIUD5YANY],
268
+ 'YMIMS5YANY': input_mapping['YMIMS5YANY'][YMIMS5YANY],
269
+ 'YMIMI5YANY': input_mapping['YMIMI5YANY'][YMIMI5YANY],
270
+
271
+ 'YMDEHPO': input_mapping['YMDEHPO'][YMDEHPO],
272
+ 'YMDETXRX': input_mapping['YMDETXRX'][YMDETXRX],
273
+ 'YMDEHARX': input_mapping['YMDEHARX'][YMDEHARX],
274
+ 'YMDEHPRX': input_mapping['YMDEHPRX'][YMDEHPRX],
275
+ 'YRXMDEYR': input_mapping['YRXMDEYR'][YRXMDEYR],
276
+ 'YHLTMDE': input_mapping['YHLTMDE'][YHLTMDE],
277
+ 'YTXMDEYR': input_mapping['YTXMDEYR'][YTXMDEYR],
278
+ 'YDOCMDE': input_mapping['YDOCMDE'][YDOCMDE],
279
+ 'YPSY2MDE': input_mapping['YPSY2MDE'][YPSY2MDE],
280
+ 'YPSY1MDE': input_mapping['YPSY1MDE'][YPSY1MDE],
281
+ 'YCOUNMDE': input_mapping['YCOUNMDE'][YCOUNMDE],
282
+
283
+ 'MDEIMPY': input_mapping['MDEIMPY'][MDEIMPY],
284
+ 'LVLDIFMEM2': input_mapping['LVLDIFMEM2'][LVLDIFMEM2],
285
+
286
+ 'YUSUITHK': input_mapping['YUSUITHK'][YUSUITHK],
287
+ 'YUSUITHKYR': input_mapping['YUSUITHKYR'][YUSUITHKYR],
288
+ 'YUSUIPLNYR': input_mapping['YUSUIPLNYR'][YUSUIPLNYR],
289
+ 'YUSUIPLN': input_mapping['YUSUIPLN'][YUSUIPLN]
290
+ }
291
+ except KeyError as e:
292
+ missing_key = e.args[0]
293
+ return (
294
+ f"Input mapping missing for key: {missing_key}. Please check your `input_mapping` dictionary.",
295
+ "Mapping Error",
296
  "No data",
297
  "No nearest neighbors info",
298
  None,
299
  None
300
  )
301
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
  user_df = pd.DataFrame(user_input_dict, index=[0])
303
 
304
  # 3) Make predictions
305
+ try:
306
+ preds, probs = predictor.make_predictions(user_df)
307
+ except Exception as e:
308
+ return (
309
+ f"Error during prediction: {e}",
310
+ "Prediction Error",
311
+ "No data",
312
+ "No nearest neighbors info",
313
+ None,
314
+ None
315
+ )
316
 
317
  # Flatten predictions for severity count
318
  all_preds = np.concatenate(preds)
 
330
 
331
  # Group them by domain
332
  domain_groups = {
333
+ "Concentration and Decision Making": ["YOWRCONC", "YOWRDCSN"],
334
+ "Sleep and Energy Levels": ["YOWRHRS", "YO_MDEA5", "YOWRELES", "YO_MDEA2"],
335
+ "Mood and Emotional State": [
336
  "YOWRCHR", "YOWRLSIN", "YOWRDEPR", "YODPDISC", "YOLOSEV", "YODPLSIN", "YODSCEV"
337
  ],
338
+ "Appetite and Weight Changes": ["YO_MDEA3", "YOWRELES"],
339
+ "Duration and Severity of Depression Symptoms": [
340
  "YODPPROB", "YOWRPROB", "YODPR2WK", "YODSMMDE", "YOPB2WK"
341
  ]
342
  }
 
355
  if not np.isnan(prob_val):
356
  text_prob = f"(Prob= {prob_val:.2f})"
357
  else:
358
+ text_prob = "(No probability available)"
359
 
360
  group_lines.append(f"{lbl} => {text_pred} {text_prob}")
361
  if group_lines:
362
+ final_str_parts.append(f"**{gname}**")
 
363
  final_str_parts.append("\n".join(group_lines))
364
+ final_str_parts.append("") # Add an empty line for spacing
365
 
366
  if final_str_parts:
367
  final_str = "\n".join(final_str_parts)
 
379
  for col, val_ in user_input_dict.items():
380
  matched = len(df[df[col] == val_])
381
  input_counts[col] = matched
382
+ bar_in_df = pd.DataFrame({
383
+ "Feature": list(input_counts.keys()),
384
+ "Count": list(input_counts.values())
385
+ })
386
  fig_in = px.bar(
387
  bar_in_df, x="Feature", y="Count",
388
  title="Number of Patients with the Same Input Feature Values"
 
412
  final_str, # 1) Prediction Results
413
  severity_msg, # 2) Mental Health Severity
414
  total_count_md, # 3) Total Patient Count
415
+ nn_md, # 4) Nearest Neighbors Summary
416
  fig_in, # 5) Bar Chart (input features)
417
  fig_lbl # 6) Bar Chart (labels)
418
  )
419
 
 
420
  ######################################
421
  # 6) UNIFIED DISTRIBUTION/CO-OCCURRENCE
422
  ######################################
 
434
  if f_ not in df.columns or label_col not in df.columns:
435
  return px.bar(title="Selected columns not found in the dataset.")
436
  grouped = df.groupby([f_, label_col]).size().reset_index(name="count")
437
+ fig = px.bar(
438
+ grouped,
439
+ x=f_,
440
+ y="count",
441
+ color=label_col,
442
+ title=f"Distribution of {f_} vs {label_col}"
443
+ )
444
  fig.update_layout(width=1200, height=600)
445
  return fig
446
 
 
450
  return px.bar(title="Selected columns not found in the dataset.")
451
  grouped = df.groupby([f1, f2, label_col]).size().reset_index(name="count")
452
  fig = px.bar(
453
+ grouped,
454
+ x=f1,
455
+ y="count",
456
+ color=label_col,
457
+ facet_col=f2,
458
+ title=f"Co-occurrence: {f1}, {f2} vs {label_col}"
459
  )
460
  fig.update_layout(width=1200, height=600)
461
  return fig
 
463
  else:
464
  return px.bar(title="Please select exactly 1 or 2 features.")
465
 
 
466
  ######################################
467
  # 7) BUILD GRADIO UI
468
  ######################################
469
  with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
470
 
471
+ # ======== TAB 1: Prediction ========
472
  with gr.Tab("Prediction"):
473
  gr.Markdown("### Please provide inputs in each of the four categories below. All fields are required.")
474
 
475
+ # Category 1: Depression & Substance Use Diagnosis (8 features)
476
  gr.Markdown("#### 1. Depression & Substance Use Diagnosis")
477
  cat1_col_labels = [
478
+ ("YMDESUD5ANYO", "YMDESUD5ANYO: ONLY MDE, ONLY SUD, BOTH, OR NEITHER"),
479
  ("YMDELT", "YMDELT: Had major depressive episode in lifetime"),
480
  ("YMDEYR", "YMDEYR: Past-year major depressive episode"),
481
  ("YMDERSUD5ANY", "YMDERSUD5ANY: MDE or SUD in past year?"),
 
487
  cat1_inputs = []
488
  for col, label_text in cat1_col_labels:
489
  cat1_inputs.append(
490
+ gr.Dropdown(
491
+ choices=list(input_mapping[col].keys()),
492
+ label=label_text
493
+ )
494
  )
495
 
496
+ # Category 2: Mental Health Treatment & Professional Consultation (11 features)
497
  gr.Markdown("#### 2. Mental Health Treatment & Professional Consultation")
498
  cat2_col_labels = [
499
  ("YMDEHPO", "YMDEHPO: Saw health prof only for MDE"),
 
511
  cat2_inputs = []
512
  for col, label_text in cat2_col_labels:
513
  cat2_inputs.append(
514
+ gr.Dropdown(
515
+ choices=list(input_mapping[col].keys()),
516
+ label=label_text
517
+ )
518
  )
519
 
520
+ # Category 3: Functional & Cognitive Impairment (2 features)
521
  gr.Markdown("#### 3. Functional & Cognitive Impairment")
522
  cat3_col_labels = [
523
  ("MDEIMPY", "MDEIMPY: MDE with severe role impairment?"),
 
526
  cat3_inputs = []
527
  for col, label_text in cat3_col_labels:
528
  cat3_inputs.append(
529
+ gr.Dropdown(
530
+ choices=list(input_mapping[col].keys()),
531
+ label=label_text
532
+ )
533
  )
534
 
535
+ # Category 4: Suicidal Thoughts & Behaviors (4 features)
536
  gr.Markdown("#### 4. Suicidal Thoughts & Behaviors")
537
  cat4_col_labels = [
538
  ("YUSUITHK", "YUSUITHK: Thought of killing self (past 12 months)?"),
 
543
  cat4_inputs = []
544
  for col, label_text in cat4_col_labels:
545
  cat4_inputs.append(
546
+ gr.Dropdown(
547
+ choices=list(input_mapping[col].keys()),
548
+ label=label_text
549
+ )
550
  )
551
 
552
+ # Combine all inputs in the correct order
553
  all_inputs = cat1_inputs + cat2_inputs + cat3_inputs + cat4_inputs
554
 
555
+ # Output components
556
  predict_btn = gr.Button("Predict")
557
 
558
  out_pred_res = gr.Textbox(label="Prediction Results (with Probability)", lines=8)
 
562
  out_bar_input= gr.Plot(label="Input Feature Counts")
563
  out_bar_label= gr.Plot(label="Predicted Label Counts")
564
 
565
+ # Connect the predict button to the predict function
566
  predict_btn.click(
567
  fn=predict,
568
  inputs=all_inputs,
 
579
  # ======== TAB 2: Unified Distribution/Co-occurrence ========
580
  with gr.Tab("Distribution/Co-occurrence"):
581
  gr.Markdown("### Select 1 or 2 features + 1 label to see a bar chart.")
582
+
583
+ # Features can be selected from the dataset's columns
584
  list_of_features = sorted(df.columns)
585
  list_of_labels = sorted(predictor.prediction_map.keys())
586
 
 
602
  outputs=combined_output
603
  )
604
 
605
+ # Finally, launch the Gradio app
606
  demo.launch()