|
import pickle |
|
import numpy as np |
|
import pandas as pd |
|
import plotly.express as px |
|
import gradio as gr |
|
|
|
|
|
|
|
|
|
df = pd.read_csv("X_train_Y_Train_merged_train.csv") |
|
|
|
|
|
model_filenames = [ |
|
"YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl", |
|
"YODPPROB.pkl", "YOWRPROB.pkl", "YODPR2WK.pkl", "YOWRDEPR.pkl", |
|
"YODPDISC.pkl", "YOLOSEV.pkl", "YOWRDCSN.pkl", "YODSMMDE.pkl", |
|
"YO_MDEA3.pkl", "YODPLSIN.pkl", "YOWRELES.pkl", "YOPB2WK.pkl" |
|
] |
|
model_path = "models/" |
|
|
|
|
|
|
|
|
|
|
|
class ModelPredictor: |
|
def __init__(self, model_path, model_filenames): |
|
self.model_path = model_path |
|
self.model_filenames = model_filenames |
|
self.models = self.load_models() |
|
|
|
self.prediction_map = { |
|
"YOWRCONC": ["Did not have difficulty concentrating", "Had difficulty concentrating"], |
|
"YOSEEDOC": ["Did not feel the need to see a doctor", "Felt the need to see a doctor"], |
|
"YOWRHRS": ["Did not have trouble sleeping", "Had trouble sleeping"], |
|
"YO_MDEA5": ["Others did not notice restlessness/lethargy", "Others noticed restlessness/lethargy"], |
|
"YOWRCHR": ["Did not feel so sad", "Felt so sad nothing could cheer up"], |
|
"YOWRLSIN": ["Did not feel bored and lose interest", "Felt bored and lost interest"], |
|
"YODPPROB": ["No other problems for 2+ weeks", "Had other problems for 2+ weeks"], |
|
"YOWRPROB": ["Did not have the worst time ever feeling", "Had the worst time ever feeling"], |
|
"YODPR2WK": ["No periods of 2+ weeks feelings", "Had periods of 2+ weeks feelings"], |
|
"YOWRDEPR": ["Did not feel depressed mostly everyday", "Felt depressed mostly everyday"], |
|
"YODPDISC": ["Mood not depressed overall", "Mood depressed overall (discrepancy)"], |
|
"YOLOSEV": ["Did not lose interest in enjoyable things", "Lost interest in enjoyable things"], |
|
"YOWRDCSN": ["Was able to make decisions", "Was unable to make decisions"], |
|
"YODSMMDE": ["Never had depression for 2+ weeks", "Had depression for 2+ weeks"], |
|
"YO_MDEA3": ["No appetite/weight changes", "Had appetite/weight changes"], |
|
"YODPLSIN": ["Never bored/lost interest", "Felt bored/lost interest"], |
|
"YOWRELES": ["Did not eat less than usual", "Ate less than usual"], |
|
"YODSCEV": ["Fewer severe symptoms", "More severe symptoms"], |
|
"YOPB2WK": ["No uneasy feelings 2+ weeks", "Had uneasy feelings 2+ weeks"], |
|
"YO_MDEA2": ["No issues w/ physical/mental well-being", "Issues w/ physical/mental well-being"] |
|
} |
|
|
|
def load_models(self): |
|
models = [] |
|
for filename in model_filenames: |
|
filepath = self.model_path + filename |
|
with open(filepath, 'rb') as file: |
|
model = pickle.load(file) |
|
models.append(model) |
|
return models |
|
|
|
def make_predictions(self, user_input): |
|
""" |
|
Returns a list of numpy arrays, each array is [0] or [1]. |
|
The i-th array corresponds to the i-th model in self.models. |
|
""" |
|
predictions = [] |
|
for model in self.models: |
|
pred = model.predict(user_input) |
|
predictions.append(pred.flatten()) |
|
return predictions |
|
|
|
def get_majority_vote(self, predictions): |
|
""" |
|
Flatten all predictions from all models, combine them, |
|
then find the majority class (0 or 1). |
|
""" |
|
combined = np.concatenate(predictions) |
|
majority = np.bincount(combined).argmax() |
|
return majority |
|
|
|
|
|
def evaluate_severity(self, majority_vote_count): |
|
if majority_vote_count >= 13: |
|
return "Mental Health Severity: Severe" |
|
elif majority_vote_count >= 9: |
|
return "Mental Health Severity: Moderate" |
|
elif majority_vote_count >= 5: |
|
return "Mental Health Severity: Low" |
|
else: |
|
return "Mental Health Severity: Very Low" |
|
|
|
|
|
|
|
|
|
|
|
def validate_inputs(*args): |
|
for arg in args: |
|
if arg == '' or arg is None: |
|
return False |
|
return True |
|
|
|
|
|
|
|
|
|
|
|
predictor = ModelPredictor(model_path, model_filenames) |
|
|
|
def predict( |
|
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX, |
|
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY, |
|
YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE, |
|
YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK, |
|
YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR |
|
): |
|
|
|
if not validate_inputs( |
|
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX, |
|
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY, |
|
YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE, |
|
YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK, |
|
YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR |
|
): |
|
return ( |
|
"Please select all required fields.", |
|
"Validation Error", |
|
"No data", |
|
None, |
|
"No data", |
|
None, |
|
None, |
|
None |
|
) |
|
|
|
|
|
user_input_data = { |
|
'YNURSMDE': [int(YNURSMDE)], |
|
'YMDEYR': [int(YMDEYR)], |
|
'YSOCMDE': [int(YSOCMDE)], |
|
'YMDESUD5ANYO': [int(YMDESUD5ANYO)], |
|
'YMSUD5YANY': [int(YMSUD5YANY)], |
|
'YUSUITHK': [int(YUSUITHK)], |
|
'YMDETXRX': [int(YMDETXRX)], |
|
'YUSUITHKYR': [int(YUSUITHKYR)], |
|
'YMDERSUD5ANY': [int(YMDERSUD5ANY)], |
|
'YUSUIPLNYR': [int(YUSUIPLNYR)], |
|
'YCOUNMDE': [int(YCOUNMDE)], |
|
'YPSY1MDE': [int(YPSY1MDE)], |
|
'YHLTMDE': [int(YHLTMDE)], |
|
'YDOCMDE': [int(YDOCMDE)], |
|
'YPSY2MDE': [int(YPSY2MDE)], |
|
'YMDEHARX': [int(YMDEHARX)], |
|
'LVLDIFMEM2': [int(LVLDIFMEM2)], |
|
'MDEIMPY': [int(MDEIMPY)], |
|
'YMDEHPO': [int(YMDEHPO)], |
|
'YMIMS5YANY': [int(YMIMS5YANY)], |
|
'YMDEIMAD5YR': [int(YMDEIMAD5YR)], |
|
'YMIUD5YANY': [int(YMIUD5YANY)], |
|
'YMDEHPRX': [int(YMDEHPRX)], |
|
'YMIMI5YANY': [int(YMIMI5YANY)], |
|
'YUSUIPLN': [int(YUSUIPLN)], |
|
'YTXMDEYR': [int(YTXMDEYR)], |
|
'YMDEAUD5YR': [int(YMDEAUD5YR)], |
|
'YRXMDEYR': [int(YRXMDEYR)], |
|
'YMDELT': [int(YMDELT)] |
|
} |
|
user_input = pd.DataFrame(user_input_data) |
|
|
|
|
|
predictions = predictor.make_predictions(user_input) |
|
|
|
|
|
majority_vote = predictor.get_majority_vote(predictions) |
|
|
|
|
|
num_ones = sum(np.concatenate(predictions) == 1) |
|
|
|
|
|
severity = predictor.evaluate_severity(num_ones) |
|
|
|
|
|
groups = { |
|
"Concentration_and_Decision_Making": ["YOWRCONC", "YOWRDCSN"], |
|
"Sleep_and_Energy_Levels": ["YOWRHRS", "YO_MDEA5", "YOWRELES", "YO_MDEA2"], |
|
"Mood_and_Emotional_State": ["YOWRCHR", "YOWRLSIN", "YOWRDEPR", "YODPDISC", |
|
"YOLOSEV", "YODPLSIN", "YODSCEV"], |
|
"Appetite_and_Weight_Changes": ["YO_MDEA3", "YOWRELES"], |
|
"Duration_and_Severity_of_Depression_Symptoms": ["YODPPROB", "YOWRPROB", |
|
"YODPR2WK", "YODSMMDE", |
|
"YOPB2WK"] |
|
} |
|
|
|
grouped_text = {k: [] for k in groups} |
|
for i, arr in enumerate(predictions): |
|
col_name = model_filenames[i].split('.')[0] |
|
pred_val = arr[0] |
|
if col_name in predictor.prediction_map and pred_val in [0,1]: |
|
text_val = predictor.prediction_map[col_name][pred_val] |
|
else: |
|
text_val = f"Prediction={pred_val}" |
|
|
|
found_group = False |
|
for gname, gcols in groups.items(): |
|
if col_name in gcols: |
|
grouped_text[gname].append(f"{col_name} => {text_val}") |
|
found_group = True |
|
break |
|
|
|
|
|
final_str = [] |
|
for gname, items in grouped_text.items(): |
|
if items: |
|
final_str.append(f"**{gname.replace('_',' ')}**") |
|
final_str.append("\n".join(items)) |
|
final_str.append("\n") |
|
final_str = "\n".join(final_str).strip() |
|
if not final_str: |
|
final_str = "No predictions made. Please check inputs." |
|
|
|
|
|
total_patients = len(df) |
|
total_patient_markdown = ( |
|
f"### Total Patient Count\nThere are **{total_patients}** patients in the dataset." |
|
) |
|
|
|
|
|
same_val_counts = {} |
|
for col, val_list in user_input_data.items(): |
|
val_ = val_list[0] |
|
same_val_counts[col] = len(df[df[col] == val_]) |
|
bar_input_df = pd.DataFrame({"Feature": list(same_val_counts.keys()), |
|
"Count": list(same_val_counts.values())}) |
|
fig_bar_input = px.bar( |
|
bar_input_df, x="Feature", y="Count", |
|
title="Number of Patients with Same Input Feature Values" |
|
) |
|
fig_bar_input.update_layout(width=800, height=500) |
|
|
|
|
|
label_counts = {} |
|
for i, arr in enumerate(predictions): |
|
lbl_col = model_filenames[i].split('.')[0] |
|
pred_val = arr[0] |
|
if pred_val in [0,1]: |
|
label_counts[lbl_col] = len(df[df[lbl_col] == pred_val]) |
|
|
|
if label_counts: |
|
bar_label_df = pd.DataFrame({"Label": list(label_counts.keys()), |
|
"Count": list(label_counts.values())}) |
|
fig_bar_labels = px.bar(bar_label_df, x="Label", y="Count", |
|
title="Number of Patients with the Same Predicted Label") |
|
fig_bar_labels.update_layout(width=800, height=500) |
|
else: |
|
fig_bar_labels = px.bar(title="No valid predicted labels to display.") |
|
fig_bar_labels.update_layout(width=800, height=500) |
|
|
|
|
|
subset_input_cols = list(user_input_data.keys())[:4] |
|
subset_labels = [fn.split('.')[0] for fn in model_filenames[:3]] |
|
dist_rows = [] |
|
for feat in subset_input_cols: |
|
if feat not in df.columns: |
|
continue |
|
for label_col in subset_labels: |
|
if label_col not in df.columns: |
|
continue |
|
tmp = df.groupby([feat, label_col]).size().reset_index(name="count") |
|
tmp["feature"] = feat |
|
tmp["label"] = label_col |
|
dist_rows.append(tmp) |
|
if dist_rows: |
|
big_dist_df = pd.concat(dist_rows, ignore_index=True) |
|
fig_dist = px.bar( |
|
big_dist_df, |
|
x=big_dist_df.columns[0], |
|
y="count", |
|
color=big_dist_df.columns[1], |
|
facet_row="feature", |
|
facet_col="label", |
|
title="Distribution of Sample Input Features vs. Sample Predicted Labels" |
|
) |
|
fig_dist.update_layout(width=1000, height=700) |
|
else: |
|
fig_dist = px.bar(title="Distribution plot not generated.") |
|
|
|
|
|
nearest_neighbors_markdown = "Nearest neighbors omitted or placed here if needed..." |
|
|
|
|
|
co_occurrence_placeholder = None |
|
|
|
|
|
return ( |
|
final_str, |
|
severity, |
|
total_patient_markdown, |
|
fig_dist, |
|
nearest_neighbors_markdown, |
|
co_occurrence_placeholder, |
|
fig_bar_input, |
|
fig_bar_labels |
|
) |
|
|
|
|
|
|
|
|
|
|
|
input_mapping = { |
|
'YNURSMDE': {"Yes": 1, "No": 0}, |
|
'YMDEYR': {"Yes": 1, "No": 2}, |
|
'YSOCMDE': {"Yes": 1, "No": 0}, |
|
'YMDESUD5ANYO': {"SUD only, no MDE": 1, "MDE only, no SUD": 2, "SUD and MDE": 3, "Neither SUD or MDE": 4}, |
|
'YMSUD5YANY': {"Yes": 1, "No": 0}, |
|
'YUSUITHK': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4}, |
|
'YMDETXRX': {"Yes": 1, "No": 0}, |
|
'YUSUITHKYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4}, |
|
'YMDERSUD5ANY': {"Yes": 1, "No": 0}, |
|
'YUSUIPLNYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4}, |
|
'YCOUNMDE': {"Yes": 1, "No": 0}, |
|
'YPSY1MDE': {"Yes": 1, "No": 0}, |
|
'YHLTMDE': {"Yes": 1, "No": 0}, |
|
'YDOCMDE': {"Yes": 1, "No": 0}, |
|
'YPSY2MDE': {"Yes": 1, "No": 0}, |
|
'YMDEHARX': {"Yes": 1, "No": 0}, |
|
'LVLDIFMEM2': {"No Difficulty": 1, "Some difficulty": 2, "A lot of difficulty or cannot do at all": 3}, |
|
'MDEIMPY': {"Yes": 1, "No": 2}, |
|
'YMDEHPO': {"Yes": 1, "No": 0}, |
|
'YMIMS5YANY': {"Yes": 1, "No": 0}, |
|
'YMDEIMAD5YR': {"Yes": 1, "No": 0}, |
|
'YMIUD5YANY': {"Yes": 1, "No": 0}, |
|
'YMDEHPRX': {"Yes": 1, "No": 0}, |
|
'YMIMI5YANY': {"Yes": 1, "No": 0}, |
|
'YUSUIPLN': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4}, |
|
'YTXMDEYR': {"Yes": 1, "No": 0}, |
|
'YMDEAUD5YR': {"Yes": 1, "No": 0}, |
|
'YRXMDEYR': {"Yes": 1, "No": 0}, |
|
'YMDELT': {"Yes": 1, "No": 2} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
def co_occurrence_plot(feature1, feature2, label_col): |
|
""" |
|
Generate a single co-occurrence bar chart grouping by [feature1, feature2, label_col]. |
|
""" |
|
if not feature1 or not feature2 or not label_col: |
|
return px.bar(title="Please select all three fields.") |
|
if feature1 not in df.columns or feature2 not in df.columns or label_col not in df.columns: |
|
return px.bar(title="Selected columns not found in the dataset.") |
|
|
|
grouped_df = df.groupby([feature1, feature2, label_col]).size().reset_index(name="count") |
|
fig = px.bar( |
|
grouped_df, |
|
x=feature1, |
|
y="count", |
|
color=label_col, |
|
facet_col=feature2, |
|
title=f"Co-Occurrence Plot: {feature1} & {feature2} vs. {label_col}" |
|
) |
|
fig.update_layout(width=1000, height=600) |
|
return fig |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo: |
|
|
|
with gr.Tab("Prediction"): |
|
|
|
YMDEYR_dd = gr.Dropdown(list(input_mapping['YMDEYR'].keys()), label="YMDEYR") |
|
YMDERSUD5ANY_dd = gr.Dropdown(list(input_mapping['YMDERSUD5ANY'].keys()), label="YMDERSUD5ANY") |
|
YMDEIMAD5YR_dd = gr.Dropdown(list(input_mapping['YMDEIMAD5YR'].keys()), label="YMDEIMAD5YR") |
|
YMIMS5YANY_dd = gr.Dropdown(list(input_mapping['YMIMS5YANY'].keys()), label="YMIMS5YANY") |
|
YMDELT_dd = gr.Dropdown(list(input_mapping['YMDELT'].keys()), label="YMDELT") |
|
YMDEHARX_dd = gr.Dropdown(list(input_mapping['YMDEHARX'].keys()), label="YMDEHARX") |
|
YMDEHPRX_dd = gr.Dropdown(list(input_mapping['YMDEHPRX'].keys()), label="YMDEHPRX") |
|
YMDETXRX_dd = gr.Dropdown(list(input_mapping['YMDETXRX'].keys()), label="YMDETXRX") |
|
YMDEHPO_dd = gr.Dropdown(list(input_mapping['YMDEHPO'].keys()), label="YMDEHPO") |
|
YMDEAUD5YR_dd = gr.Dropdown(list(input_mapping['YMDEAUD5YR'].keys()), label="YMDEAUD5YR") |
|
YMIMI5YANY_dd = gr.Dropdown(list(input_mapping['YMIMI5YANY'].keys()), label="YMIMI5YANY") |
|
YMIUD5YANY_dd = gr.Dropdown(list(input_mapping['YMIUD5YANY'].keys()), label="YMIUD5YANY") |
|
YMDESUD5ANYO_dd = gr.Dropdown(list(input_mapping['YMDESUD5ANYO'].keys()), label="YMDESUD5ANYO") |
|
|
|
|
|
YNURSMDE_dd = gr.Dropdown(list(input_mapping['YNURSMDE'].keys()), label="YNURSMDE") |
|
YSOCMDE_dd = gr.Dropdown(list(input_mapping['YSOCMDE'].keys()), label="YSOCMDE") |
|
YCOUNMDE_dd = gr.Dropdown(list(input_mapping['YCOUNMDE'].keys()), label="YCOUNMDE") |
|
YPSY1MDE_dd = gr.Dropdown(list(input_mapping['YPSY1MDE'].keys()), label="YPSY1MDE") |
|
YPSY2MDE_dd = gr.Dropdown(list(input_mapping['YPSY2MDE'].keys()), label="YPSY2MDE") |
|
YHLTMDE_dd = gr.Dropdown(list(input_mapping['YHLTMDE'].keys()), label="YHLTMDE") |
|
YDOCMDE_dd = gr.Dropdown(list(input_mapping['YDOCMDE'].keys()), label="YDOCMDE") |
|
YTXMDEYR_dd = gr.Dropdown(list(input_mapping['YTXMDEYR'].keys()), label="YTXMDEYR") |
|
|
|
|
|
YUSUITHKYR_dd = gr.Dropdown(list(input_mapping['YUSUITHKYR'].keys()), label="YUSUITHKYR") |
|
YUSUIPLNYR_dd = gr.Dropdown(list(input_mapping['YUSUIPLNYR'].keys()), label="YUSUIPLNYR") |
|
YUSUITHK_dd = gr.Dropdown(list(input_mapping['YUSUITHK'].keys()), label="YUSUITHK") |
|
YUSUIPLN_dd = gr.Dropdown(list(input_mapping['YUSUIPLN'].keys()), label="YUSUIPLN") |
|
|
|
|
|
MDEIMPY_dd = gr.Dropdown(list(input_mapping['MDEIMPY'].keys()), label="MDEIMPY") |
|
LVLDIFMEM2_dd = gr.Dropdown(list(input_mapping['LVLDIFMEM2'].keys()), label="LVLDIFMEM2") |
|
YMSUD5YANY_dd = gr.Dropdown(list(input_mapping['YMSUD5YANY'].keys()), label="YMSUD5YANY") |
|
YRXMDEYR_dd = gr.Dropdown(list(input_mapping['YRXMDEYR'].keys()), label="YRXMDEYR") |
|
|
|
|
|
predict_btn = gr.Button("Predict") |
|
|
|
|
|
out_pred_res = gr.Textbox(label="Prediction Results", lines=8) |
|
out_sev = gr.Textbox(label="Mental Health Severity", lines=2) |
|
out_count = gr.Markdown(label="Total Patient Count") |
|
out_distplot = gr.Plot(label="Distribution Plot") |
|
out_nn = gr.Markdown(label="Nearest Neighbors Summary") |
|
out_cooc = gr.Plot(label="Co-occurrence Plot Placeholder") |
|
out_bar_input = gr.Plot(label="Input Feature Counts") |
|
out_bar_labels = gr.Plot(label="Predicted Label Counts") |
|
|
|
|
|
predict_btn.click( |
|
fn=predict, |
|
inputs=[ |
|
YMDEYR_dd, YMDERSUD5ANY_dd, YMDEIMAD5YR_dd, YMIMS5YANY_dd, YMDELT_dd, YMDEHARX_dd, |
|
YMDEHPRX_dd, YMDETXRX_dd, YMDEHPO_dd, YMDEAUD5YR_dd, YMIMI5YANY_dd, YMIUD5YANY_dd, |
|
YMDESUD5ANYO_dd, YNURSMDE_dd, YSOCMDE_dd, YCOUNMDE_dd, YPSY1MDE_dd, YPSY2MDE_dd, |
|
YHLTMDE_dd, YDOCMDE_dd, YTXMDEYR_dd, YUSUITHKYR_dd, YUSUIPLNYR_dd, YUSUITHK_dd, |
|
YUSUIPLN_dd, MDEIMPY_dd, LVLDIFMEM2_dd, YMSUD5YANY_dd, YRXMDEYR_dd |
|
], |
|
outputs=[ |
|
out_pred_res, out_sev, out_count, out_distplot, |
|
out_nn, out_cooc, out_bar_input, out_bar_labels |
|
] |
|
) |
|
|
|
|
|
with gr.Tab("Co-occurrence"): |
|
gr.Markdown("## Generate a Co-Occurrence Plot on Demand\nSelect two features and one label:") |
|
with gr.Row(): |
|
feature1_dd = gr.Dropdown(sorted(df.columns), label="Feature 1") |
|
feature2_dd = gr.Dropdown(sorted(df.columns), label="Feature 2") |
|
label_dd = gr.Dropdown(sorted(df.columns), label="Label Column") |
|
out_co_occ_plot = gr.Plot(label="Co-occurrence Plot") |
|
|
|
co_occ_btn = gr.Button("Generate Plot") |
|
co_occ_btn.click( |
|
fn=co_occurrence_plot, |
|
inputs=[feature1_dd, feature2_dd, label_dd], |
|
outputs=out_co_occ_plot |
|
) |
|
|
|
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|