pantdipendra commited on
Commit
b782f65
·
verified ·
1 Parent(s): 51455ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +237 -238
app.py CHANGED
@@ -1,5 +1,6 @@
1
- import gradio as gr
2
  import pickle
 
 
3
  import numpy as np
4
  import pandas as pd
5
  import plotly.express as px
@@ -7,15 +8,13 @@ import plotly.express as px
7
  # Load the training CSV once (outside the functions so it is read only once).
8
  df = pd.read_csv("X_train_Y_Train_merged_train.csv")
9
 
10
- ##############################################################################
11
- # MODEL PREDICTOR CLASS
12
- ##############################################################################
13
-
14
  class ModelPredictor:
15
  def __init__(self, model_path, model_filenames):
16
  self.model_path = model_path
17
  self.model_filenames = model_filenames
18
  self.models = self.load_models()
 
 
19
  self.prediction_map = {
20
  "YOWRCONC": ["Did not have difficulty concentrating", "Had difficulty concentrating"],
21
  "YOSEEDOC": ["Did not feel the need to see a doctor", "Felt the need to see a doctor"],
@@ -57,7 +56,10 @@ class ModelPredictor:
57
  return models
58
 
59
  def make_predictions(self, user_input):
60
- """Returns a list of numpy arrays, each array is [0] or [1]."""
 
 
 
61
  predictions = []
62
  for model in self.models:
63
  pred = model.predict(user_input)
@@ -68,13 +70,17 @@ class ModelPredictor:
68
  def get_majority_vote(self, predictions):
69
  """
70
  Flatten all predictions from all models, combine them into a single array,
71
- then find the majority class (0 or 1).
72
  """
73
  combined_predictions = np.concatenate(predictions)
74
  majority_vote = np.bincount(combined_predictions).argmax()
75
  return majority_vote
76
 
77
- # Severity interpretation (same as before)
 
 
 
 
78
  def evaluate_severity(self, majority_vote_count):
79
  if majority_vote_count >= 13:
80
  return "Mental health severity: Severe"
@@ -85,6 +91,7 @@ class ModelPredictor:
85
  else:
86
  return "Mental health severity: Very Low"
87
 
 
88
  model_filenames = [
89
  "YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl",
90
  "YODPPROB.pkl", "YOWRPROB.pkl", "YODPR2WK.pkl", "YOWRDEPR.pkl",
@@ -94,21 +101,12 @@ model_filenames = [
94
  model_path = "models/"
95
  predictor = ModelPredictor(model_path, model_filenames)
96
 
97
- ##############################################################################
98
- # INPUT VALIDATION
99
- ##############################################################################
100
-
101
  def validate_inputs(*args):
102
- """Return False if any argument is blank or None."""
103
  for arg in args:
104
- if arg == '' or arg is None:
105
  return False
106
  return True
107
 
108
- ##############################################################################
109
- # MAIN PREDICT FUNCTION
110
- ##############################################################################
111
-
112
  def predict(
113
  YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
114
  YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
@@ -116,6 +114,20 @@ def predict(
116
  YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
117
  YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
118
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  # Prepare user_input dataframe for prediction
120
  user_input_data = {
121
  'YNURSMDE': [int(YNURSMDE)],
@@ -150,18 +162,21 @@ def predict(
150
  }
151
  user_input = pd.DataFrame(user_input_data)
152
 
153
- # 1) Make predictions for each of the 16 models
 
 
154
  predictions = predictor.make_predictions(user_input)
155
- # 2) Majority vote across all models
 
156
  majority_vote = predictor.get_majority_vote(predictions)
157
- # 3) Count how many 1's in all predictions
 
158
  majority_vote_count = sum([1 for pred in np.concatenate(predictions) if pred == 1])
 
159
  # 4) Evaluate severity
160
  severity = predictor.evaluate_severity(majority_vote_count)
161
 
162
- ############################################################################
163
- # (A) Summarize per-model predictions
164
- ############################################################################
165
  results = {
166
  "Concentration_and_Decision_Making": [],
167
  "Sleep_and_Energy_Levels": [],
@@ -180,221 +195,73 @@ def predict(
180
  "YODPR2WK", "YODSMMDE",
181
  "YOPB2WK"]
182
  }
183
-
184
  for i, pred in enumerate(predictions):
185
  model_name = model_filenames[i].split('.')[0] # e.g. 'YOWRCONC'
186
  pred_value = pred[0]
 
187
  if model_name in predictor.prediction_map and pred_value in [0, 1]:
188
  result_text = f"Model {model_name}: {predictor.prediction_map[model_name][pred_value]}"
 
 
 
189
  else:
190
- result_text = f"Model {model_name}: Unknown or out-of-range prediction {pred_value}"
191
 
 
192
  found_group = False
193
  for group_name, group_models in prediction_groups.items():
194
  if model_name in group_models:
195
  results[group_name].append(result_text)
196
  found_group = True
197
  break
 
 
 
198
 
 
199
  formatted_results = []
200
  for group, preds in results.items():
201
  if preds:
202
  formatted_results.append(f"Group {group.replace('_', ' ')}:")
203
  formatted_results.append("\n".join(preds))
204
- formatted_results.append("")
205
- if not formatted_results:
206
- formatted_results = ["No predictions made. Please check your inputs."]
207
-
208
- prediction_summary_text = "\n".join(formatted_results).strip()
209
 
210
- ############################################################################
211
- # (B) Show "Total Patient Count" (replacing old matched-vs-total)
212
- ############################################################################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  total_patients = len(df)
214
- total_patients_text = (
215
  "### Total Patient Count\n"
216
- f"This dataset contains **{total_patients}** patient records.\n\n"
217
- "In the next sections, we explore how the features and labels are distributed in these records."
218
- )
219
-
220
- ############################################################################
221
- # (C) CROSS-TABULATION & GROUPED BAR CHART (EXAMPLE)
222
- # We'll demonstrate with one feature (e.g., 'YMDEYR') vs. the actual label 'YOWRCONC'
223
- ############################################################################
224
- # Explanation:
225
- cross_tab_explanation = (
226
- "### Cross-Tabulation & Grouped Bar Chart\n"
227
- "This chart shows how often each category of a given feature (X-axis) co-occurs with each **actual label** (0 or 1). "
228
- "Interpreting this helps clinicians see which categories have a higher proportion of positive vs. negative outcomes. "
229
- "For instance, if 'Yes' in YMDEYR heavily corresponds to label=1, that suggests a stronger link between that feature and the mental health outcome."
230
- )
231
-
232
- if "YOWRCONC" in df.columns and "YMDEYR" in df.columns:
233
- # Make sure we actually have the columns needed
234
- ctab = pd.crosstab(df["YMDEYR"], df["YOWRCONC"])
235
- # ctab might have column names [0,1] for the label
236
- ctab.reset_index(inplace=True)
237
- # rename for clarity
238
- ctab.columns = ["YMDEYR_Value", "Label0_Count", "Label1_Count"]
239
-
240
- fig_crosstab = px.bar(
241
- ctab,
242
- x="YMDEYR_Value",
243
- y=["Label0_Count", "Label1_Count"],
244
- barmode="group",
245
- title="YMDEYR vs. YOWRCONC (Actual Label)",
246
- labels={
247
- "YMDEYR_Value": "YMDEYR Feature Categories",
248
- "value": "Count of Patients",
249
- "variable": "Label"
250
- }
251
- )
252
- else:
253
- # fallback if we don't have those columns
254
- fig_crosstab = px.bar(
255
- x=["Data Error"], y=[0],
256
- title="Could not generate cross-tab: 'YOWRCONC' or 'YMDEYR' not in df"
257
- )
258
-
259
- ############################################################################
260
- # (D) "SIMILAR PATIENT" / NEAREST-NEIGHBORS DEMO
261
- # We'll pick a small set of "key features", measure Hamming distance,
262
- # and find the top-K closest rows. Then we'll show how many had label=1.
263
- ############################################################################
264
- similar_explanation = (
265
- "### Similar Patients (Nearest Neighbors)\n"
266
- "Here we define a small set of key features and use a simple Hamming distance "
267
- "(count of mismatched categories) to find patients who are 'closest' to the current input. "
268
- "This helps clinicians see how similar patients were labeled or what interventions they needed."
269
  )
270
 
271
- # Example "key features" (choose whichever are most clinically relevant)
272
- key_features = ["YMDEYR", "YMDERSUD5ANY", "YMSUD5YANY", "LVLDIFMEM2"]
273
- if all(kf in df.columns for kf in key_features) and "YOWRCONC" in df.columns:
274
- # Compute distance for each row
275
- user_vector = [user_input_data[kf][0] for kf in key_features]
276
- distances = []
277
- for idx, row in df[key_features].iterrows():
278
- # Compare row to user_vector
279
- row_vector = row.values
280
- # Hamming distance = sum(row_vector[i] != user_vector[i])
281
- dist = sum(rv != uv for rv, uv in zip(row_vector, user_vector))
282
- distances.append(dist)
283
-
284
- # Add distances to a copy of df
285
- temp_df = df.copy()
286
- temp_df["HammingDist"] = distances
287
- # Sort ascending by distance, take top-K (e.g., 20)
288
- top_k = temp_df.nsmallest(20, "HammingDist")
289
- # Count how many have label=1 in top_k
290
- if "YOWRCONC" in top_k.columns:
291
- similar_label_1_count = (top_k["YOWRCONC"] == 1).sum()
292
- similar_label_0_count = (top_k["YOWRCONC"] == 0).sum()
293
- similar_text = (
294
- f"Out of the 20 most similar patients:\n"
295
- f"- {similar_label_1_count} had label=1\n"
296
- f"- {similar_label_0_count} had label=0\n"
297
- f"(Distances ranged from {top_k['HammingDist'].min()} to {top_k['HammingDist'].max()})."
298
- )
299
- else:
300
- similar_text = "Label column 'YOWRCONC' missing in dataset."
301
- else:
302
- similar_text = "Cannot compute nearest neighbors: some key features or label column are missing."
303
-
304
- ############################################################################
305
- # (E) CO-OCCURRENCE PLOT (TWO FEATURES) vs. LABEL
306
- ############################################################################
307
- cooccurrence_explanation = (
308
- "### Co-Occurrence of Two Features vs. Label\n"
309
- "This shows how two categorical features combine, and how many patients in each combination are labeled 0 or 1. "
310
- "Clinicians can spot if certain feature-combinations are particularly high-risk or high-incidence of label=1."
311
- )
312
-
313
- # Example: co-occurrence of 'YMDEYR' and 'YMDERSUD5ANY' vs. 'YOWRCONC'
314
- if all(col in df.columns for col in ["YMDEYR", "YMDERSUD5ANY", "YOWRCONC"]):
315
- co_tab = pd.crosstab([df["YMDEYR"], df["YMDERSUD5ANY"]], df["YOWRCONC"])
316
- co_tab.reset_index(inplace=True)
317
- # co_tab columns: ["YMDEYR", "YMDERSUD5ANY", "0", "1"]
318
- co_tab.columns = ["YMDEYR", "YMDERSUD5ANY", "Label0", "Label1"]
319
-
320
- # We'll create a stacked or grouped bar. Let's do grouped by label.
321
- # Construct a single column "Count" and a single column "Label" to let plotly group them
322
- data_list = []
323
- for i, row in co_tab.iterrows():
324
- data_list.append({
325
- "YMDEYR_Val": row["YMDEYR"],
326
- "YMDERSUD5ANY_Val": row["YMDERSUD5ANY"],
327
- "Label": "Label=0",
328
- "Count": row["Label0"]
329
- })
330
- data_list.append({
331
- "YMDEYR_Val": row["YMDEYR"],
332
- "YMDERSUD5ANY_Val": row["YMDERSUD5ANY"],
333
- "Label": "Label=1",
334
- "Count": row["Label1"]
335
- })
336
- df_co = pd.DataFrame(data_list)
337
-
338
- fig_cooccur = px.bar(
339
- df_co,
340
- x="YMDEYR_Val",
341
- y="Count",
342
- color="Label",
343
- facet_col="YMDERSUD5ANY_Val", # separate subplots by second feature
344
- barmode="group",
345
- title="Co-Occurrence: YMDEYR & YMDERSUD5ANY vs. YOWRCONC",
346
- labels={"YMDEYR_Val": "YMDEYR", "YMDERSUD5ANY_Val": "YMDERSUD5ANY"}
347
- )
348
- fig_cooccur.update_layout(
349
- legend_title_text="Actual Label",
350
- xaxis_title="YMDEYR Categories",
351
- yaxis_title="Number of Patients"
352
- )
353
- else:
354
- fig_cooccur = px.bar(
355
- x=["Data Error"], y=[0],
356
- title="Could not generate co-occurrence chart: missing columns"
357
- )
358
-
359
- #------------------------------------------------------------------------------
360
- # RETURN / RENDER
361
- #------------------------------------------------------------------------------
362
- # We have 6 outputs total (the code is set up for that).
363
- # We'll map them as follows:
364
- # 1) "Prediction Results" (Textbox)
365
- # 2) "Mental Health Severity" (Textbox)
366
- # 3) A Markdown that combines: total_patients_text + cross_tab_explanation + similar_explanation + cooccurrence_explanation + the nearest-neighbors result
367
- # 4) Cross-Tab Bar Chart
368
- # 5) "Number of Patients with the Same Value for Each Input Feature"
369
- # 6) "Number of Patients with Predicted Labels"
370
-
371
- # (i) Provide text results for the user’s predictions
372
- # (ii) Provide severity
373
-
374
- # Build the big markdown text for (3)
375
- big_markdown = (
376
- total_patients_text
377
- + "\n\n"
378
- + cross_tab_explanation
379
- + "\n\n"
380
- + f"**Crosstab Example**: See the bar chart below comparing 'YMDEYR' vs. actual label 'YOWRCONC'.\n\n"
381
- + similar_explanation
382
- + "\n\n"
383
- + similar_text
384
- + "\n\n"
385
- + cooccurrence_explanation
386
- + "\n\n"
387
- + "See the final chart below for how 'YMDEYR' & 'YMDERSUD5ANY' co-occur with label 'YOWRCONC'."
388
- )
389
-
390
- # (F) Bar Chart for each input feature
391
- # We'll keep the logic for counting how many in df have the same value for each feature
392
  input_counts = {}
393
- for col, val_list in user_input_data.items():
394
- val = val_list[0]
395
  same_val_count = len(df[df[col] == val])
396
  input_counts[col] = same_val_count
397
 
 
398
  bar_input_data = pd.DataFrame({
399
  "Feature": list(input_counts.keys()),
400
  "Count": list(input_counts.values())
@@ -408,13 +275,14 @@ def predict(
408
  )
409
  fig_bar_input.update_layout(xaxis={'categoryorder':'total descending'})
410
 
411
- # (G) Bar Chart for predicted labels
412
- # We'll skip "matched vs total" or "exact matching."
 
413
  label_counts = {}
414
  for i, pred in enumerate(predictions):
415
  model_name = model_filenames[i].split('.')[0]
416
  pred_value = pred[0]
417
- if pred_value in [0, 1] and model_name in df.columns:
418
  label_counts[model_name] = len(df[df[model_name] == pred_value])
419
 
420
  if len(label_counts) > 0:
@@ -426,12 +294,12 @@ def predict(
426
  bar_label_data,
427
  x="Model",
428
  y="Count",
429
- title="Number of Patients with the Same Predicted Label by Model",
430
  labels={"Model": "Predicted Column", "Count": "Number of Patients"}
431
  )
432
  fig_bar_labels.update_layout(xaxis={'categoryorder':'total descending'})
433
  else:
434
- # fallback
435
  bar_label_data = pd.DataFrame({"Model": [], "Count": []})
436
  fig_bar_labels = px.bar(
437
  bar_label_data,
@@ -440,20 +308,128 @@ def predict(
440
  title="No valid predicted labels to display"
441
  )
442
 
443
- # Finally return the updated outputs
444
- return (
445
- prediction_summary_text, # (1) Prediction Results
446
- severity, # (2) Mental Health Severity
447
- big_markdown, # (3) Our large Markdown with headings & explanations
448
- fig_crosstab, # (4) Cross-Tab Bar Chart
449
- fig_bar_input, # (5) Input Feature Bar Chart
450
- fig_bar_labels # (6) Predicted Labels Bar Chart
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451
  )
452
 
453
- ##############################################################################
454
- # INPUT MAPPING & GRADIO INTERFACE
455
- ##############################################################################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
456
 
 
 
 
457
  input_mapping = {
458
  'YNURSMDE': {"Yes": 1, "No": 0},
459
  'YMDEYR': {"Yes": 1, "No": 2},
@@ -486,8 +462,23 @@ input_mapping = {
486
  'YMDELT': {"Yes": 1, "No": 2}
487
  }
488
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
  # Define the "inputs" in the same order used in the function signature
490
  inputs = [
 
 
491
  gr.Dropdown(list(input_mapping['YMDEYR'].keys()), label="YMDEYR: PAST YEARS MAJOR DEPRESSIVE EPISODE"),
492
  gr.Dropdown(list(input_mapping['YMDERSUD5ANY'].keys()), label="YMDERSUD5ANY: MDE OR SUBSTANCE USE DISORDER - ANY"),
493
  gr.Dropdown(list(input_mapping['YMDEIMAD5YR'].keys()), label="YMDEIMAD5YR: MDE WITH SEV. IMP + ALCOHOL USE DISORDER"),
@@ -501,6 +492,8 @@ inputs = [
501
  gr.Dropdown(list(input_mapping['YMIMI5YANY'].keys()), label="YMIMI5YANY: MDE W/ ILL DRUG USE DISORDER"),
502
  gr.Dropdown(list(input_mapping['YMIUD5YANY'].keys()), label="YMIUD5YANY: MDE + ILL DRUG USE DISORDER"),
503
  gr.Dropdown(list(input_mapping['YMDESUD5ANYO'].keys()), label="YMDESUD5ANYO: MDE vs. SUD vs. BOTH vs. NEITHER"),
 
 
504
  gr.Dropdown(list(input_mapping['YNURSMDE'].keys()), label="YNURSMDE: SAW/TALK TO NURSE/OT ABOUT MDE"),
505
  gr.Dropdown(list(input_mapping['YSOCMDE'].keys()), label="YSOCMDE: SAW/TALK TO SOCIAL WORKER ABOUT MDE"),
506
  gr.Dropdown(list(input_mapping['YCOUNMDE'].keys()), label="YCOUNMDE: SAW/TALK TO COUNSELOR ABOUT MDE"),
@@ -509,22 +502,28 @@ inputs = [
509
  gr.Dropdown(list(input_mapping['YHLTMDE'].keys()), label="YHLTMDE: SAW/TALK TO HEALTH PROFESSIONAL ABOUT MDE"),
510
  gr.Dropdown(list(input_mapping['YDOCMDE'].keys()), label="YDOCMDE: SAW/TALK TO GP/FAMILY MD ABOUT MDE"),
511
  gr.Dropdown(list(input_mapping['YTXMDEYR'].keys()), label="YTXMDEYR: SAW/TALK DOCTOR/HEALTH PROF FOR MDE"),
 
 
512
  gr.Dropdown(list(input_mapping['YUSUITHKYR'].keys()), label="YUSUITHKYR: SERIOUSLY THOUGHT ABOUT KILLING SELF"),
513
  gr.Dropdown(list(input_mapping['YUSUIPLNYR'].keys()), label="YUSUIPLNYR: MADE PLANS TO KILL SELF"),
514
  gr.Dropdown(list(input_mapping['YUSUITHK'].keys()), label="YUSUITHK: THINK ABOUT KILLING SELF (12 MONTHS)"),
515
  gr.Dropdown(list(input_mapping['YUSUIPLN'].keys()), label="YUSUIPLN: MADE PLANS TO KILL SELF (12 MONTHS)"),
 
 
516
  gr.Dropdown(list(input_mapping['MDEIMPY'].keys()), label="MDEIMPY: MDE W/ SEVERE ROLE IMPAIRMENT"),
517
  gr.Dropdown(list(input_mapping['LVLDIFMEM2'].keys()), label="LVLDIFMEM2: LEVEL OF DIFFICULTY REMEMBERING/CONCENTRATING"),
518
  gr.Dropdown(list(input_mapping['YMSUD5YANY'].keys()), label="YMSUD5YANY: MDE + SUBSTANCE USE DISORDER - ANY"),
519
  gr.Dropdown(list(input_mapping['YRXMDEYR'].keys()), label="YRXMDEYR: USED MEDS FOR MDE IN PAST YEAR"),
520
  ]
521
 
522
- # We have 6 outputs now:
523
  outputs = [
524
  gr.Textbox(label="Prediction Results", lines=30),
525
  gr.Textbox(label="Mental Health Severity", lines=4),
526
- gr.Markdown(), # Combined heading & explanations for cross-tab, similar patients, co-occurrence
527
- gr.Plot(label="Cross-Tab (Feature vs. Actual Label)"),
 
 
528
  gr.Plot(label="Number of Patients per Input Feature"),
529
  gr.Plot(label="Number of Patients with Predicted Labels")
530
  ]
@@ -545,10 +544,14 @@ def predict_with_text(
545
  YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
546
  ):
547
  return (
548
- "Please select all required fields.", # Pred result
549
  "Validation Error", # Severity
550
- "", # Markdown
551
- None, None, None # Plots
 
 
 
 
552
  )
553
 
554
  # Map from user-friendly text to int
@@ -587,7 +590,6 @@ def predict_with_text(
587
  # Pass our mapped values into the original 'predict' function
588
  return predict(**user_inputs)
589
 
590
-
591
  # Custom CSS (optional)
592
  custom_css = """
593
  .gradio-container * {
@@ -606,10 +608,7 @@ custom_css = """
606
  }
607
  """
608
 
609
- ##############################################################################
610
- # LAUNCH INTERFACE
611
- ##############################################################################
612
-
613
  interface = gr.Interface(
614
  fn=predict_with_text,
615
  inputs=inputs,
 
 
1
  import pickle
2
+
3
+ import gradio as gr
4
  import numpy as np
5
  import pandas as pd
6
  import plotly.express as px
 
8
  # Load the training CSV once (outside the functions so it is read only once).
9
  df = pd.read_csv("X_train_Y_Train_merged_train.csv")
10
 
 
 
 
 
11
  class ModelPredictor:
12
  def __init__(self, model_path, model_filenames):
13
  self.model_path = model_path
14
  self.model_filenames = model_filenames
15
  self.models = self.load_models()
16
+ # For readability, you might want to keep only a few keys here if you want
17
+ # to demonstrate partial cross-tabs, etc.
18
  self.prediction_map = {
19
  "YOWRCONC": ["Did not have difficulty concentrating", "Had difficulty concentrating"],
20
  "YOSEEDOC": ["Did not feel the need to see a doctor", "Felt the need to see a doctor"],
 
56
  return models
57
 
58
  def make_predictions(self, user_input):
59
+ """
60
+ Returns a list of numpy arrays, each array is [0] or [1].
61
+ The i-th array corresponds to the i-th model in self.models.
62
+ """
63
  predictions = []
64
  for model in self.models:
65
  pred = model.predict(user_input)
 
70
  def get_majority_vote(self, predictions):
71
  """
72
  Flatten all predictions from all models, combine them into a single array,
73
+ then find the majority class (0 or 1) across all of them.
74
  """
75
  combined_predictions = np.concatenate(predictions)
76
  majority_vote = np.bincount(combined_predictions).argmax()
77
  return majority_vote
78
 
79
+ # Based on Equal Interval and Percentage-Based Method
80
+ # Severe: 13 to 16 votes (upper 25%)
81
+ # Moderate: 9 to 12 votes (upper-middle 25%)
82
+ # Low: 5 to 8 votes (lower-middle 25%)
83
+ # Very Low: 0 to 4 votes (lower 25%)
84
  def evaluate_severity(self, majority_vote_count):
85
  if majority_vote_count >= 13:
86
  return "Mental health severity: Severe"
 
91
  else:
92
  return "Mental health severity: Very Low"
93
 
94
+ # List of model filenames
95
  model_filenames = [
96
  "YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl",
97
  "YODPPROB.pkl", "YOWRPROB.pkl", "YODPR2WK.pkl", "YOWRDEPR.pkl",
 
101
  model_path = "models/"
102
  predictor = ModelPredictor(model_path, model_filenames)
103
 
 
 
 
 
104
  def validate_inputs(*args):
 
105
  for arg in args:
106
+ if arg == '' or arg is None: # Assuming empty string or None as unselected
107
  return False
108
  return True
109
 
 
 
 
 
110
  def predict(
111
  YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
112
  YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
 
114
  YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
115
  YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
116
  ):
117
+ """
118
+ Core prediction function that:
119
+ 1) Predicts with each model
120
+ 2) Aggregates results
121
+ 3) Produces an overall 'severity'
122
+ 4) Returns detailed per-model predictions
123
+ 5) Returns bar charts about how many in the dataset share the same inputs/predicted labels
124
+ 6) ***Now includes custom sections for:
125
+ - Total patient count (markdown)
126
+ - Cross-tab & grouped bar chart
127
+ - Similar Patient (Nearest Neighbors)
128
+ - Co-occurrence plot
129
+ """
130
+
131
  # Prepare user_input dataframe for prediction
132
  user_input_data = {
133
  'YNURSMDE': [int(YNURSMDE)],
 
162
  }
163
  user_input = pd.DataFrame(user_input_data)
164
 
165
+ # -----------------------
166
+ # 1) Make predictions
167
+ # -----------------------
168
  predictions = predictor.make_predictions(user_input)
169
+
170
+ # 2) Calculate majority vote (0 or 1) across all models
171
  majority_vote = predictor.get_majority_vote(predictions)
172
+
173
+ # 3) Count how many 1's in all predictions combined
174
  majority_vote_count = sum([1 for pred in np.concatenate(predictions) if pred == 1])
175
+
176
  # 4) Evaluate severity
177
  severity = predictor.evaluate_severity(majority_vote_count)
178
 
179
+ # 5) Prepare detailed results for each model group
 
 
180
  results = {
181
  "Concentration_and_Decision_Making": [],
182
  "Sleep_and_Energy_Levels": [],
 
195
  "YODPR2WK", "YODSMMDE",
196
  "YOPB2WK"]
197
  }
198
+
199
  for i, pred in enumerate(predictions):
200
  model_name = model_filenames[i].split('.')[0] # e.g. 'YOWRCONC'
201
  pred_value = pred[0]
202
+ # Map the prediction value to a human-readable string
203
  if model_name in predictor.prediction_map and pred_value in [0, 1]:
204
  result_text = f"Model {model_name}: {predictor.prediction_map[model_name][pred_value]}"
205
+ elif model_name in predictor.prediction_map:
206
+ # Out of known range => "Unknown"
207
+ result_text = f"Model {model_name}: Unknown prediction value {pred_value}"
208
  else:
209
+ result_text = f"Model {model_name}: Unknown model"
210
 
211
+ # Append to the appropriate group
212
  found_group = False
213
  for group_name, group_models in prediction_groups.items():
214
  if model_name in group_models:
215
  results[group_name].append(result_text)
216
  found_group = True
217
  break
218
+ if not found_group:
219
+ # If model doesn't match any group, skip or store it in a catch-all
220
+ pass
221
 
222
+ # 6) Nicely format the results
223
  formatted_results = []
224
  for group, preds in results.items():
225
  if preds:
226
  formatted_results.append(f"Group {group.replace('_', ' ')}:")
227
  formatted_results.append("\n".join(preds))
228
+ formatted_results.append("\n")
 
 
 
 
229
 
230
+ formatted_results = "\n".join(formatted_results).strip()
231
+
232
+ if len(formatted_results) == 0:
233
+ formatted_results = "No predictions made. Please check your inputs."
234
+
235
+ # Heuristic: if too many unknown predictions, append note
236
+ num_unknown = len([
237
+ pred for group, preds in results.items()
238
+ for pred in preds if "Unknown prediction value" in pred or "Unknown model" in pred
239
+ ])
240
+ if num_unknown > len(model_filenames) / 2:
241
+ severity += " (Unknown prediction count is high. Please consult with a human.)"
242
+
243
+ # ------------------------
244
+ # ADDITIONAL FEATURES
245
+ # ------------------------
246
+
247
+ # A) Total Patient Count (instead of the old "Pie" chart)
248
  total_patients = len(df)
249
+ total_patient_count_markdown = (
250
  "### Total Patient Count\n"
251
+ f"There are **{total_patients}** total patients in the dataset.\n\n"
252
+ "This count can help you understand the overall dataset size. "
253
+ "All subsequent analyses are relative to these patients."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  )
255
 
256
+ # B) Analyze Each Input Feature
257
+ # For each feature in user_input, compute how many patients have that same value.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  input_counts = {}
259
+ for col in user_input_data.keys():
260
+ val = user_input_data[col][0]
261
  same_val_count = len(df[df[col] == val])
262
  input_counts[col] = same_val_count
263
 
264
+ # Plot: Bar Chart for each input feature
265
  bar_input_data = pd.DataFrame({
266
  "Feature": list(input_counts.keys()),
267
  "Count": list(input_counts.values())
 
275
  )
276
  fig_bar_input.update_layout(xaxis={'categoryorder':'total descending'})
277
 
278
+ # C) Analyze Predicted Labels
279
+ # For each model's predicted label (0 or 1), count how many patients in the CSV
280
+ # have that label. We skip unknown if pred_value not in [0, 1].
281
  label_counts = {}
282
  for i, pred in enumerate(predictions):
283
  model_name = model_filenames[i].split('.')[0]
284
  pred_value = pred[0]
285
+ if pred_value in [0, 1]:
286
  label_counts[model_name] = len(df[df[model_name] == pred_value])
287
 
288
  if len(label_counts) > 0:
 
294
  bar_label_data,
295
  x="Model",
296
  y="Count",
297
+ title="Number of Patients with the Predicted Label (0 or 1) by Model",
298
  labels={"Model": "Predicted Column", "Count": "Number of Patients"}
299
  )
300
  fig_bar_labels.update_layout(xaxis={'categoryorder':'total descending'})
301
  else:
302
+ # If everything was unknown, produce an empty figure or a fallback message
303
  bar_label_data = pd.DataFrame({"Model": [], "Count": []})
304
  fig_bar_labels = px.bar(
305
  bar_label_data,
 
308
  title="No valid predicted labels to display"
309
  )
310
 
311
+ # D) Cross-Tabulation & Grouped Bar Chart
312
+ # Example: Show how a single input feature (YMDEYR) relates to one actual label (YOWRCONC).
313
+ # For demonstration only — in practice you might do this for multiple features/labels.
314
+ # NOTE: If the columns don't exist in the dataset (some code merges them differently),
315
+ # you might adapt accordingly.
316
+ if "YMDEYR" in df.columns and "YOWRCONC" in df.columns:
317
+ cross_tab_data = df.groupby(["YMDEYR", "YOWRCONC"]).size().reset_index(name="count")
318
+ fig_cross_tab = px.bar(
319
+ cross_tab_data,
320
+ x="YMDEYR",
321
+ y="count",
322
+ color="YOWRCONC",
323
+ barmode="group",
324
+ title="Cross-Tab: YMDEYR vs YOWRCONC (Grouped Bar Chart)",
325
+ labels={"YMDEYR": "Feature: YMDEYR", "YOWRCONC": "Label: YOWRCONC"}
326
+ )
327
+ else:
328
+ # Provide a fallback message if columns not found
329
+ fig_cross_tab = px.bar(title="YMDEYR or YOWRCONC not found in dataset. Cross-tab not available.")
330
+
331
+ # E) Similar Patient (Nearest Neighbors) via simple Hamming distance
332
+ # We'll pick K=5 neighbors. Then see how many had label=0 vs label=1 for
333
+ # one example label: YOWRCONC.
334
+ # (You can adapt to do multiple labels, but that can get lengthy.)
335
+ def hamming_distance(row, user_row):
336
+ dist = 0
337
+ for c in user_row.index:
338
+ if row[c] != user_row[c]:
339
+ dist += 1
340
+ return dist
341
+
342
+ # Create a single row for easy iteration
343
+ user_series = user_input.iloc[0]
344
+
345
+ # We'll compute distance for all rows in df on the same features
346
+ # that were used in the user_input.
347
+ # NOTE: In real usage, confirm these columns exist in df.
348
+ # If df lacks them or is encoded differently, you'd adapt.
349
+ features_to_compare = list(user_input.columns)
350
+ # For Hamming, ensure we pick only the columns present in df
351
+ features_to_compare = [f for f in features_to_compare if f in df.columns]
352
+
353
+ # Build a DataFrame we can safely compare
354
+ subset_df = df[features_to_compare].copy()
355
+
356
+ # Calculate distances
357
+ distances = []
358
+ for idx, row in subset_df.iterrows():
359
+ d = 0
360
+ for col in features_to_compare:
361
+ if row[col] != user_series[col]:
362
+ d += 1
363
+ distances.append(d)
364
+
365
+ # Attach distances
366
+ df_with_dist = df.copy()
367
+ df_with_dist["distance"] = distances
368
+
369
+ # Sort by distance ascending, pick top K=5
370
+ K = 5
371
+ nearest_neighbors = df_with_dist.sort_values("distance", ascending=True).head(K)
372
+
373
+ # For demonstration, let's show how many had YOWRCONC=0 vs. 1
374
+ nn_label_0 = nn_label_1 = 0
375
+ if "YOWRCONC" in nearest_neighbors.columns:
376
+ nn_label_0 = len(nearest_neighbors[nearest_neighbors["YOWRCONC"] == 0])
377
+ nn_label_1 = len(nearest_neighbors[nearest_neighbors["YOWRCONC"] == 1])
378
+
379
+ # Summarize in markdown
380
+ similar_patient_markdown = (
381
+ "### Nearest Neighbors (Simple Hamming Distance)\n"
382
+ f"We searched for the top **{K}** patients in the dataset whose categorical features "
383
+ "most closely match your input (Hamming distance).\n\n"
384
+ "**For the label `YOWRCONC`** among these neighbors:\n"
385
+ f"- {nn_label_0} had label=0\n"
386
+ f"- {nn_label_1} had label=1\n\n"
387
+ "(This is a simple illustration. In real practice, you'd refine which columns to use, "
388
+ "how to encode them, and how many neighbors to consider.)"
389
  )
390
 
391
+ # F) Co-Occurrence Plot
392
+ # Example: How two features (YMDEYR, YMDERSUD5ANY) combine with label (YOWRCONC).
393
+ # We'll produce a multi-way distribution using facet_col.
394
+ if all(col in df.columns for col in ["YMDEYR", "YMDERSUD5ANY", "YOWRCONC"]):
395
+ co_occ_data = df.groupby(["YMDEYR", "YMDERSUD5ANY", "YOWRCONC"]).size().reset_index(name="count")
396
+ fig_co_occ = px.bar(
397
+ co_occ_data,
398
+ x="YMDEYR",
399
+ y="count",
400
+ color="YOWRCONC",
401
+ facet_col="YMDERSUD5ANY",
402
+ title="Co-Occurrence Plot: YMDEYR and YMDERSUD5ANY vs YOWRCONC"
403
+ )
404
+ else:
405
+ fig_co_occ = px.bar(title="Co-occurrence plot not available (columns not found).")
406
+
407
+ # ------------------------
408
+ # Return everything
409
+ # ------------------------
410
+ # We now have 8 items to return:
411
+ # 1) Prediction Results (Textbox)
412
+ # 2) Mental Health Severity (Textbox)
413
+ # 3) Total Patient Count (Markdown)
414
+ # 4) Cross-Tab & Grouped Bar Chart (Plot)
415
+ # 5) Nearest Neighbors Summary (Markdown)
416
+ # 6) Co-Occurrence Plot (Plot)
417
+ # 7) Bar Chart for input features (Plot)
418
+ # 8) Bar Chart for predicted labels (Plot)
419
+ return (
420
+ formatted_results,
421
+ severity,
422
+ total_patient_count_markdown,
423
+ fig_cross_tab,
424
+ similar_patient_markdown,
425
+ fig_co_occ,
426
+ fig_bar_input,
427
+ fig_bar_labels
428
+ )
429
 
430
+ # -----------------------------------------------------------------------------
431
+ # MAPPING user-friendly text => numeric values
432
+ # -----------------------------------------------------------------------------
433
  input_mapping = {
434
  'YNURSMDE': {"Yes": 1, "No": 0},
435
  'YMDEYR': {"Yes": 1, "No": 2},
 
462
  'YMDELT': {"Yes": 1, "No": 2}
463
  }
464
 
465
+ # -----------------------------------------------------------------------------
466
+ # Create the Gradio interface
467
+ # -----------------------------------------------------------------------------
468
+ # We have 8 outputs now:
469
+ # 1) Prediction Results (Textbox)
470
+ # 2) Mental Health Severity (Textbox)
471
+ # 3) Total Patient Count (Markdown)
472
+ # 4) Cross-Tab & Grouped Bar Chart (Plot)
473
+ # 5) Nearest Neighbors Summary (Markdown)
474
+ # 6) Co-Occurrence Plot (Plot)
475
+ # 7) Bar Chart for input features (Plot)
476
+ # 8) Bar Chart for predicted labels (Plot)
477
+
478
  # Define the "inputs" in the same order used in the function signature
479
  inputs = [
480
+ ################# Ordered and grouped ##########################
481
+ # Questions related to Major Depressive Episode (MDE) and related impairments or disorders
482
  gr.Dropdown(list(input_mapping['YMDEYR'].keys()), label="YMDEYR: PAST YEARS MAJOR DEPRESSIVE EPISODE"),
483
  gr.Dropdown(list(input_mapping['YMDERSUD5ANY'].keys()), label="YMDERSUD5ANY: MDE OR SUBSTANCE USE DISORDER - ANY"),
484
  gr.Dropdown(list(input_mapping['YMDEIMAD5YR'].keys()), label="YMDEIMAD5YR: MDE WITH SEV. IMP + ALCOHOL USE DISORDER"),
 
492
  gr.Dropdown(list(input_mapping['YMIMI5YANY'].keys()), label="YMIMI5YANY: MDE W/ ILL DRUG USE DISORDER"),
493
  gr.Dropdown(list(input_mapping['YMIUD5YANY'].keys()), label="YMIUD5YANY: MDE + ILL DRUG USE DISORDER"),
494
  gr.Dropdown(list(input_mapping['YMDESUD5ANYO'].keys()), label="YMDESUD5ANYO: MDE vs. SUD vs. BOTH vs. NEITHER"),
495
+
496
+ # Questions related to consultations with professionals about MDE
497
  gr.Dropdown(list(input_mapping['YNURSMDE'].keys()), label="YNURSMDE: SAW/TALK TO NURSE/OT ABOUT MDE"),
498
  gr.Dropdown(list(input_mapping['YSOCMDE'].keys()), label="YSOCMDE: SAW/TALK TO SOCIAL WORKER ABOUT MDE"),
499
  gr.Dropdown(list(input_mapping['YCOUNMDE'].keys()), label="YCOUNMDE: SAW/TALK TO COUNSELOR ABOUT MDE"),
 
502
  gr.Dropdown(list(input_mapping['YHLTMDE'].keys()), label="YHLTMDE: SAW/TALK TO HEALTH PROFESSIONAL ABOUT MDE"),
503
  gr.Dropdown(list(input_mapping['YDOCMDE'].keys()), label="YDOCMDE: SAW/TALK TO GP/FAMILY MD ABOUT MDE"),
504
  gr.Dropdown(list(input_mapping['YTXMDEYR'].keys()), label="YTXMDEYR: SAW/TALK DOCTOR/HEALTH PROF FOR MDE"),
505
+
506
+ # Questions related to suicidal thoughts and plans
507
  gr.Dropdown(list(input_mapping['YUSUITHKYR'].keys()), label="YUSUITHKYR: SERIOUSLY THOUGHT ABOUT KILLING SELF"),
508
  gr.Dropdown(list(input_mapping['YUSUIPLNYR'].keys()), label="YUSUIPLNYR: MADE PLANS TO KILL SELF"),
509
  gr.Dropdown(list(input_mapping['YUSUITHK'].keys()), label="YUSUITHK: THINK ABOUT KILLING SELF (12 MONTHS)"),
510
  gr.Dropdown(list(input_mapping['YUSUIPLN'].keys()), label="YUSUIPLN: MADE PLANS TO KILL SELF (12 MONTHS)"),
511
+
512
+ # Questions related to impairment due to MDE
513
  gr.Dropdown(list(input_mapping['MDEIMPY'].keys()), label="MDEIMPY: MDE W/ SEVERE ROLE IMPAIRMENT"),
514
  gr.Dropdown(list(input_mapping['LVLDIFMEM2'].keys()), label="LVLDIFMEM2: LEVEL OF DIFFICULTY REMEMBERING/CONCENTRATING"),
515
  gr.Dropdown(list(input_mapping['YMSUD5YANY'].keys()), label="YMSUD5YANY: MDE + SUBSTANCE USE DISORDER - ANY"),
516
  gr.Dropdown(list(input_mapping['YRXMDEYR'].keys()), label="YRXMDEYR: USED MEDS FOR MDE IN PAST YEAR"),
517
  ]
518
 
519
+ # We now have 8 outputs in total:
520
  outputs = [
521
  gr.Textbox(label="Prediction Results", lines=30),
522
  gr.Textbox(label="Mental Health Severity", lines=4),
523
+ gr.Markdown(label="Total Patient Count"),
524
+ gr.Plot(label="Cross-Tab & Grouped Bar Chart"),
525
+ gr.Markdown(label="Nearest Neighbors Summary"),
526
+ gr.Plot(label="Co-Occurrence Plot"),
527
  gr.Plot(label="Number of Patients per Input Feature"),
528
  gr.Plot(label="Number of Patients with Predicted Labels")
529
  ]
 
544
  YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
545
  ):
546
  return (
547
+ "Please select all required fields.", # Prediction Results
548
  "Validation Error", # Severity
549
+ "No data", # Total Patient Count
550
+ None, # Cross-Tab figure
551
+ "No data", # Nearest Neighbors
552
+ None, # Co-Occurrence
553
+ None, # Input Features Bar
554
+ None # Predicted Labels Bar
555
  )
556
 
557
  # Map from user-friendly text to int
 
590
  # Pass our mapped values into the original 'predict' function
591
  return predict(**user_inputs)
592
 
 
593
  # Custom CSS (optional)
594
  custom_css = """
595
  .gradio-container * {
 
608
  }
609
  """
610
 
611
+ # Finally, launch the app with 8 outputs
 
 
 
612
  interface = gr.Interface(
613
  fn=predict_with_text,
614
  inputs=inputs,