pantdipendra
commited on
v4
Browse files
app.py
CHANGED
@@ -7,8 +7,14 @@ import plotly.express as px
|
|
7 |
######################################
|
8 |
# 1) LOAD DATA & MODELS
|
9 |
######################################
|
|
|
10 |
df = pd.read_csv("X_train_test_combined_dataset_Filtered_dataset.csv")
|
11 |
|
|
|
|
|
|
|
|
|
|
|
12 |
model_filenames = [
|
13 |
"YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl",
|
14 |
"YODPPROB.pkl", "YOWRPROB.pkl", "YODPR2WK.pkl", "YOWRDEPR.pkl",
|
@@ -17,7 +23,6 @@ model_filenames = [
|
|
17 |
]
|
18 |
model_path = "models/"
|
19 |
|
20 |
-
|
21 |
######################################
|
22 |
# 2) MODEL PREDICTOR
|
23 |
######################################
|
@@ -38,7 +43,7 @@ class ModelPredictor:
|
|
38 |
"YODPPROB": ["No other problems for 2+ weeks", "Had other problems for 2+ weeks"],
|
39 |
"YOWRPROB": ["No 'worst time ever' feeling", "Had 'worst time ever' feeling"],
|
40 |
"YODPR2WK": ["No depressed feelings for 2+ wks", "Had depressed feelings for 2+ wks"],
|
41 |
-
"YOWRDEPR": ["Did NOT feel sad/depressed daily",
|
42 |
"YODPDISC": ["Overall mood not sad/depressed", "Overall mood was sad/depressed"],
|
43 |
"YOLOSEV": ["Did NOT lose interest in things", "Lost interest in enjoyable things"],
|
44 |
"YOWRDCSN": ["Was able to make decisions", "Was unable to make decisions"],
|
@@ -54,9 +59,14 @@ class ModelPredictor:
|
|
54 |
def load_models(self):
|
55 |
loaded = []
|
56 |
for fname in self.model_filenames:
|
57 |
-
|
58 |
-
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
60 |
return loaded
|
61 |
|
62 |
def make_predictions(self, user_input: pd.DataFrame):
|
@@ -91,17 +101,14 @@ class ModelPredictor:
|
|
91 |
else:
|
92 |
return "Mental Health Severity: Very Low"
|
93 |
|
94 |
-
|
95 |
predictor = ModelPredictor(model_path, model_filenames)
|
96 |
|
97 |
-
|
98 |
######################################
|
99 |
# 3) FEATURE CATEGORIES + MAPPING
|
100 |
######################################
|
101 |
-
# Replaced 'YMDESUD5ANYO' with 'YMDESUD5ANY' to match your CSV
|
102 |
categories_dict = {
|
103 |
"1. Depression & Substance Use Diagnosis": [
|
104 |
-
"
|
105 |
"YMSUD5YANY", "YMIUD5YANY", "YMIMS5YANY", "YMIMI5YANY"
|
106 |
],
|
107 |
"2. Mental Health Treatment & Prof Consultation": [
|
@@ -116,9 +123,13 @@ categories_dict = {
|
|
116 |
]
|
117 |
}
|
118 |
|
119 |
-
# Again, replaced 'YMDESUD5ANYO' with 'YMDESUD5ANY'
|
120 |
input_mapping = {
|
121 |
-
'
|
|
|
|
|
|
|
|
|
|
|
122 |
'YMDELT': {"Yes": 1, "No": 2},
|
123 |
'YMDEYR': {"Yes": 1, "No": 2},
|
124 |
'YMDERSUD5ANY': {"Yes": 1, "No": 0},
|
@@ -140,7 +151,11 @@ input_mapping = {
|
|
140 |
'YCOUNMDE': {"Yes": 1, "No": 0},
|
141 |
|
142 |
'MDEIMPY': {"Yes": 1, "No": 2},
|
143 |
-
'LVLDIFMEM2': {
|
|
|
|
|
|
|
|
|
144 |
|
145 |
'YUSUITHK': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
146 |
'YUSUITHKYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
@@ -148,10 +163,9 @@ input_mapping = {
|
|
148 |
'YUSUIPLN': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4}
|
149 |
}
|
150 |
|
151 |
-
|
152 |
def validate_inputs(*args):
|
153 |
for arg in args:
|
154 |
-
if
|
155 |
return False
|
156 |
return True
|
157 |
|
@@ -209,13 +223,12 @@ def get_nearest_neighbors_info(user_input_df: pd.DataFrame, k=5):
|
|
209 |
lines.append("")
|
210 |
return "\n".join(lines)
|
211 |
|
212 |
-
|
213 |
######################################
|
214 |
# 5) PREDICT FUNCTION
|
215 |
######################################
|
216 |
def predict(
|
217 |
# Category 1 (8):
|
218 |
-
|
219 |
YMSUD5YANY, YMIUD5YANY, YMIMS5YANY, YMIMI5YANY,
|
220 |
# Category 2 (11):
|
221 |
YMDEHPO, YMDETXRX, YMDEHARX, YMDEHPRX, YRXMDEYR,
|
@@ -227,7 +240,7 @@ def predict(
|
|
227 |
):
|
228 |
# 1) Validate
|
229 |
if not validate_inputs(
|
230 |
-
|
231 |
YMSUD5YANY, YMIUD5YANY, YMIMS5YANY, YMIMI5YANY,
|
232 |
YMDEHPO, YMDETXRX, YMDEHARX, YMDEHPRX, YRXMDEYR,
|
233 |
YHLTMDE, YTXMDEYR, YDOCMDE, YPSY2MDE, YPSY1MDE, YCOUNMDE,
|
@@ -235,49 +248,71 @@ def predict(
|
|
235 |
YUSUITHK, YUSUITHKYR, YUSUIPLNYR, YUSUIPLN
|
236 |
):
|
237 |
return (
|
238 |
-
"Please select all required fields.",
|
239 |
-
"Validation Error",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
"No data",
|
241 |
"No nearest neighbors info",
|
242 |
None,
|
243 |
None
|
244 |
)
|
245 |
|
246 |
-
# 2) Convert text -> numeric
|
247 |
-
user_input_dict = {
|
248 |
-
'YMDESUD5ANY': input_mapping['YMDESUD5ANY'][YMDESUD5ANY],
|
249 |
-
'YMDELT': input_mapping['YMDELT'][YMDELT],
|
250 |
-
'YMDEYR': input_mapping['YMDEYR'][YMDEYR],
|
251 |
-
'YMDERSUD5ANY': input_mapping['YMDERSUD5ANY'][YMDERSUD5ANY],
|
252 |
-
'YMSUD5YANY': input_mapping['YMSUD5YANY'][YMSUD5YANY],
|
253 |
-
'YMIUD5YANY': input_mapping['YMIUD5YANY'][YMIUD5YANY],
|
254 |
-
'YMIMS5YANY': input_mapping['YMIMS5YANY'][YMIMS5YANY],
|
255 |
-
'YMIMI5YANY': input_mapping['YMIMI5YANY'][YMIMI5YANY],
|
256 |
-
|
257 |
-
'YMDEHPO': input_mapping['YMDEHPO'][YMDEHPO],
|
258 |
-
'YMDETXRX': input_mapping['YMDETXRX'][YMDETXRX],
|
259 |
-
'YMDEHARX': input_mapping['YMDEHARX'][YMDEHARX],
|
260 |
-
'YMDEHPRX': input_mapping['YMDEHPRX'][YMDEHPRX],
|
261 |
-
'YRXMDEYR': input_mapping['YRXMDEYR'][YRXMDEYR],
|
262 |
-
'YHLTMDE': input_mapping['YHLTMDE'][YHLTMDE],
|
263 |
-
'YTXMDEYR': input_mapping['YTXMDEYR'][YTXMDEYR],
|
264 |
-
'YDOCMDE': input_mapping['YDOCMDE'][YDOCMDE],
|
265 |
-
'YPSY2MDE': input_mapping['YPSY2MDE'][YPSY2MDE],
|
266 |
-
'YPSY1MDE': input_mapping['YPSY1MDE'][YPSY1MDE],
|
267 |
-
'YCOUNMDE': input_mapping['YCOUNMDE'][YCOUNMDE],
|
268 |
-
|
269 |
-
'MDEIMPY': input_mapping['MDEIMPY'][MDEIMPY],
|
270 |
-
'LVLDIFMEM2': input_mapping['LVLDIFMEM2'][LVLDIFMEM2],
|
271 |
-
|
272 |
-
'YUSUITHK': input_mapping['YUSUITHK'][YUSUITHK],
|
273 |
-
'YUSUITHKYR': input_mapping['YUSUITHKYR'][YUSUITHKYR],
|
274 |
-
'YUSUIPLNYR': input_mapping['YUSUIPLNYR'][YUSUIPLNYR],
|
275 |
-
'YUSUIPLN': input_mapping['YUSUIPLN'][YUSUIPLN]
|
276 |
-
}
|
277 |
user_df = pd.DataFrame(user_input_dict, index=[0])
|
278 |
|
279 |
# 3) Make predictions
|
280 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
|
282 |
# Flatten predictions for severity count
|
283 |
all_preds = np.concatenate(preds)
|
@@ -295,13 +330,13 @@ def predict(
|
|
295 |
|
296 |
# Group them by domain
|
297 |
domain_groups = {
|
298 |
-
"
|
299 |
-
"
|
300 |
-
"
|
301 |
"YOWRCHR", "YOWRLSIN", "YOWRDEPR", "YODPDISC", "YOLOSEV", "YODPLSIN", "YODSCEV"
|
302 |
],
|
303 |
-
"
|
304 |
-
"
|
305 |
"YODPPROB", "YOWRPROB", "YODPR2WK", "YODSMMDE", "YOPB2WK"
|
306 |
]
|
307 |
}
|
@@ -320,14 +355,13 @@ def predict(
|
|
320 |
if not np.isnan(prob_val):
|
321 |
text_prob = f"(Prob= {prob_val:.2f})"
|
322 |
else:
|
323 |
-
text_prob = "(No
|
324 |
|
325 |
group_lines.append(f"{lbl} => {text_pred} {text_prob}")
|
326 |
if group_lines:
|
327 |
-
|
328 |
-
final_str_parts.append(f"**{gtitle}**")
|
329 |
final_str_parts.append("\n".join(group_lines))
|
330 |
-
final_str_parts.append("")
|
331 |
|
332 |
if final_str_parts:
|
333 |
final_str = "\n".join(final_str_parts)
|
@@ -345,8 +379,10 @@ def predict(
|
|
345 |
for col, val_ in user_input_dict.items():
|
346 |
matched = len(df[df[col] == val_])
|
347 |
input_counts[col] = matched
|
348 |
-
bar_in_df = pd.DataFrame({
|
349 |
-
|
|
|
|
|
350 |
fig_in = px.bar(
|
351 |
bar_in_df, x="Feature", y="Count",
|
352 |
title="Number of Patients with the Same Input Feature Values"
|
@@ -376,12 +412,11 @@ def predict(
|
|
376 |
final_str, # 1) Prediction Results
|
377 |
severity_msg, # 2) Mental Health Severity
|
378 |
total_count_md, # 3) Total Patient Count
|
379 |
-
nn_md, # 4) Nearest Neighbors
|
380 |
fig_in, # 5) Bar Chart (input features)
|
381 |
fig_lbl # 6) Bar Chart (labels)
|
382 |
)
|
383 |
|
384 |
-
|
385 |
######################################
|
386 |
# 6) UNIFIED DISTRIBUTION/CO-OCCURRENCE
|
387 |
######################################
|
@@ -399,8 +434,13 @@ def combined_plot(feature_list, label_col):
|
|
399 |
if f_ not in df.columns or label_col not in df.columns:
|
400 |
return px.bar(title="Selected columns not found in the dataset.")
|
401 |
grouped = df.groupby([f_, label_col]).size().reset_index(name="count")
|
402 |
-
fig = px.bar(
|
403 |
-
|
|
|
|
|
|
|
|
|
|
|
404 |
fig.update_layout(width=1200, height=600)
|
405 |
return fig
|
406 |
|
@@ -410,8 +450,12 @@ def combined_plot(feature_list, label_col):
|
|
410 |
return px.bar(title="Selected columns not found in the dataset.")
|
411 |
grouped = df.groupby([f1, f2, label_col]).size().reset_index(name="count")
|
412 |
fig = px.bar(
|
413 |
-
grouped,
|
414 |
-
|
|
|
|
|
|
|
|
|
415 |
)
|
416 |
fig.update_layout(width=1200, height=600)
|
417 |
return fig
|
@@ -419,20 +463,19 @@ def combined_plot(feature_list, label_col):
|
|
419 |
else:
|
420 |
return px.bar(title="Please select exactly 1 or 2 features.")
|
421 |
|
422 |
-
|
423 |
######################################
|
424 |
# 7) BUILD GRADIO UI
|
425 |
######################################
|
426 |
with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
|
427 |
|
428 |
-
# TAB 1: Prediction
|
429 |
with gr.Tab("Prediction"):
|
430 |
gr.Markdown("### Please provide inputs in each of the four categories below. All fields are required.")
|
431 |
|
432 |
-
# Category 1
|
433 |
gr.Markdown("#### 1. Depression & Substance Use Diagnosis")
|
434 |
cat1_col_labels = [
|
435 |
-
("
|
436 |
("YMDELT", "YMDELT: Had major depressive episode in lifetime"),
|
437 |
("YMDEYR", "YMDEYR: Past-year major depressive episode"),
|
438 |
("YMDERSUD5ANY", "YMDERSUD5ANY: MDE or SUD in past year?"),
|
@@ -444,10 +487,13 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
|
|
444 |
cat1_inputs = []
|
445 |
for col, label_text in cat1_col_labels:
|
446 |
cat1_inputs.append(
|
447 |
-
gr.Dropdown(
|
|
|
|
|
|
|
448 |
)
|
449 |
|
450 |
-
# Category 2
|
451 |
gr.Markdown("#### 2. Mental Health Treatment & Professional Consultation")
|
452 |
cat2_col_labels = [
|
453 |
("YMDEHPO", "YMDEHPO: Saw health prof only for MDE"),
|
@@ -465,10 +511,13 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
|
|
465 |
cat2_inputs = []
|
466 |
for col, label_text in cat2_col_labels:
|
467 |
cat2_inputs.append(
|
468 |
-
gr.Dropdown(
|
|
|
|
|
|
|
469 |
)
|
470 |
|
471 |
-
# Category 3
|
472 |
gr.Markdown("#### 3. Functional & Cognitive Impairment")
|
473 |
cat3_col_labels = [
|
474 |
("MDEIMPY", "MDEIMPY: MDE with severe role impairment?"),
|
@@ -477,10 +526,13 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
|
|
477 |
cat3_inputs = []
|
478 |
for col, label_text in cat3_col_labels:
|
479 |
cat3_inputs.append(
|
480 |
-
gr.Dropdown(
|
|
|
|
|
|
|
481 |
)
|
482 |
|
483 |
-
# Category 4
|
484 |
gr.Markdown("#### 4. Suicidal Thoughts & Behaviors")
|
485 |
cat4_col_labels = [
|
486 |
("YUSUITHK", "YUSUITHK: Thought of killing self (past 12 months)?"),
|
@@ -491,12 +543,16 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
|
|
491 |
cat4_inputs = []
|
492 |
for col, label_text in cat4_col_labels:
|
493 |
cat4_inputs.append(
|
494 |
-
gr.Dropdown(
|
|
|
|
|
|
|
495 |
)
|
496 |
|
497 |
-
# Combine in the
|
498 |
all_inputs = cat1_inputs + cat2_inputs + cat3_inputs + cat4_inputs
|
499 |
|
|
|
500 |
predict_btn = gr.Button("Predict")
|
501 |
|
502 |
out_pred_res = gr.Textbox(label="Prediction Results (with Probability)", lines=8)
|
@@ -506,6 +562,7 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
|
|
506 |
out_bar_input= gr.Plot(label="Input Feature Counts")
|
507 |
out_bar_label= gr.Plot(label="Predicted Label Counts")
|
508 |
|
|
|
509 |
predict_btn.click(
|
510 |
fn=predict,
|
511 |
inputs=all_inputs,
|
@@ -522,8 +579,8 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
|
|
522 |
# ======== TAB 2: Unified Distribution/Co-occurrence ========
|
523 |
with gr.Tab("Distribution/Co-occurrence"):
|
524 |
gr.Markdown("### Select 1 or 2 features + 1 label to see a bar chart.")
|
525 |
-
|
526 |
-
#
|
527 |
list_of_features = sorted(df.columns)
|
528 |
list_of_labels = sorted(predictor.prediction_map.keys())
|
529 |
|
@@ -545,5 +602,5 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
|
|
545 |
outputs=combined_output
|
546 |
)
|
547 |
|
548 |
-
# Finally, launch
|
549 |
demo.launch()
|
|
|
7 |
######################################
|
8 |
# 1) LOAD DATA & MODELS
|
9 |
######################################
|
10 |
+
# Load your dataset
|
11 |
df = pd.read_csv("X_train_test_combined_dataset_Filtered_dataset.csv")
|
12 |
|
13 |
+
# Ensure 'YMDESUD5ANYO' exists in your DataFrame
|
14 |
+
if 'YMDESUD5ANYO' not in df.columns:
|
15 |
+
raise ValueError("The column 'YMDESUD5ANYO' is missing from the dataset. Please check your CSV file.")
|
16 |
+
|
17 |
+
# List of model filenames
|
18 |
model_filenames = [
|
19 |
"YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl",
|
20 |
"YODPPROB.pkl", "YOWRPROB.pkl", "YODPR2WK.pkl", "YOWRDEPR.pkl",
|
|
|
23 |
]
|
24 |
model_path = "models/"
|
25 |
|
|
|
26 |
######################################
|
27 |
# 2) MODEL PREDICTOR
|
28 |
######################################
|
|
|
43 |
"YODPPROB": ["No other problems for 2+ weeks", "Had other problems for 2+ weeks"],
|
44 |
"YOWRPROB": ["No 'worst time ever' feeling", "Had 'worst time ever' feeling"],
|
45 |
"YODPR2WK": ["No depressed feelings for 2+ wks", "Had depressed feelings for 2+ wks"],
|
46 |
+
"YOWRDEPR": ["Did NOT feel sad/depressed daily", "Felt sad/depressed mostly everyday"],
|
47 |
"YODPDISC": ["Overall mood not sad/depressed", "Overall mood was sad/depressed"],
|
48 |
"YOLOSEV": ["Did NOT lose interest in things", "Lost interest in enjoyable things"],
|
49 |
"YOWRDCSN": ["Was able to make decisions", "Was unable to make decisions"],
|
|
|
59 |
def load_models(self):
|
60 |
loaded = []
|
61 |
for fname in self.model_filenames:
|
62 |
+
try:
|
63 |
+
with open(self.model_path + fname, "rb") as f:
|
64 |
+
model = pickle.load(f)
|
65 |
+
loaded.append(model)
|
66 |
+
except FileNotFoundError:
|
67 |
+
raise FileNotFoundError(f"Model file '{fname}' not found in path '{self.model_path}'.")
|
68 |
+
except Exception as e:
|
69 |
+
raise Exception(f"Error loading model '{fname}': {e}")
|
70 |
return loaded
|
71 |
|
72 |
def make_predictions(self, user_input: pd.DataFrame):
|
|
|
101 |
else:
|
102 |
return "Mental Health Severity: Very Low"
|
103 |
|
|
|
104 |
predictor = ModelPredictor(model_path, model_filenames)
|
105 |
|
|
|
106 |
######################################
|
107 |
# 3) FEATURE CATEGORIES + MAPPING
|
108 |
######################################
|
|
|
109 |
categories_dict = {
|
110 |
"1. Depression & Substance Use Diagnosis": [
|
111 |
+
"YMDESUD5ANYO", "YMDELT", "YMDEYR", "YMDERSUD5ANY",
|
112 |
"YMSUD5YANY", "YMIUD5YANY", "YMIMS5YANY", "YMIMI5YANY"
|
113 |
],
|
114 |
"2. Mental Health Treatment & Prof Consultation": [
|
|
|
123 |
]
|
124 |
}
|
125 |
|
|
|
126 |
input_mapping = {
|
127 |
+
'YMDESUD5ANYO': {
|
128 |
+
"SUD only, no MDE": 1,
|
129 |
+
"MDE only, no SUD": 2,
|
130 |
+
"SUD and MDE": 3,
|
131 |
+
"Neither SUD or MDE": 4
|
132 |
+
},
|
133 |
'YMDELT': {"Yes": 1, "No": 2},
|
134 |
'YMDEYR': {"Yes": 1, "No": 2},
|
135 |
'YMDERSUD5ANY': {"Yes": 1, "No": 0},
|
|
|
151 |
'YCOUNMDE': {"Yes": 1, "No": 0},
|
152 |
|
153 |
'MDEIMPY': {"Yes": 1, "No": 2},
|
154 |
+
'LVLDIFMEM2': {
|
155 |
+
"No Difficulty": 1,
|
156 |
+
"Some difficulty": 2,
|
157 |
+
"A lot of difficulty or cannot do at all": 3
|
158 |
+
},
|
159 |
|
160 |
'YUSUITHK': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
161 |
'YUSUITHKYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
|
|
163 |
'YUSUIPLN': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4}
|
164 |
}
|
165 |
|
|
|
166 |
def validate_inputs(*args):
|
167 |
for arg in args:
|
168 |
+
if arg is None or arg == "":
|
169 |
return False
|
170 |
return True
|
171 |
|
|
|
223 |
lines.append("")
|
224 |
return "\n".join(lines)
|
225 |
|
|
|
226 |
######################################
|
227 |
# 5) PREDICT FUNCTION
|
228 |
######################################
|
229 |
def predict(
|
230 |
# Category 1 (8):
|
231 |
+
YMDESUD5ANYO, YMDELT, YMDEYR, YMDERSUD5ANY,
|
232 |
YMSUD5YANY, YMIUD5YANY, YMIMS5YANY, YMIMI5YANY,
|
233 |
# Category 2 (11):
|
234 |
YMDEHPO, YMDETXRX, YMDEHARX, YMDEHPRX, YRXMDEYR,
|
|
|
240 |
):
|
241 |
# 1) Validate
|
242 |
if not validate_inputs(
|
243 |
+
YMDESUD5ANYO, YMDELT, YMDEYR, YMDERSUD5ANY,
|
244 |
YMSUD5YANY, YMIUD5YANY, YMIMS5YANY, YMIMI5YANY,
|
245 |
YMDEHPO, YMDETXRX, YMDEHARX, YMDEHPRX, YRXMDEYR,
|
246 |
YHLTMDE, YTXMDEYR, YDOCMDE, YPSY2MDE, YPSY1MDE, YCOUNMDE,
|
|
|
248 |
YUSUITHK, YUSUITHKYR, YUSUIPLNYR, YUSUIPLN
|
249 |
):
|
250 |
return (
|
251 |
+
"Please select all required fields.", # 1) Prediction Results
|
252 |
+
"Validation Error", # 2) Severity
|
253 |
+
"No data", # 3) Total Count
|
254 |
+
"No nearest neighbors info", # 4) NN Summary
|
255 |
+
None, # 5) Bar chart (Input)
|
256 |
+
None # 6) Bar chart (Labels)
|
257 |
+
)
|
258 |
+
|
259 |
+
# 2) Convert text -> numeric
|
260 |
+
try:
|
261 |
+
user_input_dict = {
|
262 |
+
'YMDESUD5ANYO': input_mapping['YMDESUD5ANYO'][YMDESUD5ANYO],
|
263 |
+
'YMDELT': input_mapping['YMDELT'][YMDELT],
|
264 |
+
'YMDEYR': input_mapping['YMDEYR'][YMDEYR],
|
265 |
+
'YMDERSUD5ANY': input_mapping['YMDERSUD5ANY'][YMDERSUD5ANY],
|
266 |
+
'YMSUD5YANY': input_mapping['YMSUD5YANY'][YMSUD5YANY],
|
267 |
+
'YMIUD5YANY': input_mapping['YMIUD5YANY'][YMIUD5YANY],
|
268 |
+
'YMIMS5YANY': input_mapping['YMIMS5YANY'][YMIMS5YANY],
|
269 |
+
'YMIMI5YANY': input_mapping['YMIMI5YANY'][YMIMI5YANY],
|
270 |
+
|
271 |
+
'YMDEHPO': input_mapping['YMDEHPO'][YMDEHPO],
|
272 |
+
'YMDETXRX': input_mapping['YMDETXRX'][YMDETXRX],
|
273 |
+
'YMDEHARX': input_mapping['YMDEHARX'][YMDEHARX],
|
274 |
+
'YMDEHPRX': input_mapping['YMDEHPRX'][YMDEHPRX],
|
275 |
+
'YRXMDEYR': input_mapping['YRXMDEYR'][YRXMDEYR],
|
276 |
+
'YHLTMDE': input_mapping['YHLTMDE'][YHLTMDE],
|
277 |
+
'YTXMDEYR': input_mapping['YTXMDEYR'][YTXMDEYR],
|
278 |
+
'YDOCMDE': input_mapping['YDOCMDE'][YDOCMDE],
|
279 |
+
'YPSY2MDE': input_mapping['YPSY2MDE'][YPSY2MDE],
|
280 |
+
'YPSY1MDE': input_mapping['YPSY1MDE'][YPSY1MDE],
|
281 |
+
'YCOUNMDE': input_mapping['YCOUNMDE'][YCOUNMDE],
|
282 |
+
|
283 |
+
'MDEIMPY': input_mapping['MDEIMPY'][MDEIMPY],
|
284 |
+
'LVLDIFMEM2': input_mapping['LVLDIFMEM2'][LVLDIFMEM2],
|
285 |
+
|
286 |
+
'YUSUITHK': input_mapping['YUSUITHK'][YUSUITHK],
|
287 |
+
'YUSUITHKYR': input_mapping['YUSUITHKYR'][YUSUITHKYR],
|
288 |
+
'YUSUIPLNYR': input_mapping['YUSUIPLNYR'][YUSUIPLNYR],
|
289 |
+
'YUSUIPLN': input_mapping['YUSUIPLN'][YUSUIPLN]
|
290 |
+
}
|
291 |
+
except KeyError as e:
|
292 |
+
missing_key = e.args[0]
|
293 |
+
return (
|
294 |
+
f"Input mapping missing for key: {missing_key}. Please check your `input_mapping` dictionary.",
|
295 |
+
"Mapping Error",
|
296 |
"No data",
|
297 |
"No nearest neighbors info",
|
298 |
None,
|
299 |
None
|
300 |
)
|
301 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
302 |
user_df = pd.DataFrame(user_input_dict, index=[0])
|
303 |
|
304 |
# 3) Make predictions
|
305 |
+
try:
|
306 |
+
preds, probs = predictor.make_predictions(user_df)
|
307 |
+
except Exception as e:
|
308 |
+
return (
|
309 |
+
f"Error during prediction: {e}",
|
310 |
+
"Prediction Error",
|
311 |
+
"No data",
|
312 |
+
"No nearest neighbors info",
|
313 |
+
None,
|
314 |
+
None
|
315 |
+
)
|
316 |
|
317 |
# Flatten predictions for severity count
|
318 |
all_preds = np.concatenate(preds)
|
|
|
330 |
|
331 |
# Group them by domain
|
332 |
domain_groups = {
|
333 |
+
"Concentration and Decision Making": ["YOWRCONC", "YOWRDCSN"],
|
334 |
+
"Sleep and Energy Levels": ["YOWRHRS", "YO_MDEA5", "YOWRELES", "YO_MDEA2"],
|
335 |
+
"Mood and Emotional State": [
|
336 |
"YOWRCHR", "YOWRLSIN", "YOWRDEPR", "YODPDISC", "YOLOSEV", "YODPLSIN", "YODSCEV"
|
337 |
],
|
338 |
+
"Appetite and Weight Changes": ["YO_MDEA3", "YOWRELES"],
|
339 |
+
"Duration and Severity of Depression Symptoms": [
|
340 |
"YODPPROB", "YOWRPROB", "YODPR2WK", "YODSMMDE", "YOPB2WK"
|
341 |
]
|
342 |
}
|
|
|
355 |
if not np.isnan(prob_val):
|
356 |
text_prob = f"(Prob= {prob_val:.2f})"
|
357 |
else:
|
358 |
+
text_prob = "(No probability available)"
|
359 |
|
360 |
group_lines.append(f"{lbl} => {text_pred} {text_prob}")
|
361 |
if group_lines:
|
362 |
+
final_str_parts.append(f"**{gname}**")
|
|
|
363 |
final_str_parts.append("\n".join(group_lines))
|
364 |
+
final_str_parts.append("") # Add an empty line for spacing
|
365 |
|
366 |
if final_str_parts:
|
367 |
final_str = "\n".join(final_str_parts)
|
|
|
379 |
for col, val_ in user_input_dict.items():
|
380 |
matched = len(df[df[col] == val_])
|
381 |
input_counts[col] = matched
|
382 |
+
bar_in_df = pd.DataFrame({
|
383 |
+
"Feature": list(input_counts.keys()),
|
384 |
+
"Count": list(input_counts.values())
|
385 |
+
})
|
386 |
fig_in = px.bar(
|
387 |
bar_in_df, x="Feature", y="Count",
|
388 |
title="Number of Patients with the Same Input Feature Values"
|
|
|
412 |
final_str, # 1) Prediction Results
|
413 |
severity_msg, # 2) Mental Health Severity
|
414 |
total_count_md, # 3) Total Patient Count
|
415 |
+
nn_md, # 4) Nearest Neighbors Summary
|
416 |
fig_in, # 5) Bar Chart (input features)
|
417 |
fig_lbl # 6) Bar Chart (labels)
|
418 |
)
|
419 |
|
|
|
420 |
######################################
|
421 |
# 6) UNIFIED DISTRIBUTION/CO-OCCURRENCE
|
422 |
######################################
|
|
|
434 |
if f_ not in df.columns or label_col not in df.columns:
|
435 |
return px.bar(title="Selected columns not found in the dataset.")
|
436 |
grouped = df.groupby([f_, label_col]).size().reset_index(name="count")
|
437 |
+
fig = px.bar(
|
438 |
+
grouped,
|
439 |
+
x=f_,
|
440 |
+
y="count",
|
441 |
+
color=label_col,
|
442 |
+
title=f"Distribution of {f_} vs {label_col}"
|
443 |
+
)
|
444 |
fig.update_layout(width=1200, height=600)
|
445 |
return fig
|
446 |
|
|
|
450 |
return px.bar(title="Selected columns not found in the dataset.")
|
451 |
grouped = df.groupby([f1, f2, label_col]).size().reset_index(name="count")
|
452 |
fig = px.bar(
|
453 |
+
grouped,
|
454 |
+
x=f1,
|
455 |
+
y="count",
|
456 |
+
color=label_col,
|
457 |
+
facet_col=f2,
|
458 |
+
title=f"Co-occurrence: {f1}, {f2} vs {label_col}"
|
459 |
)
|
460 |
fig.update_layout(width=1200, height=600)
|
461 |
return fig
|
|
|
463 |
else:
|
464 |
return px.bar(title="Please select exactly 1 or 2 features.")
|
465 |
|
|
|
466 |
######################################
|
467 |
# 7) BUILD GRADIO UI
|
468 |
######################################
|
469 |
with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
|
470 |
|
471 |
+
# ======== TAB 1: Prediction ========
|
472 |
with gr.Tab("Prediction"):
|
473 |
gr.Markdown("### Please provide inputs in each of the four categories below. All fields are required.")
|
474 |
|
475 |
+
# Category 1: Depression & Substance Use Diagnosis (8 features)
|
476 |
gr.Markdown("#### 1. Depression & Substance Use Diagnosis")
|
477 |
cat1_col_labels = [
|
478 |
+
("YMDESUD5ANYO", "YMDESUD5ANYO: ONLY MDE, ONLY SUD, BOTH, OR NEITHER"),
|
479 |
("YMDELT", "YMDELT: Had major depressive episode in lifetime"),
|
480 |
("YMDEYR", "YMDEYR: Past-year major depressive episode"),
|
481 |
("YMDERSUD5ANY", "YMDERSUD5ANY: MDE or SUD in past year?"),
|
|
|
487 |
cat1_inputs = []
|
488 |
for col, label_text in cat1_col_labels:
|
489 |
cat1_inputs.append(
|
490 |
+
gr.Dropdown(
|
491 |
+
choices=list(input_mapping[col].keys()),
|
492 |
+
label=label_text
|
493 |
+
)
|
494 |
)
|
495 |
|
496 |
+
# Category 2: Mental Health Treatment & Professional Consultation (11 features)
|
497 |
gr.Markdown("#### 2. Mental Health Treatment & Professional Consultation")
|
498 |
cat2_col_labels = [
|
499 |
("YMDEHPO", "YMDEHPO: Saw health prof only for MDE"),
|
|
|
511 |
cat2_inputs = []
|
512 |
for col, label_text in cat2_col_labels:
|
513 |
cat2_inputs.append(
|
514 |
+
gr.Dropdown(
|
515 |
+
choices=list(input_mapping[col].keys()),
|
516 |
+
label=label_text
|
517 |
+
)
|
518 |
)
|
519 |
|
520 |
+
# Category 3: Functional & Cognitive Impairment (2 features)
|
521 |
gr.Markdown("#### 3. Functional & Cognitive Impairment")
|
522 |
cat3_col_labels = [
|
523 |
("MDEIMPY", "MDEIMPY: MDE with severe role impairment?"),
|
|
|
526 |
cat3_inputs = []
|
527 |
for col, label_text in cat3_col_labels:
|
528 |
cat3_inputs.append(
|
529 |
+
gr.Dropdown(
|
530 |
+
choices=list(input_mapping[col].keys()),
|
531 |
+
label=label_text
|
532 |
+
)
|
533 |
)
|
534 |
|
535 |
+
# Category 4: Suicidal Thoughts & Behaviors (4 features)
|
536 |
gr.Markdown("#### 4. Suicidal Thoughts & Behaviors")
|
537 |
cat4_col_labels = [
|
538 |
("YUSUITHK", "YUSUITHK: Thought of killing self (past 12 months)?"),
|
|
|
543 |
cat4_inputs = []
|
544 |
for col, label_text in cat4_col_labels:
|
545 |
cat4_inputs.append(
|
546 |
+
gr.Dropdown(
|
547 |
+
choices=list(input_mapping[col].keys()),
|
548 |
+
label=label_text
|
549 |
+
)
|
550 |
)
|
551 |
|
552 |
+
# Combine all inputs in the correct order
|
553 |
all_inputs = cat1_inputs + cat2_inputs + cat3_inputs + cat4_inputs
|
554 |
|
555 |
+
# Output components
|
556 |
predict_btn = gr.Button("Predict")
|
557 |
|
558 |
out_pred_res = gr.Textbox(label="Prediction Results (with Probability)", lines=8)
|
|
|
562 |
out_bar_input= gr.Plot(label="Input Feature Counts")
|
563 |
out_bar_label= gr.Plot(label="Predicted Label Counts")
|
564 |
|
565 |
+
# Connect the predict button to the predict function
|
566 |
predict_btn.click(
|
567 |
fn=predict,
|
568 |
inputs=all_inputs,
|
|
|
579 |
# ======== TAB 2: Unified Distribution/Co-occurrence ========
|
580 |
with gr.Tab("Distribution/Co-occurrence"):
|
581 |
gr.Markdown("### Select 1 or 2 features + 1 label to see a bar chart.")
|
582 |
+
|
583 |
+
# Features can be selected from the dataset's columns
|
584 |
list_of_features = sorted(df.columns)
|
585 |
list_of_labels = sorted(predictor.prediction_map.keys())
|
586 |
|
|
|
602 |
outputs=combined_output
|
603 |
)
|
604 |
|
605 |
+
# Finally, launch the Gradio app
|
606 |
demo.launch()
|