pantdipendra
commited on
v2_seperate tabs categories in UI
Browse files
app.py
CHANGED
@@ -27,53 +27,27 @@ class ModelPredictor:
|
|
27 |
self.model_filenames = model_filenames
|
28 |
self.models = self.load_models()
|
29 |
|
30 |
-
# The map from each label column to the textual meaning for 0 or 1
|
31 |
-
# (Some columns also mention '2' as positive, so adapt as needed).
|
32 |
self.prediction_map = {
|
33 |
"YOWRCONC": ["Did not have difficulty concentrating", "Had difficulty concentrating"],
|
34 |
"YOSEEDOC": ["Did not feel the need to see a doctor", "Felt the need to see a doctor"],
|
35 |
"YOWRHRS": ["Did not have trouble sleeping", "Had trouble sleeping"],
|
36 |
"YO_MDEA5": ["Others did not notice restlessness/lethargy", "Others noticed restlessness/lethargy"],
|
37 |
"YOWRCHR": ["Did not feel so sad nothing could cheer up", "Felt so sad that nothing could cheer up"],
|
38 |
-
"YOWRLSIN": [
|
39 |
-
"Did not feel bored / lose interest",
|
40 |
-
"Felt bored / lost interest in enjoyable things"
|
41 |
-
],
|
42 |
"YODPPROB": ["No other problems for 2+ weeks", "Had other problems for 2+ weeks"],
|
43 |
"YOWRPROB": ["No 'worst time ever' feeling", "Had 'worst time ever' feeling"],
|
44 |
-
"YODPR2WK": [
|
45 |
-
"No periods with depressed feelings lasting 2+ weeks",
|
46 |
-
"Had depressed feelings for 2+ weeks"
|
47 |
-
],
|
48 |
"YOWRDEPR": ["Did not feel sad/depressed mostly everyday", "Felt sad/depressed mostly everyday"],
|
49 |
-
"YODPDISC": [
|
50 |
-
|
51 |
-
"Overall mood duration was sad/depressed"
|
52 |
-
],
|
53 |
-
"YOLOSEV": [
|
54 |
-
"Did not lose interest in activities",
|
55 |
-
"Lost interest in enjoyable things"
|
56 |
-
],
|
57 |
"YOWRDCSN": ["Was able to make decisions", "Was unable to make decisions"],
|
58 |
-
"YODSMMDE": [
|
59 |
-
|
60 |
-
|
61 |
-
],
|
62 |
-
"YO_MDEA3": [
|
63 |
-
"No changes in appetite/weight",
|
64 |
-
"Had changes in appetite or weight"
|
65 |
-
],
|
66 |
-
"YODPLSIN": ["Never lost interest / felt bored", "Lost interest / felt bored"],
|
67 |
"YOWRELES": ["Did not eat less than usual", "Ate less than usual"],
|
68 |
"YODSCEV": ["Fewer severe depression symptoms", "More severe depression symptoms"],
|
69 |
-
"YOPB2WK":
|
70 |
-
|
71 |
-
"Uneasy feelings lasting 2+ weeks"
|
72 |
-
],
|
73 |
-
"YO_MDEA2": [
|
74 |
-
"No physical/mental issues for 2+ weeks",
|
75 |
-
"Had physical/mental issues for 2+ weeks"
|
76 |
-
]
|
77 |
}
|
78 |
|
79 |
def load_models(self):
|
@@ -85,10 +59,6 @@ class ModelPredictor:
|
|
85 |
return loaded
|
86 |
|
87 |
def make_predictions(self, user_input: pd.DataFrame):
|
88 |
-
"""
|
89 |
-
Return a list of np.ndarrays, each of shape (1,) or (n_samples,),
|
90 |
-
one for each model in self.models, in the same order as model_filenames.
|
91 |
-
"""
|
92 |
predictions = []
|
93 |
for model in self.models:
|
94 |
out = model.predict(user_input)
|
@@ -96,19 +66,10 @@ class ModelPredictor:
|
|
96 |
return predictions
|
97 |
|
98 |
def get_majority_vote(self, predictions):
|
99 |
-
"""
|
100 |
-
Flatten all predictions from each model into a single array
|
101 |
-
and compute the most common value (mode).
|
102 |
-
"""
|
103 |
combined = np.concatenate(predictions)
|
104 |
return np.bincount(combined).argmax()
|
105 |
|
106 |
def evaluate_severity(self, count_ones: int) -> str:
|
107 |
-
"""
|
108 |
-
The user wanted a logic: if >=13 => Severe, >=9 => Moderate, >=5 => Low, else Very Low.
|
109 |
-
Here 'count_ones' is how many '1's (or '2's) across all model predictions.
|
110 |
-
Adjust logic if needed.
|
111 |
-
"""
|
112 |
if count_ones >= 13:
|
113 |
return "Mental Health Severity: Severe"
|
114 |
elif count_ones >= 9:
|
@@ -123,174 +84,175 @@ predictor = ModelPredictor(model_path, model_filenames)
|
|
123 |
|
124 |
|
125 |
######################################
|
126 |
-
# 3)
|
127 |
######################################
|
128 |
-
|
129 |
-
""
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
|
|
|
|
|
|
|
|
140 |
|
141 |
-
|
142 |
-
|
143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
|
|
|
|
|
|
|
|
151 |
|
152 |
-
|
153 |
-
|
154 |
-
neighbors = df.loc[nn_indices]
|
155 |
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
f"Distances Range: {dists[nn_indices].min():.2f} to {dists[nn_indices].max():.2f}",
|
162 |
-
""]
|
163 |
-
for label_col, label_map in predictor.prediction_map.items():
|
164 |
-
if label_col not in neighbors.columns:
|
165 |
-
continue # Not present in df
|
166 |
-
# Values among neighbors
|
167 |
-
vals = neighbors[label_col].value_counts().to_dict()
|
168 |
-
# Example: {0: 3, 1: 2}, or {2: 4, 1: 1}, etc.
|
169 |
-
line = f"{label_col} => "
|
170 |
-
parts = []
|
171 |
-
for val, count_ in vals.items():
|
172 |
-
# If we have a mapping, use it
|
173 |
-
if val in range(len(label_map)):
|
174 |
-
meaning = label_map[val]
|
175 |
-
parts.append(f"{count_} had {meaning}")
|
176 |
-
else:
|
177 |
-
parts.append(f"{count_} had numeric={val}")
|
178 |
-
line += "; ".join(parts)
|
179 |
-
summary_lines.append(line)
|
180 |
-
summary_lines.append("")
|
181 |
-
summary_text = "\n".join(summary_lines)
|
182 |
-
return summary_text
|
183 |
|
184 |
|
185 |
-
######################################
|
186 |
-
# 4) INPUT MAPPING
|
187 |
-
######################################
|
188 |
def validate_inputs(*args):
|
189 |
for arg in args:
|
190 |
if not arg: # empty or None
|
191 |
return False
|
192 |
return True
|
193 |
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
|
223 |
|
224 |
######################################
|
225 |
-
# 5) PREDICT FUNCTION
|
226 |
######################################
|
227 |
def predict(
|
228 |
-
#
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
|
|
|
|
|
|
|
|
|
|
234 |
):
|
235 |
-
# 1) Validate
|
236 |
if not validate_inputs(
|
237 |
-
YMDESUD5ANYO,
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
|
|
242 |
):
|
243 |
return (
|
244 |
"Please select all required fields.", # 1) Prediction Results
|
245 |
-
"Validation Error", # 2)
|
246 |
-
"No data", # 3) Total
|
247 |
-
"No nearest neighbors info", # 4)
|
248 |
-
None, # 5) Bar
|
249 |
-
None # 6) Bar
|
250 |
)
|
251 |
|
252 |
-
#
|
253 |
user_input_dict = {
|
254 |
'YMDESUD5ANYO': input_mapping['YMDESUD5ANYO'][YMDESUD5ANYO],
|
255 |
-
'YMDEHPO': input_mapping['YMDEHPO'][YMDEHPO],
|
256 |
-
'YMDETXRX': input_mapping['YMDETXRX'][YMDETXRX],
|
257 |
-
'LVLDIFMEM2': input_mapping['LVLDIFMEM2'][LVLDIFMEM2],
|
258 |
-
'YMSUD5YANY': input_mapping['YMSUD5YANY'][YMSUD5YANY],
|
259 |
-
'YPSY2MDE': input_mapping['YPSY2MDE'][YPSY2MDE],
|
260 |
'YMDELT': input_mapping['YMDELT'][YMDELT],
|
261 |
-
'
|
|
|
|
|
|
|
|
|
262 |
'YMIMI5YANY': input_mapping['YMIMI5YANY'][YMIMI5YANY],
|
|
|
|
|
|
|
263 |
'YMDEHARX': input_mapping['YMDEHARX'][YMDEHARX],
|
264 |
-
'MDEIMPY': input_mapping['MDEIMPY'][MDEIMPY],
|
265 |
'YRXMDEYR': input_mapping['YRXMDEYR'][YRXMDEYR],
|
266 |
-
'YMDERSUD5ANY': input_mapping['YMDERSUD5ANY'][YMDERSUD5ANY],
|
267 |
-
'YMIMS5YANY': input_mapping['YMIMS5YANY'][YMIMS5YANY],
|
268 |
-
'YMDEYR': input_mapping['YMDEYR'][YMDEYR],
|
269 |
'YHLTMDE': input_mapping['YHLTMDE'][YHLTMDE],
|
270 |
-
'YUSUIPLNYR': input_mapping['YUSUIPLNYR'][YUSUIPLNYR],
|
271 |
-
'YMDEHPRX': input_mapping['YMDEHPRX'][YMDEHPRX],
|
272 |
-
'YUSUIPLN': input_mapping['YUSUIPLN'][YUSUIPLN],
|
273 |
-
'YPSY1MDE': input_mapping['YPSY1MDE'][YPSY1MDE],
|
274 |
-
'YMIUD5YANY': input_mapping['YMIUD5YANY'][YMIUD5YANY],
|
275 |
-
'YUSUITHK': input_mapping['YUSUITHK'][YUSUITHK],
|
276 |
'YTXMDEYR': input_mapping['YTXMDEYR'][YTXMDEYR],
|
|
|
|
|
|
|
277 |
'YCOUNMDE': input_mapping['YCOUNMDE'][YCOUNMDE],
|
278 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
279 |
}
|
280 |
user_df = pd.DataFrame(user_input_dict, index=[0])
|
281 |
|
282 |
-
#
|
283 |
-
predictions = predictor.make_predictions(user_df)
|
284 |
-
# e.g. predictions[i][0] is the predicted label for model i
|
285 |
-
# Flatten them for counting ones
|
286 |
all_preds = np.concatenate(predictions)
|
287 |
-
# In your logic, "1" might be a positive class, or "2" might be. Adapt if needed:
|
288 |
-
# For now, let's assume "1" is the relevant "positive" count:
|
289 |
count_ones = sum(all_preds == 1)
|
290 |
-
|
291 |
severity_msg = predictor.evaluate_severity(count_ones)
|
292 |
|
293 |
-
#
|
294 |
groups = {
|
295 |
"Concentration_and_Decision_Making": ["YOWRCONC", "YOWRDCSN"],
|
296 |
"Sleep_and_Energy_Levels": ["YOWRHRS", "YO_MDEA5", "YOWRELES", "YO_MDEA2"],
|
@@ -302,20 +264,20 @@ def predict(
|
|
302 |
"YODPPROB", "YOWRPROB", "YODPR2WK", "YODSMMDE", "YOPB2WK"
|
303 |
]
|
304 |
}
|
305 |
-
# Build text for each label in the order they appear in model_filenames
|
306 |
group_text = {g: [] for g in groups}
|
|
|
307 |
for i, arr in enumerate(predictions):
|
308 |
label_col = model_filenames[i].split('.')[0] # e.g. "YOWRCONC"
|
309 |
val = arr[0]
|
310 |
-
#
|
311 |
if label_col in predictor.prediction_map and val in range(len(predictor.prediction_map[label_col])):
|
312 |
text_label = predictor.prediction_map[label_col][val]
|
313 |
else:
|
314 |
text_label = f"Prediction={val}"
|
315 |
|
316 |
-
#
|
317 |
-
for group_name,
|
318 |
-
if label_col in
|
319 |
group_text[group_name].append(f"{label_col} => {text_label}")
|
320 |
break
|
321 |
|
@@ -325,20 +287,20 @@ def predict(
|
|
325 |
gtitle = gname.replace("_", " ")
|
326 |
final_str_parts.append(f"**{gtitle}**")
|
327 |
final_str_parts.append("\n".join(lines))
|
328 |
-
final_str_parts.append("")
|
329 |
if not final_str_parts:
|
330 |
final_str = "No predictions made or no matching group columns."
|
331 |
else:
|
332 |
final_str = "\n".join(final_str_parts)
|
333 |
|
334 |
-
#
|
335 |
total_count = len(df)
|
336 |
total_count_md = f"We have **{total_count}** patients in the dataset."
|
337 |
|
338 |
-
#
|
339 |
nn_md = get_nearest_neighbors_info(user_df, k=5)
|
340 |
|
341 |
-
#
|
342 |
input_counts = {}
|
343 |
for col, val_ in user_input_dict.items():
|
344 |
matched = len(df[df[col] == val_])
|
@@ -351,14 +313,12 @@ def predict(
|
|
351 |
)
|
352 |
fig_in.update_layout(width=1200, height=400)
|
353 |
|
354 |
-
#
|
355 |
-
# For each model’s label_col, see how many in df have the same predicted value
|
356 |
label_counts = {}
|
357 |
for i, arr in enumerate(predictions):
|
358 |
lbl = model_filenames[i].split('.')[0]
|
359 |
pred_val = arr[0]
|
360 |
if lbl in df.columns:
|
361 |
-
# How many in df have this same value
|
362 |
label_counts[lbl] = len(df[df[lbl] == pred_val])
|
363 |
if label_counts:
|
364 |
bar_lbl_df = pd.DataFrame({
|
@@ -379,8 +339,8 @@ def predict(
|
|
379 |
severity_msg, # 2) Mental Health Severity
|
380 |
total_count_md, # 3) Total Patient Count
|
381 |
nn_md, # 4) Nearest Neighbors Summary
|
382 |
-
fig_in, # 5) Bar Chart
|
383 |
-
fig_lbl # 6) Bar Chart
|
384 |
)
|
385 |
|
386 |
|
@@ -388,9 +348,6 @@ def predict(
|
|
388 |
# 6) EXTRA TABS / FUNCTIONS
|
389 |
######################################
|
390 |
def distribution_plot(feature_col, label_col):
|
391 |
-
"""
|
392 |
-
Creates a bar chart grouping by [feature_col, label_col], showing counts.
|
393 |
-
"""
|
394 |
if not feature_col or not label_col:
|
395 |
return px.bar(title="Please select both Feature and Label.")
|
396 |
if (feature_col not in df.columns) or (label_col not in df.columns):
|
@@ -409,9 +366,6 @@ def distribution_plot(feature_col, label_col):
|
|
409 |
|
410 |
|
411 |
def co_occurrence_plot(feature1, feature2, label_col):
|
412 |
-
"""
|
413 |
-
Similar approach but grouping by [feature1, feature2, label_col].
|
414 |
-
"""
|
415 |
if not feature1 or not feature2 or not label_col:
|
416 |
return px.bar(title="Please select all three fields.")
|
417 |
if feature1 not in df.columns or feature2 not in df.columns or label_col not in df.columns:
|
@@ -437,127 +391,104 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
|
|
437 |
|
438 |
# ======== TAB 1: PREDICTION ========
|
439 |
with gr.Tab("Prediction"):
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
),
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
),
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
),
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
),
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
)
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
)
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
),
|
478 |
-
|
479 |
-
|
480 |
-
label="YMDEHARX: Saw health professional & received medication for MDE?"
|
481 |
-
),
|
482 |
-
|
483 |
-
gr.Dropdown(
|
484 |
-
list(input_mapping['MDEIMPY'].keys()),
|
485 |
-
label="MDEIMPY: MDE with severe role impairment?"
|
486 |
-
),
|
487 |
-
gr.Dropdown(
|
488 |
-
list(input_mapping['YRXMDEYR'].keys()),
|
489 |
-
label="YRXMDEYR: Used received medication for MDE in past years?"
|
490 |
-
),
|
491 |
-
gr.Dropdown(
|
492 |
-
list(input_mapping['YMDERSUD5ANY'].keys()),
|
493 |
-
label="YMDERSUD5ANY: MDE or substance use disorder - past year?"
|
494 |
-
),
|
495 |
-
gr.Dropdown(
|
496 |
-
list(input_mapping['YMIMS5YANY'].keys()),
|
497 |
-
label="YMIMS5YANY: Past-year MDE + severe impairment + substance use?"
|
498 |
-
),
|
499 |
-
gr.Dropdown(
|
500 |
-
list(input_mapping['YMDEYR'].keys()),
|
501 |
-
label="YMDEYR: Past-year major depressive episode?"
|
502 |
-
),
|
503 |
-
|
504 |
-
gr.Dropdown(
|
505 |
-
list(input_mapping['YHLTMDE'].keys()),
|
506 |
-
label="YHLTMDE: Saw/talk to health professional about MDE in past year?"
|
507 |
-
),
|
508 |
-
gr.Dropdown(
|
509 |
-
list(input_mapping['YUSUIPLNYR'].keys()),
|
510 |
-
label="YUSUIPLNYR: Made plans to kill self in past year?"
|
511 |
-
),
|
512 |
-
gr.Dropdown(
|
513 |
-
list(input_mapping['YMDEHPRX'].keys()),
|
514 |
-
label="YMDEHPRX: Saw health prof or received med for MDE in past year?"
|
515 |
-
),
|
516 |
-
gr.Dropdown(
|
517 |
-
list(input_mapping['YUSUIPLN'].keys()),
|
518 |
-
label="YUSUIPLN: Make plans to kill yourself in past 12 months?"
|
519 |
-
),
|
520 |
-
gr.Dropdown(
|
521 |
-
list(input_mapping['YPSY1MDE'].keys()),
|
522 |
-
label="YPSY1MDE: Saw/talked to psychologist about MDE in past year?"
|
523 |
-
),
|
524 |
-
|
525 |
-
gr.Dropdown(
|
526 |
-
list(input_mapping['YMIUD5YANY'].keys()),
|
527 |
-
label="YMIUD5YANY: Past-year MDE & illicit drug use disorder?"
|
528 |
-
),
|
529 |
-
gr.Dropdown(
|
530 |
-
list(input_mapping['YUSUITHK'].keys()),
|
531 |
-
label="YUSUITHK: Youth seriously think about killing self in past 12 months?"
|
532 |
-
),
|
533 |
-
gr.Dropdown(
|
534 |
-
list(input_mapping['YTXMDEYR'].keys()),
|
535 |
-
label="YTXMDEYR: Saw or talk to doc/health prof for MDE in past year?"
|
536 |
-
),
|
537 |
-
gr.Dropdown(
|
538 |
-
list(input_mapping['YCOUNMDE'].keys()),
|
539 |
-
label="YCOUNMDE: Saw/talk to counselor about MDE in past year?"
|
540 |
-
),
|
541 |
-
gr.Dropdown(
|
542 |
-
list(input_mapping['YUSUITHKYR'].keys()),
|
543 |
-
label="YUSUITHKYR: Seriously thought about killing self?"
|
544 |
-
)
|
545 |
]
|
546 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
547 |
predict_btn = gr.Button("Predict")
|
548 |
|
549 |
-
# 6 outputs
|
550 |
out_pred_res = gr.Textbox(label="Prediction Results", lines=8)
|
551 |
out_sev = gr.Textbox(label="Mental Health Severity", lines=2)
|
552 |
out_count = gr.Markdown(label="Total Patient Count")
|
553 |
-
out_nn = gr.Markdown(label="Nearest Neighbors Summary")
|
554 |
out_bar_input= gr.Plot(label="Input Feature Counts")
|
555 |
out_bar_label= gr.Plot(label="Predicted Label Counts")
|
556 |
|
557 |
-
# Wire up the button
|
558 |
predict_btn.click(
|
559 |
fn=predict,
|
560 |
-
inputs=
|
561 |
outputs=[
|
562 |
out_pred_res, # 1
|
563 |
out_sev, # 2
|
@@ -571,10 +502,7 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
|
|
571 |
# ======== TAB 2: Distribution Analysis ========
|
572 |
with gr.Tab("Distribution Analysis"):
|
573 |
gr.Markdown("## Distribution Plot\nSelect one feature and one label column to see bar counts.")
|
574 |
-
# 1) We gather the 'input features' from input_mapping keys:
|
575 |
list_of_features = sorted(input_mapping.keys())
|
576 |
-
|
577 |
-
# 2) We gather the 'label columns' from predictor.prediction_map keys:
|
578 |
list_of_labels = sorted(predictor.prediction_map.keys())
|
579 |
|
580 |
feat_dd = gr.Dropdown(choices=list_of_features, label="Feature Column")
|
@@ -606,5 +534,5 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
|
|
606 |
outputs=co_occ_output
|
607 |
)
|
608 |
|
609 |
-
# Finally, launch
|
610 |
demo.launch()
|
|
|
27 |
self.model_filenames = model_filenames
|
28 |
self.models = self.load_models()
|
29 |
|
|
|
|
|
30 |
self.prediction_map = {
|
31 |
"YOWRCONC": ["Did not have difficulty concentrating", "Had difficulty concentrating"],
|
32 |
"YOSEEDOC": ["Did not feel the need to see a doctor", "Felt the need to see a doctor"],
|
33 |
"YOWRHRS": ["Did not have trouble sleeping", "Had trouble sleeping"],
|
34 |
"YO_MDEA5": ["Others did not notice restlessness/lethargy", "Others noticed restlessness/lethargy"],
|
35 |
"YOWRCHR": ["Did not feel so sad nothing could cheer up", "Felt so sad that nothing could cheer up"],
|
36 |
+
"YOWRLSIN": ["Did not feel bored / lose interest", "Felt bored / lost interest"],
|
|
|
|
|
|
|
37 |
"YODPPROB": ["No other problems for 2+ weeks", "Had other problems for 2+ weeks"],
|
38 |
"YOWRPROB": ["No 'worst time ever' feeling", "Had 'worst time ever' feeling"],
|
39 |
+
"YODPR2WK": ["No periods with depressed feelings lasting 2+ weeks", "Had depressed feelings 2+ weeks"],
|
|
|
|
|
|
|
40 |
"YOWRDEPR": ["Did not feel sad/depressed mostly everyday", "Felt sad/depressed mostly everyday"],
|
41 |
+
"YODPDISC": ["Overall mood not sad/depressed", "Overall mood was sad/depressed"],
|
42 |
+
"YOLOSEV": ["Did not lose interest", "Lost interest in enjoyable things"],
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
"YOWRDCSN": ["Was able to make decisions", "Was unable to make decisions"],
|
44 |
+
"YODSMMDE": ["Never had 2 weeks depression symptoms", "Had 2+ weeks of depression symptoms"],
|
45 |
+
"YO_MDEA3": ["No changes in appetite/weight", "Had changes in appetite/weight"],
|
46 |
+
"YODPLSIN": ["Never lost interest / felt bored", "Lost interest/felt bored"],
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
"YOWRELES": ["Did not eat less than usual", "Ate less than usual"],
|
48 |
"YODSCEV": ["Fewer severe depression symptoms", "More severe depression symptoms"],
|
49 |
+
"YOPB2WK": ["No uneasy feelings lasting 2+ weeks", "Uneasy feelings lasting 2+ weeks"],
|
50 |
+
"YO_MDEA2": ["No physical/mental issues (2+ weeks)", "Had physical/mental issues (2+ weeks)"]
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
}
|
52 |
|
53 |
def load_models(self):
|
|
|
59 |
return loaded
|
60 |
|
61 |
def make_predictions(self, user_input: pd.DataFrame):
|
|
|
|
|
|
|
|
|
62 |
predictions = []
|
63 |
for model in self.models:
|
64 |
out = model.predict(user_input)
|
|
|
66 |
return predictions
|
67 |
|
68 |
def get_majority_vote(self, predictions):
|
|
|
|
|
|
|
|
|
69 |
combined = np.concatenate(predictions)
|
70 |
return np.bincount(combined).argmax()
|
71 |
|
72 |
def evaluate_severity(self, count_ones: int) -> str:
|
|
|
|
|
|
|
|
|
|
|
73 |
if count_ones >= 13:
|
74 |
return "Mental Health Severity: Severe"
|
75 |
elif count_ones >= 9:
|
|
|
84 |
|
85 |
|
86 |
######################################
|
87 |
+
# 3) FEATURE CATEGORIES + MAPPING
|
88 |
######################################
|
89 |
+
categories_dict = {
|
90 |
+
"1. Depression & Substance Use Diagnosis": [
|
91 |
+
"YMDESUD5ANYO", "YMDELT", "YMDEYR", "YMDERSUD5ANY",
|
92 |
+
"YMSUD5YANY", "YMIUD5YANY", "YMIMS5YANY", "YMIMI5YANY"
|
93 |
+
],
|
94 |
+
"2. Mental Health Treatment & Prof Consultation": [
|
95 |
+
"YMDEHPO", "YMDETXRX", "YMDEHARX", "YRXMDEYR", "YHLTMDE",
|
96 |
+
"YTXMDEYR", "YDOCMDE", "YPSY2MDE", "YPSY1MDE", "YCOUNMDE"
|
97 |
+
],
|
98 |
+
"3. Functional & Cognitive Impairment": [
|
99 |
+
"MDEIMPY", "LVLDIFMEM2"
|
100 |
+
],
|
101 |
+
"4. Suicidal Thoughts & Behaviors": [
|
102 |
+
"YUSUITHK", "YUSUITHKYR", "YUSUIPLNYR", "YUSUIPLN"
|
103 |
+
]
|
104 |
+
}
|
105 |
|
106 |
+
# The numeric mappings for each of the 25 features
|
107 |
+
input_mapping = {
|
108 |
+
'YMDESUD5ANYO': {"SUD only, no MDE": 1, "MDE only, no SUD": 2, "SUD and MDE": 3, "Neither SUD or MDE": 4},
|
109 |
+
'YMDELT': {"Yes": 1, "No": 2},
|
110 |
+
'YMDEYR': {"Yes": 1, "No": 2},
|
111 |
+
'YMDERSUD5ANY': {"Yes": 1, "No": 0},
|
112 |
+
'YMSUD5YANY': {"Yes": 1, "No": 0},
|
113 |
+
'YMIUD5YANY': {"Yes": 1, "No": 0},
|
114 |
+
'YMIMS5YANY': {"Yes": 1, "No": 0},
|
115 |
+
'YMIMI5YANY': {"Yes": 1, "No": 0},
|
116 |
|
117 |
+
'YMDEHPO': {"Yes": 1, "No": 0},
|
118 |
+
'YMDETXRX': {"Yes": 1, "No": 0},
|
119 |
+
'YMDEHARX': {"Yes": 1, "No": 0},
|
120 |
+
'YRXMDEYR': {"Yes": 1, "No": 0},
|
121 |
+
'YHLTMDE': {"Yes": 1, "No": 0},
|
122 |
+
'YTXMDEYR': {"Yes": 1, "No": 0},
|
123 |
+
'YDOCMDE': {"Yes": 1, "No": 0},
|
124 |
+
'YPSY2MDE': {"Yes": 1, "No": 0},
|
125 |
+
'YPSY1MDE': {"Yes": 1, "No": 0},
|
126 |
+
'YCOUNMDE': {"Yes": 1, "No": 0},
|
127 |
|
128 |
+
'MDEIMPY': {"Yes": 1, "No": 2},
|
129 |
+
'LVLDIFMEM2': {"No Difficulty": 1, "Some difficulty": 2, "A lot of difficulty or cannot do at all": 3},
|
|
|
130 |
|
131 |
+
'YUSUITHK': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
132 |
+
'YUSUITHKYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
133 |
+
'YUSUIPLNYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
134 |
+
'YUSUIPLN': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
|
135 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
|
|
|
|
|
|
|
138 |
def validate_inputs(*args):
|
139 |
for arg in args:
|
140 |
if not arg: # empty or None
|
141 |
return False
|
142 |
return True
|
143 |
|
144 |
+
|
145 |
+
######################################
|
146 |
+
# 4) NEAREST NEIGHBORS (Grouped)
|
147 |
+
######################################
|
148 |
+
def get_nearest_neighbors_info(user_input_df: pd.DataFrame, k=5):
|
149 |
+
# Ensure columns exist in df
|
150 |
+
user_cols = user_input_df.columns
|
151 |
+
if not all(col in df.columns for col in user_cols):
|
152 |
+
return "Cannot compute nearest neighbors. Some columns not found in df."
|
153 |
+
|
154 |
+
# Subset df
|
155 |
+
sub_df = df[list(user_cols)].copy()
|
156 |
+
diffs = sub_df - user_input_df.iloc[0]
|
157 |
+
dists = (diffs**2).sum(axis=1)**0.5
|
158 |
+
nn_indices = dists.nsmallest(k).index
|
159 |
+
neighbors = df.loc[nn_indices]
|
160 |
+
|
161 |
+
lines = [f"**Nearest Neighbors (k={k})**",
|
162 |
+
f"Distances Range: {dists[nn_indices].min():.2f} to {dists[nn_indices].max():.2f}",
|
163 |
+
""]
|
164 |
+
|
165 |
+
# Group the features by our categories_dict
|
166 |
+
for cat_name, cat_feats in categories_dict.items():
|
167 |
+
lines.append(f"### {cat_name}")
|
168 |
+
for feat in cat_feats:
|
169 |
+
if feat not in neighbors.columns:
|
170 |
+
continue
|
171 |
+
# Count how many neighbors had each numeric value
|
172 |
+
val_counts = neighbors[feat].value_counts().to_dict()
|
173 |
+
# Build string like: "YMDESUD5ANYO => 3 had 1, 2 had 2..."
|
174 |
+
parts = []
|
175 |
+
for val_, count_ in val_counts.items():
|
176 |
+
parts.append(f"{count_} had '{val_}'")
|
177 |
+
joined = "; ".join(parts)
|
178 |
+
lines.append(f"**{feat}** => {joined}")
|
179 |
+
lines.append("") # blank line
|
180 |
+
|
181 |
+
return "\n".join(lines)
|
182 |
|
183 |
|
184 |
######################################
|
185 |
+
# 5) PREDICT FUNCTION
|
186 |
######################################
|
187 |
def predict(
|
188 |
+
# EXACTLY 25 features, matching categories_dict ordering.
|
189 |
+
# We'll just list them in the dictionary order we want to show them:
|
190 |
+
YMDESUD5ANYO, YMDELT, YMDEYR, YMDERSUD5ANY,
|
191 |
+
YMSUD5YANY, YMIUD5YANY, YMIMS5YANY, YMIMI5YANY,
|
192 |
+
|
193 |
+
YMDEHPO, YMDETXRX, YMDEHARX, YRXMDEYR, YHLTMDE,
|
194 |
+
YTXMDEYR, YDOCMDE, YPSY2MDE, YPSY1MDE, YCOUNMDE,
|
195 |
+
|
196 |
+
MDEIMPY, LVLDIFMEM2,
|
197 |
+
|
198 |
+
YUSUITHK, YUSUITHKYR, YUSUIPLNYR, YUSUIPLN
|
199 |
):
|
|
|
200 |
if not validate_inputs(
|
201 |
+
YMDESUD5ANYO, YMDELT, YMDEYR, YMDERSUD5ANY,
|
202 |
+
YMSUD5YANY, YMIUD5YANY, YMIMS5YANY, YMIMI5YANY,
|
203 |
+
YMDEHPO, YMDETXRX, YMDEHARX, YRXMDEYR, YHLTMDE,
|
204 |
+
YTXMDEYR, YDOCMDE, YPSY2MDE, YPSY1MDE, YCOUNMDE,
|
205 |
+
MDEIMPY, LVLDIFMEM2,
|
206 |
+
YUSUITHK, YUSUITHKYR, YUSUIPLNYR, YUSUIPLN
|
207 |
):
|
208 |
return (
|
209 |
"Please select all required fields.", # 1) Prediction Results
|
210 |
+
"Validation Error", # 2) Severity
|
211 |
+
"No data", # 3) Total Count
|
212 |
+
"No nearest neighbors info", # 4) NN Summary
|
213 |
+
None, # 5) Bar chart (Input)
|
214 |
+
None # 6) Bar chart (Labels)
|
215 |
)
|
216 |
|
217 |
+
# 1) Map user-friendly -> numeric
|
218 |
user_input_dict = {
|
219 |
'YMDESUD5ANYO': input_mapping['YMDESUD5ANYO'][YMDESUD5ANYO],
|
|
|
|
|
|
|
|
|
|
|
220 |
'YMDELT': input_mapping['YMDELT'][YMDELT],
|
221 |
+
'YMDEYR': input_mapping['YMDEYR'][YMDEYR],
|
222 |
+
'YMDERSUD5ANY': input_mapping['YMDERSUD5ANY'][YMDERSUD5ANY],
|
223 |
+
'YMSUD5YANY': input_mapping['YMSUD5YANY'][YMSUD5YANY],
|
224 |
+
'YMIUD5YANY': input_mapping['YMIUD5YANY'][YMIUD5YANY],
|
225 |
+
'YMIMS5YANY': input_mapping['YMIMS5YANY'][YMIMS5YANY],
|
226 |
'YMIMI5YANY': input_mapping['YMIMI5YANY'][YMIMI5YANY],
|
227 |
+
|
228 |
+
'YMDEHPO': input_mapping['YMDEHPO'][YMDEHPO],
|
229 |
+
'YMDETXRX': input_mapping['YMDETXRX'][YMDETXRX],
|
230 |
'YMDEHARX': input_mapping['YMDEHARX'][YMDEHARX],
|
|
|
231 |
'YRXMDEYR': input_mapping['YRXMDEYR'][YRXMDEYR],
|
|
|
|
|
|
|
232 |
'YHLTMDE': input_mapping['YHLTMDE'][YHLTMDE],
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
'YTXMDEYR': input_mapping['YTXMDEYR'][YTXMDEYR],
|
234 |
+
'YDOCMDE': input_mapping['YDOCMDE'][YDOCMDE],
|
235 |
+
'YPSY2MDE': input_mapping['YPSY2MDE'][YPSY2MDE],
|
236 |
+
'YPSY1MDE': input_mapping['YPSY1MDE'][YPSY1MDE],
|
237 |
'YCOUNMDE': input_mapping['YCOUNMDE'][YCOUNMDE],
|
238 |
+
|
239 |
+
'MDEIMPY': input_mapping['MDEIMPY'][MDEIMPY],
|
240 |
+
'LVLDIFMEM2': input_mapping['LVLDIFMEM2'][LVLDIFMEM2],
|
241 |
+
|
242 |
+
'YUSUITHK': input_mapping['YUSUITHK'][YUSUITHK],
|
243 |
+
'YUSUITHKYR': input_mapping['YUSUITHKYR'][YUSUITHKYR],
|
244 |
+
'YUSUIPLNYR': input_mapping['YUSUIPLNYR'][YUSUIPLNYR],
|
245 |
+
'YUSUIPLN': input_mapping['YUSUIPLN'][YUSUIPLN]
|
246 |
}
|
247 |
user_df = pd.DataFrame(user_input_dict, index=[0])
|
248 |
|
249 |
+
# 2) Predict
|
250 |
+
predictions = predictor.make_predictions(user_df)
|
|
|
|
|
251 |
all_preds = np.concatenate(predictions)
|
|
|
|
|
252 |
count_ones = sum(all_preds == 1)
|
|
|
253 |
severity_msg = predictor.evaluate_severity(count_ones)
|
254 |
|
255 |
+
# 3) Grouped textual results
|
256 |
groups = {
|
257 |
"Concentration_and_Decision_Making": ["YOWRCONC", "YOWRDCSN"],
|
258 |
"Sleep_and_Energy_Levels": ["YOWRHRS", "YO_MDEA5", "YOWRELES", "YO_MDEA2"],
|
|
|
264 |
"YODPPROB", "YOWRPROB", "YODPR2WK", "YODSMMDE", "YOPB2WK"
|
265 |
]
|
266 |
}
|
|
|
267 |
group_text = {g: [] for g in groups}
|
268 |
+
# The model_filenames order determines which label is i
|
269 |
for i, arr in enumerate(predictions):
|
270 |
label_col = model_filenames[i].split('.')[0] # e.g. "YOWRCONC"
|
271 |
val = arr[0]
|
272 |
+
# If we have a textual map, use it
|
273 |
if label_col in predictor.prediction_map and val in range(len(predictor.prediction_map[label_col])):
|
274 |
text_label = predictor.prediction_map[label_col][val]
|
275 |
else:
|
276 |
text_label = f"Prediction={val}"
|
277 |
|
278 |
+
# Put in whichever group
|
279 |
+
for group_name, cols_ in groups.items():
|
280 |
+
if label_col in cols_:
|
281 |
group_text[group_name].append(f"{label_col} => {text_label}")
|
282 |
break
|
283 |
|
|
|
287 |
gtitle = gname.replace("_", " ")
|
288 |
final_str_parts.append(f"**{gtitle}**")
|
289 |
final_str_parts.append("\n".join(lines))
|
290 |
+
final_str_parts.append("")
|
291 |
if not final_str_parts:
|
292 |
final_str = "No predictions made or no matching group columns."
|
293 |
else:
|
294 |
final_str = "\n".join(final_str_parts)
|
295 |
|
296 |
+
# 4) Additional info
|
297 |
total_count = len(df)
|
298 |
total_count_md = f"We have **{total_count}** patients in the dataset."
|
299 |
|
300 |
+
# 5) Nearest Neighbors
|
301 |
nn_md = get_nearest_neighbors_info(user_df, k=5)
|
302 |
|
303 |
+
# 6) Bar chart for input features
|
304 |
input_counts = {}
|
305 |
for col, val_ in user_input_dict.items():
|
306 |
matched = len(df[df[col] == val_])
|
|
|
313 |
)
|
314 |
fig_in.update_layout(width=1200, height=400)
|
315 |
|
316 |
+
# 7) Bar chart for predicted labels
|
|
|
317 |
label_counts = {}
|
318 |
for i, arr in enumerate(predictions):
|
319 |
lbl = model_filenames[i].split('.')[0]
|
320 |
pred_val = arr[0]
|
321 |
if lbl in df.columns:
|
|
|
322 |
label_counts[lbl] = len(df[df[lbl] == pred_val])
|
323 |
if label_counts:
|
324 |
bar_lbl_df = pd.DataFrame({
|
|
|
339 |
severity_msg, # 2) Mental Health Severity
|
340 |
total_count_md, # 3) Total Patient Count
|
341 |
nn_md, # 4) Nearest Neighbors Summary
|
342 |
+
fig_in, # 5) Bar Chart (input features)
|
343 |
+
fig_lbl # 6) Bar Chart (labels)
|
344 |
)
|
345 |
|
346 |
|
|
|
348 |
# 6) EXTRA TABS / FUNCTIONS
|
349 |
######################################
|
350 |
def distribution_plot(feature_col, label_col):
|
|
|
|
|
|
|
351 |
if not feature_col or not label_col:
|
352 |
return px.bar(title="Please select both Feature and Label.")
|
353 |
if (feature_col not in df.columns) or (label_col not in df.columns):
|
|
|
366 |
|
367 |
|
368 |
def co_occurrence_plot(feature1, feature2, label_col):
|
|
|
|
|
|
|
369 |
if not feature1 or not feature2 or not label_col:
|
370 |
return px.bar(title="Please select all three fields.")
|
371 |
if feature1 not in df.columns or feature2 not in df.columns or label_col not in df.columns:
|
|
|
391 |
|
392 |
# ======== TAB 1: PREDICTION ========
|
393 |
with gr.Tab("Prediction"):
|
394 |
+
gr.Markdown(
|
395 |
+
"""
|
396 |
+
### Please provide inputs in each of the four categories below.
|
397 |
+
*All fields are required.*
|
398 |
+
"""
|
399 |
+
)
|
400 |
+
|
401 |
+
# For clarity, we define an ordered list of the features in the exact sequence
|
402 |
+
# matching our predict() function. We’ll group them under the same headings.
|
403 |
+
cat1_col_labels = [
|
404 |
+
("YMDESUD5ANYO", "YMDESUD5ANYO: ONLY MDE, ONLY SUD, BOTH, OR NEITHER"),
|
405 |
+
("YMDELT", "YMDELT: Had major depressive episode in lifetime"),
|
406 |
+
("YMDEYR", "YMDEYR: Past-year major depressive episode"),
|
407 |
+
("YMDERSUD5ANY", "YMDERSUD5ANY: MDE or substance use disorder - past year"),
|
408 |
+
("YMSUD5YANY", "YMSUD5YANY: Past-year MDE & substance use disorder"),
|
409 |
+
("YMIUD5YANY", "YMIUD5YANY: Past-year MDE & illicit drug use disorder"),
|
410 |
+
("YMIMS5YANY", "YMIMS5YANY: Past-year MDE + severe impairment + substance use"),
|
411 |
+
("YMIMI5YANY", "YMIMI5YANY: Past-year MDE with severe impairment & illicit drug use")
|
412 |
+
]
|
413 |
+
cat2_col_labels = [
|
414 |
+
("YMDEHPO", "YMDEHPO: Saw health prof only for MDE in past year"),
|
415 |
+
("YMDETXRX", "YMDETXRX: Received treatment/counseling if saw doc/prof for MDE"),
|
416 |
+
("YMDEHARX", "YMDEHARX: Saw health professional & received medication for MDE"),
|
417 |
+
("YRXMDEYR", "YRXMDEYR: Used received medication for MDE in past years"),
|
418 |
+
("YHLTMDE", "YHLTMDE: Saw/talked to health professional about MDE in past year"),
|
419 |
+
("YTXMDEYR", "YTXMDEYR: Saw or talked to doc/health prof for MDE in past year"),
|
420 |
+
("YDOCMDE", "YDOCMDE: Saw/talked to general practitioner/family MD about MDE"),
|
421 |
+
("YPSY2MDE", "YPSY2MDE: Saw/talked to psychiatrist about MDE"),
|
422 |
+
("YPSY1MDE", "YPSY1MDE: Saw/talked to psychologist about MDE"),
|
423 |
+
("YCOUNMDE", "YCOUNMDE: Saw/talked to counselor about MDE")
|
424 |
+
]
|
425 |
+
cat3_col_labels = [
|
426 |
+
("MDEIMPY", "MDEIMPY: MDE with severe role impairment"),
|
427 |
+
("LVLDIFMEM2", "LVLDIFMEM2: Difficulty remembering/concentrating")
|
428 |
+
]
|
429 |
+
cat4_col_labels = [
|
430 |
+
("YUSUITHK", "YUSUITHK: Youth seriously think about killing self in past 12 months"),
|
431 |
+
("YUSUITHKYR", "YUSUITHKYR: Seriously thought about killing self"),
|
432 |
+
("YUSUIPLNYR", "YUSUIPLNYR: Made plans to kill self in past year"),
|
433 |
+
("YUSUIPLN", "YUSUIPLN: Made plans to kill yourself in past 12 months")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
434 |
]
|
435 |
|
436 |
+
# Category 1 block
|
437 |
+
gr.Markdown("#### 1. Depression & Substance Use Diagnosis")
|
438 |
+
cat1_inputs = []
|
439 |
+
for col, label_text in cat1_col_labels:
|
440 |
+
dd = gr.Dropdown(
|
441 |
+
choices=list(input_mapping[col].keys()),
|
442 |
+
label=label_text
|
443 |
+
)
|
444 |
+
cat1_inputs.append(dd)
|
445 |
+
|
446 |
+
# Category 2 block
|
447 |
+
gr.Markdown("#### 2. Mental Health Treatment & Professional Consultation")
|
448 |
+
cat2_inputs = []
|
449 |
+
for col, label_text in cat2_col_labels:
|
450 |
+
dd = gr.Dropdown(
|
451 |
+
choices=list(input_mapping[col].keys()),
|
452 |
+
label=label_text
|
453 |
+
)
|
454 |
+
cat2_inputs.append(dd)
|
455 |
+
|
456 |
+
# Category 3 block
|
457 |
+
gr.Markdown("#### 3. Functional & Cognitive Impairment")
|
458 |
+
cat3_inputs = []
|
459 |
+
for col, label_text in cat3_col_labels:
|
460 |
+
dd = gr.Dropdown(
|
461 |
+
choices=list(input_mapping[col].keys()),
|
462 |
+
label=label_text
|
463 |
+
)
|
464 |
+
cat3_inputs.append(dd)
|
465 |
+
|
466 |
+
# Category 4 block
|
467 |
+
gr.Markdown("#### 4. Suicidal Thoughts & Behaviors")
|
468 |
+
cat4_inputs = []
|
469 |
+
for col, label_text in cat4_col_labels:
|
470 |
+
dd = gr.Dropdown(
|
471 |
+
choices=list(input_mapping[col].keys()),
|
472 |
+
label=label_text
|
473 |
+
)
|
474 |
+
cat4_inputs.append(dd)
|
475 |
+
|
476 |
+
# The overall input list must match the order in `predict()`
|
477 |
+
all_inputs = cat1_inputs + cat2_inputs + cat3_inputs + cat4_inputs
|
478 |
+
|
479 |
predict_btn = gr.Button("Predict")
|
480 |
|
481 |
+
# 6 outputs
|
482 |
out_pred_res = gr.Textbox(label="Prediction Results", lines=8)
|
483 |
out_sev = gr.Textbox(label="Mental Health Severity", lines=2)
|
484 |
out_count = gr.Markdown(label="Total Patient Count")
|
485 |
+
out_nn = gr.Markdown(label="Nearest Neighbors Summary (Grouped by Category)")
|
486 |
out_bar_input= gr.Plot(label="Input Feature Counts")
|
487 |
out_bar_label= gr.Plot(label="Predicted Label Counts")
|
488 |
|
|
|
489 |
predict_btn.click(
|
490 |
fn=predict,
|
491 |
+
inputs=all_inputs,
|
492 |
outputs=[
|
493 |
out_pred_res, # 1
|
494 |
out_sev, # 2
|
|
|
502 |
# ======== TAB 2: Distribution Analysis ========
|
503 |
with gr.Tab("Distribution Analysis"):
|
504 |
gr.Markdown("## Distribution Plot\nSelect one feature and one label column to see bar counts.")
|
|
|
505 |
list_of_features = sorted(input_mapping.keys())
|
|
|
|
|
506 |
list_of_labels = sorted(predictor.prediction_map.keys())
|
507 |
|
508 |
feat_dd = gr.Dropdown(choices=list_of_features, label="Feature Column")
|
|
|
534 |
outputs=co_occ_output
|
535 |
)
|
536 |
|
537 |
+
# Finally, launch
|
538 |
demo.launch()
|