import pickle import gradio as gr import numpy as np import pandas as pd import plotly.express as px # Load the training CSV once. df = pd.read_csv("X_train_Y_Train_merged_train.csv") ###################################### # 1) MODEL PREDICTOR CLASS ###################################### class ModelPredictor: def __init__(self, model_path, model_filenames): self.model_path = model_path self.model_filenames = model_filenames self.models = self.load_models() # Mapping from label column to human-readable strings for 0/1 self.prediction_map = { "YOWRCONC": ["No difficulty concentrating", "Had difficulty concentrating"], "YOSEEDOC": ["No need to see doctor", "Needed to see doctor"], "YOWRHRS": ["No trouble sleeping", "Had trouble sleeping"], "YO_MDEA5": ["Others didn't notice restlessness", "Others noticed restlessness"], "YOWRCHR": ["Not sad beyond cheering", "Felt so sad no one could cheer up"], "YOWRLSIN": ["Never felt bored/lost interest", "Felt bored/lost interest"], "YODPPROB": ["No other problems for 2+ weeks", "Had other problems for 2+ weeks"], "YOWRPROB": ["No worst time feeling", "Felt worst time ever"], "YODPR2WK": ["No depressed feelings for 2+ wks", "Depressed feelings for 2+ wks"], "YOWRDEPR": ["Not sad or depressed most days", "Sad or depressed most days"], "YODPDISC": ["Mood not depressed overall", "Mood depressed overall (discrepancy)"], "YOLOSEV": ["Did not lose interest in activities", "Lost interest in activities"], "YOWRDCSN": ["Could make decisions", "Could not make decisions"], "YODSMMDE": ["No 2+ week depression episodes", "Had 2+ week depression episodes"], "YO_MDEA3": ["No appetite/weight changes", "Yes appetite/weight changes"], "YODPLSIN": ["Never bored/lost interest", "Often bored/lost interest"], "YOWRELES": ["Did not eat less", "Ate less than usual"], "YODSCEV": ["Fewer severe symptoms", "More severe symptoms"], "YOPB2WK": ["No uneasy feelings daily 2+ wks", "Uneasy feelings daily 2+ wks"], "YO_MDEA2": ["No issues physical/mental daily", "Issues physical/mental daily 2+ wks"] } def load_models(self): models = [] for fn in self.model_filenames: filepath = self.model_path + fn with open(filepath, "rb") as file: models.append(pickle.load(file)) return models def make_predictions(self, user_input): """Return list of numpy arrays, each array either [0] or [1].""" preds = [] for m in self.models: out = m.predict(user_input) preds.append(np.array(out).flatten()) return preds def get_majority_vote(self, predictions): """Flatten all predictions and find 0 or 1 with majority.""" combined = np.concatenate(predictions) return np.bincount(combined).argmax() def evaluate_severity(self, majority_vote_count): """Heuristic: Based on 16 total models, 0-4=Very Low, 5-8=Low, 9-12=Moderate, 13-16=Severe.""" if majority_vote_count >= 13: return "Mental health severity: Severe" elif majority_vote_count >= 9: return "Mental health severity: Moderate" elif majority_vote_count >= 5: return "Mental health severity: Low" else: return "Mental health severity: Very Low" ###################################### # 2) CONFIGURATIONS ###################################### model_filenames = [ "YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl", "YODPPROB.pkl", "YOWRPROB.pkl", "YODPR2WK.pkl", "YOWRDEPR.pkl", "YODPDISC.pkl", "YOLOSEV.pkl", "YOWRDCSN.pkl", "YODSMMDE.pkl", "YO_MDEA3.pkl", "YODPLSIN.pkl", "YOWRELES.pkl", "YOPB2WK.pkl" ] model_path = "models/" predictor = ModelPredictor(model_path, model_filenames) ###################################### # 3) INPUT VALIDATION ###################################### def validate_inputs(*args): # Just ensure all required (non-co-occurrence) fields are picked for arg in args: if arg == '' or arg is None: return False return True ###################################### # 4) PREDICTION FUNCTION ###################################### def predict( # Original required features YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX, YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY, YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE, YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK, YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR, # **New** optional picks for co-occurrence co_occ_feature1, co_occ_feature2, co_occ_label ): """ Main function that: - Predicts with the 16 models - Aggregates results - Produces severity - Returns distribution & bar charts - Finds K=2 Nearest Neighbors - Produces *one* co-occurrence plot based on user-chosen columns """ # 1) Build user_input for models user_input_data = { 'YNURSMDE': [int(YNURSMDE)], 'YMDEYR': [int(YMDEYR)], 'YSOCMDE': [int(YSOCMDE)], 'YMDESUD5ANYO': [int(YMDESUD5ANYO)], 'YMSUD5YANY': [int(YMSUD5YANY)], 'YUSUITHK': [int(YUSUITHK)], 'YMDETXRX': [int(YMDETXRX)], 'YUSUITHKYR': [int(YUSUITHKYR)], 'YMDERSUD5ANY': [int(YMDERSUD5ANY)], 'YUSUIPLNYR': [int(YUSUIPLNYR)], 'YCOUNMDE': [int(YCOUNMDE)], 'YPSY1MDE': [int(YPSY1MDE)], 'YHLTMDE': [int(YHLTMDE)], 'YDOCMDE': [int(YDOCMDE)], 'YPSY2MDE': [int(YPSY2MDE)], 'YMDEHARX': [int(YMDEHARX)], 'LVLDIFMEM2': [int(LVLDIFMEM2)], 'MDEIMPY': [int(MDEIMPY)], 'YMDEHPO': [int(YMDEHPO)], 'YMIMS5YANY': [int(YMIMS5YANY)], 'YMDEIMAD5YR': [int(YMDEIMAD5YR)], 'YMIUD5YANY': [int(YMIUD5YANY)], 'YMDEHPRX': [int(YMDEHPRX)], 'YMIMI5YANY': [int(YMIMI5YANY)], 'YUSUIPLN': [int(YUSUIPLN)], 'YTXMDEYR': [int(YTXMDEYR)], 'YMDEAUD5YR': [int(YMDEAUD5YR)], 'YRXMDEYR': [int(YRXMDEYR)], 'YMDELT': [int(YMDELT)] } user_input = pd.DataFrame(user_input_data) # 2) Model Predictions predictions = predictor.make_predictions(user_input) majority_vote = predictor.get_majority_vote(predictions) majority_vote_count = np.sum(np.concatenate(predictions) == 1) severity = predictor.evaluate_severity(majority_vote_count) # 3) Summarize textual results results_by_group = { "Concentration_and_Decision_Making": [], "Sleep_and_Energy_Levels": [], "Mood_and_Emotional_State": [], "Appetite_and_Weight_Changes": [], "Duration_and_Severity_of_Depression_Symptoms": [] } group_map = { "Concentration_and_Decision_Making": ["YOWRCONC", "YOWRDCSN"], "Sleep_and_Energy_Levels": ["YOWRHRS", "YO_MDEA5", "YOWRELES", "YO_MDEA2"], "Mood_and_Emotional_State": ["YOWRCHR", "YOWRLSIN", "YOWRDEPR", "YODPDISC", "YOLOSEV", "YODPLSIN", "YODSCEV"], "Appetite_and_Weight_Changes": ["YO_MDEA3", "YOWRELES"], "Duration_and_Severity_of_Depression_Symptoms": ["YODPPROB", "YOWRPROB", "YODPR2WK", "YODSMMDE", "YOPB2WK"] } # Convert each model's 0/1 to text grouped_output_lines = [] for i, pred_array in enumerate(predictions): col_name = model_filenames[i].split(".")[0] # e.g., "YOWRCONC" val = pred_array[0] if col_name in predictor.prediction_map and val in [0, 1]: text = predictor.prediction_map[col_name][val] out_line = f"{col_name}: {text}" else: out_line = f"{col_name}: Prediction={val}" # Find group placed = False for g_key, g_cols in group_map.items(): if col_name in g_cols: results_by_group[g_key].append(out_line) placed = True break if not placed: # If it didn't fall into any known group, skip or handle pass # Format into a single string for group_label, pred_lines in results_by_group.items(): if pred_lines: grouped_output_lines.append(f"Group {group_label}:") grouped_output_lines.append("\n".join(pred_lines)) grouped_output_lines.append("") if len(grouped_output_lines) == 0: final_result_text = "No predictions made. Check inputs." else: final_result_text = "\n".join(grouped_output_lines).strip() # 4) Additional Features # A) Total patient count total_patients = len(df) total_count_md = ( "### Total Patient Count\n" f"**{total_patients}** total patients in the dataset." ) # B) Bar chart of how many have same inputs input_counts = {} for c in user_input_data.keys(): v = user_input_data[c][0] input_counts[c] = len(df[df[c] == v]) df_input_counts = pd.DataFrame({"Feature": list(input_counts.keys()), "Count": list(input_counts.values())}) fig_input_bar = px.bar( df_input_counts, x="Feature", y="Count", title="Number of Patients with the Same Value for Each Input Feature" ) fig_input_bar.update_layout(xaxis={"categoryorder": "total descending"}) # C) Bar chart for predicted labels label_counts = {} for i, pred_array in enumerate(predictions): col_name = model_filenames[i].split(".")[0] val = pred_array[0] if val in [0,1]: label_counts[col_name] = len(df[df[col_name] == val]) if len(label_counts) > 0: df_label_counts = pd.DataFrame({ "Label Column": list(label_counts.keys()), "Count": list(label_counts.values()) }) fig_label_bar = px.bar( df_label_counts, x="Label Column", y="Count", title="Number of Patients with the Same Predicted Label" ) else: fig_label_bar = px.bar(title="No valid predicted labels to display") # D) Simple Distribution Plot (demo for first 3 labels & 4 inputs) # (Unchanged from prior approach; you can remove if you prefer.) sample_feats = list(user_input_data.keys())[:31] sample_labels = [fn.split(".")[0] for fn in model_filenames[:15]] dist_segments = [] for feat in sample_feats: if feat not in df.columns: continue for lbl in sample_labels: if lbl not in df.columns: continue temp_g = df.groupby([feat,lbl]).size().reset_index(name="count") temp_g["feature"] = feat temp_g["label"] = lbl dist_segments.append(temp_g) if len(dist_segments) > 0: big_dist_df = pd.concat(dist_segments, ignore_index=True) fig_dist = px.bar( big_dist_df, x=big_dist_df.columns[0], y="count", color=big_dist_df.columns[1], facet_row="feature", facet_col="label", title="Sample Distribution Plot (first 4 features vs first 3 labels)" ) fig_dist.update_layout(height=700) else: fig_dist = px.bar(title="No distribution plot generated (columns not found).") # E) Nearest Neighbors with K=2 # We keep K=2, but for *all* label columns, we show their actual 0/1 or mapped text # (same approach as before). # ... [omitted here for brevity, or replicate your existing code for K=2 nearest neighbors] ... # We'll do a short version to keep focus on co-occ: # --------------------------------------------------------------------- # Build Hamming distance across user_input columns columns_for_distance = list(user_input.columns) sub_df = df[columns_for_distance].copy() user_row = user_input.iloc[0] distances = [] for idx, row_ in sub_df.iterrows(): dist_ = sum(row_[col] != user_row[col] for col in columns_for_distance) distances.append(dist_) df_dist = df.copy() df_dist["distance"] = distances # Sort ascending, pick K=2 K = 2 nearest_neighbors = df_dist.sort_values("distance", ascending=True).head(K) # Summarize in Markdown nn_md = ["### Nearest Neighbors (K=2)"] nn_md.append("(In a real application, you'd refine which features matter, how to encode them, etc.)\n") for irow in nearest_neighbors.itertuples(): nn_md.append(f"- **Neighbor ID {irow.Index}**: distance={irow.distance}") nn_md_str = "\n".join(nn_md) # F) Co-occurrence Plot for user-chosen feature1, feature2, label # If the user picks "None" or doesn't pick valid columns, skip or fallback. if (co_occ_feature1 is not None and co_occ_feature1 != "None" and co_occ_feature2 is not None and co_occ_feature2 != "None" and co_occ_label is not None and co_occ_label != "None"): # Check if these columns are in df if (co_occ_feature1 in df.columns and co_occ_feature2 in df.columns and co_occ_label in df.columns): # Group by [co_occ_feature1, co_occ_feature2, co_occ_label] co_data = df.groupby([co_occ_feature1, co_occ_feature2, co_occ_label]).size().reset_index(name="count") fig_co_occ = px.bar( co_data, x=co_occ_feature1, y="count", color=co_occ_label, facet_col=co_occ_feature2, title=f"Co-occurrence: {co_occ_feature1} & {co_occ_feature2} vs {co_occ_label}" ) else: fig_co_occ = px.bar(title="One or more selected columns not found in dataframe.") else: fig_co_occ = px.bar(title="No co-occurrence plot (choose two features + one label).") # Return all 8 outputs return ( final_result_text, # (1) Predictions severity, # (2) Severity total_count_md, # (3) Total patient count fig_dist, # (4) Distribution Plot nn_md_str, # (5) Nearest Neighbors fig_co_occ, # (6) Co-occurrence fig_input_bar, # (7) Bar Chart (input features) fig_label_bar # (8) Bar Chart (labels) ) ###################################### # 5) MAPPING (user -> int) ###################################### input_mapping = { 'YNURSMDE': {"Yes": 1, "No": 0}, 'YMDEYR': {"Yes": 1, "No": 2}, 'YSOCMDE': {"Yes": 1, "No": 0}, 'YMDESUD5ANYO': {"SUD only": 1, "MDE only": 2, "SUD & MDE": 3, "Neither": 4}, 'YMSUD5YANY': {"Yes": 1, "No": 0}, 'YUSUITHK': {"Yes": 1, "No": 2, "Unsure": 3, "Don't want to answer": 4}, 'YMDETXRX': {"Yes": 1, "No": 0}, 'YUSUITHKYR': {"Yes": 1, "No": 2, "Unsure": 3, "Don't want to answer": 4}, 'YMDERSUD5ANY': {"Yes": 1, "No": 0}, 'YUSUIPLNYR': {"Yes": 1, "No": 2, "Unsure": 3, "Don't want to answer": 4}, 'YCOUNMDE': {"Yes": 1, "No": 0}, 'YPSY1MDE': {"Yes": 1, "No": 0}, 'YHLTMDE': {"Yes": 1, "No": 0}, 'YDOCMDE': {"Yes": 1, "No": 0}, 'YPSY2MDE': {"Yes": 1, "No": 0}, 'YMDEHARX': {"Yes": 1, "No": 0}, 'LVLDIFMEM2': {"No Difficulty": 1, "Some Difficulty": 2, "A lot or cannot do": 3}, 'MDEIMPY': {"Yes": 1, "No": 2}, 'YMDEHPO': {"Yes": 1, "No": 0}, 'YMIMS5YANY': {"Yes": 1, "No": 0}, 'YMDEIMAD5YR': {"Yes": 1, "No": 0}, 'YMIUD5YANY': {"Yes": 1, "No": 0}, 'YMDEHPRX': {"Yes": 1, "No": 0}, 'YMIMI5YANY': {"Yes": 1, "No": 0}, 'YUSUIPLN': {"Yes": 1, "No": 2, "Unsure": 3, "Don't want to answer": 4}, 'YTXMDEYR': {"Yes": 1, "No": 0}, 'YMDEAUD5YR': {"Yes": 1, "No": 0}, 'YRXMDEYR': {"Yes": 1, "No": 0}, 'YMDELT': {"Yes": 1, "No": 2} } ###################################### # 6) THE GRADIO INTERFACE ###################################### import gradio as gr # (A) The original required inputs original_inputs = [ gr.Dropdown(list(input_mapping['YMDEYR'].keys()), label="YMDEYR: Past Year MDE?"), gr.Dropdown(list(input_mapping['YMDERSUD5ANY'].keys()), label="YMDERSUD5ANY: MDE or SUD - ANY?"), gr.Dropdown(list(input_mapping['YMDEIMAD5YR'].keys()), label="YMDEIMAD5YR: MDE + ALCOHOL?"), gr.Dropdown(list(input_mapping['YMIMS5YANY'].keys()), label="YMIMS5YANY: MDE + SUBSTANCE?"), gr.Dropdown(list(input_mapping['YMDELT'].keys()), label="YMDELT: MDE in Lifetime?"), gr.Dropdown(list(input_mapping['YMDEHARX'].keys()), label="YMDEHARX: Saw Health Prof + Meds?"), gr.Dropdown(list(input_mapping['YMDEHPRX'].keys()), label="YMDEHPRX: Saw Health Prof or Meds?"), gr.Dropdown(list(input_mapping['YMDETXRX'].keys()), label="YMDETXRX: Received Treatment?"), gr.Dropdown(list(input_mapping['YMDEHPO'].keys()), label="YMDEHPO: Saw Health Prof Only?"), gr.Dropdown(list(input_mapping['YMDEAUD5YR'].keys()), label="YMDEAUD5YR: MDE + Alcohol Use?"), gr.Dropdown(list(input_mapping['YMIMI5YANY'].keys()), label="YMIMI5YANY: MDE + ILL Drug Use?"), gr.Dropdown(list(input_mapping['YMIUD5YANY'].keys()), label="YMIUD5YANY: MDE + ILL Drug Use?"), gr.Dropdown(list(input_mapping['YMDESUD5ANYO'].keys()), label="YMDESUD5ANYO: MDE vs SUD vs BOTH vs NEITHER"), # Consultations gr.Dropdown(list(input_mapping['YNURSMDE'].keys()), label="YNURSMDE: Nurse/OT about MDE?"), gr.Dropdown(list(input_mapping['YSOCMDE'].keys()), label="YSOCMDE: Social Worker?"), gr.Dropdown(list(input_mapping['YCOUNMDE'].keys()), label="YCOUNMDE: Counselor?"), gr.Dropdown(list(input_mapping['YPSY1MDE'].keys()), label="YPSY1MDE: Psychologist?"), gr.Dropdown(list(input_mapping['YPSY2MDE'].keys()), label="YPSY2MDE: Psychiatrist?"), gr.Dropdown(list(input_mapping['YHLTMDE'].keys()), label="YHLTMDE: Health Prof?"), gr.Dropdown(list(input_mapping['YDOCMDE'].keys()), label="YDOCMDE: GP/Family MD?"), gr.Dropdown(list(input_mapping['YTXMDEYR'].keys()), label="YTXMDEYR: Doctor/Health Prof?"), # Suicidal gr.Dropdown(list(input_mapping['YUSUITHKYR'].keys()), label="YUSUITHKYR: Serious Suicide Thoughts?"), gr.Dropdown(list(input_mapping['YUSUIPLNYR'].keys()), label="YUSUIPLNYR: Made Plans?"), gr.Dropdown(list(input_mapping['YUSUITHK'].keys()), label="YUSUITHK: Suicide Thoughts (12 mo)?"), gr.Dropdown(list(input_mapping['YUSUIPLN'].keys()), label="YUSUIPLN: Made Plans (12 mo)?"), # Impairments gr.Dropdown(list(input_mapping['MDEIMPY'].keys()), label="MDEIMPY: Severe Role Impairment?"), gr.Dropdown(list(input_mapping['LVLDIFMEM2'].keys()), label="LVLDIFMEM2: Difficulty Remembering/Concentrating?"), gr.Dropdown(list(input_mapping['YMSUD5YANY'].keys()), label="YMSUD5YANY: MDE + Substance?"), gr.Dropdown(list(input_mapping['YRXMDEYR'].keys()), label="YRXMDEYR: Used Meds for MDE (12 mo)?"), ] # (B) The new co-occurrence inputs # We'll give them defaults of "None" to indicate no selection. all_cols = ["None"] + df.columns.tolist() # 'None' plus the actual columns from your df co_occ_feature1 = gr.Dropdown(all_cols, label="Co-Occ Feature 1", value="None") co_occ_feature2 = gr.Dropdown(all_cols, label="Co-Occ Feature 2", value="None") all_label_cols = ["None"] + list(predictor.prediction_map.keys()) # e.g., "YOWRCONC", "YOWRHRS", ... co_occ_label = gr.Dropdown(all_label_cols, label="Co-Occ Label", value="None") # Combine them into a single input list inputs = original_inputs + [co_occ_feature1, co_occ_feature2, co_occ_label] # 8 outputs as before outputs = [ gr.Textbox(label="Prediction Results", lines=15), gr.Textbox(label="Mental Health Severity", lines=2), gr.Markdown(label="Total Patient Count"), gr.Plot(label="Distribution Plot (Sample)"), gr.Markdown(label="Nearest Neighbors (K=2)"), gr.Plot(label="Co-occurrence Plot"), gr.Plot(label="Same Value Bar (Inputs)"), gr.Plot(label="Predicted Label Bar") ] ###################################### # 7) WRAPPER ###################################### def predict_with_text( # match the function signature exactly (29 required + 3 for co-occ) YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX, YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY, YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE, YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK, YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR, co_occ_feature1, co_occ_feature2, co_occ_label ): # Validate the original 29 fields valid = validate_inputs( YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX, YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY, YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE, YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK, YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR ) if not valid: return ( "Please select all required fields.", "Validation Error", "No data", None, "No data", None, None, None ) # Map to numeric user_inputs = { 'YNURSMDE': input_mapping['YNURSMDE'][YNURSMDE], 'YMDEYR': input_mapping['YMDEYR'][YMDEYR], 'YSOCMDE': input_mapping['YSOCMDE'][YSOCMDE], 'YMDESUD5ANYO': input_mapping['YMDESUD5ANYO'][YMDESUD5ANYO], 'YMSUD5YANY': input_mapping['YMSUD5YANY'][YMSUD5YANY], 'YUSUITHK': input_mapping['YUSUITHK'][YUSUITHK], 'YMDETXRX': input_mapping['YMDETXRX'][YMDETXRX], 'YUSUITHKYR': input_mapping['YUSUITHKYR'][YUSUITHKYR], 'YMDERSUD5ANY': input_mapping['YMDERSUD5ANY'][YMDERSUD5ANY], 'YUSUIPLNYR': input_mapping['YUSUIPLNYR'][YUSUIPLNYR], 'YCOUNMDE': input_mapping['YCOUNMDE'][YCOUNMDE], 'YPSY1MDE': input_mapping['YPSY1MDE'][YPSY1MDE], 'YHLTMDE': input_mapping['YHLTMDE'][YHLTMDE], 'YDOCMDE': input_mapping['YDOCMDE'][YDOCMDE], 'YPSY2MDE': input_mapping['YPSY2MDE'][YPSY2MDE], 'YMDEHARX': input_mapping['YMDEHARX'][YMDEHARX], 'LVLDIFMEM2': input_mapping['LVLDIFMEM2'][LVLDIFMEM2], 'MDEIMPY': input_mapping['MDEIMPY'][MDEIMPY], 'YMDEHPO': input_mapping['YMDEHPO'][YMDEHPO], 'YMIMS5YANY': input_mapping['YMIMS5YANY'][YMIMS5YANY], 'YMDEIMAD5YR': input_mapping['YMDEIMAD5YR'][YMDEIMAD5YR], 'YMIUD5YANY': input_mapping['YMIUD5YANY'][YMIUD5YANY], 'YMDEHPRX': input_mapping['YMDEHPRX'][YMDEHPRX], 'YMIMI5YANY': input_mapping['YMIMI5YANY'][YMIMI5YANY], 'YUSUIPLN': input_mapping['YUSUIPLN'][YUSUIPLN], 'YTXMDEYR': input_mapping['YTXMDEYR'][YTXMDEYR], 'YMDEAUD5YR': input_mapping['YMDEAUD5YR'][YMDEAUD5YR], 'YRXMDEYR': input_mapping['YRXMDEYR'][YRXMDEYR], 'YMDELT': input_mapping['YMDELT'][YMDELT] } # Call the core predict function with the co-occ choices as well return predict( **user_inputs, co_occ_feature1=co_occ_feature1, co_occ_feature2=co_occ_feature2, co_occ_label=co_occ_label ) custom_css = """ .gradio-container * { color: #1B1212 !important; } """ interface = gr.Interface( fn=predict_with_text, inputs=inputs, outputs=outputs, title="Mental Health Screening (NSDUH) with Selective Co-Occurrence", css=custom_css, description=""" **Instructions**: 1. Fill out all required fields regarding MDE/Substance Use/Consultations/Suicidal/Impairments. 2. (Optional) Choose 2 features and 1 label for the *Co-occurrence* plot. - If you do not select them (or leave them as "None"), that plot will be skipped. 3. Click "Submit" to get predictions, severity, distribution plots, nearest neighbors, and your custom co-occurrence chart. """ ) if __name__ == "__main__": interface.launch()