pantdipendra's picture
Update app.py
69090fc verified
raw
history blame
24.2 kB
import pickle
import gradio as gr
import numpy as np
import pandas as pd
import plotly.express as px
# Load the training CSV once.
df = pd.read_csv("X_train_Y_Train_merged_train.csv")
######################################
# 1) MODEL PREDICTOR CLASS
######################################
class ModelPredictor:
def __init__(self, model_path, model_filenames):
self.model_path = model_path
self.model_filenames = model_filenames
self.models = self.load_models()
# Mapping from label column to human-readable strings for 0/1
self.prediction_map = {
"YOWRCONC": ["No difficulty concentrating", "Had difficulty concentrating"],
"YOSEEDOC": ["No need to see doctor", "Needed to see doctor"],
"YOWRHRS": ["No trouble sleeping", "Had trouble sleeping"],
"YO_MDEA5": ["Others didn't notice restlessness", "Others noticed restlessness"],
"YOWRCHR": ["Not sad beyond cheering", "Felt so sad no one could cheer up"],
"YOWRLSIN": ["Never felt bored/lost interest", "Felt bored/lost interest"],
"YODPPROB": ["No other problems for 2+ weeks", "Had other problems for 2+ weeks"],
"YOWRPROB": ["No worst time feeling", "Felt worst time ever"],
"YODPR2WK": ["No depressed feelings for 2+ wks", "Depressed feelings for 2+ wks"],
"YOWRDEPR": ["Not sad or depressed most days", "Sad or depressed most days"],
"YODPDISC": ["Mood not depressed overall", "Mood depressed overall (discrepancy)"],
"YOLOSEV": ["Did not lose interest in activities", "Lost interest in activities"],
"YOWRDCSN": ["Could make decisions", "Could not make decisions"],
"YODSMMDE": ["No 2+ week depression episodes", "Had 2+ week depression episodes"],
"YO_MDEA3": ["No appetite/weight changes", "Yes appetite/weight changes"],
"YODPLSIN": ["Never bored/lost interest", "Often bored/lost interest"],
"YOWRELES": ["Did not eat less", "Ate less than usual"],
"YODSCEV": ["Fewer severe symptoms", "More severe symptoms"],
"YOPB2WK": ["No uneasy feelings daily 2+ wks", "Uneasy feelings daily 2+ wks"],
"YO_MDEA2": ["No issues physical/mental daily", "Issues physical/mental daily 2+ wks"]
}
def load_models(self):
models = []
for fn in self.model_filenames:
filepath = self.model_path + fn
with open(filepath, "rb") as file:
models.append(pickle.load(file))
return models
def make_predictions(self, user_input):
"""Return list of numpy arrays, each array either [0] or [1]."""
preds = []
for m in self.models:
out = m.predict(user_input)
preds.append(np.array(out).flatten())
return preds
def get_majority_vote(self, predictions):
"""Flatten all predictions and find 0 or 1 with majority."""
combined = np.concatenate(predictions)
return np.bincount(combined).argmax()
def evaluate_severity(self, majority_vote_count):
"""Heuristic: Based on 16 total models, 0-4=Very Low, 5-8=Low, 9-12=Moderate, 13-16=Severe."""
if majority_vote_count >= 13:
return "Mental health severity: Severe"
elif majority_vote_count >= 9:
return "Mental health severity: Moderate"
elif majority_vote_count >= 5:
return "Mental health severity: Low"
else:
return "Mental health severity: Very Low"
######################################
# 2) CONFIGURATIONS
######################################
model_filenames = [
"YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl",
"YODPPROB.pkl", "YOWRPROB.pkl", "YODPR2WK.pkl", "YOWRDEPR.pkl",
"YODPDISC.pkl", "YOLOSEV.pkl", "YOWRDCSN.pkl", "YODSMMDE.pkl",
"YO_MDEA3.pkl", "YODPLSIN.pkl", "YOWRELES.pkl", "YOPB2WK.pkl"
]
model_path = "models/"
predictor = ModelPredictor(model_path, model_filenames)
######################################
# 3) INPUT VALIDATION
######################################
def validate_inputs(*args):
# Just ensure all required (non-co-occurrence) fields are picked
for arg in args:
if arg == '' or arg is None:
return False
return True
######################################
# 4) PREDICTION FUNCTION
######################################
def predict(
# Original required features
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE,
YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR,
# **New** optional picks for co-occurrence
co_occ_feature1, co_occ_feature2, co_occ_label
):
"""
Main function that:
- Predicts with the 16 models
- Aggregates results
- Produces severity
- Returns distribution & bar charts
- Finds K=2 Nearest Neighbors
- Produces *one* co-occurrence plot based on user-chosen columns
"""
# 1) Build user_input for models
user_input_data = {
'YNURSMDE': [int(YNURSMDE)],
'YMDEYR': [int(YMDEYR)],
'YSOCMDE': [int(YSOCMDE)],
'YMDESUD5ANYO': [int(YMDESUD5ANYO)],
'YMSUD5YANY': [int(YMSUD5YANY)],
'YUSUITHK': [int(YUSUITHK)],
'YMDETXRX': [int(YMDETXRX)],
'YUSUITHKYR': [int(YUSUITHKYR)],
'YMDERSUD5ANY': [int(YMDERSUD5ANY)],
'YUSUIPLNYR': [int(YUSUIPLNYR)],
'YCOUNMDE': [int(YCOUNMDE)],
'YPSY1MDE': [int(YPSY1MDE)],
'YHLTMDE': [int(YHLTMDE)],
'YDOCMDE': [int(YDOCMDE)],
'YPSY2MDE': [int(YPSY2MDE)],
'YMDEHARX': [int(YMDEHARX)],
'LVLDIFMEM2': [int(LVLDIFMEM2)],
'MDEIMPY': [int(MDEIMPY)],
'YMDEHPO': [int(YMDEHPO)],
'YMIMS5YANY': [int(YMIMS5YANY)],
'YMDEIMAD5YR': [int(YMDEIMAD5YR)],
'YMIUD5YANY': [int(YMIUD5YANY)],
'YMDEHPRX': [int(YMDEHPRX)],
'YMIMI5YANY': [int(YMIMI5YANY)],
'YUSUIPLN': [int(YUSUIPLN)],
'YTXMDEYR': [int(YTXMDEYR)],
'YMDEAUD5YR': [int(YMDEAUD5YR)],
'YRXMDEYR': [int(YRXMDEYR)],
'YMDELT': [int(YMDELT)]
}
user_input = pd.DataFrame(user_input_data)
# 2) Model Predictions
predictions = predictor.make_predictions(user_input)
majority_vote = predictor.get_majority_vote(predictions)
majority_vote_count = np.sum(np.concatenate(predictions) == 1)
severity = predictor.evaluate_severity(majority_vote_count)
# 3) Summarize textual results
results_by_group = {
"Concentration_and_Decision_Making": [],
"Sleep_and_Energy_Levels": [],
"Mood_and_Emotional_State": [],
"Appetite_and_Weight_Changes": [],
"Duration_and_Severity_of_Depression_Symptoms": []
}
group_map = {
"Concentration_and_Decision_Making": ["YOWRCONC", "YOWRDCSN"],
"Sleep_and_Energy_Levels": ["YOWRHRS", "YO_MDEA5", "YOWRELES", "YO_MDEA2"],
"Mood_and_Emotional_State": ["YOWRCHR", "YOWRLSIN", "YOWRDEPR", "YODPDISC",
"YOLOSEV", "YODPLSIN", "YODSCEV"],
"Appetite_and_Weight_Changes": ["YO_MDEA3", "YOWRELES"],
"Duration_and_Severity_of_Depression_Symptoms": ["YODPPROB", "YOWRPROB",
"YODPR2WK", "YODSMMDE",
"YOPB2WK"]
}
# Convert each model's 0/1 to text
grouped_output_lines = []
for i, pred_array in enumerate(predictions):
col_name = model_filenames[i].split(".")[0] # e.g., "YOWRCONC"
val = pred_array[0]
if col_name in predictor.prediction_map and val in [0, 1]:
text = predictor.prediction_map[col_name][val]
out_line = f"{col_name}: {text}"
else:
out_line = f"{col_name}: Prediction={val}"
# Find group
placed = False
for g_key, g_cols in group_map.items():
if col_name in g_cols:
results_by_group[g_key].append(out_line)
placed = True
break
if not placed:
# If it didn't fall into any known group, skip or handle
pass
# Format into a single string
for group_label, pred_lines in results_by_group.items():
if pred_lines:
grouped_output_lines.append(f"Group {group_label}:")
grouped_output_lines.append("\n".join(pred_lines))
grouped_output_lines.append("")
if len(grouped_output_lines) == 0:
final_result_text = "No predictions made. Check inputs."
else:
final_result_text = "\n".join(grouped_output_lines).strip()
# 4) Additional Features
# A) Total patient count
total_patients = len(df)
total_count_md = (
"### Total Patient Count\n"
f"**{total_patients}** total patients in the dataset."
)
# B) Bar chart of how many have same inputs
input_counts = {}
for c in user_input_data.keys():
v = user_input_data[c][0]
input_counts[c] = len(df[df[c] == v])
df_input_counts = pd.DataFrame({"Feature": list(input_counts.keys()), "Count": list(input_counts.values())})
fig_input_bar = px.bar(
df_input_counts,
x="Feature",
y="Count",
title="Number of Patients with the Same Value for Each Input Feature"
)
fig_input_bar.update_layout(xaxis={"categoryorder": "total descending"})
# C) Bar chart for predicted labels
label_counts = {}
for i, pred_array in enumerate(predictions):
col_name = model_filenames[i].split(".")[0]
val = pred_array[0]
if val in [0,1]:
label_counts[col_name] = len(df[df[col_name] == val])
if len(label_counts) > 0:
df_label_counts = pd.DataFrame({
"Label Column": list(label_counts.keys()),
"Count": list(label_counts.values())
})
fig_label_bar = px.bar(
df_label_counts,
x="Label Column",
y="Count",
title="Number of Patients with the Same Predicted Label"
)
else:
fig_label_bar = px.bar(title="No valid predicted labels to display")
# D) Simple Distribution Plot (demo for first 3 labels & 4 inputs)
# (Unchanged from prior approach; you can remove if you prefer.)
sample_feats = list(user_input_data.keys())[:31]
sample_labels = [fn.split(".")[0] for fn in model_filenames[:15]]
dist_segments = []
for feat in sample_feats:
if feat not in df.columns:
continue
for lbl in sample_labels:
if lbl not in df.columns:
continue
temp_g = df.groupby([feat,lbl]).size().reset_index(name="count")
temp_g["feature"] = feat
temp_g["label"] = lbl
dist_segments.append(temp_g)
if len(dist_segments) > 0:
big_dist_df = pd.concat(dist_segments, ignore_index=True)
fig_dist = px.bar(
big_dist_df,
x=big_dist_df.columns[0],
y="count",
color=big_dist_df.columns[1],
facet_row="feature",
facet_col="label",
title="Sample Distribution Plot (first 4 features vs first 3 labels)"
)
fig_dist.update_layout(height=700)
else:
fig_dist = px.bar(title="No distribution plot generated (columns not found).")
# E) Nearest Neighbors with K=2
# We keep K=2, but for *all* label columns, we show their actual 0/1 or mapped text
# (same approach as before).
# ... [omitted here for brevity, or replicate your existing code for K=2 nearest neighbors] ...
# We'll do a short version to keep focus on co-occ:
# ---------------------------------------------------------------------
# Build Hamming distance across user_input columns
columns_for_distance = list(user_input.columns)
sub_df = df[columns_for_distance].copy()
user_row = user_input.iloc[0]
distances = []
for idx, row_ in sub_df.iterrows():
dist_ = sum(row_[col] != user_row[col] for col in columns_for_distance)
distances.append(dist_)
df_dist = df.copy()
df_dist["distance"] = distances
# Sort ascending, pick K=2
K = 2
nearest_neighbors = df_dist.sort_values("distance", ascending=True).head(K)
# Summarize in Markdown
nn_md = ["### Nearest Neighbors (K=2)"]
nn_md.append("(In a real application, you'd refine which features matter, how to encode them, etc.)\n")
for irow in nearest_neighbors.itertuples():
nn_md.append(f"- **Neighbor ID {irow.Index}**: distance={irow.distance}")
nn_md_str = "\n".join(nn_md)
# F) Co-occurrence Plot for user-chosen feature1, feature2, label
# If the user picks "None" or doesn't pick valid columns, skip or fallback.
if (co_occ_feature1 is not None and co_occ_feature1 != "None" and
co_occ_feature2 is not None and co_occ_feature2 != "None" and
co_occ_label is not None and co_occ_label != "None"):
# Check if these columns are in df
if (co_occ_feature1 in df.columns and
co_occ_feature2 in df.columns and
co_occ_label in df.columns):
# Group by [co_occ_feature1, co_occ_feature2, co_occ_label]
co_data = df.groupby([co_occ_feature1, co_occ_feature2, co_occ_label]).size().reset_index(name="count")
fig_co_occ = px.bar(
co_data,
x=co_occ_feature1,
y="count",
color=co_occ_label,
facet_col=co_occ_feature2,
title=f"Co-occurrence: {co_occ_feature1} & {co_occ_feature2} vs {co_occ_label}"
)
else:
fig_co_occ = px.bar(title="One or more selected columns not found in dataframe.")
else:
fig_co_occ = px.bar(title="No co-occurrence plot (choose two features + one label).")
# Return all 8 outputs
return (
final_result_text, # (1) Predictions
severity, # (2) Severity
total_count_md, # (3) Total patient count
fig_dist, # (4) Distribution Plot
nn_md_str, # (5) Nearest Neighbors
fig_co_occ, # (6) Co-occurrence
fig_input_bar, # (7) Bar Chart (input features)
fig_label_bar # (8) Bar Chart (labels)
)
######################################
# 5) MAPPING (user -> int)
######################################
input_mapping = {
'YNURSMDE': {"Yes": 1, "No": 0},
'YMDEYR': {"Yes": 1, "No": 2},
'YSOCMDE': {"Yes": 1, "No": 0},
'YMDESUD5ANYO': {"SUD only": 1, "MDE only": 2, "SUD & MDE": 3, "Neither": 4},
'YMSUD5YANY': {"Yes": 1, "No": 0},
'YUSUITHK': {"Yes": 1, "No": 2, "Unsure": 3, "Don't want to answer": 4},
'YMDETXRX': {"Yes": 1, "No": 0},
'YUSUITHKYR': {"Yes": 1, "No": 2, "Unsure": 3, "Don't want to answer": 4},
'YMDERSUD5ANY': {"Yes": 1, "No": 0},
'YUSUIPLNYR': {"Yes": 1, "No": 2, "Unsure": 3, "Don't want to answer": 4},
'YCOUNMDE': {"Yes": 1, "No": 0},
'YPSY1MDE': {"Yes": 1, "No": 0},
'YHLTMDE': {"Yes": 1, "No": 0},
'YDOCMDE': {"Yes": 1, "No": 0},
'YPSY2MDE': {"Yes": 1, "No": 0},
'YMDEHARX': {"Yes": 1, "No": 0},
'LVLDIFMEM2': {"No Difficulty": 1, "Some Difficulty": 2, "A lot or cannot do": 3},
'MDEIMPY': {"Yes": 1, "No": 2},
'YMDEHPO': {"Yes": 1, "No": 0},
'YMIMS5YANY': {"Yes": 1, "No": 0},
'YMDEIMAD5YR': {"Yes": 1, "No": 0},
'YMIUD5YANY': {"Yes": 1, "No": 0},
'YMDEHPRX': {"Yes": 1, "No": 0},
'YMIMI5YANY': {"Yes": 1, "No": 0},
'YUSUIPLN': {"Yes": 1, "No": 2, "Unsure": 3, "Don't want to answer": 4},
'YTXMDEYR': {"Yes": 1, "No": 0},
'YMDEAUD5YR': {"Yes": 1, "No": 0},
'YRXMDEYR': {"Yes": 1, "No": 0},
'YMDELT': {"Yes": 1, "No": 2}
}
######################################
# 6) THE GRADIO INTERFACE
######################################
import gradio as gr
# (A) The original required inputs
original_inputs = [
gr.Dropdown(list(input_mapping['YMDEYR'].keys()), label="YMDEYR: Past Year MDE?"),
gr.Dropdown(list(input_mapping['YMDERSUD5ANY'].keys()), label="YMDERSUD5ANY: MDE or SUD - ANY?"),
gr.Dropdown(list(input_mapping['YMDEIMAD5YR'].keys()), label="YMDEIMAD5YR: MDE + ALCOHOL?"),
gr.Dropdown(list(input_mapping['YMIMS5YANY'].keys()), label="YMIMS5YANY: MDE + SUBSTANCE?"),
gr.Dropdown(list(input_mapping['YMDELT'].keys()), label="YMDELT: MDE in Lifetime?"),
gr.Dropdown(list(input_mapping['YMDEHARX'].keys()), label="YMDEHARX: Saw Health Prof + Meds?"),
gr.Dropdown(list(input_mapping['YMDEHPRX'].keys()), label="YMDEHPRX: Saw Health Prof or Meds?"),
gr.Dropdown(list(input_mapping['YMDETXRX'].keys()), label="YMDETXRX: Received Treatment?"),
gr.Dropdown(list(input_mapping['YMDEHPO'].keys()), label="YMDEHPO: Saw Health Prof Only?"),
gr.Dropdown(list(input_mapping['YMDEAUD5YR'].keys()), label="YMDEAUD5YR: MDE + Alcohol Use?"),
gr.Dropdown(list(input_mapping['YMIMI5YANY'].keys()), label="YMIMI5YANY: MDE + ILL Drug Use?"),
gr.Dropdown(list(input_mapping['YMIUD5YANY'].keys()), label="YMIUD5YANY: MDE + ILL Drug Use?"),
gr.Dropdown(list(input_mapping['YMDESUD5ANYO'].keys()), label="YMDESUD5ANYO: MDE vs SUD vs BOTH vs NEITHER"),
# Consultations
gr.Dropdown(list(input_mapping['YNURSMDE'].keys()), label="YNURSMDE: Nurse/OT about MDE?"),
gr.Dropdown(list(input_mapping['YSOCMDE'].keys()), label="YSOCMDE: Social Worker?"),
gr.Dropdown(list(input_mapping['YCOUNMDE'].keys()), label="YCOUNMDE: Counselor?"),
gr.Dropdown(list(input_mapping['YPSY1MDE'].keys()), label="YPSY1MDE: Psychologist?"),
gr.Dropdown(list(input_mapping['YPSY2MDE'].keys()), label="YPSY2MDE: Psychiatrist?"),
gr.Dropdown(list(input_mapping['YHLTMDE'].keys()), label="YHLTMDE: Health Prof?"),
gr.Dropdown(list(input_mapping['YDOCMDE'].keys()), label="YDOCMDE: GP/Family MD?"),
gr.Dropdown(list(input_mapping['YTXMDEYR'].keys()), label="YTXMDEYR: Doctor/Health Prof?"),
# Suicidal
gr.Dropdown(list(input_mapping['YUSUITHKYR'].keys()), label="YUSUITHKYR: Serious Suicide Thoughts?"),
gr.Dropdown(list(input_mapping['YUSUIPLNYR'].keys()), label="YUSUIPLNYR: Made Plans?"),
gr.Dropdown(list(input_mapping['YUSUITHK'].keys()), label="YUSUITHK: Suicide Thoughts (12 mo)?"),
gr.Dropdown(list(input_mapping['YUSUIPLN'].keys()), label="YUSUIPLN: Made Plans (12 mo)?"),
# Impairments
gr.Dropdown(list(input_mapping['MDEIMPY'].keys()), label="MDEIMPY: Severe Role Impairment?"),
gr.Dropdown(list(input_mapping['LVLDIFMEM2'].keys()), label="LVLDIFMEM2: Difficulty Remembering/Concentrating?"),
gr.Dropdown(list(input_mapping['YMSUD5YANY'].keys()), label="YMSUD5YANY: MDE + Substance?"),
gr.Dropdown(list(input_mapping['YRXMDEYR'].keys()), label="YRXMDEYR: Used Meds for MDE (12 mo)?"),
]
# (B) The new co-occurrence inputs
# We'll give them defaults of "None" to indicate no selection.
all_cols = ["None"] + df.columns.tolist() # 'None' plus the actual columns from your df
co_occ_feature1 = gr.Dropdown(all_cols, label="Co-Occ Feature 1", value="None")
co_occ_feature2 = gr.Dropdown(all_cols, label="Co-Occ Feature 2", value="None")
all_label_cols = ["None"] + list(predictor.prediction_map.keys()) # e.g., "YOWRCONC", "YOWRHRS", ...
co_occ_label = gr.Dropdown(all_label_cols, label="Co-Occ Label", value="None")
# Combine them into a single input list
inputs = original_inputs + [co_occ_feature1, co_occ_feature2, co_occ_label]
# 8 outputs as before
outputs = [
gr.Textbox(label="Prediction Results", lines=15),
gr.Textbox(label="Mental Health Severity", lines=2),
gr.Markdown(label="Total Patient Count"),
gr.Plot(label="Distribution Plot (Sample)"),
gr.Markdown(label="Nearest Neighbors (K=2)"),
gr.Plot(label="Co-occurrence Plot"),
gr.Plot(label="Same Value Bar (Inputs)"),
gr.Plot(label="Predicted Label Bar")
]
######################################
# 7) WRAPPER
######################################
def predict_with_text(
# match the function signature exactly (29 required + 3 for co-occ)
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE,
YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR,
co_occ_feature1, co_occ_feature2, co_occ_label
):
# Validate the original 29 fields
valid = validate_inputs(
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE,
YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
)
if not valid:
return (
"Please select all required fields.",
"Validation Error",
"No data",
None,
"No data",
None,
None,
None
)
# Map to numeric
user_inputs = {
'YNURSMDE': input_mapping['YNURSMDE'][YNURSMDE],
'YMDEYR': input_mapping['YMDEYR'][YMDEYR],
'YSOCMDE': input_mapping['YSOCMDE'][YSOCMDE],
'YMDESUD5ANYO': input_mapping['YMDESUD5ANYO'][YMDESUD5ANYO],
'YMSUD5YANY': input_mapping['YMSUD5YANY'][YMSUD5YANY],
'YUSUITHK': input_mapping['YUSUITHK'][YUSUITHK],
'YMDETXRX': input_mapping['YMDETXRX'][YMDETXRX],
'YUSUITHKYR': input_mapping['YUSUITHKYR'][YUSUITHKYR],
'YMDERSUD5ANY': input_mapping['YMDERSUD5ANY'][YMDERSUD5ANY],
'YUSUIPLNYR': input_mapping['YUSUIPLNYR'][YUSUIPLNYR],
'YCOUNMDE': input_mapping['YCOUNMDE'][YCOUNMDE],
'YPSY1MDE': input_mapping['YPSY1MDE'][YPSY1MDE],
'YHLTMDE': input_mapping['YHLTMDE'][YHLTMDE],
'YDOCMDE': input_mapping['YDOCMDE'][YDOCMDE],
'YPSY2MDE': input_mapping['YPSY2MDE'][YPSY2MDE],
'YMDEHARX': input_mapping['YMDEHARX'][YMDEHARX],
'LVLDIFMEM2': input_mapping['LVLDIFMEM2'][LVLDIFMEM2],
'MDEIMPY': input_mapping['MDEIMPY'][MDEIMPY],
'YMDEHPO': input_mapping['YMDEHPO'][YMDEHPO],
'YMIMS5YANY': input_mapping['YMIMS5YANY'][YMIMS5YANY],
'YMDEIMAD5YR': input_mapping['YMDEIMAD5YR'][YMDEIMAD5YR],
'YMIUD5YANY': input_mapping['YMIUD5YANY'][YMIUD5YANY],
'YMDEHPRX': input_mapping['YMDEHPRX'][YMDEHPRX],
'YMIMI5YANY': input_mapping['YMIMI5YANY'][YMIMI5YANY],
'YUSUIPLN': input_mapping['YUSUIPLN'][YUSUIPLN],
'YTXMDEYR': input_mapping['YTXMDEYR'][YTXMDEYR],
'YMDEAUD5YR': input_mapping['YMDEAUD5YR'][YMDEAUD5YR],
'YRXMDEYR': input_mapping['YRXMDEYR'][YRXMDEYR],
'YMDELT': input_mapping['YMDELT'][YMDELT]
}
# Call the core predict function with the co-occ choices as well
return predict(
**user_inputs,
co_occ_feature1=co_occ_feature1,
co_occ_feature2=co_occ_feature2,
co_occ_label=co_occ_label
)
custom_css = """
.gradio-container * {
color: #1B1212 !important;
}
"""
interface = gr.Interface(
fn=predict_with_text,
inputs=inputs,
outputs=outputs,
title="Mental Health Screening (NSDUH) with Selective Co-Occurrence",
css=custom_css,
description="""
**Instructions**:
1. Fill out all required fields regarding MDE/Substance Use/Consultations/Suicidal/Impairments.
2. (Optional) Choose 2 features and 1 label for the *Co-occurrence* plot.
- If you do not select them (or leave them as "None"), that plot will be skipped.
3. Click "Submit" to get predictions, severity, distribution plots, nearest neighbors, and your custom co-occurrence chart.
"""
)
if __name__ == "__main__":
interface.launch()