pantdipendra commited on
Commit
16ca108
·
verified ·
1 Parent(s): 6749d1f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +216 -242
app.py CHANGED
@@ -5,11 +5,10 @@ import plotly.express as px
5
  import gradio as gr
6
 
7
  ######################################
8
- # 1) Load Data & Prepare
9
  ######################################
10
- df = pd.read_csv("X_train_Y_Train_merged_train.csv")
11
 
12
- # List of model filenames (adjust if needed)
13
  model_filenames = [
14
  "YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl",
15
  "YODPPROB.pkl", "YOWRPROB.pkl", "YODPR2WK.pkl", "YOWRDEPR.pkl",
@@ -18,70 +17,60 @@ model_filenames = [
18
  ]
19
  model_path = "models/"
20
 
21
-
22
- ######################################
23
- # 2) Model Predictor
24
- ######################################
25
  class ModelPredictor:
26
  def __init__(self, model_path, model_filenames):
27
  self.model_path = model_path
28
  self.model_filenames = model_filenames
29
  self.models = self.load_models()
30
- # Mapping from label column to human-readable strings for 0/1
31
  self.prediction_map = {
32
- "YOWRCONC": ["Did not have difficulty concentrating", "Had difficulty concentrating"],
33
- "YOSEEDOC": ["Did not feel the need to see a doctor", "Felt the need to see a doctor"],
34
- "YOWRHRS": ["Did not have trouble sleeping", "Had trouble sleeping"],
35
- "YO_MDEA5": ["Others did not notice restlessness/lethargy", "Others noticed restlessness/lethargy"],
36
- "YOWRCHR": ["Did not feel so sad", "Felt so sad nothing could cheer up"],
37
- "YOWRLSIN": ["Did not feel bored and lose interest", "Felt bored and lost interest"],
38
- "YODPPROB": ["No other problems for 2+ weeks", "Had other problems for 2+ weeks"],
39
- "YOWRPROB": ["Did not have the worst time ever feeling", "Had the worst time ever feeling"],
40
- "YODPR2WK": ["No periods of 2+ weeks feelings", "Had periods of 2+ weeks feelings"],
41
  "YOWRDEPR": ["Did not feel depressed mostly everyday", "Felt depressed mostly everyday"],
42
- "YODPDISC": ["Mood not depressed overall", "Mood depressed overall (discrepancy)"],
43
- "YOLOSEV": ["Did not lose interest in enjoyable things", "Lost interest in enjoyable things"],
44
  "YOWRDCSN": ["Was able to make decisions", "Was unable to make decisions"],
45
- "YODSMMDE": ["Never had depression for 2+ weeks", "Had depression for 2+ weeks"],
46
  "YO_MDEA3": ["No appetite/weight changes", "Had appetite/weight changes"],
47
  "YODPLSIN": ["Never bored/lost interest", "Felt bored/lost interest"],
48
  "YOWRELES": ["Did not eat less than usual", "Ate less than usual"],
49
  "YODSCEV": ["Fewer severe symptoms", "More severe symptoms"],
50
- "YOPB2WK": ["No uneasy feelings 2+ weeks", "Had uneasy feelings 2+ weeks"],
51
- "YO_MDEA2": ["No issues w/ physical/mental well-being", "Issues w/ physical/mental well-being"]
52
  }
53
 
54
  def load_models(self):
55
- models = []
56
- for filename in model_filenames:
57
- filepath = self.model_path + filename
58
- with open(filepath, 'rb') as file:
59
- model = pickle.load(file)
60
- models.append(model)
61
- return models
62
-
63
- def make_predictions(self, user_input):
64
  """
65
- Returns a list of numpy arrays, each array is [0] or [1].
66
- The i-th array corresponds to the i-th model in self.models.
67
  """
68
  predictions = []
69
  for model in self.models:
70
- pred = model.predict(user_input)
71
- predictions.append(pred.flatten())
72
  return predictions
73
 
74
  def get_majority_vote(self, predictions):
75
- """
76
- Flatten all predictions from all models, combine them,
77
- then find the majority class (0 or 1).
78
- """
79
  combined = np.concatenate(predictions)
80
- majority = np.bincount(combined).argmax()
81
- return majority
82
 
83
- # Simple threshold approach (0-4 => Very Low, 5-8 => Low, etc.)
84
- def evaluate_severity(self, majority_vote_count):
85
  if majority_vote_count >= 13:
86
  return "Mental Health Severity: Severe"
87
  elif majority_vote_count >= 9:
@@ -91,22 +80,52 @@ class ModelPredictor:
91
  else:
92
  return "Mental Health Severity: Very Low"
93
 
 
94
 
95
  ######################################
96
- # 3) Validate Inputs
97
  ######################################
98
  def validate_inputs(*args):
99
  for arg in args:
100
- if arg == '' or arg is None:
101
  return False
102
  return True
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  ######################################
106
- # 4) Core Prediction
107
  ######################################
108
- predictor = ModelPredictor(model_path, model_filenames)
109
-
110
  def predict(
111
  YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
112
  YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
@@ -114,7 +133,7 @@ def predict(
114
  YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
115
  YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
116
  ):
117
- # Validate
118
  if not validate_inputs(
119
  YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
120
  YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
@@ -133,53 +152,50 @@ def predict(
133
  None
134
  )
135
 
136
- # Build dataframe from user inputs
137
- user_input_data = {
138
- 'YNURSMDE': [int(YNURSMDE)],
139
- 'YMDEYR': [int(YMDEYR)],
140
- 'YSOCMDE': [int(YSOCMDE)],
141
- 'YMDESUD5ANYO': [int(YMDESUD5ANYO)],
142
- 'YMSUD5YANY': [int(YMSUD5YANY)],
143
- 'YUSUITHK': [int(YUSUITHK)],
144
- 'YMDETXRX': [int(YMDETXRX)],
145
- 'YUSUITHKYR': [int(YUSUITHKYR)],
146
- 'YMDERSUD5ANY': [int(YMDERSUD5ANY)],
147
- 'YUSUIPLNYR': [int(YUSUIPLNYR)],
148
- 'YCOUNMDE': [int(YCOUNMDE)],
149
- 'YPSY1MDE': [int(YPSY1MDE)],
150
- 'YHLTMDE': [int(YHLTMDE)],
151
- 'YDOCMDE': [int(YDOCMDE)],
152
- 'YPSY2MDE': [int(YPSY2MDE)],
153
- 'YMDEHARX': [int(YMDEHARX)],
154
- 'LVLDIFMEM2': [int(LVLDIFMEM2)],
155
- 'MDEIMPY': [int(MDEIMPY)],
156
- 'YMDEHPO': [int(YMDEHPO)],
157
- 'YMIMS5YANY': [int(YMIMS5YANY)],
158
- 'YMDEIMAD5YR': [int(YMDEIMAD5YR)],
159
- 'YMIUD5YANY': [int(YMIUD5YANY)],
160
- 'YMDEHPRX': [int(YMDEHPRX)],
161
- 'YMIMI5YANY': [int(YMIMI5YANY)],
162
- 'YUSUIPLN': [int(YUSUIPLN)],
163
- 'YTXMDEYR': [int(YTXMDEYR)],
164
- 'YMDEAUD5YR': [int(YMDEAUD5YR)],
165
- 'YRXMDEYR': [int(YRXMDEYR)],
166
- 'YMDELT': [int(YMDELT)]
167
  }
168
- user_input = pd.DataFrame(user_input_data)
169
 
170
- # 1) Predictions
171
- predictions = predictor.make_predictions(user_input)
172
-
173
- # 2) Majority vote
174
  majority_vote = predictor.get_majority_vote(predictions)
 
 
 
 
175
 
176
- # 3) Count of '1's
177
- num_ones = sum(np.concatenate(predictions) == 1)
178
-
179
- # 4) Severity
180
- severity = predictor.evaluate_severity(num_ones)
181
-
182
- # 5) Group textual results
183
  groups = {
184
  "Concentration_and_Decision_Making": ["YOWRCONC", "YOWRDCSN"],
185
  "Sleep_and_Energy_Levels": ["YOWRHRS", "YO_MDEA5", "YOWRELES", "YO_MDEA2"],
@@ -191,186 +207,145 @@ def predict(
191
  "YOPB2WK"]
192
  }
193
 
194
- grouped_text = {k: [] for k in groups}
195
  for i, arr in enumerate(predictions):
196
- col_name = model_filenames[i].split('.')[0]
197
- pred_val = arr[0]
198
- if col_name in predictor.prediction_map and pred_val in [0,1]:
199
- text_val = predictor.prediction_map[col_name][pred_val]
200
  else:
201
- text_val = f"Prediction={pred_val}"
202
-
203
- found_group = False
204
  for gname, gcols in groups.items():
205
- if col_name in gcols:
206
- grouped_text[gname].append(f"{col_name} => {text_val}")
207
- found_group = True
208
  break
209
- # If not found_group, we do nothing (skip or put in a "misc" group)
210
-
211
- final_str = []
212
- for gname, items in grouped_text.items():
213
- if items:
214
- final_str.append(f"**{gname.replace('_',' ')}**")
215
- final_str.append("\n".join(items))
216
- final_str.append("\n")
217
- final_str = "\n".join(final_str).strip()
218
- if not final_str:
219
- final_str = "No predictions made. Please check inputs."
220
-
221
- # Additional info
222
- total_patients = len(df)
223
- total_patient_markdown = (
224
- f"### Total Patient Count\nThere are **{total_patients}** patients in the dataset."
225
- )
226
-
227
- # A) Bar chart for input features
228
- same_val_counts = {}
229
- for col, val_list in user_input_data.items():
230
- val_ = val_list[0]
231
- same_val_counts[col] = len(df[df[col] == val_])
232
- bar_input_df = pd.DataFrame({"Feature": list(same_val_counts.keys()),
233
- "Count": list(same_val_counts.values())})
234
- fig_bar_input = px.bar(
235
- bar_input_df, x="Feature", y="Count",
236
- title="Number of Patients with Same Input Feature Values"
237
- )
238
- fig_bar_input.update_layout(width=800, height=500)
239
 
240
- # B) Bar chart for predicted labels
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  label_counts = {}
242
  for i, arr in enumerate(predictions):
243
- lbl_col = model_filenames[i].split('.')[0]
244
  pred_val = arr[0]
245
  if pred_val in [0,1]:
246
- label_counts[lbl_col] = len(df[df[lbl_col] == pred_val])
247
-
248
  if label_counts:
249
- bar_label_df = pd.DataFrame({"Label": list(label_counts.keys()),
250
- "Count": list(label_counts.values())})
251
- fig_bar_labels = px.bar(bar_label_df, x="Label", y="Count",
252
- title="Number of Patients with the Same Predicted Label")
253
- fig_bar_labels.update_layout(width=800, height=500)
254
  else:
255
- fig_bar_labels = px.bar(title="No valid predicted labels to display.")
256
- fig_bar_labels.update_layout(width=800, height=500)
257
-
258
- # C) Distribution Plot (small sample)
259
- subset_input_cols = list(user_input_data.keys())[:4] # first 4 input columns
260
- subset_labels = [fn.split('.')[0] for fn in model_filenames[:3]] # first 3 label columns
261
- dist_rows = []
262
- for feat in subset_input_cols:
263
- if feat not in df.columns:
264
  continue
265
- for label_col in subset_labels:
266
- if label_col not in df.columns:
267
  continue
268
- tmp = df.groupby([feat, label_col]).size().reset_index(name="count")
269
- tmp["feature"] = feat
270
- tmp["label"] = label_col
271
- dist_rows.append(tmp)
272
- if dist_rows:
273
- big_dist_df = pd.concat(dist_rows, ignore_index=True)
274
  fig_dist = px.bar(
275
- big_dist_df,
276
- x=big_dist_df.columns[0],
277
  y="count",
278
- color=big_dist_df.columns[1],
279
  facet_row="feature",
280
  facet_col="label",
281
- title="Distribution of Sample Input Features vs. Sample Predicted Labels"
282
  )
283
- fig_dist.update_layout(width=1000, height=700)
284
  else:
285
  fig_dist = px.bar(title="Distribution plot not generated.")
286
 
287
- # D) Nearest neighbors (placeholder or your own logic)
288
- nearest_neighbors_markdown = "Nearest neighbors omitted or placed here if needed..."
289
-
290
- # We won't produce a co-occurrence plot by default here, so set to None
291
- co_occurrence_placeholder = None
292
 
293
- # Return the 8 outputs
294
  return (
295
- final_str, # 1) Prediction Results
296
- severity, # 2) Mental Health Severity
297
- total_patient_markdown, # 3) Total Patient Count
298
- fig_dist, # 4) Distribution Plot
299
- nearest_neighbors_markdown, # 5) Nearest Neighbors
300
- co_occurrence_placeholder, # 6) Co-occurrence Plot placeholder
301
- fig_bar_input, # 7) Bar Chart for input features
302
- fig_bar_labels # 8) Bar Chart for predicted labels
303
  )
304
 
305
-
306
- ######################################
307
- # 5) Input Mapping
308
- ######################################
309
- input_mapping = {
310
- 'YNURSMDE': {"Yes": 1, "No": 0},
311
- 'YMDEYR': {"Yes": 1, "No": 2},
312
- 'YSOCMDE': {"Yes": 1, "No": 0},
313
- 'YMDESUD5ANYO': {"SUD only, no MDE": 1, "MDE only, no SUD": 2, "SUD and MDE": 3, "Neither SUD or MDE": 4},
314
- 'YMSUD5YANY': {"Yes": 1, "No": 0},
315
- 'YUSUITHK': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
316
- 'YMDETXRX': {"Yes": 1, "No": 0},
317
- 'YUSUITHKYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
318
- 'YMDERSUD5ANY': {"Yes": 1, "No": 0},
319
- 'YUSUIPLNYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
320
- 'YCOUNMDE': {"Yes": 1, "No": 0},
321
- 'YPSY1MDE': {"Yes": 1, "No": 0},
322
- 'YHLTMDE': {"Yes": 1, "No": 0},
323
- 'YDOCMDE': {"Yes": 1, "No": 0},
324
- 'YPSY2MDE': {"Yes": 1, "No": 0},
325
- 'YMDEHARX': {"Yes": 1, "No": 0},
326
- 'LVLDIFMEM2': {"No Difficulty": 1, "Some difficulty": 2, "A lot of difficulty or cannot do at all": 3},
327
- 'MDEIMPY': {"Yes": 1, "No": 2},
328
- 'YMDEHPO': {"Yes": 1, "No": 0},
329
- 'YMIMS5YANY': {"Yes": 1, "No": 0},
330
- 'YMDEIMAD5YR': {"Yes": 1, "No": 0},
331
- 'YMIUD5YANY': {"Yes": 1, "No": 0},
332
- 'YMDEHPRX': {"Yes": 1, "No": 0},
333
- 'YMIMI5YANY': {"Yes": 1, "No": 0},
334
- 'YUSUIPLN': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
335
- 'YTXMDEYR': {"Yes": 1, "No": 0},
336
- 'YMDEAUD5YR': {"Yes": 1, "No": 0},
337
- 'YRXMDEYR': {"Yes": 1, "No": 0},
338
- 'YMDELT': {"Yes": 1, "No": 2}
339
- }
340
-
341
-
342
  ######################################
343
- # 6) Co-Occurrence Function
344
  ######################################
345
  def co_occurrence_plot(feature1, feature2, label_col):
346
  """
347
- Generate a single co-occurrence bar chart grouping by [feature1, feature2, label_col].
348
  """
349
- if not feature1 or not feature2 or not label_col:
350
  return px.bar(title="Please select all three fields.")
351
  if feature1 not in df.columns or feature2 not in df.columns or label_col not in df.columns:
352
  return px.bar(title="Selected columns not found in the dataset.")
353
 
354
- grouped_df = df.groupby([feature1, feature2, label_col]).size().reset_index(name="count")
355
  fig = px.bar(
356
- grouped_df,
357
  x=feature1,
358
  y="count",
359
  color=label_col,
360
  facet_col=feature2,
361
- title=f"Co-Occurrence Plot: {feature1} & {feature2} vs. {label_col}"
362
  )
363
- fig.update_layout(width=1000, height=600)
364
  return fig
365
 
366
-
367
  ######################################
368
- # 7) Gradio Interface with Tabs
369
  ######################################
370
- with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
371
-
372
  with gr.Tab("Prediction"):
373
- # --------- INPUT FIELDS --------- #
374
  YMDEYR_dd = gr.Dropdown(list(input_mapping['YMDEYR'].keys()), label="YMDEYR")
375
  YMDERSUD5ANY_dd = gr.Dropdown(list(input_mapping['YMDERSUD5ANY'].keys()), label="YMDERSUD5ANY")
376
  YMDEIMAD5YR_dd = gr.Dropdown(list(input_mapping['YMDEIMAD5YR'].keys()), label="YMDEIMAD5YR")
@@ -395,7 +370,7 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
395
  YDOCMDE_dd = gr.Dropdown(list(input_mapping['YDOCMDE'].keys()), label="YDOCMDE")
396
  YTXMDEYR_dd = gr.Dropdown(list(input_mapping['YTXMDEYR'].keys()), label="YTXMDEYR")
397
 
398
- # Suicidal thoughts/plans
399
  YUSUITHKYR_dd = gr.Dropdown(list(input_mapping['YUSUITHKYR'].keys()), label="YUSUITHKYR")
400
  YUSUIPLNYR_dd = gr.Dropdown(list(input_mapping['YUSUIPLNYR'].keys()), label="YUSUIPLNYR")
401
  YUSUITHK_dd = gr.Dropdown(list(input_mapping['YUSUITHK'].keys()), label="YUSUITHK")
@@ -407,10 +382,10 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
407
  YMSUD5YANY_dd = gr.Dropdown(list(input_mapping['YMSUD5YANY'].keys()), label="YMSUD5YANY")
408
  YRXMDEYR_dd = gr.Dropdown(list(input_mapping['YRXMDEYR'].keys()), label="YRXMDEYR")
409
 
410
- # --------- PREDICT BUTTON (BEFORE OUTPUTS) --------- #
411
  predict_btn = gr.Button("Predict")
412
 
413
- # --------- OUTPUTS (IN THE SAME ORDER AS THE RETURN TUPLE) --------- #
414
  out_pred_res = gr.Textbox(label="Prediction Results", lines=8)
415
  out_sev = gr.Textbox(label="Mental Health Severity", lines=2)
416
  out_count = gr.Markdown(label="Total Patient Count")
@@ -420,7 +395,7 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
420
  out_bar_input = gr.Plot(label="Input Feature Counts")
421
  out_bar_labels = gr.Plot(label="Predicted Label Counts")
422
 
423
- # Link button to the function
424
  predict_btn.click(
425
  fn=predict,
426
  inputs=[
@@ -436,21 +411,20 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
436
  ]
437
  )
438
 
439
- # ------------- SECOND TAB (CO-OCCURRENCE) -------------
440
  with gr.Tab("Co-occurrence"):
441
- gr.Markdown("## Generate a Co-Occurrence Plot on Demand\nSelect two features and one label:")
442
  with gr.Row():
443
- feature1_dd = gr.Dropdown(sorted(df.columns), label="Feature 1")
444
- feature2_dd = gr.Dropdown(sorted(df.columns), label="Feature 2")
445
  label_dd = gr.Dropdown(sorted(df.columns), label="Label Column")
446
- out_co_occ_plot = gr.Plot(label="Co-occurrence Plot")
 
447
 
448
- co_occ_btn = gr.Button("Generate Plot")
449
- co_occ_btn.click(
450
  fn=co_occurrence_plot,
451
- inputs=[feature1_dd, feature2_dd, label_dd],
452
- outputs=out_co_occ_plot
453
  )
454
 
455
- # Optionally, you can customize your CSS or server launch parameters
456
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
5
  import gradio as gr
6
 
7
  ######################################
8
+ # 1) LOAD DATA & MODELS
9
  ######################################
10
+ df = pd.read_csv("X_train_Y_Train_merged_train.csv") # Make sure the CSV is present
11
 
 
12
  model_filenames = [
13
  "YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl",
14
  "YODPPROB.pkl", "YOWRPROB.pkl", "YODPR2WK.pkl", "YOWRDEPR.pkl",
 
17
  ]
18
  model_path = "models/"
19
 
 
 
 
 
20
  class ModelPredictor:
21
  def __init__(self, model_path, model_filenames):
22
  self.model_path = model_path
23
  self.model_filenames = model_filenames
24
  self.models = self.load_models()
25
+ # Mapping from each label column to a list: [meaning_of_0, meaning_of_1]
26
  self.prediction_map = {
27
+ "YOWRCONC": ["No difficulty concentrating", "Had difficulty concentrating"],
28
+ "YOSEEDOC": ["Did not feel need for doctor", "Felt need for doctor"],
29
+ "YOWRHRS": ["No trouble sleeping", "Had trouble sleeping"],
30
+ "YO_MDEA5": ["No restlessness/lethargy noted", "Others noticed restlessness/lethargy"],
31
+ "YOWRCHR": ["Did not feel so sad", "Felt so sad that nothing cheered up"],
32
+ "YOWRLSIN": ["No boredom/loss of interest", "Bored/lost interest in everything"],
33
+ "YODPPROB": ["No other 2+ week problems", "Had other 2+ week problems"],
34
+ "YOWRPROB": ["Did not have worst feeling ever", "Had worst time feeling"],
35
+ "YODPR2WK": ["No 2+ weeks of these feelings", "Had 2+ weeks of these feelings"],
36
  "YOWRDEPR": ["Did not feel depressed mostly everyday", "Felt depressed mostly everyday"],
37
+ "YODPDISC": ["Mood not depressed overall", "Mood depressed overall discrepancy"],
38
+ "YOLOSEV": ["No loss of interest in enjoyable things", "Lost interest in enjoyable things"],
39
  "YOWRDCSN": ["Was able to make decisions", "Was unable to make decisions"],
40
+ "YODSMMDE": ["No 2+ week depression episodes", "Had 2+ week depression episodes"],
41
  "YO_MDEA3": ["No appetite/weight changes", "Had appetite/weight changes"],
42
  "YODPLSIN": ["Never bored/lost interest", "Felt bored/lost interest"],
43
  "YOWRELES": ["Did not eat less than usual", "Ate less than usual"],
44
  "YODSCEV": ["Fewer severe symptoms", "More severe symptoms"],
45
+ "YOPB2WK": ["No uneasy feelings for 2+ weeks", "Had uneasy feelings for 2+ weeks"],
46
+ "YO_MDEA2": ["No daily well-being issues", "Daily well-being issues for 2+ weeks"]
47
  }
48
 
49
  def load_models(self):
50
+ loaded = []
51
+ for fname in self.model_filenames:
52
+ with open(self.model_path + fname, "rb") as f:
53
+ model = pickle.load(f)
54
+ loaded.append(model)
55
+ return loaded
56
+
57
+ def make_predictions(self, user_input: pd.DataFrame):
 
58
  """
59
+ Return list of arrays, each array is [0] or [1].
 
60
  """
61
  predictions = []
62
  for model in self.models:
63
+ out = model.predict(user_input)
64
+ predictions.append(out.flatten())
65
  return predictions
66
 
67
  def get_majority_vote(self, predictions):
 
 
 
 
68
  combined = np.concatenate(predictions)
69
+ # find 0 or 1 that is most frequent
70
+ return np.bincount(combined).argmax()
71
 
72
+ def evaluate_severity(self, majority_vote_count: int) -> str:
73
+ # Simple thresholds
74
  if majority_vote_count >= 13:
75
  return "Mental Health Severity: Severe"
76
  elif majority_vote_count >= 9:
 
80
  else:
81
  return "Mental Health Severity: Very Low"
82
 
83
+ predictor = ModelPredictor(model_path, model_filenames)
84
 
85
  ######################################
86
+ # 2) VALIDATION, INPUT MAPPING
87
  ######################################
88
  def validate_inputs(*args):
89
  for arg in args:
90
+ if not arg: # empty or None
91
  return False
92
  return True
93
 
94
+ input_mapping = {
95
+ 'YNURSMDE': {"Yes": 1, "No": 0},
96
+ 'YMDEYR': {"Yes": 1, "No": 2},
97
+ 'YSOCMDE': {"Yes": 1, "No": 0},
98
+ 'YMDESUD5ANYO': {"SUD only, no MDE": 1, "MDE only, no SUD": 2, "SUD and MDE": 3, "Neither SUD or MDE": 4},
99
+ 'YMSUD5YANY': {"Yes": 1, "No": 0},
100
+ 'YUSUITHK': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
101
+ 'YMDETXRX': {"Yes": 1, "No": 0},
102
+ 'YUSUITHKYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
103
+ 'YMDERSUD5ANY': {"Yes": 1, "No": 0},
104
+ 'YUSUIPLNYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
105
+ 'YCOUNMDE': {"Yes": 1, "No": 0},
106
+ 'YPSY1MDE': {"Yes": 1, "No": 0},
107
+ 'YHLTMDE': {"Yes": 1, "No": 0},
108
+ 'YDOCMDE': {"Yes": 1, "No": 0},
109
+ 'YPSY2MDE': {"Yes": 1, "No": 0},
110
+ 'YMDEHARX': {"Yes": 1, "No": 0},
111
+ 'LVLDIFMEM2': {"No Difficulty": 1, "Some difficulty": 2, "A lot of difficulty or cannot do at all": 3},
112
+ 'MDEIMPY': {"Yes": 1, "No": 2},
113
+ 'YMDEHPO': {"Yes": 1, "No": 0},
114
+ 'YMIMS5YANY': {"Yes": 1, "No": 0},
115
+ 'YMDEIMAD5YR': {"Yes": 1, "No": 0},
116
+ 'YMIUD5YANY': {"Yes": 1, "No": 0},
117
+ 'YMDEHPRX': {"Yes": 1, "No": 0},
118
+ 'YMIMI5YANY': {"Yes": 1, "No": 0},
119
+ 'YUSUIPLN': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
120
+ 'YTXMDEYR': {"Yes": 1, "No": 0},
121
+ 'YMDEAUD5YR': {"Yes": 1, "No": 0},
122
+ 'YRXMDEYR': {"Yes": 1, "No": 0},
123
+ 'YMDELT': {"Yes": 1, "No": 2}
124
+ }
125
 
126
  ######################################
127
+ # 3) PREDICT FUNCTION
128
  ######################################
 
 
129
  def predict(
130
  YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
131
  YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
 
133
  YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
134
  YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
135
  ):
136
+ # 1) Validate
137
  if not validate_inputs(
138
  YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
139
  YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
 
152
  None
153
  )
154
 
155
+ # 2) Map user-friendly -> numeric
156
+ user_input_dict = {
157
+ 'YNURSMDE': input_mapping['YNURSMDE'][YNURSMDE],
158
+ 'YMDEYR': input_mapping['YMDEYR'][YMDEYR],
159
+ 'YSOCMDE': input_mapping['YSOCMDE'][YSOCMDE],
160
+ 'YMDESUD5ANYO': input_mapping['YMDESUD5ANYO'][YMDESUD5ANYO],
161
+ 'YMSUD5YANY': input_mapping['YMSUD5YANY'][YMSUD5YANY],
162
+ 'YUSUITHK': input_mapping['YUSUITHK'][YUSUITHK],
163
+ 'YMDETXRX': input_mapping['YMDETXRX'][YMDETXRX],
164
+ 'YUSUITHKYR': input_mapping['YUSUITHKYR'][YUSUITHKYR],
165
+ 'YMDERSUD5ANY': input_mapping['YMDERSUD5ANY'][YMDERSUD5ANY],
166
+ 'YUSUIPLNYR': input_mapping['YUSUIPLNYR'][YUSUIPLNYR],
167
+ 'YCOUNMDE': input_mapping['YCOUNMDE'][YCOUNMDE],
168
+ 'YPSY1MDE': input_mapping['YPSY1MDE'][YPSY1MDE],
169
+ 'YHLTMDE': input_mapping['YHLTMDE'][YHLTMDE],
170
+ 'YDOCMDE': input_mapping['YDOCMDE'][YDOCMDE],
171
+ 'YPSY2MDE': input_mapping['YPSY2MDE'][YPSY2MDE],
172
+ 'YMDEHARX': input_mapping['YMDEHARX'][YMDEHARX],
173
+ 'LVLDIFMEM2': input_mapping['LVLDIFMEM2'][LVLDIFMEM2],
174
+ 'MDEIMPY': input_mapping['MDEIMPY'][MDEIMPY],
175
+ 'YMDEHPO': input_mapping['YMDEHPO'][YMDEHPO],
176
+ 'YMIMS5YANY': input_mapping['YMIMS5YANY'][YMIMS5YANY],
177
+ 'YMDEIMAD5YR': input_mapping['YMDEIMAD5YR'][YMDEIMAD5YR],
178
+ 'YMIUD5YANY': input_mapping['YMIUD5YANY'][YMIUD5YANY],
179
+ 'YMDEHPRX': input_mapping['YMDEHPRX'][YMDEHPRX],
180
+ 'YMIMI5YANY': input_mapping['YMIMI5YANY'][YMIMI5YANY],
181
+ 'YUSUIPLN': input_mapping['YUSUIPLN'][YUSUIPLN],
182
+ 'YTXMDEYR': input_mapping['YTXMDEYR'][YTXMDEYR],
183
+ 'YMDEAUD5YR': input_mapping['YMDEAUD5YR'][YMDEAUD5YR],
184
+ 'YRXMDEYR': input_mapping['YRXMDEYR'][YRXMDEYR],
185
+ 'YMDELT': input_mapping['YMDELT'][YMDELT]
186
  }
187
+ user_df = pd.DataFrame(user_input_dict, index=[0])
188
 
189
+ # 3) Make predictions
190
+ predictions = predictor.make_predictions(user_df)
191
+ # majority
 
192
  majority_vote = predictor.get_majority_vote(predictions)
193
+ # how many are '1'
194
+ count_ones = sum(np.concatenate(predictions) == 1)
195
+ # severity
196
+ severity_msg = predictor.evaluate_severity(count_ones)
197
 
198
+ # 4) Format textual results for each group (just as an example)
 
 
 
 
 
 
199
  groups = {
200
  "Concentration_and_Decision_Making": ["YOWRCONC", "YOWRDCSN"],
201
  "Sleep_and_Energy_Levels": ["YOWRHRS", "YO_MDEA5", "YOWRELES", "YO_MDEA2"],
 
207
  "YOPB2WK"]
208
  }
209
 
210
+ group_text = {g: [] for g in groups}
211
  for i, arr in enumerate(predictions):
212
+ label_col = model_filenames[i].split('.')[0] # e.g. 'YOWRCONC'
213
+ val = arr[0]
214
+ if label_col in predictor.prediction_map and val in [0,1]:
215
+ text_label = predictor.prediction_map[label_col][val]
216
  else:
217
+ text_label = f"Prediction={val}"
218
+ # see which group
219
+ found = False
220
  for gname, gcols in groups.items():
221
+ if label_col in gcols:
222
+ group_text[gname].append(f"{label_col} => {text_label}")
223
+ found = True
224
  break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
+ # build final results
227
+ final_str_parts = []
228
+ for gname, lines in group_text.items():
229
+ if lines:
230
+ final_str_parts.append(f"**{gname.replace('_',' ')}**")
231
+ final_str_parts.append("\n".join(lines))
232
+ final_str_parts.append("")
233
+ if not final_str_parts:
234
+ final_str = "No predictions made or no matching group columns."
235
+ else:
236
+ final_str = "\n".join(final_str_parts)
237
+
238
+ # 5) Additional features
239
+ # total patients
240
+ total_count = len(df)
241
+ total_count_md = f"### Total Patient Count\nWe have **{total_count}** patients in the dataset."
242
+
243
+ # bar chart for input features
244
+ input_counts = {}
245
+ for col, val_ in user_input_dict.items():
246
+ # only 1 item
247
+ v = val_
248
+ # how many have that value?
249
+ matched = len(df[df[col] == v])
250
+ input_counts[col] = matched
251
+ bar_in_df = pd.DataFrame({"Feature": list(input_counts.keys()),
252
+ "Count": list(input_counts.values())})
253
+ fig_in = px.bar(bar_in_df, x="Feature", y="Count",
254
+ title="Number of Patients with Same Input Feature Values")
255
+ fig_in.update_layout(width=700, height=400)
256
+
257
+ # bar chart for predicted labels
258
  label_counts = {}
259
  for i, arr in enumerate(predictions):
260
+ lblcol = model_filenames[i].split('.')[0]
261
  pred_val = arr[0]
262
  if pred_val in [0,1]:
263
+ # how many in df have that label?
264
+ label_counts[lblcol] = len(df[df[lblcol] == pred_val])
265
  if label_counts:
266
+ bar_lbl_df = pd.DataFrame({"Label": list(label_counts.keys()),
267
+ "Count": list(label_counts.values())})
268
+ fig_lbl = px.bar(bar_lbl_df, x="Label", y="Count",
269
+ title="Number of Patients with the Same Predicted Label")
270
+ fig_lbl.update_layout(width=700, height=400)
271
  else:
272
+ fig_lbl = px.bar(title="No valid predicted labels to display.")
273
+ fig_lbl.update_layout(width=700, height=400)
274
+
275
+ # distribution plot (just a small sample)
276
+ feat_sample = list(user_input_dict.keys())[:3]
277
+ label_sample = [mf.split('.')[0] for mf in model_filenames[:2]]
278
+ rows = []
279
+ for f_ in feat_sample:
280
+ if f_ not in df.columns:
281
  continue
282
+ for l_ in label_sample:
283
+ if l_ not in df.columns:
284
  continue
285
+ sub_g = df.groupby([f_, l_]).size().reset_index(name="count")
286
+ sub_g["feature"] = f_
287
+ sub_g["label"] = l_
288
+ rows.append(sub_g)
289
+ if rows:
290
+ big_df = pd.concat(rows, ignore_index=True)
291
  fig_dist = px.bar(
292
+ big_df,
293
+ x=big_df.columns[0], # feature value
294
  y="count",
295
+ color=big_df.columns[1], # label value
296
  facet_row="feature",
297
  facet_col="label",
298
+ title="Distribution (Sample Input Features vs Sample Labels)"
299
  )
300
+ fig_dist.update_layout(width=900, height=600)
301
  else:
302
  fig_dist = px.bar(title="Distribution plot not generated.")
303
 
304
+ # nearest neighbors or co-occ placeholder
305
+ nn_md = "Nearest neighbors / advanced metrics not implemented in this version."
306
+ co_occ_placeholder = None
 
 
307
 
 
308
  return (
309
+ final_str, # 1) Prediction Results
310
+ severity_msg, # 2) Mental Health Severity
311
+ total_count_md, # 3) Total Patient Count
312
+ fig_dist, # 4) Distribution Plot
313
+ nn_md, # 5) Nearest Neighbors (Markdown)
314
+ co_occ_placeholder, # 6) Co-occurrence Plot
315
+ fig_in, # 7) Bar Chart for input features
316
+ fig_lbl # 8) Bar Chart for predicted labels
317
  )
318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  ######################################
320
+ # 4) CO-OCCURRENCE FUNCTION
321
  ######################################
322
  def co_occurrence_plot(feature1, feature2, label_col):
323
  """
324
+ Create a bar chart for co-occurrence among feature1, feature2, and label_col.
325
  """
326
+ if (not feature1) or (not feature2) or (not label_col):
327
  return px.bar(title="Please select all three fields.")
328
  if feature1 not in df.columns or feature2 not in df.columns or label_col not in df.columns:
329
  return px.bar(title="Selected columns not found in the dataset.")
330
 
331
+ grouped = df.groupby([feature1, feature2, label_col]).size().reset_index(name="count")
332
  fig = px.bar(
333
+ grouped,
334
  x=feature1,
335
  y="count",
336
  color=label_col,
337
  facet_col=feature2,
338
+ title=f"Co-occurrence: {feature1}, {feature2} vs {label_col}"
339
  )
340
+ fig.update_layout(width=900, height=600)
341
  return fig
342
 
 
343
  ######################################
344
+ # 5) BUILD GRADIO UI
345
  ######################################
346
+ with gr.Blocks(css=".gradio-container {max-width: 1100px;}") as demo:
 
347
  with gr.Tab("Prediction"):
348
+ # Input fields in the same order as predict(...)
349
  YMDEYR_dd = gr.Dropdown(list(input_mapping['YMDEYR'].keys()), label="YMDEYR")
350
  YMDERSUD5ANY_dd = gr.Dropdown(list(input_mapping['YMDERSUD5ANY'].keys()), label="YMDERSUD5ANY")
351
  YMDEIMAD5YR_dd = gr.Dropdown(list(input_mapping['YMDEIMAD5YR'].keys()), label="YMDEIMAD5YR")
 
370
  YDOCMDE_dd = gr.Dropdown(list(input_mapping['YDOCMDE'].keys()), label="YDOCMDE")
371
  YTXMDEYR_dd = gr.Dropdown(list(input_mapping['YTXMDEYR'].keys()), label="YTXMDEYR")
372
 
373
+ # Suicidal
374
  YUSUITHKYR_dd = gr.Dropdown(list(input_mapping['YUSUITHKYR'].keys()), label="YUSUITHKYR")
375
  YUSUIPLNYR_dd = gr.Dropdown(list(input_mapping['YUSUIPLNYR'].keys()), label="YUSUIPLNYR")
376
  YUSUITHK_dd = gr.Dropdown(list(input_mapping['YUSUITHK'].keys()), label="YUSUITHK")
 
382
  YMSUD5YANY_dd = gr.Dropdown(list(input_mapping['YMSUD5YANY'].keys()), label="YMSUD5YANY")
383
  YRXMDEYR_dd = gr.Dropdown(list(input_mapping['YRXMDEYR'].keys()), label="YRXMDEYR")
384
 
385
+ # Button
386
  predict_btn = gr.Button("Predict")
387
 
388
+ # 8 outputs
389
  out_pred_res = gr.Textbox(label="Prediction Results", lines=8)
390
  out_sev = gr.Textbox(label="Mental Health Severity", lines=2)
391
  out_count = gr.Markdown(label="Total Patient Count")
 
395
  out_bar_input = gr.Plot(label="Input Feature Counts")
396
  out_bar_labels = gr.Plot(label="Predicted Label Counts")
397
 
398
+ # Connect
399
  predict_btn.click(
400
  fn=predict,
401
  inputs=[
 
411
  ]
412
  )
413
 
 
414
  with gr.Tab("Co-occurrence"):
415
+ gr.Markdown("## Co-Occurrence Plot\nSelect two features + one label to see a distribution.")
416
  with gr.Row():
417
+ feat1_dd = gr.Dropdown(sorted(df.columns), label="Feature 1")
418
+ feat2_dd = gr.Dropdown(sorted(df.columns), label="Feature 2")
419
  label_dd = gr.Dropdown(sorted(df.columns), label="Label Column")
420
+ generate_btn = gr.Button("Generate Plot")
421
+ co_occ_output = gr.Plot()
422
 
423
+ generate_btn.click(
 
424
  fn=co_occurrence_plot,
425
+ inputs=[feat1_dd, feat2_dd, label_dd],
426
+ outputs=co_occ_output
427
  )
428
 
429
+ # Launch
430
+ demo.launch()