File size: 33,922 Bytes
25a63b5
d1b265f
f9a8132
 
f9f7e1c
 
3b96ce2
16ca108
3b96ce2
ebac442
ef45228
f9a8132
ebac442
 
 
 
 
3b96ce2
 
 
 
 
 
 
 
e84fe7d
 
 
f9a8132
 
 
 
 
c458985
a5bb04e
 
 
230ba50
2f5f1fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42d91b3
f9a8132
 
16ca108
 
ebac442
 
 
 
 
 
 
 
16ca108
 
 
975c60a
6b501f6
a5bb04e
 
 
 
975c60a
2e504c0
 
3b96ce2
a5bb04e
2e504c0
 
a5bb04e
2e504c0
a5bb04e
 
2e504c0
 
 
1fd21ae
96f0310
975c60a
96f0310
 
975c60a
c99193b
33c5960
c99193b
33c5960
c99193b
33c5960
f9a8132
c99193b
 
96f0310
f9a8132
16ca108
6749d1f
cf4c3a5
6b501f6
c458985
e9e83fc
 
ebac442
e9e83fc
 
 
975c60a
 
e9e83fc
 
 
 
 
 
 
 
c458985
54cc11a
e9e83fc
ebac442
 
 
 
 
 
6b501f6
 
a4ec388
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b501f6
 
ebac442
 
 
 
 
6b501f6
 
 
 
 
e9e83fc
c458985
c38370d
6c60301
ebac442
6c60301
 
437bcb0
e9e83fc
2e504c0
e9e83fc
 
 
 
 
 
a5bb04e
e9e83fc
 
 
 
 
2e504c0
 
 
 
 
 
54cc11a
2e504c0
a5bb04e
2e504c0
54cc11a
2e504c0
 
 
 
 
 
a5bb04e
 
2e504c0
a5bb04e
2e504c0
 
 
6b501f6
2e504c0
e9e83fc
6749d1f
cf4c3a5
e9e83fc
cf4c3a5
f9f7e1c
6b501f6
ebac442
e9e83fc
6b501f6
975c60a
 
6b501f6
e9e83fc
6b501f6
e9e83fc
f9f7e1c
2e504c0
6749d1f
ebac442
e9e83fc
975c60a
 
e9e83fc
 
6749d1f
 
ebac442
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e504c0
 
 
 
6749d1f
 
16ca108
f9f7e1c
2e504c0
ebac442
 
 
 
 
 
 
 
 
 
 
2e504c0
a5bb04e
6b501f6
ef45228
96f0310
 
 
d0aa35f
 
89e8b14
 
 
 
 
 
 
 
d0aa35f
89e8b14
d0aa35f
89e8b14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ef45228
6b501f6
 
 
 
a5bb04e
54cc11a
 
 
 
6b501f6
 
 
ebac442
e29d4c6
ebac442
87dd6c1
e84fe7d
ebac442
 
e84fe7d
 
46e9809
975c60a
16ca108
6b501f6
 
2e504c0
54cc11a
 
 
 
 
 
 
ef45228
54cc11a
6db7f5e
ef45228
54cc11a
 
6db7f5e
ef45228
54cc11a
 
 
 
 
 
 
 
 
 
 
 
 
2e504c0
6b501f6
ebac442
6b501f6
a5bb04e
6b501f6
2e504c0
16ca108
2e504c0
 
16ca108
2e504c0
07f9caa
 
 
2e504c0
c458985
16ca108
2e504c0
16ca108
 
e84fe7d
16ca108
ebac442
 
 
 
c458985
 
 
 
d79221a
16ca108
0af7f5a
 
2e504c0
 
0af7f5a
 
 
 
ef45228
 
0af7f5a
d0aa35f
 
 
 
 
 
 
e84fe7d
ef45228
 
 
d0aa35f
 
 
 
 
0af7f5a
e84fe7d
685722d
cf4c3a5
16ca108
d79221a
16ca108
0b36f6e
 
54cc11a
0b36f6e
 
 
 
 
f9f7e1c
cf4c3a5
2e504c0
cf4c3a5
2e504c0
 
6b501f6
2e504c0
6b501f6
f92effe
 
2e504c0
 
6b501f6
2e504c0
f92effe
 
a5bb04e
f92effe
 
 
 
 
a5bb04e
f92effe
a5bb04e
 
f92effe
2e504c0
 
f92effe
a5bb04e
f92effe
ebac442
f92effe
 
 
ebac442
a5bb04e
ebac442
2e504c0
 
 
 
6b501f6
f92effe
a5bb04e
f92effe
2e504c0
f92effe
 
 
ebac442
f92effe
a5bb04e
2e504c0
 
 
cf4c3a5
2e504c0
 
d1b265f
cf4c3a5
c458985
cf4c3a5
e84fe7d
ef45228
 
 
 
 
e84fe7d
ebac442
3b96ce2
a5bb04e
e9e83fc
ebac442
975c60a
e9e83fc
ebac442
e9e83fc
 
975c60a
e9e83fc
 
685722d
 
e9e83fc
 
 
2e504c0
ebac442
 
 
 
e9e83fc
 
ebac442
e9e83fc
975c60a
2e504c0
975c60a
6b501f6
 
975c60a
6b501f6
2e504c0
975c60a
2e504c0
 
 
975c60a
e9e83fc
 
2e504c0
ebac442
 
 
 
e9e83fc
 
ebac442
e9e83fc
975c60a
 
 
 
e9e83fc
 
2e504c0
ebac442
 
 
 
e9e83fc
 
ebac442
e9e83fc
975c60a
2e504c0
975c60a
b56bd78
975c60a
 
e9e83fc
 
2e504c0
ebac442
 
 
 
e9e83fc
 
a5bb04e
e9e83fc
 
a5bb04e
6749d1f
6b501f6
a5bb04e
33c5960
a5bb04e
 
 
 
3b96ce2
a5bb04e
3b96ce2
 
e9e83fc
3b96ce2
6b501f6
 
 
 
 
 
3b96ce2
33c8e0b
1fd21ae
2e504c0
 
6b501f6
f92effe
726e8be
a5bb04e
f92effe
2e504c0
 
 
c458985
6b501f6
 
a5bb04e
6b501f6
f92effe
2e504c0
 
f92effe
2e504c0
 
 
 
3b96ce2
67c356a
46cf3df
 
 
9e09423
 
 
 
91636d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e09423
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
088513d
 
9e09423
46cf3df
 
 
 
72807ec
 
 
46cf3df
ecff5d3
4f4c821
ecff5d3
a5bb04e
87dd6c1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
import pickle
import gradio as gr
import numpy as np
import pandas as pd
import plotly.express as px

######################################
# 1) LOAD DATA & MODELS
######################################
# Load your dataset
df = pd.read_csv("X_train_test_combined_dataset_Filtered_dataset.csv")

# Ensure 'YMDESUD5ANYO' exists in your DataFrame
if 'YMDESUD5ANYO' not in df.columns:
    raise ValueError("The column 'YMDESUD5ANYO' is missing from the dataset. Please check your CSV file.")

# List of model filenames
model_filenames = [
    "YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl",
    "YODPPROB.pkl", "YOWRPROB.pkl", "YODPR2WK.pkl", "YOWRDEPR.pkl",
    "YODPDISC.pkl", "YOLOSEV.pkl", "YOWRDCSN.pkl", "YODSMMDE.pkl",
    "YO_MDEA3.pkl", "YODPLSIN.pkl", "YOWRELES.pkl", "YOPB2WK.pkl"
]
model_path = "models/"

######################################
# 2) MODEL PREDICTOR
######################################
class ModelPredictor:
    def __init__(self, model_path, model_filenames):
        self.model_path = model_path
        self.model_filenames = model_filenames
        self.models = self.load_models()

        # Mapping each label (column) so that 
        # - "1" = first item in list
        # - "2" = second item
        self.prediction_map = {
            "YOWRCONC": {2: "Did NOT have difficulty concentrating", 1: "Had difficulty concentrating"},
            "YOSEEDOC": {2: "Did NOT feel the need to see a doctor", 1: "Felt the need to see a doctor"},
            "YO_MDEA5": {2: "No restlessness/lethargy noticed",     1: "Others noticed restlessness/lethargy"},
            "YOWRLSIN": {2: "Did NOT feel bored/lose interest",     1: "Felt bored/lost interest"},
            "YODPPROB": {2: "No other problems for 2+ weeks",       1: "Had other problems for 2+ weeks"},
            "YOWRPROB": {2: "No 'worst time ever' feeling",         1: "Had 'worst time ever' feeling"},
            "YODPR2WK": {2: "No depressed feelings for 2+ wks",     1: "Had depressed feelings for 2+ wks"},
            "YOWRDEPR": {2: "Did NOT feel sad/depressed daily",     1: "Felt sad/depressed mostly everyday"},
            "YODPDISC": {2: "Overall mood not sad/depressed",       1: "Overall mood was sad/depressed"},
            "YOLOSEV":  {2: "Did NOT lose interest in things",      1: "Lost interest in enjoyable things"},
            "YOWRDCSN": {2: "Was able to make decisions",           1: "Was unable to make decisions"},
            "YODSMMDE": {2: "No 2+ wks depression symptoms",        1: "Had 2+ wks depression symptoms"},
            "YO_MDEA3": {2: "No appetite/weight changes",           1: "Had changes in appetite/weight"},
            "YODPLSIN": {2: "Never lost interest/felt bored",       1: "Lost interest/felt bored"},
            "YOWRELES": {2: "Did NOT eat less than usual",          1: "Ate less than usual"},
            "YOPB2WK":  {2: "No uneasy feelings 2+ weeks",          1: "Uneasy feelings 2+ weeks"}
        }

    def load_models(self):
        loaded = []
        for fname in self.model_filenames:
            try:
                with open(self.model_path + fname, "rb") as f:
                    model = pickle.load(f)
                loaded.append(model)
            except FileNotFoundError:
                raise FileNotFoundError(f"Model file '{fname}' not found in path '{self.model_path}'.")
            except Exception as e:
                raise Exception(f"Error loading model '{fname}': {e}")
        return loaded

    def make_predictions(self, user_input: pd.DataFrame):
        """
        Return:
          - A list of np.array [1/2], one for each model
          - A list of np.array [prob_of_2], if predict_proba is available, else np.nan
        IMPORTANT: This code assumes your model returns [1, 2].
        If your model is returning [0, 1], you'll need a transform or re-train it to return [1, 2].
        """
        preds = []
        probs = []
        for model in self.models:
            y_pred = model.predict(user_input)  # Suppose this returns [1 or 2].
            preds.append(y_pred.flatten())

            # If model can do predict_proba, we interpret the "2" class as the second column
            if hasattr(model, "predict_proba"):
                y_prob_2 = model.predict_proba(user_input)[:, 1]  
                probs.append(y_prob_2)
            else:
                probs.append(np.full(len(user_input), np.nan))
        return preds, probs

    def evaluate_severity(self, count_ones: int) -> str:
        """
        Evaluate severity based on how many labels predicted = 1.
        The bigger the number of 1’s, the more severe the condition.
        """
        if 0 <= count_ones <= 5:
            return "Low"
        elif 6 <= count_ones <= 10:
            return "Moderate"
        elif 11 <= count_ones <= 16:
            return "Severe"
        else:
            return "Cannot tell the status"

    

predictor = ModelPredictor(model_path, model_filenames)

######################################
# 3) FEATURE CATEGORIES + MAPPING
######################################
categories_dict = {
    "1. Depression & Substance Use Diagnosis": [
        "YMDESUD5ANYO", "YMDELT", "YMDEYR", "YMDERSUD5ANY",
        "YMSUD5YANY", "YMIUD5YANY", "YMIMS5YANY", "YMIMI5YANY"
    ],
    "2. Mental Health Treatment & Prof Consultation": [
        "YMDEHPO", "YMDETXRX", "YMDEHARX", "YMDEHPRX", "YRXMDEYR", 
        "YHLTMDE", "YTXMDEYR", "YDOCMDE", "YPSY2MDE", "YPSY1MDE", "YCOUNMDE"
    ],
    "3. Functional & Cognitive Impairment": [
        "MDEIMPY", "LVLDIFMEM2"
    ],
    "4. Suicidal Thoughts & Behaviors": [
        "YUSUITHK", "YUSUITHKYR", "YUSUIPLNYR", "YUSUIPLN"
    ]
}

# NOTE: input_mapping below for capturing user choices => numeric codes. 
input_mapping = {
    'YMDESUD5ANYO': {
        "SUD only, no MDE": 1, 
        "MDE only, no SUD": 2, 
        "SUD and MDE": 3, 
        "Neither SUD or MDE": 4
    },
    'YMDELT':       {"Yes": 1, "No": 2},
    'YMDEYR':       {"Yes": 1, "No": 2},
    'YMDERSUD5ANY': {"Yes": 1, "No": 0},
    'YMSUD5YANY':   {"Yes": 1, "No": 0},
    'YMIUD5YANY':   {"Yes": 1, "No": 0},
    'YMIMS5YANY':   {"Yes": 1, "No": 0},
    'YMIMI5YANY':   {"Yes": 1, "No": 0},

    'YMDEHPO':      {"Yes": 1, "No": 0},
    'YMDETXRX':     {"Yes": 1, "No": 0},
    'YMDEHARX':     {"Yes": 1, "No": 0},
    'YMDEHPRX':     {"Yes": 1, "No": 0},
    'YRXMDEYR':     {"Yes": 1, "No": 0},
    'YHLTMDE':      {"Yes": 1, "No": 0},
    'YTXMDEYR':     {"Yes": 1, "No": 0},
    'YDOCMDE':      {"Yes": 1, "No": 0},
    'YPSY2MDE':     {"Yes": 1, "No": 0},
    'YPSY1MDE':     {"Yes": 1, "No": 0},
    'YCOUNMDE':     {"Yes": 1, "No": 0},

    'MDEIMPY':      {"Yes": 1, "No": 2},
    'LVLDIFMEM2':   {
        "No Difficulty": 1, 
        "Some difficulty": 2, 
        "A lot of difficulty or cannot do at all": 3
    },

    'YUSUITHK':     {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
    'YUSUITHKYR':   {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
    'YUSUIPLNYR':   {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
    'YUSUIPLN':     {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4}
}

def validate_inputs(*args):
    for arg in args:
        if arg is None or arg == "":
            return False
    return True

######################################
# 4) NEAREST NEIGHBORS
######################################
def get_nearest_neighbors_info(user_input_df: pd.DataFrame, k=5):
    user_cols = user_input_df.columns
    if not all(col in df.columns for col in user_cols):
        return "Cannot compute nearest neighbors. Some columns not found in df."

    sub_df = df[user_cols].copy()
    diffs = sub_df - user_input_df.iloc[0]
    dists = (diffs**2).sum(axis=1)**0.5
    nn_indices = dists.nsmallest(k).index
    neighbors = df.loc[nn_indices]

    lines = [
        f"**Nearest Neighbors (k={k})**",
        f"Distances range: {dists[nn_indices].min():.2f} to {dists[nn_indices].max():.2f}",
        ""
    ]

    # (Removed user-input numeric->text section per request.)

    # Show label columns among neighbors
    lines.append("**Label Distribution Among Neighbors**")
    label_cols = list(predictor.prediction_map.keys())  
    for lbl in label_cols:
        if lbl not in neighbors.columns:
            continue
        val_counts = neighbors[lbl].value_counts().to_dict()
        parts = []
        for val_, count_ in val_counts.items():
            # If we only mapped [1,2], we check if val_ in [1,2]
            if lbl in predictor.prediction_map and val_ in [1,2]:
                label_text = predictor.prediction_map[lbl][val_]
                parts.append(f"{count_} had '{label_text}' (value={val_})")
            else:
                parts.append(f"{count_} had numeric={val_}")
        lines.append(f"- {lbl}: " + "; ".join(parts))

    lines.append("")
    return "\n".join(lines)

######################################
# 5) PREDICT FUNCTION
######################################
def predict(
    # Category 1 (8):
    YMDESUD5ANYO, YMDELT, YMDEYR, YMDERSUD5ANY,
    YMSUD5YANY, YMIUD5YANY, YMIMS5YANY, YMIMI5YANY,
    # Category 2 (11):
    YMDEHPO, YMDETXRX, YMDEHARX, YMDEHPRX, YRXMDEYR,
    YHLTMDE, YTXMDEYR, YDOCMDE, YPSY2MDE, YPSY1MDE, YCOUNMDE,
    # Category 3 (2):
    MDEIMPY, LVLDIFMEM2,
    # Category 4 (4):
    YUSUITHK, YUSUITHKYR, YUSUIPLNYR, YUSUIPLN
):
    # 1) Validate
    if not validate_inputs(
        YMDESUD5ANYO, YMDELT, YMDEYR, YMDERSUD5ANY,
        YMSUD5YANY, YMIUD5YANY, YMIMS5YANY, YMIMI5YANY,
        YMDEHPO, YMDETXRX, YMDEHARX, YMDEHPRX, YRXMDEYR,
        YHLTMDE, YTXMDEYR, YDOCMDE, YPSY2MDE, YPSY1MDE, YCOUNMDE,
        MDEIMPY, LVLDIFMEM2,
        YUSUITHK, YUSUITHKYR, YUSUIPLNYR, YUSUIPLN
    ):
        return (
            "Please select all required fields.",  # 1) Prediction Results
            "Validation Error",                    # 2) Severity
            "No data",                             # 3) Total Count
            "No nearest neighbors info",           # 4) NN Summary
            None,                                  # 5) Bar chart (Input)
            None                                   # 6) Bar chart (Labels)
        )

    # 2) Convert text -> numeric 
    try:
        user_input_dict = {
            'YMDESUD5ANYO': input_mapping['YMDESUD5ANYO'][YMDESUD5ANYO],
            'YMDELT':       input_mapping['YMDELT'][YMDELT],
            'YMDEYR':       input_mapping['YMDEYR'][YMDEYR],
            'YMDERSUD5ANY': input_mapping['YMDERSUD5ANY'][YMDERSUD5ANY],
            'YMSUD5YANY':   input_mapping['YMSUD5YANY'][YMSUD5YANY],
            'YMIUD5YANY':   input_mapping['YMIUD5YANY'][YMIUD5YANY],
            'YMIMS5YANY':   input_mapping['YMIMS5YANY'][YMIMS5YANY],
            'YMIMI5YANY':   input_mapping['YMIMI5YANY'][YMIMI5YANY],

            'YMDEHPO':      input_mapping['YMDEHPO'][YMDEHPO],
            'YMDETXRX':     input_mapping['YMDETXRX'][YMDETXRX],
            'YMDEHARX':     input_mapping['YMDEHARX'][YMDEHARX],
            'YMDEHPRX':     input_mapping['YMDEHPRX'][YMDEHPRX],
            'YRXMDEYR':     input_mapping['YRXMDEYR'][YRXMDEYR],
            'YHLTMDE':      input_mapping['YHLTMDE'][YHLTMDE],
            'YTXMDEYR':     input_mapping['YTXMDEYR'][YTXMDEYR],
            'YDOCMDE':      input_mapping['YDOCMDE'][YDOCMDE],
            'YPSY2MDE':     input_mapping['YPSY2MDE'][YPSY2MDE],
            'YPSY1MDE':     input_mapping['YPSY1MDE'][YPSY1MDE],
            'YCOUNMDE':     input_mapping['YCOUNMDE'][YCOUNMDE],

            'MDEIMPY':      input_mapping['MDEIMPY'][MDEIMPY],
            'LVLDIFMEM2':   input_mapping['LVLDIFMEM2'][LVLDIFMEM2],

            'YUSUITHK':     input_mapping['YUSUITHK'][YUSUITHK],
            'YUSUITHKYR':   input_mapping['YUSUITHKYR'][YUSUITHKYR],
            'YUSUIPLNYR':   input_mapping['YUSUIPLNYR'][YUSUIPLNYR],
            'YUSUIPLN':     input_mapping['YUSUIPLN'][YUSUIPLN]
        }
    except KeyError as e:
        missing_key = e.args[0]
        return (
            f"Input mapping missing for key: {missing_key}. Please check your `input_mapping` dictionary.",
            "Mapping Error",
            "No data",
            "No nearest neighbors info",
            None,
            None
        )

    user_df = pd.DataFrame(user_input_dict, index=[0])

    # 3) Make predictions
    try:
        preds, probs = predictor.make_predictions(user_df)
    except Exception as e:
        return (
            f"Error during prediction: {e}",
            "Prediction Error",
            "No data",
            "No nearest neighbors info",
            None,
            None
        )

    # Flatten predictions into a single array
    all_preds = np.concatenate(preds)
    # Count how many are "1"
    count_ones = np.sum(all_preds == 1)
    
    # Evaluate severity using count_ones
    severity_base = predictor.evaluate_severity(count_ones)

  #  # -------------------------------
  #  # Sum of predicted probabilities
  #  # -------------------------------
  #  # 'probs' is a list of arrays; each array is the prob for class=2 from each model.
  #  sum_prob_2 = sum(prob[0] for prob in probs if not np.isnan(prob[0]))
  #  sum_prob_1 = sum((1 - prob[0]) for prob in probs if not np.isnan(prob[0]))
  #  severity_msg = f"{severity_base} (Sum of Prob (Bad Mental Status)={sum_prob_1:.2f}, Prob (Ok Mental Status)={sum_prob_2:.2f})"

    # -------------------------------
    # Sum, average, and standard deviation of predicted probabilities
    # -------------------------------
    
    # Filter probabilities and exclude NaN values
    filtered_probs_2 = [prob[0] for prob in probs if not np.isnan(prob[0])]
    filtered_probs_1 = [1 - prob[0] for prob in probs if not np.isnan(prob[0])]
    
    sum_prob_2 = sum(filtered_probs_2)
    sum_prob_1 = sum(filtered_probs_1)
    
    avg_prob_2 = np.mean(filtered_probs_2)
    avg_prob_1 = np.mean(filtered_probs_1)
    
    std_dev_prob_2 = np.std(filtered_probs_2)
    std_dev_prob_1 = np.std(filtered_probs_1)

    severity_msg = (
        f"{severity_base} "
        f"(Avg Prob (Bad Mental Status)={avg_prob_1:.2f} ± {std_dev_prob_1:.2f}, "
        f"Avg Prob (Ok Mental Status)={avg_prob_2:.2f} ± {std_dev_prob_2:.2f})"
    )

    # 4) Summarize predictions (with probabilities)
    label_prediction_info = {}
    for i, fname in enumerate(model_filenames):
        lbl_col = fname.split('.')[0]
        pred_val = preds[i][0]     # e.g. 1 or 2
        prob_val_for_2 = probs[i][0]  # probability for class=2
        # Probability for class=1 => (1 - prob_val_for_2)
        prob_of_pred_class = prob_val_for_2 if (pred_val == 2) else (1 - prob_val_for_2)
        label_prediction_info[lbl_col] = (pred_val, prob_val_for_2)

    # Group them by domain
    domain_groups = {
        "Concentration and Decision Making": ["YOWRCONC", "YOWRDCSN"],
        "Sleep and Energy Levels": ["YO_MDEA5", "YOSEEDOC"],
        "Mood and Emotional State": [
            "YOWRLSIN", "YOWRDEPR", "YODPDISC", "YOLOSEV", "YODPLSIN"
        ],
        "Appetite and Weight Changes": ["YO_MDEA3", "YOWRELES"],
        "Duration and Severity of Depression Symptoms": [
            "YODPPROB", "YOWRPROB", "YODPR2WK", "YODSMMDE", "YOPB2WK"
        ]
    }

    final_str_parts = []
    for gname, lbls in domain_groups.items():
        group_lines = []
        for lbl in lbls:
            if lbl not in label_prediction_info:
                continue
            pred_val, prob_val_for_2 = label_prediction_info[lbl]
            # Probability for the predicted class
            if np.isnan(prob_val_for_2):
                text_prob = "(No probability available)"
            else:
                if pred_val == 2:
                    # Probability of class=2
                    text_prob = f"(Prob= {prob_val_for_2:.2f} for predicted class = Ok Mental Status)"
                else:
                    # Probability of class=1
                    prob_of_1 = 1 - prob_val_for_2
                    text_prob = f"(Prob= {prob_of_1:.2f} for predicted class = Bad Mental Status)"

            # If pred_val is 1 or 2, we have a mapping
            if lbl in predictor.prediction_map and pred_val in [1, 2]:
                text_pred = predictor.prediction_map[lbl][pred_val]
            else:
                text_pred = f"Prediction={pred_val}"

            # Add an emoji indicator
            if pred_val == 2:
                icon = "✅"  # green check
            else:
                icon = "❌"  # red cross

            group_lines.append(f"{lbl} => {icon} {text_pred} {text_prob}")

        if group_lines:
            final_str_parts.append(f"**{gname}**")
            final_str_parts.append("\n".join(group_lines))
            final_str_parts.append("")  # spacing

    if final_str_parts:
        final_str = "\n".join(final_str_parts)
    else:
        final_str = "No predictions made or no matching group columns."

    # 5) Additional info
    ##########total_count_md = f"We have **{len(df)}** patients in the dataset."########
    total_count_md = f""
    
    # 6) Nearest Neighbors
    nn_md = get_nearest_neighbors_info(user_df, k=5)

    # 7) Bar chart for input features
    input_counts = {}
    for col, val_ in user_input_dict.items():
        matched = len(df[df[col] == val_])
        input_counts[col] = matched
    bar_in_df = pd.DataFrame({
        "Feature": list(input_counts.keys()),
        "Count": list(input_counts.values())
    })
    fig_in = px.bar(
        bar_in_df, x="Feature", y="Count",
        title="Number of Patients with the Same Input Feature Values"
    )
    fig_in.update_layout(width=1200, height=400)

    # 8) Bar chart for predicted labels
    label_counts = {}
    for lbl_col, (pred_val, _) in label_prediction_info.items():
        if lbl_col in df.columns:
            label_counts[lbl_col] = len(df[df[lbl_col] == pred_val])
    if label_counts:
        bar_lbl_df = pd.DataFrame({
            "Label": list(label_counts.keys()),
            "Count": list(label_counts.values()),
            "Pred_Val": [label_prediction_info[lbl_col][0] for lbl_col in label_counts.keys()]
        })
        # Assign legend text & color based on predicted value
        #  - 2 => "Ok Mental Status" (green)
        #  - 1 => "Bad Mental Status" (red)
        bar_lbl_df["Mental Status"] = bar_lbl_df["Pred_Val"].apply(
            lambda x: "Ok Mental Status" if x == 2 else "Bad Mental Status"
        )

        fig_lbl = px.bar(
            bar_lbl_df, 
            x="Label", 
            y="Count", 
            color="Mental Status",
            color_discrete_map={
                "Ok Mental Status": "green",
                "Bad Mental Status": "red"
            },
            title="Number of Patients with the Same Predicted Label"
        )
        fig_lbl.update_layout(width=1200, height=400)
    else:
        fig_lbl = px.bar(title="No valid predicted labels to display.")
        fig_lbl.update_layout(width=1200, height=400)

    return (
        final_str,         # 1) Prediction Results
        severity_msg,      # 2) Mental Health Severity
        total_count_md,    # 3) Total Patient Count
        nn_md,             # 4) Nearest Neighbors Summary
        fig_in,            # 5) Bar Chart (input features)
        fig_lbl            # 6) Bar Chart (labels)
    )

######################################
# 6) UNIFIED DISTRIBUTION/CO-OCCURRENCE
######################################
def combined_plot(feature_list, label_col):
    """
    If user picks 1 feature => distribution plot.
    If user picks 2 features => co-occurrence plot.
    Otherwise => show error or empty plot.
    This function also maps numeric codes to text using 'input_mapping'
    and 'predictor.prediction_map' so that the plots display more readable labels.
    """
    if not label_col:
        return px.bar(title="Please select a label column.")

    df_copy = df.copy()

    # Convert numeric codes -> text for features
    for col, text_to_num_dict in input_mapping.items():
        if col in df_copy.columns:
            num_to_text = {v: k for k, v in text_to_num_dict.items()}
            df_copy[col] = df_copy[col].map(num_to_text).fillna(df_copy[col])

    # Convert label 1/2 -> text if label_col is in predictor.prediction_map
    if label_col in predictor.prediction_map and label_col in df_copy.columns:
        map_12 = predictor.prediction_map[label_col]
        df_copy[label_col] = df_copy[label_col].map(map_12).fillna(df_copy[label_col])

    if len(feature_list) == 1:
        f_ = feature_list[0]
        if f_ not in df_copy.columns or label_col not in df_copy.columns:
            return px.bar(title="Selected columns not found in dataset.")
        grouped = df_copy.groupby([f_, label_col]).size().reset_index(name="count")
        fig = px.bar(
            grouped,
            x=f_,
            y="count",
            color=label_col,
            title=f"Distribution of {f_} vs {label_col} (Mapped)"
        )
        fig.update_layout(width=1200, height=600)
        return fig

    elif len(feature_list) == 2:
        f1, f2 = feature_list
        if (f1 not in df_copy.columns) or (f2 not in df_copy.columns) or (label_col not in df_copy.columns):
            return px.bar(title="Selected columns not found in dataset.")
        grouped = df_copy.groupby([f1, f2, label_col]).size().reset_index(name="count")
        fig = px.bar(
            grouped,
            x=f1,
            y="count",
            color=label_col,
            facet_col=f2,
            title=f"Co-occurrence: {f1}, {f2} vs {label_col} (Mapped)"
        )
        fig.update_layout(width=1200, height=600)
        return fig

    else:
        return px.bar(title="Please select exactly 1 or 2 features.")

######################################
# 7) BUILD GRADIO UI
######################################
with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
    # DISCLAIMER
    gr.Markdown(
        "#### **Disclaimer**: This is a prototype aiming to apply data-driven AI for mental assessment. "
        "It is advised to seek consultation and assessment from a real clinician whenever needed."
    )

    # ======== TAB 1: Prediction ========
    with gr.Tab("Prediction"):
        gr.Markdown("### Please provide inputs in each category below. All fields are required.")

        # Category 1: Depression & Substance Use Diagnosis (8 features)
        gr.Markdown("#### 1. Depression & Substance Use Diagnosis")
        cat1_col_labels = [
            ("YMDESUD5ANYO", "YMDESUD5ANYO: ONLY MDE, ONLY SUD, BOTH, OR NEITHER"),
            ("YMDELT",       "YMDELT: Had major depressive episode in lifetime"),
            ("YMDEYR",       "YMDEYR: Past-year major depressive episode"),
            ("YMDERSUD5ANY", "YMDERSUD5ANY: MDE or SUD in past year?"),
            ("YMSUD5YANY",   "YMSUD5YANY: Past-year MDE & substance use disorder"),
            ("YMIUD5YANY",   "YMIUD5YANY: Past-year MDE & illicit drug use disorder"),
            ("YMIMS5YANY",   "YMIMS5YANY: Past-year MDE + severe impairment + substance use"),
            ("YMIMI5YANY",   "YMIMI5YANY: Past-year MDE w/ severe impairment & illicit drug use")
        ]
        cat1_inputs = []
        for col, label_text in cat1_col_labels:
            cat1_inputs.append(
                gr.Dropdown(
                    choices=list(input_mapping[col].keys()),
                    label=label_text
                )
            )

        # Category 2: Mental Health Treatment & Professional Consultation (11 features)
        gr.Markdown("#### 2. Mental Health Treatment & Professional Consultation")
        cat2_col_labels = [
            ("YMDEHPO",   "YMDEHPO: Saw health prof only for MDE"),
            ("YMDETXRX",  "YMDETXRX: Received treatment/counseling if saw doc/prof for MDE"),
            ("YMDEHARX",  "YMDEHARX: Saw health prof & medication for MDE"),
            ("YMDEHPRX",  "YMDEHPRX: Saw health prof or med for MDE in past year?"),
            ("YRXMDEYR",  "YRXMDEYR: Used medication for MDE in past years"),
            ("YHLTMDE",   "YHLTMDE: Saw/talked to health prof about MDE"),
            ("YTXMDEYR",  "YTXMDEYR: Saw/talked to doc/prof for MDE in past year"),
            ("YDOCMDE",   "YDOCMDE: Saw/talked to general practitioner/family MD"),
            ("YPSY2MDE",  "YPSY2MDE: Saw/talked to psychiatrist"),
            ("YPSY1MDE",  "YPSY1MDE: Saw/talked to psychologist"),
            ("YCOUNMDE",  "YCOUNMDE: Saw/talked to counselor")
        ]
        cat2_inputs = []
        for col, label_text in cat2_col_labels:
            cat2_inputs.append(
                gr.Dropdown(
                    choices=list(input_mapping[col].keys()),
                    label=label_text
                )
            )

        # Category 3: Functional & Cognitive Impairment (2 features)
        gr.Markdown("#### 3. Functional & Cognitive Impairment")
        cat3_col_labels = [
            ("MDEIMPY",    "MDEIMPY: MDE with severe role impairment?"),
            ("LVLDIFMEM2", "LVLDIFMEM2: Difficulty remembering/concentrating")
        ]
        cat3_inputs = []
        for col, label_text in cat3_col_labels:
            cat3_inputs.append(
                gr.Dropdown(
                    choices=list(input_mapping[col].keys()),
                    label=label_text
                )
            )

        # Category 4: Suicidal Thoughts & Behaviors (4 features)
        gr.Markdown("#### 4. Suicidal Thoughts & Behaviors")
        cat4_col_labels = [
            ("YUSUITHK",   "YUSUITHK: Thought of killing self (past 12 months)?"),
            ("YUSUITHKYR", "YUSUITHKYR: Seriously thought about killing self?"),
            ("YUSUIPLNYR", "YUSUIPLNYR: Made plans to kill self in past years?"),
            ("YUSUIPLN",   "YUSUIPLN: Made plans to kill yourself in past 12 months?")
        ]
        cat4_inputs = []
        for col, label_text in cat4_col_labels:
            cat4_inputs.append(
                gr.Dropdown(
                    choices=list(input_mapping[col].keys()),
                    label=label_text
                )
            )

        # Combine all
        all_inputs = cat1_inputs + cat2_inputs + cat3_inputs + cat4_inputs

        # Outputs
        predict_btn = gr.Button("Predict")

        out_pred_res   = gr.Textbox(label="Prediction Results (with Probability)", lines=8)
        out_sev        = gr.Textbox(label="Suggested Mental Health Severity", lines=2)
        out_count      = gr.Markdown(label="Total Patient Count")
        out_nn         = gr.Markdown(label="Nearest Neighbors Summary")
        out_bar_input  = gr.Plot(label="Input Feature Counts")
        out_bar_label  = gr.Plot(label="Predicted Label Counts")

        # Connect predict button
        predict_btn.click(
            fn=predict,
            inputs=all_inputs,
            outputs=[
                out_pred_res,
                out_sev,
                out_count,
                out_nn,
                out_bar_input,
                out_bar_label
            ]
        )

    # ======== TAB 2: Unified Distribution/Co-occurrence ========
    with gr.Tab("Distribution/Co-occurrence"):
        gr.Markdown("### Select 1 or 2 features + 1 label to see a bar chart.")

        list_of_features = sorted(input_mapping.keys())
        list_of_labels   = sorted(predictor.prediction_map.keys())

        selected_features = gr.CheckboxGroup(
            choices=list_of_features,
            label="Select 1 or 2 features"
        )
        label_dd = gr.Dropdown(
            choices=list_of_labels,
            label="Label Column (e.g. YOWRCONC, YOSEEDOC, etc.)"
        )

        generate_combined_btn = gr.Button("Generate Plot")
        combined_output = gr.Plot()

        generate_combined_btn.click(
            fn=combined_plot,
            inputs=[selected_features, label_dd],
            outputs=combined_output
        )


    # Define the glossary content
    feature_label_table_glossary_content = """
        ## Glossary for the Input Features and Target Labels
        
        The compiled glossary for each of the identified input features and output features is listed in Table 1 and Table 2 below.
        
        **Table 1: Target Labels**
        
        | S.N | Target Label | Description |
        |----|-------------|-------------|
        | 1  | YOWRCONC | On most days, did you have a lot more trouble than usual keeping your mind on things? |
        | 2  | YOSEEDOC | At any time in the past 12 months, did you see or talk to a medical doctor or other professional about your feelings? |
        | 3  | YO_MDEA5 | Others Noticed That the Respondent Was Restless or Lethargic |
        | 4  | YOWRLSIN | During that worst period of time, did you become bored with almost everything like school, work, hobbies, and things you like to do for fun? |
        | 5  | YODPPROB | Did you ever have any of the problems (sleep, eating, energy, etc.) for two weeks or longer? |
        | 6  | YOWRPROB | Can you think of the worst time when you felt for two weeks or longer and also had these other problems? |
        | 7  | YODPR2WK | Did you ever have a period of time that lasted most of the day, almost every day, for two weeks or longer? |
        | 8  | YOWRDEPR | During that time, did you feel sad, empty, or depressed for most of the day nearly every day? |
        | 9  | YODPDISC | Did you ever feel discouraged about how things were going in your life? |
        | 10 | YOLOSEV | Have you ever had a period when you lost interest and became bored with most things? |
        | 11 | YOWRDCSN | Were you unable to make up your mind about things? |
        | 12 | YODSMMDE | Score of Symptom Indicators 1 Through 9 |
        | 13 | YO_MDEA3 | Changes in Appetite or Weight |
        | 14 | YODPLSIN | Did you ever lose interest and become really bored with most things? |
        | 15 | YOWRELES | Did you eat much less than usual almost every day during that time? |
        | 16 | YOPB2WK | In the past 12 months, did you have a period of time when you felt for two weeks or longer? |

        **Table 1: Input Features**
        
        | S.N | Input Feature | Description |
        |----|---------------|-------------|
        | 1  | YMDEYR | Youth: Past Year Major Depressive Episode (MDE) |
        | 2  | YMDERSUD5ANY | Youth: Major Depressive Episode or Substance Use Disorder - Past Year - DSM-5 - Any |
        | 3  | YMIMS5YANY | Youth: Past Year Major Depressive Episode with Severe Impairment and Substance Use Disorder - DSM-5 - Any |
        | 4  | YMDELT | Youth: Lifetime Major Depressive Episode (MDE) |
        | 5  | YMDEHARX | Youth: Saw Health Professional and Prescribed Medication for Major Depressive Episode in Past Year |
        | 6  | YMDEHPRX | Youth: Saw Health Professional or Prescribed Medication for Major Depressive Episode in Past Year |
        | 7  | YMDETXRX | Youth: Received Treatment/Counseling or Prescribed Medication for Major Depressive Episode in Past Year |
        | 8  | YMDEHPO | Youth: Saw Health Professional Only for Major Depressive Episode in Past Year |
        | 9  | YMIMI5YANY | Youth: Past Year Major Depressive Episode with Severe Impairment and Illicit Drug Use Disorder - DSM-5 - Any |
        | 10 | YMIUD5YANY | Youth: Past Year Major Depressive Episode and Illicit Drug Use Disorder - DSM-5 - Any |
        | 11 | YMDESUD5ANYO | Youth: Only Major Depressive Episode, Only Substance Use Disorder, Both, or Neither - Past Year - DSM-5 - Any |
        | 12 | YCOUNMDE | Youth: Saw/Talked to Counselor About Major Depressive Episode in Past Year |
        | 13 | YPSY1MDE | Youth: Saw/Talked to Psychologist About Major Depressive Episode in Past Year |
        | 14 | YPSY2MDE | Youth: Saw/Talked to Psychiatrist About Major Depressive Episode in Past Year |
        | 15 | YHLTMDE | Youth: Saw/Talked to Health Professional About Major Depressive Episode in Past Year |
        | 16 | YDOCMDE | Youth: Saw/Talked to General Practitioner/Family Doctor About Major Depressive Episode in Past Year |
        | 17 | YTXMDEYR | Youth: Saw or Talked to Doctor/Professional for Major Depressive Episode in Past Year |
        | 18 | YUSUITHKYR | Youth: Seriously Thought About Killing Self in Past Year |
        | 19 | YUSUIPLNYR | Youth: Made Plans to Kill Self in Past Year |
        | 20 | YUSUITHK | Youth: Seriously Thought About Killing Self in Past 12 Months |
        | 21 | YUSUIPLN | Youth: Made Plans to Kill Yourself in Past 12 Months |
        | 22 | MDEIMPY | Youth: Major Depressive Episode with Severe Role Impairment |
        | 23 | LVLDIFMEM2 | Level of Difficulty Remembering or Concentrating |
        | 24 | YMSUD5YANY | Youth: Past Year Major Depressive Episode and Substance Use Disorder - DSM-5 - Any |
        | 25 | YRXMDEYR | Youth: Used Prescription Medication for Major Depressive Episode in Past Year |
        
        More information can be found at:
        - [NSDUH 2021 Codebook](https://www.datafiles.samhsa.gov/sites/default/files/field-uploads-protected/studies/NSDUH-2021/NSDUH-2021-datasets/NSDUH-2021-DS0001/NSDUH-2021-DS0001-info/NSDUH-2021-DS0001-info-codebook.pdf)
        - [NSDUH 2022 Codebook](https://www.datafiles.samhsa.gov/sites/default/files/field-uploads-protected/studies/NSDUH-2022/NSDUH-2022-datasets/NSDUH-2022-DS0001/NSDUH-2022-DS0001-info/NSDUH-2022-DS0001-info-codebook.pdf)
        """

    def glossary_display():
        return feature_label_table_glossary_content

    
    with gr.Tab("Feature and label description"):
        gr.Markdown(feature_label_table_glossary_content)

    with gr.Tab("Summary Statistics"):
        gr.Markdown("![Summary Statistics Table](https://huggingface.co/spaces/pantdipendra/AdolescentsMentalHealthPrediction/resolve/main/Table111.jpg)")

# Launch the Gradio app
demo.launch()