pantdipendra's picture
v2_seperate tabs categories in UI
e9e83fc verified
raw
history blame
22.4 kB
import pickle
import gradio as gr
import numpy as np
import pandas as pd
import plotly.express as px
######################################
# 1) LOAD DATA & MODELS
######################################
df = pd.read_csv("X_train_test_combined_dataset_Filtered_dataset.csv")
model_filenames = [
"YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl",
"YODPPROB.pkl", "YOWRPROB.pkl", "YODPR2WK.pkl", "YOWRDEPR.pkl",
"YODPDISC.pkl", "YOLOSEV.pkl", "YOWRDCSN.pkl", "YODSMMDE.pkl",
"YO_MDEA3.pkl", "YODPLSIN.pkl", "YOWRELES.pkl", "YOPB2WK.pkl"
]
model_path = "models/"
######################################
# 2) MODEL PREDICTOR
######################################
class ModelPredictor:
def __init__(self, model_path, model_filenames):
self.model_path = model_path
self.model_filenames = model_filenames
self.models = self.load_models()
self.prediction_map = {
"YOWRCONC": ["Did not have difficulty concentrating", "Had difficulty concentrating"],
"YOSEEDOC": ["Did not feel the need to see a doctor", "Felt the need to see a doctor"],
"YOWRHRS": ["Did not have trouble sleeping", "Had trouble sleeping"],
"YO_MDEA5": ["Others did not notice restlessness/lethargy", "Others noticed restlessness/lethargy"],
"YOWRCHR": ["Did not feel so sad nothing could cheer up", "Felt so sad that nothing could cheer up"],
"YOWRLSIN": ["Did not feel bored / lose interest", "Felt bored / lost interest"],
"YODPPROB": ["No other problems for 2+ weeks", "Had other problems for 2+ weeks"],
"YOWRPROB": ["No 'worst time ever' feeling", "Had 'worst time ever' feeling"],
"YODPR2WK": ["No periods with depressed feelings lasting 2+ weeks", "Had depressed feelings 2+ weeks"],
"YOWRDEPR": ["Did not feel sad/depressed mostly everyday", "Felt sad/depressed mostly everyday"],
"YODPDISC": ["Overall mood not sad/depressed", "Overall mood was sad/depressed"],
"YOLOSEV": ["Did not lose interest", "Lost interest in enjoyable things"],
"YOWRDCSN": ["Was able to make decisions", "Was unable to make decisions"],
"YODSMMDE": ["Never had 2 weeks depression symptoms", "Had 2+ weeks of depression symptoms"],
"YO_MDEA3": ["No changes in appetite/weight", "Had changes in appetite/weight"],
"YODPLSIN": ["Never lost interest / felt bored", "Lost interest/felt bored"],
"YOWRELES": ["Did not eat less than usual", "Ate less than usual"],
"YODSCEV": ["Fewer severe depression symptoms", "More severe depression symptoms"],
"YOPB2WK": ["No uneasy feelings lasting 2+ weeks", "Uneasy feelings lasting 2+ weeks"],
"YO_MDEA2": ["No physical/mental issues (2+ weeks)", "Had physical/mental issues (2+ weeks)"]
}
def load_models(self):
loaded = []
for fname in self.model_filenames:
with open(self.model_path + fname, "rb") as f:
model = pickle.load(f)
loaded.append(model)
return loaded
def make_predictions(self, user_input: pd.DataFrame):
predictions = []
for model in self.models:
out = model.predict(user_input)
predictions.append(out.flatten())
return predictions
def get_majority_vote(self, predictions):
combined = np.concatenate(predictions)
return np.bincount(combined).argmax()
def evaluate_severity(self, count_ones: int) -> str:
if count_ones >= 13:
return "Mental Health Severity: Severe"
elif count_ones >= 9:
return "Mental Health Severity: Moderate"
elif count_ones >= 5:
return "Mental Health Severity: Low"
else:
return "Mental Health Severity: Very Low"
predictor = ModelPredictor(model_path, model_filenames)
######################################
# 3) FEATURE CATEGORIES + MAPPING
######################################
categories_dict = {
"1. Depression & Substance Use Diagnosis": [
"YMDESUD5ANYO", "YMDELT", "YMDEYR", "YMDERSUD5ANY",
"YMSUD5YANY", "YMIUD5YANY", "YMIMS5YANY", "YMIMI5YANY"
],
"2. Mental Health Treatment & Prof Consultation": [
"YMDEHPO", "YMDETXRX", "YMDEHARX", "YRXMDEYR", "YHLTMDE",
"YTXMDEYR", "YDOCMDE", "YPSY2MDE", "YPSY1MDE", "YCOUNMDE"
],
"3. Functional & Cognitive Impairment": [
"MDEIMPY", "LVLDIFMEM2"
],
"4. Suicidal Thoughts & Behaviors": [
"YUSUITHK", "YUSUITHKYR", "YUSUIPLNYR", "YUSUIPLN"
]
}
# The numeric mappings for each of the 25 features
input_mapping = {
'YMDESUD5ANYO': {"SUD only, no MDE": 1, "MDE only, no SUD": 2, "SUD and MDE": 3, "Neither SUD or MDE": 4},
'YMDELT': {"Yes": 1, "No": 2},
'YMDEYR': {"Yes": 1, "No": 2},
'YMDERSUD5ANY': {"Yes": 1, "No": 0},
'YMSUD5YANY': {"Yes": 1, "No": 0},
'YMIUD5YANY': {"Yes": 1, "No": 0},
'YMIMS5YANY': {"Yes": 1, "No": 0},
'YMIMI5YANY': {"Yes": 1, "No": 0},
'YMDEHPO': {"Yes": 1, "No": 0},
'YMDETXRX': {"Yes": 1, "No": 0},
'YMDEHARX': {"Yes": 1, "No": 0},
'YRXMDEYR': {"Yes": 1, "No": 0},
'YHLTMDE': {"Yes": 1, "No": 0},
'YTXMDEYR': {"Yes": 1, "No": 0},
'YDOCMDE': {"Yes": 1, "No": 0},
'YPSY2MDE': {"Yes": 1, "No": 0},
'YPSY1MDE': {"Yes": 1, "No": 0},
'YCOUNMDE': {"Yes": 1, "No": 0},
'MDEIMPY': {"Yes": 1, "No": 2},
'LVLDIFMEM2': {"No Difficulty": 1, "Some difficulty": 2, "A lot of difficulty or cannot do at all": 3},
'YUSUITHK': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
'YUSUITHKYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
'YUSUIPLNYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
'YUSUIPLN': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4},
}
def validate_inputs(*args):
for arg in args:
if not arg: # empty or None
return False
return True
######################################
# 4) NEAREST NEIGHBORS (Grouped)
######################################
def get_nearest_neighbors_info(user_input_df: pd.DataFrame, k=5):
# Ensure columns exist in df
user_cols = user_input_df.columns
if not all(col in df.columns for col in user_cols):
return "Cannot compute nearest neighbors. Some columns not found in df."
# Subset df
sub_df = df[list(user_cols)].copy()
diffs = sub_df - user_input_df.iloc[0]
dists = (diffs**2).sum(axis=1)**0.5
nn_indices = dists.nsmallest(k).index
neighbors = df.loc[nn_indices]
lines = [f"**Nearest Neighbors (k={k})**",
f"Distances Range: {dists[nn_indices].min():.2f} to {dists[nn_indices].max():.2f}",
""]
# Group the features by our categories_dict
for cat_name, cat_feats in categories_dict.items():
lines.append(f"### {cat_name}")
for feat in cat_feats:
if feat not in neighbors.columns:
continue
# Count how many neighbors had each numeric value
val_counts = neighbors[feat].value_counts().to_dict()
# Build string like: "YMDESUD5ANYO => 3 had 1, 2 had 2..."
parts = []
for val_, count_ in val_counts.items():
parts.append(f"{count_} had '{val_}'")
joined = "; ".join(parts)
lines.append(f"**{feat}** => {joined}")
lines.append("") # blank line
return "\n".join(lines)
######################################
# 5) PREDICT FUNCTION
######################################
def predict(
# EXACTLY 25 features, matching categories_dict ordering.
# We'll just list them in the dictionary order we want to show them:
YMDESUD5ANYO, YMDELT, YMDEYR, YMDERSUD5ANY,
YMSUD5YANY, YMIUD5YANY, YMIMS5YANY, YMIMI5YANY,
YMDEHPO, YMDETXRX, YMDEHARX, YRXMDEYR, YHLTMDE,
YTXMDEYR, YDOCMDE, YPSY2MDE, YPSY1MDE, YCOUNMDE,
MDEIMPY, LVLDIFMEM2,
YUSUITHK, YUSUITHKYR, YUSUIPLNYR, YUSUIPLN
):
if not validate_inputs(
YMDESUD5ANYO, YMDELT, YMDEYR, YMDERSUD5ANY,
YMSUD5YANY, YMIUD5YANY, YMIMS5YANY, YMIMI5YANY,
YMDEHPO, YMDETXRX, YMDEHARX, YRXMDEYR, YHLTMDE,
YTXMDEYR, YDOCMDE, YPSY2MDE, YPSY1MDE, YCOUNMDE,
MDEIMPY, LVLDIFMEM2,
YUSUITHK, YUSUITHKYR, YUSUIPLNYR, YUSUIPLN
):
return (
"Please select all required fields.", # 1) Prediction Results
"Validation Error", # 2) Severity
"No data", # 3) Total Count
"No nearest neighbors info", # 4) NN Summary
None, # 5) Bar chart (Input)
None # 6) Bar chart (Labels)
)
# 1) Map user-friendly -> numeric
user_input_dict = {
'YMDESUD5ANYO': input_mapping['YMDESUD5ANYO'][YMDESUD5ANYO],
'YMDELT': input_mapping['YMDELT'][YMDELT],
'YMDEYR': input_mapping['YMDEYR'][YMDEYR],
'YMDERSUD5ANY': input_mapping['YMDERSUD5ANY'][YMDERSUD5ANY],
'YMSUD5YANY': input_mapping['YMSUD5YANY'][YMSUD5YANY],
'YMIUD5YANY': input_mapping['YMIUD5YANY'][YMIUD5YANY],
'YMIMS5YANY': input_mapping['YMIMS5YANY'][YMIMS5YANY],
'YMIMI5YANY': input_mapping['YMIMI5YANY'][YMIMI5YANY],
'YMDEHPO': input_mapping['YMDEHPO'][YMDEHPO],
'YMDETXRX': input_mapping['YMDETXRX'][YMDETXRX],
'YMDEHARX': input_mapping['YMDEHARX'][YMDEHARX],
'YRXMDEYR': input_mapping['YRXMDEYR'][YRXMDEYR],
'YHLTMDE': input_mapping['YHLTMDE'][YHLTMDE],
'YTXMDEYR': input_mapping['YTXMDEYR'][YTXMDEYR],
'YDOCMDE': input_mapping['YDOCMDE'][YDOCMDE],
'YPSY2MDE': input_mapping['YPSY2MDE'][YPSY2MDE],
'YPSY1MDE': input_mapping['YPSY1MDE'][YPSY1MDE],
'YCOUNMDE': input_mapping['YCOUNMDE'][YCOUNMDE],
'MDEIMPY': input_mapping['MDEIMPY'][MDEIMPY],
'LVLDIFMEM2': input_mapping['LVLDIFMEM2'][LVLDIFMEM2],
'YUSUITHK': input_mapping['YUSUITHK'][YUSUITHK],
'YUSUITHKYR': input_mapping['YUSUITHKYR'][YUSUITHKYR],
'YUSUIPLNYR': input_mapping['YUSUIPLNYR'][YUSUIPLNYR],
'YUSUIPLN': input_mapping['YUSUIPLN'][YUSUIPLN]
}
user_df = pd.DataFrame(user_input_dict, index=[0])
# 2) Predict
predictions = predictor.make_predictions(user_df)
all_preds = np.concatenate(predictions)
count_ones = sum(all_preds == 1)
severity_msg = predictor.evaluate_severity(count_ones)
# 3) Grouped textual results
groups = {
"Concentration_and_Decision_Making": ["YOWRCONC", "YOWRDCSN"],
"Sleep_and_Energy_Levels": ["YOWRHRS", "YO_MDEA5", "YOWRELES", "YO_MDEA2"],
"Mood_and_Emotional_State": [
"YOWRCHR", "YOWRLSIN", "YOWRDEPR", "YODPDISC", "YOLOSEV", "YODPLSIN", "YODSCEV"
],
"Appetite_and_Weight_Changes": ["YO_MDEA3", "YOWRELES"],
"Duration_and_Severity_of_Depression_Symptoms": [
"YODPPROB", "YOWRPROB", "YODPR2WK", "YODSMMDE", "YOPB2WK"
]
}
group_text = {g: [] for g in groups}
# The model_filenames order determines which label is i
for i, arr in enumerate(predictions):
label_col = model_filenames[i].split('.')[0] # e.g. "YOWRCONC"
val = arr[0]
# If we have a textual map, use it
if label_col in predictor.prediction_map and val in range(len(predictor.prediction_map[label_col])):
text_label = predictor.prediction_map[label_col][val]
else:
text_label = f"Prediction={val}"
# Put in whichever group
for group_name, cols_ in groups.items():
if label_col in cols_:
group_text[group_name].append(f"{label_col} => {text_label}")
break
final_str_parts = []
for gname, lines in group_text.items():
if lines:
gtitle = gname.replace("_", " ")
final_str_parts.append(f"**{gtitle}**")
final_str_parts.append("\n".join(lines))
final_str_parts.append("")
if not final_str_parts:
final_str = "No predictions made or no matching group columns."
else:
final_str = "\n".join(final_str_parts)
# 4) Additional info
total_count = len(df)
total_count_md = f"We have **{total_count}** patients in the dataset."
# 5) Nearest Neighbors
nn_md = get_nearest_neighbors_info(user_df, k=5)
# 6) Bar chart for input features
input_counts = {}
for col, val_ in user_input_dict.items():
matched = len(df[df[col] == val_])
input_counts[col] = matched
bar_in_df = pd.DataFrame({"Feature": list(input_counts.keys()),
"Count": list(input_counts.values())})
fig_in = px.bar(
bar_in_df, x="Feature", y="Count",
title="Number of Patients with the Same Input Feature Values"
)
fig_in.update_layout(width=1200, height=400)
# 7) Bar chart for predicted labels
label_counts = {}
for i, arr in enumerate(predictions):
lbl = model_filenames[i].split('.')[0]
pred_val = arr[0]
if lbl in df.columns:
label_counts[lbl] = len(df[df[lbl] == pred_val])
if label_counts:
bar_lbl_df = pd.DataFrame({
"Label": list(label_counts.keys()),
"Count": list(label_counts.values())
})
fig_lbl = px.bar(
bar_lbl_df, x="Label", y="Count",
title="Number of Patients with the Same Predicted Label"
)
fig_lbl.update_layout(width=1200, height=400)
else:
fig_lbl = px.bar(title="No valid predicted labels to display.")
fig_lbl.update_layout(width=1200, height=400)
return (
final_str, # 1) Prediction Results
severity_msg, # 2) Mental Health Severity
total_count_md, # 3) Total Patient Count
nn_md, # 4) Nearest Neighbors Summary
fig_in, # 5) Bar Chart (input features)
fig_lbl # 6) Bar Chart (labels)
)
######################################
# 6) EXTRA TABS / FUNCTIONS
######################################
def distribution_plot(feature_col, label_col):
if not feature_col or not label_col:
return px.bar(title="Please select both Feature and Label.")
if (feature_col not in df.columns) or (label_col not in df.columns):
return px.bar(title="Selected columns not found in the dataset.")
grouped = df.groupby([feature_col, label_col]).size().reset_index(name="count")
fig = px.bar(
grouped,
x=feature_col,
y="count",
color=label_col,
title=f"Distribution of {feature_col} vs {label_col}"
)
fig.update_layout(width=1200, height=600)
return fig
def co_occurrence_plot(feature1, feature2, label_col):
if not feature1 or not feature2 or not label_col:
return px.bar(title="Please select all three fields.")
if feature1 not in df.columns or feature2 not in df.columns or label_col not in df.columns:
return px.bar(title="Selected columns not found in the dataset.")
grouped = df.groupby([feature1, feature2, label_col]).size().reset_index(name="count")
fig = px.bar(
grouped,
x=feature1,
y="count",
color=label_col,
facet_col=feature2,
title=f"Co-occurrence: {feature1}, {feature2} vs {label_col}"
)
fig.update_layout(width=1200, height=600)
return fig
######################################
# 7) BUILD GRADIO UI
######################################
with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
# ======== TAB 1: PREDICTION ========
with gr.Tab("Prediction"):
gr.Markdown(
"""
### Please provide inputs in each of the four categories below.
*All fields are required.*
"""
)
# For clarity, we define an ordered list of the features in the exact sequence
# matching our predict() function. We’ll group them under the same headings.
cat1_col_labels = [
("YMDESUD5ANYO", "YMDESUD5ANYO: ONLY MDE, ONLY SUD, BOTH, OR NEITHER"),
("YMDELT", "YMDELT: Had major depressive episode in lifetime"),
("YMDEYR", "YMDEYR: Past-year major depressive episode"),
("YMDERSUD5ANY", "YMDERSUD5ANY: MDE or substance use disorder - past year"),
("YMSUD5YANY", "YMSUD5YANY: Past-year MDE & substance use disorder"),
("YMIUD5YANY", "YMIUD5YANY: Past-year MDE & illicit drug use disorder"),
("YMIMS5YANY", "YMIMS5YANY: Past-year MDE + severe impairment + substance use"),
("YMIMI5YANY", "YMIMI5YANY: Past-year MDE with severe impairment & illicit drug use")
]
cat2_col_labels = [
("YMDEHPO", "YMDEHPO: Saw health prof only for MDE in past year"),
("YMDETXRX", "YMDETXRX: Received treatment/counseling if saw doc/prof for MDE"),
("YMDEHARX", "YMDEHARX: Saw health professional & received medication for MDE"),
("YRXMDEYR", "YRXMDEYR: Used received medication for MDE in past years"),
("YHLTMDE", "YHLTMDE: Saw/talked to health professional about MDE in past year"),
("YTXMDEYR", "YTXMDEYR: Saw or talked to doc/health prof for MDE in past year"),
("YDOCMDE", "YDOCMDE: Saw/talked to general practitioner/family MD about MDE"),
("YPSY2MDE", "YPSY2MDE: Saw/talked to psychiatrist about MDE"),
("YPSY1MDE", "YPSY1MDE: Saw/talked to psychologist about MDE"),
("YCOUNMDE", "YCOUNMDE: Saw/talked to counselor about MDE")
]
cat3_col_labels = [
("MDEIMPY", "MDEIMPY: MDE with severe role impairment"),
("LVLDIFMEM2", "LVLDIFMEM2: Difficulty remembering/concentrating")
]
cat4_col_labels = [
("YUSUITHK", "YUSUITHK: Youth seriously think about killing self in past 12 months"),
("YUSUITHKYR", "YUSUITHKYR: Seriously thought about killing self"),
("YUSUIPLNYR", "YUSUIPLNYR: Made plans to kill self in past year"),
("YUSUIPLN", "YUSUIPLN: Made plans to kill yourself in past 12 months")
]
# Category 1 block
gr.Markdown("#### 1. Depression & Substance Use Diagnosis")
cat1_inputs = []
for col, label_text in cat1_col_labels:
dd = gr.Dropdown(
choices=list(input_mapping[col].keys()),
label=label_text
)
cat1_inputs.append(dd)
# Category 2 block
gr.Markdown("#### 2. Mental Health Treatment & Professional Consultation")
cat2_inputs = []
for col, label_text in cat2_col_labels:
dd = gr.Dropdown(
choices=list(input_mapping[col].keys()),
label=label_text
)
cat2_inputs.append(dd)
# Category 3 block
gr.Markdown("#### 3. Functional & Cognitive Impairment")
cat3_inputs = []
for col, label_text in cat3_col_labels:
dd = gr.Dropdown(
choices=list(input_mapping[col].keys()),
label=label_text
)
cat3_inputs.append(dd)
# Category 4 block
gr.Markdown("#### 4. Suicidal Thoughts & Behaviors")
cat4_inputs = []
for col, label_text in cat4_col_labels:
dd = gr.Dropdown(
choices=list(input_mapping[col].keys()),
label=label_text
)
cat4_inputs.append(dd)
# The overall input list must match the order in `predict()`
all_inputs = cat1_inputs + cat2_inputs + cat3_inputs + cat4_inputs
predict_btn = gr.Button("Predict")
# 6 outputs
out_pred_res = gr.Textbox(label="Prediction Results", lines=8)
out_sev = gr.Textbox(label="Mental Health Severity", lines=2)
out_count = gr.Markdown(label="Total Patient Count")
out_nn = gr.Markdown(label="Nearest Neighbors Summary (Grouped by Category)")
out_bar_input= gr.Plot(label="Input Feature Counts")
out_bar_label= gr.Plot(label="Predicted Label Counts")
predict_btn.click(
fn=predict,
inputs=all_inputs,
outputs=[
out_pred_res, # 1
out_sev, # 2
out_count, # 3
out_nn, # 4
out_bar_input, # 5
out_bar_label # 6
]
)
# ======== TAB 2: Distribution Analysis ========
with gr.Tab("Distribution Analysis"):
gr.Markdown("## Distribution Plot\nSelect one feature and one label column to see bar counts.")
list_of_features = sorted(input_mapping.keys())
list_of_labels = sorted(predictor.prediction_map.keys())
feat_dd = gr.Dropdown(choices=list_of_features, label="Feature Column")
lbl_dd = gr.Dropdown(choices=list_of_labels, label="Label Column")
generate_dist_btn = gr.Button("Generate Distribution Plot")
dist_output = gr.Plot()
generate_dist_btn.click(
fn=distribution_plot,
inputs=[feat_dd, lbl_dd],
outputs=dist_output
)
# ======== TAB 3: Co-occurrence ========
with gr.Tab("Co-occurrence"):
gr.Markdown("## Co-Occurrence Plot\nSelect two features + one label to see a 3-way distribution.")
feat1_dd = gr.Dropdown(choices=list_of_features, label="Feature 1")
feat2_dd = gr.Dropdown(choices=list_of_features, label="Feature 2")
label_dd = gr.Dropdown(choices=list_of_labels, label="Label Column")
generate_btn = gr.Button("Generate Co-occurrence Plot")
co_occ_output = gr.Plot()
generate_btn.click(
fn=co_occurrence_plot,
inputs=[feat1_dd, feat2_dd, label_dd],
outputs=co_occ_output
)
# Finally, launch
demo.launch()