|
import pickle |
|
import gradio as gr |
|
import numpy as np |
|
import pandas as pd |
|
import plotly.express as px |
|
|
|
|
|
df = pd.read_csv("X_train_Y_Train_merged_train.csv") |
|
|
|
|
|
|
|
|
|
class ModelPredictor: |
|
def __init__(self, model_path, model_filenames): |
|
self.model_path = model_path |
|
self.model_filenames = model_filenames |
|
self.models = self.load_models() |
|
|
|
self.prediction_map = { |
|
"YOWRCONC": ["No difficulty concentrating", "Had difficulty concentrating"], |
|
"YOSEEDOC": ["No need to see doctor", "Needed to see doctor"], |
|
"YOWRHRS": ["No trouble sleeping", "Had trouble sleeping"], |
|
"YO_MDEA5": ["Others didn't notice restlessness", "Others noticed restlessness"], |
|
"YOWRCHR": ["Not sad beyond cheering", "Felt so sad no one could cheer up"], |
|
"YOWRLSIN": ["Never felt bored/lost interest", "Felt bored/lost interest"], |
|
"YODPPROB": ["No other problems for 2+ weeks", "Had other problems for 2+ weeks"], |
|
"YOWRPROB": ["No worst time feeling", "Felt worst time ever"], |
|
"YODPR2WK": ["No depressed feelings for 2+ wks", "Depressed feelings for 2+ wks"], |
|
"YOWRDEPR": ["Not sad or depressed most days", "Sad or depressed most days"], |
|
"YODPDISC": ["Mood not depressed overall", "Mood depressed overall (discrepancy)"], |
|
"YOLOSEV": ["Did not lose interest in activities", "Lost interest in activities"], |
|
"YOWRDCSN": ["Could make decisions", "Could not make decisions"], |
|
"YODSMMDE": ["No 2+ week depression episodes", "Had 2+ week depression episodes"], |
|
"YO_MDEA3": ["No appetite/weight changes", "Yes appetite/weight changes"], |
|
"YODPLSIN": ["Never bored/lost interest", "Often bored/lost interest"], |
|
"YOWRELES": ["Did not eat less", "Ate less than usual"], |
|
"YODSCEV": ["Fewer severe symptoms", "More severe symptoms"], |
|
"YOPB2WK": ["No uneasy feelings daily 2+ wks", "Uneasy feelings daily 2+ wks"], |
|
"YO_MDEA2": ["No issues physical/mental daily", "Issues physical/mental daily 2+ wks"] |
|
} |
|
|
|
def load_models(self): |
|
models = [] |
|
for fn in self.model_filenames: |
|
filepath = self.model_path + fn |
|
with open(filepath, "rb") as file: |
|
models.append(pickle.load(file)) |
|
return models |
|
|
|
def make_predictions(self, user_input): |
|
"""Return list of numpy arrays, each array either [0] or [1].""" |
|
preds = [] |
|
for m in self.models: |
|
out = m.predict(user_input) |
|
preds.append(np.array(out).flatten()) |
|
return preds |
|
|
|
def get_majority_vote(self, predictions): |
|
"""Flatten all predictions and find 0 or 1 with majority.""" |
|
combined = np.concatenate(predictions) |
|
return np.bincount(combined).argmax() |
|
|
|
def evaluate_severity(self, majority_vote_count): |
|
"""Heuristic: Based on 16 total models, 0-4=Very Low, 5-8=Low, 9-12=Moderate, 13-16=Severe.""" |
|
if majority_vote_count >= 13: |
|
return "Mental health severity: Severe" |
|
elif majority_vote_count >= 9: |
|
return "Mental health severity: Moderate" |
|
elif majority_vote_count >= 5: |
|
return "Mental health severity: Low" |
|
else: |
|
return "Mental health severity: Very Low" |
|
|
|
|
|
|
|
|
|
model_filenames = [ |
|
"YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl", |
|
"YODPPROB.pkl", "YOWRPROB.pkl", "YODPR2WK.pkl", "YOWRDEPR.pkl", |
|
"YODPDISC.pkl", "YOLOSEV.pkl", "YOWRDCSN.pkl", "YODSMMDE.pkl", |
|
"YO_MDEA3.pkl", "YODPLSIN.pkl", "YOWRELES.pkl", "YOPB2WK.pkl" |
|
] |
|
model_path = "models/" |
|
predictor = ModelPredictor(model_path, model_filenames) |
|
|
|
|
|
|
|
|
|
def validate_inputs(*args): |
|
|
|
for arg in args: |
|
if arg == '' or arg is None: |
|
return False |
|
return True |
|
|
|
|
|
|
|
|
|
def predict( |
|
|
|
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX, |
|
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY, |
|
YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE, |
|
YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK, |
|
YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR, |
|
|
|
co_occ_feature1, co_occ_feature2, co_occ_label |
|
): |
|
""" |
|
Main function that: |
|
- Predicts with the 16 models |
|
- Aggregates results |
|
- Produces severity |
|
- Returns distribution & bar charts |
|
- Finds K=2 Nearest Neighbors |
|
- Produces *one* co-occurrence plot based on user-chosen columns |
|
""" |
|
|
|
|
|
user_input_data = { |
|
'YNURSMDE': [int(YNURSMDE)], |
|
'YMDEYR': [int(YMDEYR)], |
|
'YSOCMDE': [int(YSOCMDE)], |
|
'YMDESUD5ANYO': [int(YMDESUD5ANYO)], |
|
'YMSUD5YANY': [int(YMSUD5YANY)], |
|
'YUSUITHK': [int(YUSUITHK)], |
|
'YMDETXRX': [int(YMDETXRX)], |
|
'YUSUITHKYR': [int(YUSUITHKYR)], |
|
'YMDERSUD5ANY': [int(YMDERSUD5ANY)], |
|
'YUSUIPLNYR': [int(YUSUIPLNYR)], |
|
'YCOUNMDE': [int(YCOUNMDE)], |
|
'YPSY1MDE': [int(YPSY1MDE)], |
|
'YHLTMDE': [int(YHLTMDE)], |
|
'YDOCMDE': [int(YDOCMDE)], |
|
'YPSY2MDE': [int(YPSY2MDE)], |
|
'YMDEHARX': [int(YMDEHARX)], |
|
'LVLDIFMEM2': [int(LVLDIFMEM2)], |
|
'MDEIMPY': [int(MDEIMPY)], |
|
'YMDEHPO': [int(YMDEHPO)], |
|
'YMIMS5YANY': [int(YMIMS5YANY)], |
|
'YMDEIMAD5YR': [int(YMDEIMAD5YR)], |
|
'YMIUD5YANY': [int(YMIUD5YANY)], |
|
'YMDEHPRX': [int(YMDEHPRX)], |
|
'YMIMI5YANY': [int(YMIMI5YANY)], |
|
'YUSUIPLN': [int(YUSUIPLN)], |
|
'YTXMDEYR': [int(YTXMDEYR)], |
|
'YMDEAUD5YR': [int(YMDEAUD5YR)], |
|
'YRXMDEYR': [int(YRXMDEYR)], |
|
'YMDELT': [int(YMDELT)] |
|
} |
|
user_input = pd.DataFrame(user_input_data) |
|
|
|
|
|
predictions = predictor.make_predictions(user_input) |
|
majority_vote = predictor.get_majority_vote(predictions) |
|
majority_vote_count = np.sum(np.concatenate(predictions) == 1) |
|
severity = predictor.evaluate_severity(majority_vote_count) |
|
|
|
|
|
results_by_group = { |
|
"Concentration_and_Decision_Making": [], |
|
"Sleep_and_Energy_Levels": [], |
|
"Mood_and_Emotional_State": [], |
|
"Appetite_and_Weight_Changes": [], |
|
"Duration_and_Severity_of_Depression_Symptoms": [] |
|
} |
|
group_map = { |
|
"Concentration_and_Decision_Making": ["YOWRCONC", "YOWRDCSN"], |
|
"Sleep_and_Energy_Levels": ["YOWRHRS", "YO_MDEA5", "YOWRELES", "YO_MDEA2"], |
|
"Mood_and_Emotional_State": ["YOWRCHR", "YOWRLSIN", "YOWRDEPR", "YODPDISC", |
|
"YOLOSEV", "YODPLSIN", "YODSCEV"], |
|
"Appetite_and_Weight_Changes": ["YO_MDEA3", "YOWRELES"], |
|
"Duration_and_Severity_of_Depression_Symptoms": ["YODPPROB", "YOWRPROB", |
|
"YODPR2WK", "YODSMMDE", |
|
"YOPB2WK"] |
|
} |
|
|
|
|
|
grouped_output_lines = [] |
|
for i, pred_array in enumerate(predictions): |
|
col_name = model_filenames[i].split(".")[0] |
|
val = pred_array[0] |
|
if col_name in predictor.prediction_map and val in [0, 1]: |
|
text = predictor.prediction_map[col_name][val] |
|
out_line = f"{col_name}: {text}" |
|
else: |
|
out_line = f"{col_name}: Prediction={val}" |
|
|
|
|
|
placed = False |
|
for g_key, g_cols in group_map.items(): |
|
if col_name in g_cols: |
|
results_by_group[g_key].append(out_line) |
|
placed = True |
|
break |
|
if not placed: |
|
|
|
pass |
|
|
|
|
|
for group_label, pred_lines in results_by_group.items(): |
|
if pred_lines: |
|
grouped_output_lines.append(f"Group {group_label}:") |
|
grouped_output_lines.append("\n".join(pred_lines)) |
|
grouped_output_lines.append("") |
|
|
|
if len(grouped_output_lines) == 0: |
|
final_result_text = "No predictions made. Check inputs." |
|
else: |
|
final_result_text = "\n".join(grouped_output_lines).strip() |
|
|
|
|
|
|
|
total_patients = len(df) |
|
total_count_md = ( |
|
"### Total Patient Count\n" |
|
f"**{total_patients}** total patients in the dataset." |
|
) |
|
|
|
|
|
input_counts = {} |
|
for c in user_input_data.keys(): |
|
v = user_input_data[c][0] |
|
input_counts[c] = len(df[df[c] == v]) |
|
df_input_counts = pd.DataFrame({"Feature": list(input_counts.keys()), "Count": list(input_counts.values())}) |
|
fig_input_bar = px.bar( |
|
df_input_counts, |
|
x="Feature", |
|
y="Count", |
|
title="Number of Patients with the Same Value for Each Input Feature" |
|
) |
|
fig_input_bar.update_layout(xaxis={"categoryorder": "total descending"}) |
|
|
|
|
|
label_counts = {} |
|
for i, pred_array in enumerate(predictions): |
|
col_name = model_filenames[i].split(".")[0] |
|
val = pred_array[0] |
|
if val in [0,1]: |
|
label_counts[col_name] = len(df[df[col_name] == val]) |
|
|
|
if len(label_counts) > 0: |
|
df_label_counts = pd.DataFrame({ |
|
"Label Column": list(label_counts.keys()), |
|
"Count": list(label_counts.values()) |
|
}) |
|
fig_label_bar = px.bar( |
|
df_label_counts, |
|
x="Label Column", |
|
y="Count", |
|
title="Number of Patients with the Same Predicted Label" |
|
) |
|
else: |
|
fig_label_bar = px.bar(title="No valid predicted labels to display") |
|
|
|
|
|
|
|
sample_feats = list(user_input_data.keys())[:31] |
|
sample_labels = [fn.split(".")[0] for fn in model_filenames[:15]] |
|
dist_segments = [] |
|
for feat in sample_feats: |
|
if feat not in df.columns: |
|
continue |
|
for lbl in sample_labels: |
|
if lbl not in df.columns: |
|
continue |
|
temp_g = df.groupby([feat,lbl]).size().reset_index(name="count") |
|
temp_g["feature"] = feat |
|
temp_g["label"] = lbl |
|
dist_segments.append(temp_g) |
|
if len(dist_segments) > 0: |
|
big_dist_df = pd.concat(dist_segments, ignore_index=True) |
|
fig_dist = px.bar( |
|
big_dist_df, |
|
x=big_dist_df.columns[0], |
|
y="count", |
|
color=big_dist_df.columns[1], |
|
facet_row="feature", |
|
facet_col="label", |
|
title="Sample Distribution Plot (first 4 features vs first 3 labels)" |
|
) |
|
fig_dist.update_layout(height=700) |
|
else: |
|
fig_dist = px.bar(title="No distribution plot generated (columns not found).") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
columns_for_distance = list(user_input.columns) |
|
sub_df = df[columns_for_distance].copy() |
|
user_row = user_input.iloc[0] |
|
distances = [] |
|
for idx, row_ in sub_df.iterrows(): |
|
dist_ = sum(row_[col] != user_row[col] for col in columns_for_distance) |
|
distances.append(dist_) |
|
df_dist = df.copy() |
|
df_dist["distance"] = distances |
|
|
|
K = 2 |
|
nearest_neighbors = df_dist.sort_values("distance", ascending=True).head(K) |
|
|
|
|
|
nn_md = ["### Nearest Neighbors (K=2)"] |
|
nn_md.append("(In a real application, you'd refine which features matter, how to encode them, etc.)\n") |
|
for irow in nearest_neighbors.itertuples(): |
|
nn_md.append(f"- **Neighbor ID {irow.Index}**: distance={irow.distance}") |
|
nn_md_str = "\n".join(nn_md) |
|
|
|
|
|
|
|
if (co_occ_feature1 is not None and co_occ_feature1 != "None" and |
|
co_occ_feature2 is not None and co_occ_feature2 != "None" and |
|
co_occ_label is not None and co_occ_label != "None"): |
|
|
|
if (co_occ_feature1 in df.columns and |
|
co_occ_feature2 in df.columns and |
|
co_occ_label in df.columns): |
|
|
|
co_data = df.groupby([co_occ_feature1, co_occ_feature2, co_occ_label]).size().reset_index(name="count") |
|
fig_co_occ = px.bar( |
|
co_data, |
|
x=co_occ_feature1, |
|
y="count", |
|
color=co_occ_label, |
|
facet_col=co_occ_feature2, |
|
title=f"Co-occurrence: {co_occ_feature1} & {co_occ_feature2} vs {co_occ_label}" |
|
) |
|
else: |
|
fig_co_occ = px.bar(title="One or more selected columns not found in dataframe.") |
|
else: |
|
fig_co_occ = px.bar(title="No co-occurrence plot (choose two features + one label).") |
|
|
|
|
|
return ( |
|
final_result_text, |
|
severity, |
|
total_count_md, |
|
fig_dist, |
|
nn_md_str, |
|
fig_co_occ, |
|
fig_input_bar, |
|
fig_label_bar |
|
) |
|
|
|
|
|
|
|
|
|
input_mapping = { |
|
'YNURSMDE': {"Yes": 1, "No": 0}, |
|
'YMDEYR': {"Yes": 1, "No": 2}, |
|
'YSOCMDE': {"Yes": 1, "No": 0}, |
|
'YMDESUD5ANYO': {"SUD only": 1, "MDE only": 2, "SUD & MDE": 3, "Neither": 4}, |
|
'YMSUD5YANY': {"Yes": 1, "No": 0}, |
|
'YUSUITHK': {"Yes": 1, "No": 2, "Unsure": 3, "Don't want to answer": 4}, |
|
'YMDETXRX': {"Yes": 1, "No": 0}, |
|
'YUSUITHKYR': {"Yes": 1, "No": 2, "Unsure": 3, "Don't want to answer": 4}, |
|
'YMDERSUD5ANY': {"Yes": 1, "No": 0}, |
|
'YUSUIPLNYR': {"Yes": 1, "No": 2, "Unsure": 3, "Don't want to answer": 4}, |
|
'YCOUNMDE': {"Yes": 1, "No": 0}, |
|
'YPSY1MDE': {"Yes": 1, "No": 0}, |
|
'YHLTMDE': {"Yes": 1, "No": 0}, |
|
'YDOCMDE': {"Yes": 1, "No": 0}, |
|
'YPSY2MDE': {"Yes": 1, "No": 0}, |
|
'YMDEHARX': {"Yes": 1, "No": 0}, |
|
'LVLDIFMEM2': {"No Difficulty": 1, "Some Difficulty": 2, "A lot or cannot do": 3}, |
|
'MDEIMPY': {"Yes": 1, "No": 2}, |
|
'YMDEHPO': {"Yes": 1, "No": 0}, |
|
'YMIMS5YANY': {"Yes": 1, "No": 0}, |
|
'YMDEIMAD5YR': {"Yes": 1, "No": 0}, |
|
'YMIUD5YANY': {"Yes": 1, "No": 0}, |
|
'YMDEHPRX': {"Yes": 1, "No": 0}, |
|
'YMIMI5YANY': {"Yes": 1, "No": 0}, |
|
'YUSUIPLN': {"Yes": 1, "No": 2, "Unsure": 3, "Don't want to answer": 4}, |
|
'YTXMDEYR': {"Yes": 1, "No": 0}, |
|
'YMDEAUD5YR': {"Yes": 1, "No": 0}, |
|
'YRXMDEYR': {"Yes": 1, "No": 0}, |
|
'YMDELT': {"Yes": 1, "No": 2} |
|
} |
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
|
|
|
|
original_inputs = [ |
|
gr.Dropdown(list(input_mapping['YMDEYR'].keys()), label="YMDEYR: Past Year MDE?"), |
|
gr.Dropdown(list(input_mapping['YMDERSUD5ANY'].keys()), label="YMDERSUD5ANY: MDE or SUD - ANY?"), |
|
gr.Dropdown(list(input_mapping['YMDEIMAD5YR'].keys()), label="YMDEIMAD5YR: MDE + ALCOHOL?"), |
|
gr.Dropdown(list(input_mapping['YMIMS5YANY'].keys()), label="YMIMS5YANY: MDE + SUBSTANCE?"), |
|
gr.Dropdown(list(input_mapping['YMDELT'].keys()), label="YMDELT: MDE in Lifetime?"), |
|
gr.Dropdown(list(input_mapping['YMDEHARX'].keys()), label="YMDEHARX: Saw Health Prof + Meds?"), |
|
gr.Dropdown(list(input_mapping['YMDEHPRX'].keys()), label="YMDEHPRX: Saw Health Prof or Meds?"), |
|
gr.Dropdown(list(input_mapping['YMDETXRX'].keys()), label="YMDETXRX: Received Treatment?"), |
|
gr.Dropdown(list(input_mapping['YMDEHPO'].keys()), label="YMDEHPO: Saw Health Prof Only?"), |
|
gr.Dropdown(list(input_mapping['YMDEAUD5YR'].keys()), label="YMDEAUD5YR: MDE + Alcohol Use?"), |
|
gr.Dropdown(list(input_mapping['YMIMI5YANY'].keys()), label="YMIMI5YANY: MDE + ILL Drug Use?"), |
|
gr.Dropdown(list(input_mapping['YMIUD5YANY'].keys()), label="YMIUD5YANY: MDE + ILL Drug Use?"), |
|
gr.Dropdown(list(input_mapping['YMDESUD5ANYO'].keys()), label="YMDESUD5ANYO: MDE vs SUD vs BOTH vs NEITHER"), |
|
|
|
|
|
gr.Dropdown(list(input_mapping['YNURSMDE'].keys()), label="YNURSMDE: Nurse/OT about MDE?"), |
|
gr.Dropdown(list(input_mapping['YSOCMDE'].keys()), label="YSOCMDE: Social Worker?"), |
|
gr.Dropdown(list(input_mapping['YCOUNMDE'].keys()), label="YCOUNMDE: Counselor?"), |
|
gr.Dropdown(list(input_mapping['YPSY1MDE'].keys()), label="YPSY1MDE: Psychologist?"), |
|
gr.Dropdown(list(input_mapping['YPSY2MDE'].keys()), label="YPSY2MDE: Psychiatrist?"), |
|
gr.Dropdown(list(input_mapping['YHLTMDE'].keys()), label="YHLTMDE: Health Prof?"), |
|
gr.Dropdown(list(input_mapping['YDOCMDE'].keys()), label="YDOCMDE: GP/Family MD?"), |
|
gr.Dropdown(list(input_mapping['YTXMDEYR'].keys()), label="YTXMDEYR: Doctor/Health Prof?"), |
|
|
|
|
|
gr.Dropdown(list(input_mapping['YUSUITHKYR'].keys()), label="YUSUITHKYR: Serious Suicide Thoughts?"), |
|
gr.Dropdown(list(input_mapping['YUSUIPLNYR'].keys()), label="YUSUIPLNYR: Made Plans?"), |
|
gr.Dropdown(list(input_mapping['YUSUITHK'].keys()), label="YUSUITHK: Suicide Thoughts (12 mo)?"), |
|
gr.Dropdown(list(input_mapping['YUSUIPLN'].keys()), label="YUSUIPLN: Made Plans (12 mo)?"), |
|
|
|
|
|
gr.Dropdown(list(input_mapping['MDEIMPY'].keys()), label="MDEIMPY: Severe Role Impairment?"), |
|
gr.Dropdown(list(input_mapping['LVLDIFMEM2'].keys()), label="LVLDIFMEM2: Difficulty Remembering/Concentrating?"), |
|
gr.Dropdown(list(input_mapping['YMSUD5YANY'].keys()), label="YMSUD5YANY: MDE + Substance?"), |
|
gr.Dropdown(list(input_mapping['YRXMDEYR'].keys()), label="YRXMDEYR: Used Meds for MDE (12 mo)?"), |
|
] |
|
|
|
|
|
|
|
all_cols = ["None"] + df.columns.tolist() |
|
co_occ_feature1 = gr.Dropdown(all_cols, label="Co-Occ Feature 1", value="None") |
|
co_occ_feature2 = gr.Dropdown(all_cols, label="Co-Occ Feature 2", value="None") |
|
all_label_cols = ["None"] + list(predictor.prediction_map.keys()) |
|
co_occ_label = gr.Dropdown(all_label_cols, label="Co-Occ Label", value="None") |
|
|
|
|
|
inputs = original_inputs + [co_occ_feature1, co_occ_feature2, co_occ_label] |
|
|
|
|
|
outputs = [ |
|
gr.Textbox(label="Prediction Results", lines=15), |
|
gr.Textbox(label="Mental Health Severity", lines=2), |
|
gr.Markdown(label="Total Patient Count"), |
|
gr.Plot(label="Distribution Plot (Sample)"), |
|
gr.Markdown(label="Nearest Neighbors (K=2)"), |
|
gr.Plot(label="Co-occurrence Plot"), |
|
gr.Plot(label="Same Value Bar (Inputs)"), |
|
gr.Plot(label="Predicted Label Bar") |
|
] |
|
|
|
|
|
|
|
|
|
def predict_with_text( |
|
|
|
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX, |
|
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY, |
|
YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE, |
|
YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK, |
|
YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR, |
|
co_occ_feature1, co_occ_feature2, co_occ_label |
|
): |
|
|
|
valid = validate_inputs( |
|
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX, |
|
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY, |
|
YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE, |
|
YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK, |
|
YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR |
|
) |
|
if not valid: |
|
return ( |
|
"Please select all required fields.", |
|
"Validation Error", |
|
"No data", |
|
None, |
|
"No data", |
|
None, |
|
None, |
|
None |
|
) |
|
|
|
|
|
user_inputs = { |
|
'YNURSMDE': input_mapping['YNURSMDE'][YNURSMDE], |
|
'YMDEYR': input_mapping['YMDEYR'][YMDEYR], |
|
'YSOCMDE': input_mapping['YSOCMDE'][YSOCMDE], |
|
'YMDESUD5ANYO': input_mapping['YMDESUD5ANYO'][YMDESUD5ANYO], |
|
'YMSUD5YANY': input_mapping['YMSUD5YANY'][YMSUD5YANY], |
|
'YUSUITHK': input_mapping['YUSUITHK'][YUSUITHK], |
|
'YMDETXRX': input_mapping['YMDETXRX'][YMDETXRX], |
|
'YUSUITHKYR': input_mapping['YUSUITHKYR'][YUSUITHKYR], |
|
'YMDERSUD5ANY': input_mapping['YMDERSUD5ANY'][YMDERSUD5ANY], |
|
'YUSUIPLNYR': input_mapping['YUSUIPLNYR'][YUSUIPLNYR], |
|
'YCOUNMDE': input_mapping['YCOUNMDE'][YCOUNMDE], |
|
'YPSY1MDE': input_mapping['YPSY1MDE'][YPSY1MDE], |
|
'YHLTMDE': input_mapping['YHLTMDE'][YHLTMDE], |
|
'YDOCMDE': input_mapping['YDOCMDE'][YDOCMDE], |
|
'YPSY2MDE': input_mapping['YPSY2MDE'][YPSY2MDE], |
|
'YMDEHARX': input_mapping['YMDEHARX'][YMDEHARX], |
|
'LVLDIFMEM2': input_mapping['LVLDIFMEM2'][LVLDIFMEM2], |
|
'MDEIMPY': input_mapping['MDEIMPY'][MDEIMPY], |
|
'YMDEHPO': input_mapping['YMDEHPO'][YMDEHPO], |
|
'YMIMS5YANY': input_mapping['YMIMS5YANY'][YMIMS5YANY], |
|
'YMDEIMAD5YR': input_mapping['YMDEIMAD5YR'][YMDEIMAD5YR], |
|
'YMIUD5YANY': input_mapping['YMIUD5YANY'][YMIUD5YANY], |
|
'YMDEHPRX': input_mapping['YMDEHPRX'][YMDEHPRX], |
|
'YMIMI5YANY': input_mapping['YMIMI5YANY'][YMIMI5YANY], |
|
'YUSUIPLN': input_mapping['YUSUIPLN'][YUSUIPLN], |
|
'YTXMDEYR': input_mapping['YTXMDEYR'][YTXMDEYR], |
|
'YMDEAUD5YR': input_mapping['YMDEAUD5YR'][YMDEAUD5YR], |
|
'YRXMDEYR': input_mapping['YRXMDEYR'][YRXMDEYR], |
|
'YMDELT': input_mapping['YMDELT'][YMDELT] |
|
} |
|
|
|
|
|
return predict( |
|
**user_inputs, |
|
co_occ_feature1=co_occ_feature1, |
|
co_occ_feature2=co_occ_feature2, |
|
co_occ_label=co_occ_label |
|
) |
|
|
|
|
|
custom_css = """ |
|
.gradio-container * { |
|
color: #1B1212 !important; |
|
} |
|
""" |
|
|
|
interface = gr.Interface( |
|
fn=predict_with_text, |
|
inputs=inputs, |
|
outputs=outputs, |
|
title="Mental Health Screening (NSDUH) with Selective Co-Occurrence", |
|
css=custom_css, |
|
description=""" |
|
**Instructions**: |
|
1. Fill out all required fields regarding MDE/Substance Use/Consultations/Suicidal/Impairments. |
|
2. (Optional) Choose 2 features and 1 label for the *Co-occurrence* plot. |
|
- If you do not select them (or leave them as "None"), that plot will be skipped. |
|
3. Click "Submit" to get predictions, severity, distribution plots, nearest neighbors, and your custom co-occurrence chart. |
|
""" |
|
) |
|
|
|
if __name__ == "__main__": |
|
interface.launch() |
|
|