pantdipendra
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
-
import gradio as gr
|
2 |
import pickle
|
|
|
|
|
3 |
import numpy as np
|
4 |
import pandas as pd
|
5 |
import plotly.express as px
|
@@ -7,15 +8,13 @@ import plotly.express as px
|
|
7 |
# Load the training CSV once (outside the functions so it is read only once).
|
8 |
df = pd.read_csv("X_train_Y_Train_merged_train.csv")
|
9 |
|
10 |
-
##############################################################################
|
11 |
-
# MODEL PREDICTOR CLASS
|
12 |
-
##############################################################################
|
13 |
-
|
14 |
class ModelPredictor:
|
15 |
def __init__(self, model_path, model_filenames):
|
16 |
self.model_path = model_path
|
17 |
self.model_filenames = model_filenames
|
18 |
self.models = self.load_models()
|
|
|
|
|
19 |
self.prediction_map = {
|
20 |
"YOWRCONC": ["Did not have difficulty concentrating", "Had difficulty concentrating"],
|
21 |
"YOSEEDOC": ["Did not feel the need to see a doctor", "Felt the need to see a doctor"],
|
@@ -57,7 +56,10 @@ class ModelPredictor:
|
|
57 |
return models
|
58 |
|
59 |
def make_predictions(self, user_input):
|
60 |
-
"""
|
|
|
|
|
|
|
61 |
predictions = []
|
62 |
for model in self.models:
|
63 |
pred = model.predict(user_input)
|
@@ -68,13 +70,17 @@ class ModelPredictor:
|
|
68 |
def get_majority_vote(self, predictions):
|
69 |
"""
|
70 |
Flatten all predictions from all models, combine them into a single array,
|
71 |
-
then find the majority class (0 or 1).
|
72 |
"""
|
73 |
combined_predictions = np.concatenate(predictions)
|
74 |
majority_vote = np.bincount(combined_predictions).argmax()
|
75 |
return majority_vote
|
76 |
|
77 |
-
#
|
|
|
|
|
|
|
|
|
78 |
def evaluate_severity(self, majority_vote_count):
|
79 |
if majority_vote_count >= 13:
|
80 |
return "Mental health severity: Severe"
|
@@ -85,6 +91,7 @@ class ModelPredictor:
|
|
85 |
else:
|
86 |
return "Mental health severity: Very Low"
|
87 |
|
|
|
88 |
model_filenames = [
|
89 |
"YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl",
|
90 |
"YODPPROB.pkl", "YOWRPROB.pkl", "YODPR2WK.pkl", "YOWRDEPR.pkl",
|
@@ -94,21 +101,12 @@ model_filenames = [
|
|
94 |
model_path = "models/"
|
95 |
predictor = ModelPredictor(model_path, model_filenames)
|
96 |
|
97 |
-
##############################################################################
|
98 |
-
# INPUT VALIDATION
|
99 |
-
##############################################################################
|
100 |
-
|
101 |
def validate_inputs(*args):
|
102 |
-
"""Return False if any argument is blank or None."""
|
103 |
for arg in args:
|
104 |
-
if arg == '' or arg is None:
|
105 |
return False
|
106 |
return True
|
107 |
|
108 |
-
##############################################################################
|
109 |
-
# MAIN PREDICT FUNCTION
|
110 |
-
##############################################################################
|
111 |
-
|
112 |
def predict(
|
113 |
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
|
114 |
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
|
@@ -116,6 +114,20 @@ def predict(
|
|
116 |
YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
|
117 |
YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
|
118 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
# Prepare user_input dataframe for prediction
|
120 |
user_input_data = {
|
121 |
'YNURSMDE': [int(YNURSMDE)],
|
@@ -150,18 +162,21 @@ def predict(
|
|
150 |
}
|
151 |
user_input = pd.DataFrame(user_input_data)
|
152 |
|
153 |
-
#
|
|
|
|
|
154 |
predictions = predictor.make_predictions(user_input)
|
155 |
-
|
|
|
156 |
majority_vote = predictor.get_majority_vote(predictions)
|
157 |
-
|
|
|
158 |
majority_vote_count = sum([1 for pred in np.concatenate(predictions) if pred == 1])
|
|
|
159 |
# 4) Evaluate severity
|
160 |
severity = predictor.evaluate_severity(majority_vote_count)
|
161 |
|
162 |
-
|
163 |
-
# (A) Summarize per-model predictions
|
164 |
-
############################################################################
|
165 |
results = {
|
166 |
"Concentration_and_Decision_Making": [],
|
167 |
"Sleep_and_Energy_Levels": [],
|
@@ -180,221 +195,73 @@ def predict(
|
|
180 |
"YODPR2WK", "YODSMMDE",
|
181 |
"YOPB2WK"]
|
182 |
}
|
183 |
-
|
184 |
for i, pred in enumerate(predictions):
|
185 |
model_name = model_filenames[i].split('.')[0] # e.g. 'YOWRCONC'
|
186 |
pred_value = pred[0]
|
|
|
187 |
if model_name in predictor.prediction_map and pred_value in [0, 1]:
|
188 |
result_text = f"Model {model_name}: {predictor.prediction_map[model_name][pred_value]}"
|
|
|
|
|
|
|
189 |
else:
|
190 |
-
result_text = f"Model {model_name}: Unknown
|
191 |
|
|
|
192 |
found_group = False
|
193 |
for group_name, group_models in prediction_groups.items():
|
194 |
if model_name in group_models:
|
195 |
results[group_name].append(result_text)
|
196 |
found_group = True
|
197 |
break
|
|
|
|
|
|
|
198 |
|
|
|
199 |
formatted_results = []
|
200 |
for group, preds in results.items():
|
201 |
if preds:
|
202 |
formatted_results.append(f"Group {group.replace('_', ' ')}:")
|
203 |
formatted_results.append("\n".join(preds))
|
204 |
-
formatted_results.append("")
|
205 |
-
if not formatted_results:
|
206 |
-
formatted_results = ["No predictions made. Please check your inputs."]
|
207 |
-
|
208 |
-
prediction_summary_text = "\n".join(formatted_results).strip()
|
209 |
|
210 |
-
|
211 |
-
|
212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
total_patients = len(df)
|
214 |
-
|
215 |
"### Total Patient Count\n"
|
216 |
-
f"
|
217 |
-
"
|
218 |
-
|
219 |
-
|
220 |
-
############################################################################
|
221 |
-
# (C) CROSS-TABULATION & GROUPED BAR CHART (EXAMPLE)
|
222 |
-
# We'll demonstrate with one feature (e.g., 'YMDEYR') vs. the actual label 'YOWRCONC'
|
223 |
-
############################################################################
|
224 |
-
# Explanation:
|
225 |
-
cross_tab_explanation = (
|
226 |
-
"### Cross-Tabulation & Grouped Bar Chart\n"
|
227 |
-
"This chart shows how often each category of a given feature (X-axis) co-occurs with each **actual label** (0 or 1). "
|
228 |
-
"Interpreting this helps clinicians see which categories have a higher proportion of positive vs. negative outcomes. "
|
229 |
-
"For instance, if 'Yes' in YMDEYR heavily corresponds to label=1, that suggests a stronger link between that feature and the mental health outcome."
|
230 |
-
)
|
231 |
-
|
232 |
-
if "YOWRCONC" in df.columns and "YMDEYR" in df.columns:
|
233 |
-
# Make sure we actually have the columns needed
|
234 |
-
ctab = pd.crosstab(df["YMDEYR"], df["YOWRCONC"])
|
235 |
-
# ctab might have column names [0,1] for the label
|
236 |
-
ctab.reset_index(inplace=True)
|
237 |
-
# rename for clarity
|
238 |
-
ctab.columns = ["YMDEYR_Value", "Label0_Count", "Label1_Count"]
|
239 |
-
|
240 |
-
fig_crosstab = px.bar(
|
241 |
-
ctab,
|
242 |
-
x="YMDEYR_Value",
|
243 |
-
y=["Label0_Count", "Label1_Count"],
|
244 |
-
barmode="group",
|
245 |
-
title="YMDEYR vs. YOWRCONC (Actual Label)",
|
246 |
-
labels={
|
247 |
-
"YMDEYR_Value": "YMDEYR Feature Categories",
|
248 |
-
"value": "Count of Patients",
|
249 |
-
"variable": "Label"
|
250 |
-
}
|
251 |
-
)
|
252 |
-
else:
|
253 |
-
# fallback if we don't have those columns
|
254 |
-
fig_crosstab = px.bar(
|
255 |
-
x=["Data Error"], y=[0],
|
256 |
-
title="Could not generate cross-tab: 'YOWRCONC' or 'YMDEYR' not in df"
|
257 |
-
)
|
258 |
-
|
259 |
-
############################################################################
|
260 |
-
# (D) "SIMILAR PATIENT" / NEAREST-NEIGHBORS DEMO
|
261 |
-
# We'll pick a small set of "key features", measure Hamming distance,
|
262 |
-
# and find the top-K closest rows. Then we'll show how many had label=1.
|
263 |
-
############################################################################
|
264 |
-
similar_explanation = (
|
265 |
-
"### Similar Patients (Nearest Neighbors)\n"
|
266 |
-
"Here we define a small set of key features and use a simple Hamming distance "
|
267 |
-
"(count of mismatched categories) to find patients who are 'closest' to the current input. "
|
268 |
-
"This helps clinicians see how similar patients were labeled or what interventions they needed."
|
269 |
)
|
270 |
|
271 |
-
#
|
272 |
-
|
273 |
-
if all(kf in df.columns for kf in key_features) and "YOWRCONC" in df.columns:
|
274 |
-
# Compute distance for each row
|
275 |
-
user_vector = [user_input_data[kf][0] for kf in key_features]
|
276 |
-
distances = []
|
277 |
-
for idx, row in df[key_features].iterrows():
|
278 |
-
# Compare row to user_vector
|
279 |
-
row_vector = row.values
|
280 |
-
# Hamming distance = sum(row_vector[i] != user_vector[i])
|
281 |
-
dist = sum(rv != uv for rv, uv in zip(row_vector, user_vector))
|
282 |
-
distances.append(dist)
|
283 |
-
|
284 |
-
# Add distances to a copy of df
|
285 |
-
temp_df = df.copy()
|
286 |
-
temp_df["HammingDist"] = distances
|
287 |
-
# Sort ascending by distance, take top-K (e.g., 20)
|
288 |
-
top_k = temp_df.nsmallest(20, "HammingDist")
|
289 |
-
# Count how many have label=1 in top_k
|
290 |
-
if "YOWRCONC" in top_k.columns:
|
291 |
-
similar_label_1_count = (top_k["YOWRCONC"] == 1).sum()
|
292 |
-
similar_label_0_count = (top_k["YOWRCONC"] == 0).sum()
|
293 |
-
similar_text = (
|
294 |
-
f"Out of the 20 most similar patients:\n"
|
295 |
-
f"- {similar_label_1_count} had label=1\n"
|
296 |
-
f"- {similar_label_0_count} had label=0\n"
|
297 |
-
f"(Distances ranged from {top_k['HammingDist'].min()} to {top_k['HammingDist'].max()})."
|
298 |
-
)
|
299 |
-
else:
|
300 |
-
similar_text = "Label column 'YOWRCONC' missing in dataset."
|
301 |
-
else:
|
302 |
-
similar_text = "Cannot compute nearest neighbors: some key features or label column are missing."
|
303 |
-
|
304 |
-
############################################################################
|
305 |
-
# (E) CO-OCCURRENCE PLOT (TWO FEATURES) vs. LABEL
|
306 |
-
############################################################################
|
307 |
-
cooccurrence_explanation = (
|
308 |
-
"### Co-Occurrence of Two Features vs. Label\n"
|
309 |
-
"This shows how two categorical features combine, and how many patients in each combination are labeled 0 or 1. "
|
310 |
-
"Clinicians can spot if certain feature-combinations are particularly high-risk or high-incidence of label=1."
|
311 |
-
)
|
312 |
-
|
313 |
-
# Example: co-occurrence of 'YMDEYR' and 'YMDERSUD5ANY' vs. 'YOWRCONC'
|
314 |
-
if all(col in df.columns for col in ["YMDEYR", "YMDERSUD5ANY", "YOWRCONC"]):
|
315 |
-
co_tab = pd.crosstab([df["YMDEYR"], df["YMDERSUD5ANY"]], df["YOWRCONC"])
|
316 |
-
co_tab.reset_index(inplace=True)
|
317 |
-
# co_tab columns: ["YMDEYR", "YMDERSUD5ANY", "0", "1"]
|
318 |
-
co_tab.columns = ["YMDEYR", "YMDERSUD5ANY", "Label0", "Label1"]
|
319 |
-
|
320 |
-
# We'll create a stacked or grouped bar. Let's do grouped by label.
|
321 |
-
# Construct a single column "Count" and a single column "Label" to let plotly group them
|
322 |
-
data_list = []
|
323 |
-
for i, row in co_tab.iterrows():
|
324 |
-
data_list.append({
|
325 |
-
"YMDEYR_Val": row["YMDEYR"],
|
326 |
-
"YMDERSUD5ANY_Val": row["YMDERSUD5ANY"],
|
327 |
-
"Label": "Label=0",
|
328 |
-
"Count": row["Label0"]
|
329 |
-
})
|
330 |
-
data_list.append({
|
331 |
-
"YMDEYR_Val": row["YMDEYR"],
|
332 |
-
"YMDERSUD5ANY_Val": row["YMDERSUD5ANY"],
|
333 |
-
"Label": "Label=1",
|
334 |
-
"Count": row["Label1"]
|
335 |
-
})
|
336 |
-
df_co = pd.DataFrame(data_list)
|
337 |
-
|
338 |
-
fig_cooccur = px.bar(
|
339 |
-
df_co,
|
340 |
-
x="YMDEYR_Val",
|
341 |
-
y="Count",
|
342 |
-
color="Label",
|
343 |
-
facet_col="YMDERSUD5ANY_Val", # separate subplots by second feature
|
344 |
-
barmode="group",
|
345 |
-
title="Co-Occurrence: YMDEYR & YMDERSUD5ANY vs. YOWRCONC",
|
346 |
-
labels={"YMDEYR_Val": "YMDEYR", "YMDERSUD5ANY_Val": "YMDERSUD5ANY"}
|
347 |
-
)
|
348 |
-
fig_cooccur.update_layout(
|
349 |
-
legend_title_text="Actual Label",
|
350 |
-
xaxis_title="YMDEYR Categories",
|
351 |
-
yaxis_title="Number of Patients"
|
352 |
-
)
|
353 |
-
else:
|
354 |
-
fig_cooccur = px.bar(
|
355 |
-
x=["Data Error"], y=[0],
|
356 |
-
title="Could not generate co-occurrence chart: missing columns"
|
357 |
-
)
|
358 |
-
|
359 |
-
#------------------------------------------------------------------------------
|
360 |
-
# RETURN / RENDER
|
361 |
-
#------------------------------------------------------------------------------
|
362 |
-
# We have 6 outputs total (the code is set up for that).
|
363 |
-
# We'll map them as follows:
|
364 |
-
# 1) "Prediction Results" (Textbox)
|
365 |
-
# 2) "Mental Health Severity" (Textbox)
|
366 |
-
# 3) A Markdown that combines: total_patients_text + cross_tab_explanation + similar_explanation + cooccurrence_explanation + the nearest-neighbors result
|
367 |
-
# 4) Cross-Tab Bar Chart
|
368 |
-
# 5) "Number of Patients with the Same Value for Each Input Feature"
|
369 |
-
# 6) "Number of Patients with Predicted Labels"
|
370 |
-
|
371 |
-
# (i) Provide text results for the user’s predictions
|
372 |
-
# (ii) Provide severity
|
373 |
-
|
374 |
-
# Build the big markdown text for (3)
|
375 |
-
big_markdown = (
|
376 |
-
total_patients_text
|
377 |
-
+ "\n\n"
|
378 |
-
+ cross_tab_explanation
|
379 |
-
+ "\n\n"
|
380 |
-
+ f"**Crosstab Example**: See the bar chart below comparing 'YMDEYR' vs. actual label 'YOWRCONC'.\n\n"
|
381 |
-
+ similar_explanation
|
382 |
-
+ "\n\n"
|
383 |
-
+ similar_text
|
384 |
-
+ "\n\n"
|
385 |
-
+ cooccurrence_explanation
|
386 |
-
+ "\n\n"
|
387 |
-
+ "See the final chart below for how 'YMDEYR' & 'YMDERSUD5ANY' co-occur with label 'YOWRCONC'."
|
388 |
-
)
|
389 |
-
|
390 |
-
# (F) Bar Chart for each input feature
|
391 |
-
# We'll keep the logic for counting how many in df have the same value for each feature
|
392 |
input_counts = {}
|
393 |
-
for col
|
394 |
-
val =
|
395 |
same_val_count = len(df[df[col] == val])
|
396 |
input_counts[col] = same_val_count
|
397 |
|
|
|
398 |
bar_input_data = pd.DataFrame({
|
399 |
"Feature": list(input_counts.keys()),
|
400 |
"Count": list(input_counts.values())
|
@@ -408,13 +275,14 @@ def predict(
|
|
408 |
)
|
409 |
fig_bar_input.update_layout(xaxis={'categoryorder':'total descending'})
|
410 |
|
411 |
-
#
|
412 |
-
#
|
|
|
413 |
label_counts = {}
|
414 |
for i, pred in enumerate(predictions):
|
415 |
model_name = model_filenames[i].split('.')[0]
|
416 |
pred_value = pred[0]
|
417 |
-
if pred_value in [0, 1]
|
418 |
label_counts[model_name] = len(df[df[model_name] == pred_value])
|
419 |
|
420 |
if len(label_counts) > 0:
|
@@ -426,12 +294,12 @@ def predict(
|
|
426 |
bar_label_data,
|
427 |
x="Model",
|
428 |
y="Count",
|
429 |
-
title="Number of Patients with the
|
430 |
labels={"Model": "Predicted Column", "Count": "Number of Patients"}
|
431 |
)
|
432 |
fig_bar_labels.update_layout(xaxis={'categoryorder':'total descending'})
|
433 |
else:
|
434 |
-
# fallback
|
435 |
bar_label_data = pd.DataFrame({"Model": [], "Count": []})
|
436 |
fig_bar_labels = px.bar(
|
437 |
bar_label_data,
|
@@ -440,20 +308,128 @@ def predict(
|
|
440 |
title="No valid predicted labels to display"
|
441 |
)
|
442 |
|
443 |
-
#
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
451 |
)
|
452 |
|
453 |
-
|
454 |
-
#
|
455 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
456 |
|
|
|
|
|
|
|
457 |
input_mapping = {
|
458 |
'YNURSMDE': {"Yes": 1, "No": 0},
|
459 |
'YMDEYR': {"Yes": 1, "No": 2},
|
@@ -486,8 +462,23 @@ input_mapping = {
|
|
486 |
'YMDELT': {"Yes": 1, "No": 2}
|
487 |
}
|
488 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
489 |
# Define the "inputs" in the same order used in the function signature
|
490 |
inputs = [
|
|
|
|
|
491 |
gr.Dropdown(list(input_mapping['YMDEYR'].keys()), label="YMDEYR: PAST YEARS MAJOR DEPRESSIVE EPISODE"),
|
492 |
gr.Dropdown(list(input_mapping['YMDERSUD5ANY'].keys()), label="YMDERSUD5ANY: MDE OR SUBSTANCE USE DISORDER - ANY"),
|
493 |
gr.Dropdown(list(input_mapping['YMDEIMAD5YR'].keys()), label="YMDEIMAD5YR: MDE WITH SEV. IMP + ALCOHOL USE DISORDER"),
|
@@ -501,6 +492,8 @@ inputs = [
|
|
501 |
gr.Dropdown(list(input_mapping['YMIMI5YANY'].keys()), label="YMIMI5YANY: MDE W/ ILL DRUG USE DISORDER"),
|
502 |
gr.Dropdown(list(input_mapping['YMIUD5YANY'].keys()), label="YMIUD5YANY: MDE + ILL DRUG USE DISORDER"),
|
503 |
gr.Dropdown(list(input_mapping['YMDESUD5ANYO'].keys()), label="YMDESUD5ANYO: MDE vs. SUD vs. BOTH vs. NEITHER"),
|
|
|
|
|
504 |
gr.Dropdown(list(input_mapping['YNURSMDE'].keys()), label="YNURSMDE: SAW/TALK TO NURSE/OT ABOUT MDE"),
|
505 |
gr.Dropdown(list(input_mapping['YSOCMDE'].keys()), label="YSOCMDE: SAW/TALK TO SOCIAL WORKER ABOUT MDE"),
|
506 |
gr.Dropdown(list(input_mapping['YCOUNMDE'].keys()), label="YCOUNMDE: SAW/TALK TO COUNSELOR ABOUT MDE"),
|
@@ -509,22 +502,28 @@ inputs = [
|
|
509 |
gr.Dropdown(list(input_mapping['YHLTMDE'].keys()), label="YHLTMDE: SAW/TALK TO HEALTH PROFESSIONAL ABOUT MDE"),
|
510 |
gr.Dropdown(list(input_mapping['YDOCMDE'].keys()), label="YDOCMDE: SAW/TALK TO GP/FAMILY MD ABOUT MDE"),
|
511 |
gr.Dropdown(list(input_mapping['YTXMDEYR'].keys()), label="YTXMDEYR: SAW/TALK DOCTOR/HEALTH PROF FOR MDE"),
|
|
|
|
|
512 |
gr.Dropdown(list(input_mapping['YUSUITHKYR'].keys()), label="YUSUITHKYR: SERIOUSLY THOUGHT ABOUT KILLING SELF"),
|
513 |
gr.Dropdown(list(input_mapping['YUSUIPLNYR'].keys()), label="YUSUIPLNYR: MADE PLANS TO KILL SELF"),
|
514 |
gr.Dropdown(list(input_mapping['YUSUITHK'].keys()), label="YUSUITHK: THINK ABOUT KILLING SELF (12 MONTHS)"),
|
515 |
gr.Dropdown(list(input_mapping['YUSUIPLN'].keys()), label="YUSUIPLN: MADE PLANS TO KILL SELF (12 MONTHS)"),
|
|
|
|
|
516 |
gr.Dropdown(list(input_mapping['MDEIMPY'].keys()), label="MDEIMPY: MDE W/ SEVERE ROLE IMPAIRMENT"),
|
517 |
gr.Dropdown(list(input_mapping['LVLDIFMEM2'].keys()), label="LVLDIFMEM2: LEVEL OF DIFFICULTY REMEMBERING/CONCENTRATING"),
|
518 |
gr.Dropdown(list(input_mapping['YMSUD5YANY'].keys()), label="YMSUD5YANY: MDE + SUBSTANCE USE DISORDER - ANY"),
|
519 |
gr.Dropdown(list(input_mapping['YRXMDEYR'].keys()), label="YRXMDEYR: USED MEDS FOR MDE IN PAST YEAR"),
|
520 |
]
|
521 |
|
522 |
-
# We have
|
523 |
outputs = [
|
524 |
gr.Textbox(label="Prediction Results", lines=30),
|
525 |
gr.Textbox(label="Mental Health Severity", lines=4),
|
526 |
-
gr.Markdown(
|
527 |
-
gr.Plot(label="Cross-Tab
|
|
|
|
|
528 |
gr.Plot(label="Number of Patients per Input Feature"),
|
529 |
gr.Plot(label="Number of Patients with Predicted Labels")
|
530 |
]
|
@@ -545,10 +544,14 @@ def predict_with_text(
|
|
545 |
YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
|
546 |
):
|
547 |
return (
|
548 |
-
"Please select all required fields.", #
|
549 |
"Validation Error", # Severity
|
550 |
-
"",
|
551 |
-
None,
|
|
|
|
|
|
|
|
|
552 |
)
|
553 |
|
554 |
# Map from user-friendly text to int
|
@@ -587,7 +590,6 @@ def predict_with_text(
|
|
587 |
# Pass our mapped values into the original 'predict' function
|
588 |
return predict(**user_inputs)
|
589 |
|
590 |
-
|
591 |
# Custom CSS (optional)
|
592 |
custom_css = """
|
593 |
.gradio-container * {
|
@@ -606,10 +608,7 @@ custom_css = """
|
|
606 |
}
|
607 |
"""
|
608 |
|
609 |
-
|
610 |
-
# LAUNCH INTERFACE
|
611 |
-
##############################################################################
|
612 |
-
|
613 |
interface = gr.Interface(
|
614 |
fn=predict_with_text,
|
615 |
inputs=inputs,
|
|
|
|
|
1 |
import pickle
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
import numpy as np
|
5 |
import pandas as pd
|
6 |
import plotly.express as px
|
|
|
8 |
# Load the training CSV once (outside the functions so it is read only once).
|
9 |
df = pd.read_csv("X_train_Y_Train_merged_train.csv")
|
10 |
|
|
|
|
|
|
|
|
|
11 |
class ModelPredictor:
|
12 |
def __init__(self, model_path, model_filenames):
|
13 |
self.model_path = model_path
|
14 |
self.model_filenames = model_filenames
|
15 |
self.models = self.load_models()
|
16 |
+
# For readability, you might want to keep only a few keys here if you want
|
17 |
+
# to demonstrate partial cross-tabs, etc.
|
18 |
self.prediction_map = {
|
19 |
"YOWRCONC": ["Did not have difficulty concentrating", "Had difficulty concentrating"],
|
20 |
"YOSEEDOC": ["Did not feel the need to see a doctor", "Felt the need to see a doctor"],
|
|
|
56 |
return models
|
57 |
|
58 |
def make_predictions(self, user_input):
|
59 |
+
"""
|
60 |
+
Returns a list of numpy arrays, each array is [0] or [1].
|
61 |
+
The i-th array corresponds to the i-th model in self.models.
|
62 |
+
"""
|
63 |
predictions = []
|
64 |
for model in self.models:
|
65 |
pred = model.predict(user_input)
|
|
|
70 |
def get_majority_vote(self, predictions):
|
71 |
"""
|
72 |
Flatten all predictions from all models, combine them into a single array,
|
73 |
+
then find the majority class (0 or 1) across all of them.
|
74 |
"""
|
75 |
combined_predictions = np.concatenate(predictions)
|
76 |
majority_vote = np.bincount(combined_predictions).argmax()
|
77 |
return majority_vote
|
78 |
|
79 |
+
# Based on Equal Interval and Percentage-Based Method
|
80 |
+
# Severe: 13 to 16 votes (upper 25%)
|
81 |
+
# Moderate: 9 to 12 votes (upper-middle 25%)
|
82 |
+
# Low: 5 to 8 votes (lower-middle 25%)
|
83 |
+
# Very Low: 0 to 4 votes (lower 25%)
|
84 |
def evaluate_severity(self, majority_vote_count):
|
85 |
if majority_vote_count >= 13:
|
86 |
return "Mental health severity: Severe"
|
|
|
91 |
else:
|
92 |
return "Mental health severity: Very Low"
|
93 |
|
94 |
+
# List of model filenames
|
95 |
model_filenames = [
|
96 |
"YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl",
|
97 |
"YODPPROB.pkl", "YOWRPROB.pkl", "YODPR2WK.pkl", "YOWRDEPR.pkl",
|
|
|
101 |
model_path = "models/"
|
102 |
predictor = ModelPredictor(model_path, model_filenames)
|
103 |
|
|
|
|
|
|
|
|
|
104 |
def validate_inputs(*args):
|
|
|
105 |
for arg in args:
|
106 |
+
if arg == '' or arg is None: # Assuming empty string or None as unselected
|
107 |
return False
|
108 |
return True
|
109 |
|
|
|
|
|
|
|
|
|
110 |
def predict(
|
111 |
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
|
112 |
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
|
|
|
114 |
YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
|
115 |
YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
|
116 |
):
|
117 |
+
"""
|
118 |
+
Core prediction function that:
|
119 |
+
1) Predicts with each model
|
120 |
+
2) Aggregates results
|
121 |
+
3) Produces an overall 'severity'
|
122 |
+
4) Returns detailed per-model predictions
|
123 |
+
5) Returns bar charts about how many in the dataset share the same inputs/predicted labels
|
124 |
+
6) ***Now includes custom sections for:
|
125 |
+
- Total patient count (markdown)
|
126 |
+
- Cross-tab & grouped bar chart
|
127 |
+
- Similar Patient (Nearest Neighbors)
|
128 |
+
- Co-occurrence plot
|
129 |
+
"""
|
130 |
+
|
131 |
# Prepare user_input dataframe for prediction
|
132 |
user_input_data = {
|
133 |
'YNURSMDE': [int(YNURSMDE)],
|
|
|
162 |
}
|
163 |
user_input = pd.DataFrame(user_input_data)
|
164 |
|
165 |
+
# -----------------------
|
166 |
+
# 1) Make predictions
|
167 |
+
# -----------------------
|
168 |
predictions = predictor.make_predictions(user_input)
|
169 |
+
|
170 |
+
# 2) Calculate majority vote (0 or 1) across all models
|
171 |
majority_vote = predictor.get_majority_vote(predictions)
|
172 |
+
|
173 |
+
# 3) Count how many 1's in all predictions combined
|
174 |
majority_vote_count = sum([1 for pred in np.concatenate(predictions) if pred == 1])
|
175 |
+
|
176 |
# 4) Evaluate severity
|
177 |
severity = predictor.evaluate_severity(majority_vote_count)
|
178 |
|
179 |
+
# 5) Prepare detailed results for each model group
|
|
|
|
|
180 |
results = {
|
181 |
"Concentration_and_Decision_Making": [],
|
182 |
"Sleep_and_Energy_Levels": [],
|
|
|
195 |
"YODPR2WK", "YODSMMDE",
|
196 |
"YOPB2WK"]
|
197 |
}
|
198 |
+
|
199 |
for i, pred in enumerate(predictions):
|
200 |
model_name = model_filenames[i].split('.')[0] # e.g. 'YOWRCONC'
|
201 |
pred_value = pred[0]
|
202 |
+
# Map the prediction value to a human-readable string
|
203 |
if model_name in predictor.prediction_map and pred_value in [0, 1]:
|
204 |
result_text = f"Model {model_name}: {predictor.prediction_map[model_name][pred_value]}"
|
205 |
+
elif model_name in predictor.prediction_map:
|
206 |
+
# Out of known range => "Unknown"
|
207 |
+
result_text = f"Model {model_name}: Unknown prediction value {pred_value}"
|
208 |
else:
|
209 |
+
result_text = f"Model {model_name}: Unknown model"
|
210 |
|
211 |
+
# Append to the appropriate group
|
212 |
found_group = False
|
213 |
for group_name, group_models in prediction_groups.items():
|
214 |
if model_name in group_models:
|
215 |
results[group_name].append(result_text)
|
216 |
found_group = True
|
217 |
break
|
218 |
+
if not found_group:
|
219 |
+
# If model doesn't match any group, skip or store it in a catch-all
|
220 |
+
pass
|
221 |
|
222 |
+
# 6) Nicely format the results
|
223 |
formatted_results = []
|
224 |
for group, preds in results.items():
|
225 |
if preds:
|
226 |
formatted_results.append(f"Group {group.replace('_', ' ')}:")
|
227 |
formatted_results.append("\n".join(preds))
|
228 |
+
formatted_results.append("\n")
|
|
|
|
|
|
|
|
|
229 |
|
230 |
+
formatted_results = "\n".join(formatted_results).strip()
|
231 |
+
|
232 |
+
if len(formatted_results) == 0:
|
233 |
+
formatted_results = "No predictions made. Please check your inputs."
|
234 |
+
|
235 |
+
# Heuristic: if too many unknown predictions, append note
|
236 |
+
num_unknown = len([
|
237 |
+
pred for group, preds in results.items()
|
238 |
+
for pred in preds if "Unknown prediction value" in pred or "Unknown model" in pred
|
239 |
+
])
|
240 |
+
if num_unknown > len(model_filenames) / 2:
|
241 |
+
severity += " (Unknown prediction count is high. Please consult with a human.)"
|
242 |
+
|
243 |
+
# ------------------------
|
244 |
+
# ADDITIONAL FEATURES
|
245 |
+
# ------------------------
|
246 |
+
|
247 |
+
# A) Total Patient Count (instead of the old "Pie" chart)
|
248 |
total_patients = len(df)
|
249 |
+
total_patient_count_markdown = (
|
250 |
"### Total Patient Count\n"
|
251 |
+
f"There are **{total_patients}** total patients in the dataset.\n\n"
|
252 |
+
"This count can help you understand the overall dataset size. "
|
253 |
+
"All subsequent analyses are relative to these patients."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
254 |
)
|
255 |
|
256 |
+
# B) Analyze Each Input Feature
|
257 |
+
# For each feature in user_input, compute how many patients have that same value.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
258 |
input_counts = {}
|
259 |
+
for col in user_input_data.keys():
|
260 |
+
val = user_input_data[col][0]
|
261 |
same_val_count = len(df[df[col] == val])
|
262 |
input_counts[col] = same_val_count
|
263 |
|
264 |
+
# Plot: Bar Chart for each input feature
|
265 |
bar_input_data = pd.DataFrame({
|
266 |
"Feature": list(input_counts.keys()),
|
267 |
"Count": list(input_counts.values())
|
|
|
275 |
)
|
276 |
fig_bar_input.update_layout(xaxis={'categoryorder':'total descending'})
|
277 |
|
278 |
+
# C) Analyze Predicted Labels
|
279 |
+
# For each model's predicted label (0 or 1), count how many patients in the CSV
|
280 |
+
# have that label. We skip unknown if pred_value not in [0, 1].
|
281 |
label_counts = {}
|
282 |
for i, pred in enumerate(predictions):
|
283 |
model_name = model_filenames[i].split('.')[0]
|
284 |
pred_value = pred[0]
|
285 |
+
if pred_value in [0, 1]:
|
286 |
label_counts[model_name] = len(df[df[model_name] == pred_value])
|
287 |
|
288 |
if len(label_counts) > 0:
|
|
|
294 |
bar_label_data,
|
295 |
x="Model",
|
296 |
y="Count",
|
297 |
+
title="Number of Patients with the Predicted Label (0 or 1) by Model",
|
298 |
labels={"Model": "Predicted Column", "Count": "Number of Patients"}
|
299 |
)
|
300 |
fig_bar_labels.update_layout(xaxis={'categoryorder':'total descending'})
|
301 |
else:
|
302 |
+
# If everything was unknown, produce an empty figure or a fallback message
|
303 |
bar_label_data = pd.DataFrame({"Model": [], "Count": []})
|
304 |
fig_bar_labels = px.bar(
|
305 |
bar_label_data,
|
|
|
308 |
title="No valid predicted labels to display"
|
309 |
)
|
310 |
|
311 |
+
# D) Cross-Tabulation & Grouped Bar Chart
|
312 |
+
# Example: Show how a single input feature (YMDEYR) relates to one actual label (YOWRCONC).
|
313 |
+
# For demonstration only — in practice you might do this for multiple features/labels.
|
314 |
+
# NOTE: If the columns don't exist in the dataset (some code merges them differently),
|
315 |
+
# you might adapt accordingly.
|
316 |
+
if "YMDEYR" in df.columns and "YOWRCONC" in df.columns:
|
317 |
+
cross_tab_data = df.groupby(["YMDEYR", "YOWRCONC"]).size().reset_index(name="count")
|
318 |
+
fig_cross_tab = px.bar(
|
319 |
+
cross_tab_data,
|
320 |
+
x="YMDEYR",
|
321 |
+
y="count",
|
322 |
+
color="YOWRCONC",
|
323 |
+
barmode="group",
|
324 |
+
title="Cross-Tab: YMDEYR vs YOWRCONC (Grouped Bar Chart)",
|
325 |
+
labels={"YMDEYR": "Feature: YMDEYR", "YOWRCONC": "Label: YOWRCONC"}
|
326 |
+
)
|
327 |
+
else:
|
328 |
+
# Provide a fallback message if columns not found
|
329 |
+
fig_cross_tab = px.bar(title="YMDEYR or YOWRCONC not found in dataset. Cross-tab not available.")
|
330 |
+
|
331 |
+
# E) Similar Patient (Nearest Neighbors) via simple Hamming distance
|
332 |
+
# We'll pick K=5 neighbors. Then see how many had label=0 vs label=1 for
|
333 |
+
# one example label: YOWRCONC.
|
334 |
+
# (You can adapt to do multiple labels, but that can get lengthy.)
|
335 |
+
def hamming_distance(row, user_row):
|
336 |
+
dist = 0
|
337 |
+
for c in user_row.index:
|
338 |
+
if row[c] != user_row[c]:
|
339 |
+
dist += 1
|
340 |
+
return dist
|
341 |
+
|
342 |
+
# Create a single row for easy iteration
|
343 |
+
user_series = user_input.iloc[0]
|
344 |
+
|
345 |
+
# We'll compute distance for all rows in df on the same features
|
346 |
+
# that were used in the user_input.
|
347 |
+
# NOTE: In real usage, confirm these columns exist in df.
|
348 |
+
# If df lacks them or is encoded differently, you'd adapt.
|
349 |
+
features_to_compare = list(user_input.columns)
|
350 |
+
# For Hamming, ensure we pick only the columns present in df
|
351 |
+
features_to_compare = [f for f in features_to_compare if f in df.columns]
|
352 |
+
|
353 |
+
# Build a DataFrame we can safely compare
|
354 |
+
subset_df = df[features_to_compare].copy()
|
355 |
+
|
356 |
+
# Calculate distances
|
357 |
+
distances = []
|
358 |
+
for idx, row in subset_df.iterrows():
|
359 |
+
d = 0
|
360 |
+
for col in features_to_compare:
|
361 |
+
if row[col] != user_series[col]:
|
362 |
+
d += 1
|
363 |
+
distances.append(d)
|
364 |
+
|
365 |
+
# Attach distances
|
366 |
+
df_with_dist = df.copy()
|
367 |
+
df_with_dist["distance"] = distances
|
368 |
+
|
369 |
+
# Sort by distance ascending, pick top K=5
|
370 |
+
K = 5
|
371 |
+
nearest_neighbors = df_with_dist.sort_values("distance", ascending=True).head(K)
|
372 |
+
|
373 |
+
# For demonstration, let's show how many had YOWRCONC=0 vs. 1
|
374 |
+
nn_label_0 = nn_label_1 = 0
|
375 |
+
if "YOWRCONC" in nearest_neighbors.columns:
|
376 |
+
nn_label_0 = len(nearest_neighbors[nearest_neighbors["YOWRCONC"] == 0])
|
377 |
+
nn_label_1 = len(nearest_neighbors[nearest_neighbors["YOWRCONC"] == 1])
|
378 |
+
|
379 |
+
# Summarize in markdown
|
380 |
+
similar_patient_markdown = (
|
381 |
+
"### Nearest Neighbors (Simple Hamming Distance)\n"
|
382 |
+
f"We searched for the top **{K}** patients in the dataset whose categorical features "
|
383 |
+
"most closely match your input (Hamming distance).\n\n"
|
384 |
+
"**For the label `YOWRCONC`** among these neighbors:\n"
|
385 |
+
f"- {nn_label_0} had label=0\n"
|
386 |
+
f"- {nn_label_1} had label=1\n\n"
|
387 |
+
"(This is a simple illustration. In real practice, you'd refine which columns to use, "
|
388 |
+
"how to encode them, and how many neighbors to consider.)"
|
389 |
)
|
390 |
|
391 |
+
# F) Co-Occurrence Plot
|
392 |
+
# Example: How two features (YMDEYR, YMDERSUD5ANY) combine with label (YOWRCONC).
|
393 |
+
# We'll produce a multi-way distribution using facet_col.
|
394 |
+
if all(col in df.columns for col in ["YMDEYR", "YMDERSUD5ANY", "YOWRCONC"]):
|
395 |
+
co_occ_data = df.groupby(["YMDEYR", "YMDERSUD5ANY", "YOWRCONC"]).size().reset_index(name="count")
|
396 |
+
fig_co_occ = px.bar(
|
397 |
+
co_occ_data,
|
398 |
+
x="YMDEYR",
|
399 |
+
y="count",
|
400 |
+
color="YOWRCONC",
|
401 |
+
facet_col="YMDERSUD5ANY",
|
402 |
+
title="Co-Occurrence Plot: YMDEYR and YMDERSUD5ANY vs YOWRCONC"
|
403 |
+
)
|
404 |
+
else:
|
405 |
+
fig_co_occ = px.bar(title="Co-occurrence plot not available (columns not found).")
|
406 |
+
|
407 |
+
# ------------------------
|
408 |
+
# Return everything
|
409 |
+
# ------------------------
|
410 |
+
# We now have 8 items to return:
|
411 |
+
# 1) Prediction Results (Textbox)
|
412 |
+
# 2) Mental Health Severity (Textbox)
|
413 |
+
# 3) Total Patient Count (Markdown)
|
414 |
+
# 4) Cross-Tab & Grouped Bar Chart (Plot)
|
415 |
+
# 5) Nearest Neighbors Summary (Markdown)
|
416 |
+
# 6) Co-Occurrence Plot (Plot)
|
417 |
+
# 7) Bar Chart for input features (Plot)
|
418 |
+
# 8) Bar Chart for predicted labels (Plot)
|
419 |
+
return (
|
420 |
+
formatted_results,
|
421 |
+
severity,
|
422 |
+
total_patient_count_markdown,
|
423 |
+
fig_cross_tab,
|
424 |
+
similar_patient_markdown,
|
425 |
+
fig_co_occ,
|
426 |
+
fig_bar_input,
|
427 |
+
fig_bar_labels
|
428 |
+
)
|
429 |
|
430 |
+
# -----------------------------------------------------------------------------
|
431 |
+
# MAPPING user-friendly text => numeric values
|
432 |
+
# -----------------------------------------------------------------------------
|
433 |
input_mapping = {
|
434 |
'YNURSMDE': {"Yes": 1, "No": 0},
|
435 |
'YMDEYR': {"Yes": 1, "No": 2},
|
|
|
462 |
'YMDELT': {"Yes": 1, "No": 2}
|
463 |
}
|
464 |
|
465 |
+
# -----------------------------------------------------------------------------
|
466 |
+
# Create the Gradio interface
|
467 |
+
# -----------------------------------------------------------------------------
|
468 |
+
# We have 8 outputs now:
|
469 |
+
# 1) Prediction Results (Textbox)
|
470 |
+
# 2) Mental Health Severity (Textbox)
|
471 |
+
# 3) Total Patient Count (Markdown)
|
472 |
+
# 4) Cross-Tab & Grouped Bar Chart (Plot)
|
473 |
+
# 5) Nearest Neighbors Summary (Markdown)
|
474 |
+
# 6) Co-Occurrence Plot (Plot)
|
475 |
+
# 7) Bar Chart for input features (Plot)
|
476 |
+
# 8) Bar Chart for predicted labels (Plot)
|
477 |
+
|
478 |
# Define the "inputs" in the same order used in the function signature
|
479 |
inputs = [
|
480 |
+
################# Ordered and grouped ##########################
|
481 |
+
# Questions related to Major Depressive Episode (MDE) and related impairments or disorders
|
482 |
gr.Dropdown(list(input_mapping['YMDEYR'].keys()), label="YMDEYR: PAST YEARS MAJOR DEPRESSIVE EPISODE"),
|
483 |
gr.Dropdown(list(input_mapping['YMDERSUD5ANY'].keys()), label="YMDERSUD5ANY: MDE OR SUBSTANCE USE DISORDER - ANY"),
|
484 |
gr.Dropdown(list(input_mapping['YMDEIMAD5YR'].keys()), label="YMDEIMAD5YR: MDE WITH SEV. IMP + ALCOHOL USE DISORDER"),
|
|
|
492 |
gr.Dropdown(list(input_mapping['YMIMI5YANY'].keys()), label="YMIMI5YANY: MDE W/ ILL DRUG USE DISORDER"),
|
493 |
gr.Dropdown(list(input_mapping['YMIUD5YANY'].keys()), label="YMIUD5YANY: MDE + ILL DRUG USE DISORDER"),
|
494 |
gr.Dropdown(list(input_mapping['YMDESUD5ANYO'].keys()), label="YMDESUD5ANYO: MDE vs. SUD vs. BOTH vs. NEITHER"),
|
495 |
+
|
496 |
+
# Questions related to consultations with professionals about MDE
|
497 |
gr.Dropdown(list(input_mapping['YNURSMDE'].keys()), label="YNURSMDE: SAW/TALK TO NURSE/OT ABOUT MDE"),
|
498 |
gr.Dropdown(list(input_mapping['YSOCMDE'].keys()), label="YSOCMDE: SAW/TALK TO SOCIAL WORKER ABOUT MDE"),
|
499 |
gr.Dropdown(list(input_mapping['YCOUNMDE'].keys()), label="YCOUNMDE: SAW/TALK TO COUNSELOR ABOUT MDE"),
|
|
|
502 |
gr.Dropdown(list(input_mapping['YHLTMDE'].keys()), label="YHLTMDE: SAW/TALK TO HEALTH PROFESSIONAL ABOUT MDE"),
|
503 |
gr.Dropdown(list(input_mapping['YDOCMDE'].keys()), label="YDOCMDE: SAW/TALK TO GP/FAMILY MD ABOUT MDE"),
|
504 |
gr.Dropdown(list(input_mapping['YTXMDEYR'].keys()), label="YTXMDEYR: SAW/TALK DOCTOR/HEALTH PROF FOR MDE"),
|
505 |
+
|
506 |
+
# Questions related to suicidal thoughts and plans
|
507 |
gr.Dropdown(list(input_mapping['YUSUITHKYR'].keys()), label="YUSUITHKYR: SERIOUSLY THOUGHT ABOUT KILLING SELF"),
|
508 |
gr.Dropdown(list(input_mapping['YUSUIPLNYR'].keys()), label="YUSUIPLNYR: MADE PLANS TO KILL SELF"),
|
509 |
gr.Dropdown(list(input_mapping['YUSUITHK'].keys()), label="YUSUITHK: THINK ABOUT KILLING SELF (12 MONTHS)"),
|
510 |
gr.Dropdown(list(input_mapping['YUSUIPLN'].keys()), label="YUSUIPLN: MADE PLANS TO KILL SELF (12 MONTHS)"),
|
511 |
+
|
512 |
+
# Questions related to impairment due to MDE
|
513 |
gr.Dropdown(list(input_mapping['MDEIMPY'].keys()), label="MDEIMPY: MDE W/ SEVERE ROLE IMPAIRMENT"),
|
514 |
gr.Dropdown(list(input_mapping['LVLDIFMEM2'].keys()), label="LVLDIFMEM2: LEVEL OF DIFFICULTY REMEMBERING/CONCENTRATING"),
|
515 |
gr.Dropdown(list(input_mapping['YMSUD5YANY'].keys()), label="YMSUD5YANY: MDE + SUBSTANCE USE DISORDER - ANY"),
|
516 |
gr.Dropdown(list(input_mapping['YRXMDEYR'].keys()), label="YRXMDEYR: USED MEDS FOR MDE IN PAST YEAR"),
|
517 |
]
|
518 |
|
519 |
+
# We now have 8 outputs in total:
|
520 |
outputs = [
|
521 |
gr.Textbox(label="Prediction Results", lines=30),
|
522 |
gr.Textbox(label="Mental Health Severity", lines=4),
|
523 |
+
gr.Markdown(label="Total Patient Count"),
|
524 |
+
gr.Plot(label="Cross-Tab & Grouped Bar Chart"),
|
525 |
+
gr.Markdown(label="Nearest Neighbors Summary"),
|
526 |
+
gr.Plot(label="Co-Occurrence Plot"),
|
527 |
gr.Plot(label="Number of Patients per Input Feature"),
|
528 |
gr.Plot(label="Number of Patients with Predicted Labels")
|
529 |
]
|
|
|
544 |
YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
|
545 |
):
|
546 |
return (
|
547 |
+
"Please select all required fields.", # Prediction Results
|
548 |
"Validation Error", # Severity
|
549 |
+
"No data", # Total Patient Count
|
550 |
+
None, # Cross-Tab figure
|
551 |
+
"No data", # Nearest Neighbors
|
552 |
+
None, # Co-Occurrence
|
553 |
+
None, # Input Features Bar
|
554 |
+
None # Predicted Labels Bar
|
555 |
)
|
556 |
|
557 |
# Map from user-friendly text to int
|
|
|
590 |
# Pass our mapped values into the original 'predict' function
|
591 |
return predict(**user_inputs)
|
592 |
|
|
|
593 |
# Custom CSS (optional)
|
594 |
custom_css = """
|
595 |
.gradio-container * {
|
|
|
608 |
}
|
609 |
"""
|
610 |
|
611 |
+
# Finally, launch the app with 8 outputs
|
|
|
|
|
|
|
612 |
interface = gr.Interface(
|
613 |
fn=predict_with_text,
|
614 |
inputs=inputs,
|