File size: 25,917 Bytes
25a63b5
d1b265f
f9a8132
 
f9f7e1c
 
3b96ce2
16ca108
3b96ce2
ebac442
d1b265f
f9a8132
ebac442
 
 
 
 
3b96ce2
 
 
 
 
 
 
 
e84fe7d
 
 
f9a8132
 
 
 
 
c458985
6b501f6
230ba50
2e504c0
 
6b501f6
2e504c0
975c60a
 
2e504c0
f92effe
975c60a
2e504c0
975c60a
2e504c0
975c60a
 
2e504c0
87dd6c1
42d91b3
f9a8132
 
16ca108
 
ebac442
 
 
 
 
 
 
 
16ca108
 
 
975c60a
6b501f6
 
 
975c60a
2e504c0
 
3b96ce2
2e504c0
 
 
 
6b501f6
2e504c0
 
 
 
1fd21ae
c458985
975c60a
6b501f6
975c60a
c458985
3b96ce2
c458985
3b96ce2
c458985
3b96ce2
f9a8132
3b96ce2
f9a8132
16ca108
6749d1f
cf4c3a5
6b501f6
c458985
e9e83fc
 
ebac442
e9e83fc
 
 
975c60a
 
e9e83fc
 
 
 
 
 
 
 
c458985
e9e83fc
ebac442
 
 
 
 
 
6b501f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebac442
 
 
 
 
6b501f6
 
 
 
 
e9e83fc
c458985
c38370d
6c60301
ebac442
6c60301
 
437bcb0
e9e83fc
2e504c0
e9e83fc
 
 
 
 
 
 
 
 
 
 
 
2e504c0
 
 
 
 
 
6b501f6
2e504c0
6b501f6
2e504c0
 
 
 
 
 
 
6b501f6
2e504c0
 
 
 
6b501f6
 
2e504c0
 
 
 
 
 
 
 
 
 
 
 
 
6b501f6
2e504c0
e9e83fc
6749d1f
cf4c3a5
e9e83fc
cf4c3a5
f9f7e1c
6b501f6
ebac442
e9e83fc
6b501f6
975c60a
 
6b501f6
e9e83fc
6b501f6
e9e83fc
f9f7e1c
2e504c0
6749d1f
ebac442
e9e83fc
975c60a
 
e9e83fc
 
6749d1f
 
ebac442
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e504c0
 
 
 
6749d1f
 
16ca108
f9f7e1c
2e504c0
ebac442
 
 
 
 
 
 
 
 
 
 
2e504c0
6b501f6
 
2e504c0
16ca108
3b96ce2
6b501f6
 
 
 
 
 
 
 
 
 
 
ebac442
87dd6c1
ebac442
87dd6c1
e84fe7d
ebac442
 
e84fe7d
 
46e9809
975c60a
16ca108
6b501f6
 
2e504c0
 
 
 
 
 
 
6b501f6
2e504c0
 
 
ebac442
2e504c0
6b501f6
 
ebac442
6b501f6
ebac442
6b501f6
2e504c0
16ca108
2e504c0
 
16ca108
2e504c0
6b501f6
c458985
2e504c0
c458985
16ca108
2e504c0
16ca108
 
e84fe7d
16ca108
ebac442
 
 
 
c458985
 
 
 
d79221a
16ca108
0b36f6e
642143a
2e504c0
 
642143a
 
0b36f6e
642143a
 
 
0b36f6e
642143a
 
 
 
 
 
 
 
 
 
0b36f6e
642143a
 
e84fe7d
642143a
 
 
 
 
 
e84fe7d
685722d
cf4c3a5
16ca108
d79221a
16ca108
0b36f6e
 
 
 
 
 
 
 
f9f7e1c
cf4c3a5
2e504c0
cf4c3a5
2e504c0
 
6b501f6
2e504c0
6b501f6
f92effe
 
2e504c0
 
6b501f6
2e504c0
f92effe
 
 
 
 
 
685722d
f92effe
 
 
685722d
f92effe
 
 
 
 
 
2e504c0
 
f92effe
2e504c0
f92effe
ebac442
f92effe
 
 
ebac442
f92effe
ebac442
2e504c0
 
 
 
6b501f6
f92effe
2e504c0
f92effe
2e504c0
f92effe
 
 
ebac442
f92effe
 
2e504c0
 
 
cf4c3a5
2e504c0
 
d1b265f
cf4c3a5
c458985
cf4c3a5
e84fe7d
 
ebac442
3b96ce2
6b501f6
e9e83fc
ebac442
975c60a
e9e83fc
ebac442
e9e83fc
 
975c60a
e9e83fc
 
685722d
 
e9e83fc
 
 
2e504c0
ebac442
 
 
 
e9e83fc
 
ebac442
e9e83fc
975c60a
2e504c0
975c60a
6b501f6
 
975c60a
6b501f6
2e504c0
975c60a
2e504c0
 
 
975c60a
e9e83fc
 
2e504c0
ebac442
 
 
 
e9e83fc
 
ebac442
e9e83fc
975c60a
 
 
 
e9e83fc
 
2e504c0
ebac442
 
 
 
e9e83fc
 
ebac442
e9e83fc
975c60a
2e504c0
975c60a
 
 
 
e9e83fc
 
2e504c0
ebac442
 
 
 
e9e83fc
 
ebac442
e9e83fc
 
ebac442
6749d1f
6b501f6
2e504c0
c458985
 
2e504c0
c458985
685722d
3b96ce2
ebac442
3b96ce2
 
e9e83fc
3b96ce2
6b501f6
 
 
 
 
 
3b96ce2
33c8e0b
1fd21ae
2e504c0
 
6b501f6
f92effe
685722d
726e8be
f92effe
4848b3d
f92effe
2e504c0
 
 
c458985
6b501f6
 
 
 
f92effe
2e504c0
 
f92effe
2e504c0
 
 
 
3b96ce2
67c356a
ebac442
87dd6c1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
import pickle
import gradio as gr
import numpy as np
import pandas as pd
import plotly.express as px

######################################
# 1) LOAD DATA & MODELS
######################################
# Load your dataset
df = pd.read_csv("X_train_test_combined_dataset_Filtered_dataset.csv")  

# Ensure 'YMDESUD5ANYO' exists in your DataFrame
if 'YMDESUD5ANYO' not in df.columns:
    raise ValueError("The column 'YMDESUD5ANYO' is missing from the dataset. Please check your CSV file.")

# List of model filenames
model_filenames = [
    "YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl",
    "YODPPROB.pkl", "YOWRPROB.pkl", "YODPR2WK.pkl", "YOWRDEPR.pkl",
    "YODPDISC.pkl", "YOLOSEV.pkl", "YOWRDCSN.pkl", "YODSMMDE.pkl",
    "YO_MDEA3.pkl", "YODPLSIN.pkl", "YOWRELES.pkl", "YOPB2WK.pkl"
]
model_path = "models/"

######################################
# 2) MODEL PREDICTOR
######################################
class ModelPredictor:
    def __init__(self, model_path, model_filenames):
        self.model_path = model_path
        self.model_filenames = model_filenames
        self.models = self.load_models()

        # Mapping each label (column) to textual meaning for 0/1
        self.prediction_map = {
            "YOWRCONC": ["Did NOT have difficulty concentrating", "Had difficulty concentrating"],
            "YOSEEDOC": ["Did NOT feel the need to see a doctor",  "Felt the need to see a doctor"],
            "YO_MDEA5": ["No restlessness/lethargy noticed",       "Others noticed restlessness/lethargy"],
            "YOWRLSIN": ["Did NOT feel bored/lose interest",       "Felt bored/lost interest"],
            "YODPPROB": ["No other problems for 2+ weeks",         "Had other problems for 2+ weeks"],
            "YOWRPROB": ["No 'worst time ever' feeling",           "Had 'worst time ever' feeling"],
            "YODPR2WK": ["No depressed feelings for 2+ wks",       "Had depressed feelings for 2+ wks"],
            "YOWRDEPR": ["Did NOT feel sad/depressed daily",       "Felt sad/depressed mostly everyday"],
            "YODPDISC": ["Overall mood not sad/depressed",         "Overall mood was sad/depressed"],
            "YOLOSEV":  ["Did NOT lose interest in things",        "Lost interest in enjoyable things"],
            "YOWRDCSN": ["Was able to make decisions",             "Was unable to make decisions"],
            "YODSMMDE": ["No 2+ wks depression symptoms",          "Had 2+ wks depression symptoms"],
            "YO_MDEA3": ["No appetite/weight changes",             "Had changes in appetite/weight"],
            "YODPLSIN": ["Never lost interest/felt bored",         "Lost interest/felt bored"],
            "YOWRELES": ["Did NOT eat less than usual",            "Ate less than usual"],
            "YOPB2WK":  ["No uneasy feelings 2+ weeks",            "Uneasy feelings 2+ weeks"]
        }

    def load_models(self):
        loaded = []
        for fname in self.model_filenames:
            try:
                with open(self.model_path + fname, "rb") as f:
                    model = pickle.load(f)
                loaded.append(model)
            except FileNotFoundError:
                raise FileNotFoundError(f"Model file '{fname}' not found in path '{self.model_path}'.")
            except Exception as e:
                raise Exception(f"Error loading model '{fname}': {e}")
        return loaded

    def make_predictions(self, user_input: pd.DataFrame):
        """
        Return:
          - A list of np.array [0/1], one for each model
          - A list of np.array [prob_of_1], if predict_proba is available, else np.nan
        """
        preds = []
        probs = []
        for model in self.models:
            y_pred = model.predict(user_input)
            preds.append(y_pred.flatten())

            if hasattr(model, "predict_proba"):
                y_prob = model.predict_proba(user_input)[:, 1]  # Probability that label=1
                probs.append(y_prob)
            else:
                probs.append(np.full(len(user_input), np.nan))
        return preds, probs

    def evaluate_severity(self, count_ones: int) -> str:
        """
        Evaluate severity based on total # of '1' predictions across all labels.
        """
        if count_ones >= 13:
            return "Mental Health Severity: Severe"
        elif count_ones >= 9:
            return "Mental Health Severity: Moderate"
        elif count_ones >= 5:
            return "Mental Health Severity: Low"
        else:
            return "Mental Health Severity: Very Low"

predictor = ModelPredictor(model_path, model_filenames)

######################################
# 3) FEATURE CATEGORIES + MAPPING
######################################
categories_dict = {
    "1. Depression & Substance Use Diagnosis": [
        "YMDESUD5ANYO", "YMDELT", "YMDEYR", "YMDERSUD5ANY",
        "YMSUD5YANY", "YMIUD5YANY", "YMIMS5YANY", "YMIMI5YANY"
    ],
    "2. Mental Health Treatment & Prof Consultation": [
        "YMDEHPO", "YMDETXRX", "YMDEHARX", "YMDEHPRX", "YRXMDEYR", 
        "YHLTMDE", "YTXMDEYR", "YDOCMDE", "YPSY2MDE", "YPSY1MDE", "YCOUNMDE"
    ],
    "3. Functional & Cognitive Impairment": [
        "MDEIMPY", "LVLDIFMEM2"
    ],
    "4. Suicidal Thoughts & Behaviors": [
        "YUSUITHK", "YUSUITHKYR", "YUSUIPLNYR", "YUSUIPLN"
    ]
}

input_mapping = {
    'YMDESUD5ANYO': {
        "SUD only, no MDE": 1, 
        "MDE only, no SUD": 2, 
        "SUD and MDE": 3, 
        "Neither SUD or MDE": 4
    },
    'YMDELT':       {"Yes": 1, "No": 2},
    'YMDEYR':       {"Yes": 1, "No": 2},
    'YMDERSUD5ANY': {"Yes": 1, "No": 0},
    'YMSUD5YANY':   {"Yes": 1, "No": 0},
    'YMIUD5YANY':   {"Yes": 1, "No": 0},
    'YMIMS5YANY':   {"Yes": 1, "No": 0},
    'YMIMI5YANY':   {"Yes": 1, "No": 0},

    'YMDEHPO':      {"Yes": 1, "No": 0},
    'YMDETXRX':     {"Yes": 1, "No": 0},
    'YMDEHARX':     {"Yes": 1, "No": 0},
    'YMDEHPRX':     {"Yes": 1, "No": 0},
    'YRXMDEYR':     {"Yes": 1, "No": 0},
    'YHLTMDE':      {"Yes": 1, "No": 0},
    'YTXMDEYR':     {"Yes": 1, "No": 0},
    'YDOCMDE':      {"Yes": 1, "No": 0},
    'YPSY2MDE':     {"Yes": 1, "No": 0},
    'YPSY1MDE':     {"Yes": 1, "No": 0},
    'YCOUNMDE':     {"Yes": 1, "No": 0},

    'MDEIMPY':      {"Yes": 1, "No": 2},
    'LVLDIFMEM2':   {
        "No Difficulty": 1, 
        "Some difficulty": 2, 
        "A lot of difficulty or cannot do at all": 3
    },

    'YUSUITHK':     {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
    'YUSUITHKYR':   {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
    'YUSUIPLNYR':   {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
    'YUSUIPLN':     {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4}
}

def validate_inputs(*args):
    for arg in args:
        if arg is None or arg == "":
            return False
    return True

######################################
# 4) NEAREST NEIGHBORS
######################################
def get_nearest_neighbors_info(user_input_df: pd.DataFrame, k=5):
    user_cols = user_input_df.columns
    if not all(col in df.columns for col in user_cols):
        return "Cannot compute nearest neighbors. Some columns not found in df."

    sub_df = df[list(user_cols)].copy()
    diffs = sub_df - user_input_df.iloc[0]
    dists = (diffs**2).sum(axis=1)**0.5
    nn_indices = dists.nsmallest(k).index
    neighbors = df.loc[nn_indices]

    lines = [
        f"**Nearest Neighbors (k={k})**",
        f"Distances range: {dists[nn_indices].min():.2f} to {dists[nn_indices].max():.2f}",
        ""
    ]

    # A) Show user input in numeric->text form
    lines.append("**User Input (numeric -> text)**")
    for col in user_cols:
        val_numeric = user_input_df.iloc[0][col]
        text_val = None
        if col in input_mapping:
            for txt_key, num_val in input_mapping[col].items():
                if val_numeric == num_val:
                    text_val = txt_key
                    break
        if not text_val:
            text_val = f"{val_numeric} (no mapping found)"
        lines.append(f"- {col} = {val_numeric} => '{text_val}'")
    lines.append("")

    # B) Show label columns among neighbors
    label_cols = list(predictor.prediction_map.keys())  
    lines.append("**Label Distribution Among Neighbors**")
    for lbl in label_cols:
        if lbl not in neighbors.columns:
            continue
        val_counts = neighbors[lbl].value_counts().to_dict()
        parts = []
        for val_, count_ in val_counts.items():
            if val_ in [0,1] and lbl in predictor.prediction_map:
                label_text = predictor.prediction_map[lbl][val_]
                parts.append(f"{count_} had '{label_text}'")
            else:
                parts.append(f"{count_} had numeric={val_}")
        lines.append(f"- {lbl}: " + "; ".join(parts))

    lines.append("")
    return "\n".join(lines)

######################################
# 5) PREDICT FUNCTION
######################################
def predict(
    # Category 1 (8):
    YMDESUD5ANYO, YMDELT, YMDEYR, YMDERSUD5ANY,
    YMSUD5YANY, YMIUD5YANY, YMIMS5YANY, YMIMI5YANY,
    # Category 2 (11):
    YMDEHPO, YMDETXRX, YMDEHARX, YMDEHPRX, YRXMDEYR,
    YHLTMDE, YTXMDEYR, YDOCMDE, YPSY2MDE, YPSY1MDE, YCOUNMDE,
    # Category 3 (2):
    MDEIMPY, LVLDIFMEM2,
    # Category 4 (4):
    YUSUITHK, YUSUITHKYR, YUSUIPLNYR, YUSUIPLN
):
    # 1) Validate
    if not validate_inputs(
        YMDESUD5ANYO, YMDELT, YMDEYR, YMDERSUD5ANY,
        YMSUD5YANY, YMIUD5YANY, YMIMS5YANY, YMIMI5YANY,
        YMDEHPO, YMDETXRX, YMDEHARX, YMDEHPRX, YRXMDEYR,
        YHLTMDE, YTXMDEYR, YDOCMDE, YPSY2MDE, YPSY1MDE, YCOUNMDE,
        MDEIMPY, LVLDIFMEM2,
        YUSUITHK, YUSUITHKYR, YUSUIPLNYR, YUSUIPLN
    ):
        return (
            "Please select all required fields.",  # 1) Prediction Results
            "Validation Error",                    # 2) Severity
            "No data",                             # 3) Total Count
            "No nearest neighbors info",           # 4) NN Summary
            None,                                  # 5) Bar chart (Input)
            None                                   # 6) Bar chart (Labels)
        )

    # 2) Convert text -> numeric 
    try:
        user_input_dict = {
            'YMDESUD5ANYO': input_mapping['YMDESUD5ANYO'][YMDESUD5ANYO],
            'YMDELT':       input_mapping['YMDELT'][YMDELT],
            'YMDEYR':       input_mapping['YMDEYR'][YMDEYR],
            'YMDERSUD5ANY': input_mapping['YMDERSUD5ANY'][YMDERSUD5ANY],
            'YMSUD5YANY':   input_mapping['YMSUD5YANY'][YMSUD5YANY],
            'YMIUD5YANY':   input_mapping['YMIUD5YANY'][YMIUD5YANY],
            'YMIMS5YANY':   input_mapping['YMIMS5YANY'][YMIMS5YANY],
            'YMIMI5YANY':   input_mapping['YMIMI5YANY'][YMIMI5YANY],

            'YMDEHPO':      input_mapping['YMDEHPO'][YMDEHPO],
            'YMDETXRX':     input_mapping['YMDETXRX'][YMDETXRX],
            'YMDEHARX':     input_mapping['YMDEHARX'][YMDEHARX],
            'YMDEHPRX':     input_mapping['YMDEHPRX'][YMDEHPRX],
            'YRXMDEYR':     input_mapping['YRXMDEYR'][YRXMDEYR],
            'YHLTMDE':      input_mapping['YHLTMDE'][YHLTMDE],
            'YTXMDEYR':     input_mapping['YTXMDEYR'][YTXMDEYR],
            'YDOCMDE':      input_mapping['YDOCMDE'][YDOCMDE],
            'YPSY2MDE':     input_mapping['YPSY2MDE'][YPSY2MDE],
            'YPSY1MDE':     input_mapping['YPSY1MDE'][YPSY1MDE],
            'YCOUNMDE':     input_mapping['YCOUNMDE'][YCOUNMDE],

            'MDEIMPY':      input_mapping['MDEIMPY'][MDEIMPY],
            'LVLDIFMEM2':   input_mapping['LVLDIFMEM2'][LVLDIFMEM2],

            'YUSUITHK':     input_mapping['YUSUITHK'][YUSUITHK],
            'YUSUITHKYR':   input_mapping['YUSUITHKYR'][YUSUITHKYR],
            'YUSUIPLNYR':   input_mapping['YUSUIPLNYR'][YUSUIPLNYR],
            'YUSUIPLN':     input_mapping['YUSUIPLN'][YUSUIPLN]
        }
    except KeyError as e:
        missing_key = e.args[0]
        return (
            f"Input mapping missing for key: {missing_key}. Please check your `input_mapping` dictionary.",
            "Mapping Error",
            "No data",
            "No nearest neighbors info",
            None,
            None
        )

    user_df = pd.DataFrame(user_input_dict, index=[0])

    # 3) Make predictions
    try:
        preds, probs = predictor.make_predictions(user_df)
    except Exception as e:
        return (
            f"Error during prediction: {e}",
            "Prediction Error",
            "No data",
            "No nearest neighbors info",
            None,
            None
        )

    # Flatten predictions for severity count
    all_preds = np.concatenate(preds)
    count_ones = np.sum(all_preds == 1)
    severity_msg = predictor.evaluate_severity(count_ones)

    # 4) Summarize predictions (with probabilities)
    # Build label -> (pred_value, prob_value)
    label_prediction_info = {}
    for i, fname in enumerate(model_filenames):
        lbl_col = fname.split('.')[0]
        pred_val = preds[i][0]
        prob_val = probs[i][0]
        label_prediction_info[lbl_col] = (pred_val, prob_val)

    # Group them by domain
    domain_groups = {
        "Concentration and Decision Making": ["YOWRCONC", "YOWRDCSN"],
        "Sleep and Energy Levels": ["YO_MDEA5", "YOWRELES"],
        "Mood and Emotional State": [
            "YOWRLSIN", "YOWRDEPR", "YODPDISC", "YOLOSEV", "YODPLSIN"
        ],
        "Appetite and Weight Changes": ["YO_MDEA3", "YOWRELES"],
        "Duration and Severity of Depression Symptoms": [
            "YODPPROB", "YOWRPROB", "YODPR2WK", "YODSMMDE", "YOPB2WK"
        ]
    }

    final_str_parts = []
    for gname, lbls in domain_groups.items():
        group_lines = []
        for lbl in lbls:
            if lbl in label_prediction_info:
                pred_val, prob_val = label_prediction_info[lbl]
                if lbl in predictor.prediction_map and pred_val in [0,1]:
                    text_pred = predictor.prediction_map[lbl][pred_val]
                else:
                    text_pred = f"Prediction={pred_val}"

                if not np.isnan(prob_val):
                    text_prob = f"(Prob= {prob_val:.2f})"
                else:
                    text_prob = "(No probability available)"

                group_lines.append(f"{lbl} => {text_pred} {text_prob}")
        if group_lines:
            final_str_parts.append(f"**{gname}**")
            final_str_parts.append("\n".join(group_lines))
            final_str_parts.append("")  # Add an empty line for spacing

    if final_str_parts:
        final_str = "\n".join(final_str_parts)
    else:
        final_str = "No predictions made or no matching group columns."

    # 5) Additional info
    total_count_md = f"We have **{len(df)}** patients in the dataset."

    # 6) Nearest Neighbors
    nn_md = get_nearest_neighbors_info(user_df, k=5)

    # 7) Bar chart for input features
    input_counts = {}
    for col, val_ in user_input_dict.items():
        matched = len(df[df[col] == val_])
        input_counts[col] = matched
    bar_in_df = pd.DataFrame({
        "Feature": list(input_counts.keys()),
        "Count": list(input_counts.values())
    })
    fig_in = px.bar(
        bar_in_df, x="Feature", y="Count",
        title="Number of Patients with the Same Input Feature Values"
    )
    fig_in.update_layout(width=1200, height=400)

    # 8) Bar chart for predicted labels (UPDATED)
    label_df_list = []
    for lbl_col, (pred_val, _) in label_prediction_info.items():
        if lbl_col in df.columns:
            # Count how many patients in df have the predicted value
            predicted_count = len(df[df[lbl_col] == pred_val])

            # Determine the "other" class (0 ↔ 1)
            other_val = 1 - pred_val
            other_count = len(df[df[lbl_col] == other_val])

            label_df_list.append({
                "Label": lbl_col,
                "Class": f"Predicted_{pred_val}",
                "Count": predicted_count
            })
            label_df_list.append({
                "Label": lbl_col,
                "Class": f"Opposite_{other_val}",
                "Count": other_count
            })

    if label_df_list:
        bar_lbl_df = pd.DataFrame(label_df_list)
        fig_lbl = px.bar(
            bar_lbl_df,
            x="Label",
            y="Count",
            color="Class",
            barmode="group",
            title="Number of Patients with the Predicted vs. Opposite Label"
        )
        fig_lbl.update_layout(width=1200, height=400)
    else:
        fig_lbl = px.bar(title="No valid predicted labels to display.")
        fig_lbl.update_layout(width=1200, height=400)

    return (
        final_str,         # 1) Prediction Results
        severity_msg,      # 2) Mental Health Severity
        total_count_md,    # 3) Total Patient Count
        nn_md,             # 4) Nearest Neighbors Summary
        fig_in,            # 5) Bar Chart (input features)
        fig_lbl            # 6) Bar Chart (labels)
    )

######################################
# 6) UNIFIED DISTRIBUTION/CO-OCCURRENCE
######################################
def combined_plot(feature_list, label_col):
    """
    If user picks 1 feature => distribution plot.
    If user picks 2 features => co-occurrence plot.
    Otherwise => show error or empty plot.
    This function also maps numeric codes to text using 'input_mapping'
    and 'predictor.prediction_map' so that the plots display more readable labels.
    """
    if not label_col:
        return px.bar(title="Please select a label column.")

    # Make a copy of your dataset
    df_copy = df.copy()

    # A) Convert numeric codes -> text for each feature in `input_mapping`
    for col, text_to_num_dict in input_mapping.items():
        if col in df_copy.columns:
            # Reverse mapping: "Yes"->1 becomes 1->"Yes"
            num_to_text = {v: k for k, v in text_to_num_dict.items()}
            df_copy[col] = df_copy[col].map(num_to_text).fillna(df_copy[col])

    # B) Convert label 0/1 to text in df_copy if label_col is in predictor.prediction_map
    if label_col in predictor.prediction_map and label_col in df_copy.columns:
        zero_text, one_text = predictor.prediction_map[label_col]
        label_map = {0: zero_text, 1: one_text}
        df_copy[label_col] = df_copy[label_col].map(label_map).fillna(df_copy[label_col])

    # Now proceed with the plotting
    if len(feature_list) == 1:
        f_ = feature_list[0]
        if f_ not in df_copy.columns or label_col not in df_copy.columns:
            return px.bar(title="Selected columns not found in the dataset.")
        grouped = df_copy.groupby([f_, label_col]).size().reset_index(name="count")
        fig = px.bar(
            grouped,
            x=f_,
            y="count",
            color=label_col,
            title=f"Distribution of {f_} vs {label_col} (Text Mapped)"
        )
        fig.update_layout(width=1200, height=600)
        return fig

    elif len(feature_list) == 2:
        f1, f2 = feature_list
        if (f1 not in df_copy.columns) or (f2 not in df_copy.columns) or (label_col not in df_copy.columns):
            return px.bar(title="Selected columns not found in the dataset.")
        grouped = df_copy.groupby([f1, f2, label_col]).size().reset_index(name="count")
        fig = px.bar(
            grouped,
            x=f1,
            y="count",
            color=label_col,
            facet_col=f2,
            title=f"Co-occurrence: {f1}, {f2} vs {label_col} (Text Mapped)"
        )
        fig.update_layout(width=1200, height=600)
        return fig

    else:
        return px.bar(title="Please select exactly 1 or 2 features.")

######################################
# 7) BUILD GRADIO UI
######################################
with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:

    # ======== TAB 1: Prediction ========
    with gr.Tab("Prediction"):
        gr.Markdown("### Please provide inputs in each of the four categories below. All fields are required.")

        # Category 1: Depression & Substance Use Diagnosis (8 features)
        gr.Markdown("#### 1. Depression & Substance Use Diagnosis")
        cat1_col_labels = [
            ("YMDESUD5ANYO", "YMDESUD5ANYO: ONLY MDE, ONLY SUD, BOTH, OR NEITHER"),
            ("YMDELT",       "YMDELT: Had major depressive episode in lifetime"),
            ("YMDEYR",       "YMDEYR: Past-year major depressive episode"),
            ("YMDERSUD5ANY", "YMDERSUD5ANY: MDE or SUD in past year?"),
            ("YMSUD5YANY",   "YMSUD5YANY: Past-year MDE & substance use disorder"),
            ("YMIUD5YANY",   "YMIUD5YANY: Past-year MDE & illicit drug use disorder"),
            ("YMIMS5YANY",   "YMIMS5YANY: Past-year MDE + severe impairment + substance use"),
            ("YMIMI5YANY",   "YMIMI5YANY: Past-year MDE w/ severe impairment & illicit drug use")
        ]
        cat1_inputs = []
        for col, label_text in cat1_col_labels:
            cat1_inputs.append(
                gr.Dropdown(
                    choices=list(input_mapping[col].keys()),
                    label=label_text
                )
            )

        # Category 2: Mental Health Treatment & Professional Consultation (11 features)
        gr.Markdown("#### 2. Mental Health Treatment & Professional Consultation")
        cat2_col_labels = [
            ("YMDEHPO",   "YMDEHPO: Saw health prof only for MDE"),
            ("YMDETXRX",  "YMDETXRX: Received treatment/counseling if saw doc/prof for MDE"),
            ("YMDEHARX",  "YMDEHARX: Saw health prof & medication for MDE"),
            ("YMDEHPRX",  "YMDEHPRX: Saw health prof or med for MDE in past year?"),
            ("YRXMDEYR",  "YRXMDEYR: Used medication for MDE in past years"),
            ("YHLTMDE",   "YHLTMDE: Saw/talked to health prof about MDE"),
            ("YTXMDEYR",  "YTXMDEYR: Saw/talked to doc/prof for MDE in past year"),
            ("YDOCMDE",   "YDOCMDE: Saw/talked to general practitioner/family MD"),
            ("YPSY2MDE",  "YPSY2MDE: Saw/talked to psychiatrist"),
            ("YPSY1MDE",  "YPSY1MDE: Saw/talked to psychologist"),
            ("YCOUNMDE",  "YCOUNMDE: Saw/talked to counselor")
        ]
        cat2_inputs = []
        for col, label_text in cat2_col_labels:
            cat2_inputs.append(
                gr.Dropdown(
                    choices=list(input_mapping[col].keys()),
                    label=label_text
                )
            )

        # Category 3: Functional & Cognitive Impairment (2 features)
        gr.Markdown("#### 3. Functional & Cognitive Impairment")
        cat3_col_labels = [
            ("MDEIMPY",    "MDEIMPY: MDE with severe role impairment?"),
            ("LVLDIFMEM2", "LVLDIFMEM2: Difficulty remembering/concentrating")
        ]
        cat3_inputs = []
        for col, label_text in cat3_col_labels:
            cat3_inputs.append(
                gr.Dropdown(
                    choices=list(input_mapping[col].keys()),
                    label=label_text
                )
            )

        # Category 4: Suicidal Thoughts & Behaviors (4 features)
        gr.Markdown("#### 4. Suicidal Thoughts & Behaviors")
        cat4_col_labels = [
            ("YUSUITHK",   "YUSUITHK: Thought of killing self (past 12 months)?"),
            ("YUSUITHKYR", "YUSUITHKYR: Seriously thought about killing self?"),
            ("YUSUIPLNYR", "YUSUIPLNYR: Made plans to kill self in past year?"),
            ("YUSUIPLN",   "YUSUIPLN: Made plans to kill yourself in past 12 months?")
        ]
        cat4_inputs = []
        for col, label_text in cat4_col_labels:
            cat4_inputs.append(
                gr.Dropdown(
                    choices=list(input_mapping[col].keys()),
                    label=label_text
                )
            )

        # Combine all inputs in the correct order
        all_inputs = cat1_inputs + cat2_inputs + cat3_inputs + cat4_inputs

        # Output components
        predict_btn = gr.Button("Predict")

        out_pred_res = gr.Textbox(label="Prediction Results (with Probability)", lines=8)
        out_sev      = gr.Textbox(label="Mental Health Severity", lines=2)
        out_count    = gr.Markdown(label="Total Patient Count")
        out_nn       = gr.Markdown(label="Nearest Neighbors Summary")
        out_bar_input= gr.Plot(label="Input Feature Counts")
        out_bar_label= gr.Plot(label="Predicted Label Counts")

        # Connect the predict button to the predict function
        predict_btn.click(
            fn=predict,
            inputs=all_inputs,
            outputs=[
                out_pred_res,
                out_sev,
                out_count,
                out_nn,
                out_bar_input,
                out_bar_label
            ]
        )

    # ======== TAB 2: Unified Distribution/Co-occurrence ========
    with gr.Tab("Distribution/Co-occurrence"):
        gr.Markdown("### Select 1 or 2 features + 1 label to see a bar chart.")

        # Show only your 25 input features
        list_of_features = sorted(input_mapping.keys())
        # Show all label columns from the predictor map
        list_of_labels = sorted(predictor.prediction_map.keys())

        selected_features = gr.CheckboxGroup(
            choices=list_of_features,
            label="Select 1 or 2 features"
        )
        label_dd = gr.Dropdown(
            choices=list_of_labels,
            label="Label Column (e.g., YOWRCONC, YOSEEDOC, etc.)"
        )

        generate_combined_btn = gr.Button("Generate Plot")
        combined_output = gr.Plot()

        generate_combined_btn.click(
            fn=combined_plot,
            inputs=[selected_features, label_dd],
            outputs=combined_output
        )

# Finally, launch the Gradio app
demo.launch()