import pickle import numpy as np import pandas as pd import plotly.express as px import gradio as gr ###################################### # 1) Load Data & Prepare ###################################### df = pd.read_csv("X_train_Y_Train_merged_train.csv") # List of model filenames (adjust if needed) model_filenames = [ "YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl", "YODPPROB.pkl", "YOWRPROB.pkl", "YODPR2WK.pkl", "YOWRDEPR.pkl", "YODPDISC.pkl", "YOLOSEV.pkl", "YOWRDCSN.pkl", "YODSMMDE.pkl", "YO_MDEA3.pkl", "YODPLSIN.pkl", "YOWRELES.pkl", "YOPB2WK.pkl" ] model_path = "models/" ###################################### # 2) Model Predictor ###################################### class ModelPredictor: def __init__(self, model_path, model_filenames): self.model_path = model_path self.model_filenames = model_filenames self.models = self.load_models() # Mapping from label column to human-readable strings for 0/1 self.prediction_map = { "YOWRCONC": ["Did not have difficulty concentrating", "Had difficulty concentrating"], "YOSEEDOC": ["Did not feel the need to see a doctor", "Felt the need to see a doctor"], "YOWRHRS": ["Did not have trouble sleeping", "Had trouble sleeping"], "YO_MDEA5": ["Others did not notice restlessness/lethargy", "Others noticed restlessness/lethargy"], "YOWRCHR": ["Did not feel so sad", "Felt so sad nothing could cheer up"], "YOWRLSIN": ["Did not feel bored and lose interest", "Felt bored and lost interest"], "YODPPROB": ["No other problems for 2+ weeks", "Had other problems for 2+ weeks"], "YOWRPROB": ["Did not have the worst time ever feeling", "Had the worst time ever feeling"], "YODPR2WK": ["No periods of 2+ weeks feelings", "Had periods of 2+ weeks feelings"], "YOWRDEPR": ["Did not feel depressed mostly everyday", "Felt depressed mostly everyday"], "YODPDISC": ["Mood not depressed overall", "Mood depressed overall (discrepancy)"], "YOLOSEV": ["Did not lose interest in enjoyable things", "Lost interest in enjoyable things"], "YOWRDCSN": ["Was able to make decisions", "Was unable to make decisions"], "YODSMMDE": ["Never had depression for 2+ weeks", "Had depression for 2+ weeks"], "YO_MDEA3": ["No appetite/weight changes", "Had appetite/weight changes"], "YODPLSIN": ["Never bored/lost interest", "Felt bored/lost interest"], "YOWRELES": ["Did not eat less than usual", "Ate less than usual"], "YODSCEV": ["Fewer severe symptoms", "More severe symptoms"], "YOPB2WK": ["No uneasy feelings 2+ weeks", "Had uneasy feelings 2+ weeks"], "YO_MDEA2": ["No issues w/ physical/mental well-being", "Issues w/ physical/mental well-being"] } def load_models(self): models = [] for filename in model_filenames: filepath = self.model_path + filename with open(filepath, 'rb') as file: model = pickle.load(file) models.append(model) return models def make_predictions(self, user_input): """ Returns a list of numpy arrays, each array is [0] or [1]. The i-th array corresponds to the i-th model in self.models. """ predictions = [] for model in self.models: pred = model.predict(user_input) predictions.append(pred.flatten()) return predictions def get_majority_vote(self, predictions): """ Flatten all predictions from all models, combine them, then find the majority class (0 or 1). """ combined = np.concatenate(predictions) majority = np.bincount(combined).argmax() return majority # Simple threshold approach (0-4 => Very Low, 5-8 => Low, etc.) def evaluate_severity(self, majority_vote_count): if majority_vote_count >= 13: return "Mental Health Severity: Severe" elif majority_vote_count >= 9: return "Mental Health Severity: Moderate" elif majority_vote_count >= 5: return "Mental Health Severity: Low" else: return "Mental Health Severity: Very Low" ###################################### # 3) Validate Inputs ###################################### def validate_inputs(*args): for arg in args: if arg == '' or arg is None: return False return True ###################################### # 4) Core Prediction ###################################### predictor = ModelPredictor(model_path, model_filenames) def predict( YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX, YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY, YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE, YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK, YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR ): # Validate if not validate_inputs( YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX, YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY, YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE, YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK, YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR ): return ( "Please select all required fields.", "Validation Error", "No data", None, "No data", None, None, None ) # Build dataframe from user inputs user_input_data = { 'YNURSMDE': [int(YNURSMDE)], 'YMDEYR': [int(YMDEYR)], 'YSOCMDE': [int(YSOCMDE)], 'YMDESUD5ANYO': [int(YMDESUD5ANYO)], 'YMSUD5YANY': [int(YMSUD5YANY)], 'YUSUITHK': [int(YUSUITHK)], 'YMDETXRX': [int(YMDETXRX)], 'YUSUITHKYR': [int(YUSUITHKYR)], 'YMDERSUD5ANY': [int(YMDERSUD5ANY)], 'YUSUIPLNYR': [int(YUSUIPLNYR)], 'YCOUNMDE': [int(YCOUNMDE)], 'YPSY1MDE': [int(YPSY1MDE)], 'YHLTMDE': [int(YHLTMDE)], 'YDOCMDE': [int(YDOCMDE)], 'YPSY2MDE': [int(YPSY2MDE)], 'YMDEHARX': [int(YMDEHARX)], 'LVLDIFMEM2': [int(LVLDIFMEM2)], 'MDEIMPY': [int(MDEIMPY)], 'YMDEHPO': [int(YMDEHPO)], 'YMIMS5YANY': [int(YMIMS5YANY)], 'YMDEIMAD5YR': [int(YMDEIMAD5YR)], 'YMIUD5YANY': [int(YMIUD5YANY)], 'YMDEHPRX': [int(YMDEHPRX)], 'YMIMI5YANY': [int(YMIMI5YANY)], 'YUSUIPLN': [int(YUSUIPLN)], 'YTXMDEYR': [int(YTXMDEYR)], 'YMDEAUD5YR': [int(YMDEAUD5YR)], 'YRXMDEYR': [int(YRXMDEYR)], 'YMDELT': [int(YMDELT)] } user_input = pd.DataFrame(user_input_data) # 1) Predictions predictions = predictor.make_predictions(user_input) # 2) Majority vote majority_vote = predictor.get_majority_vote(predictions) # 3) Count of '1's num_ones = sum(np.concatenate(predictions) == 1) # 4) Severity severity = predictor.evaluate_severity(num_ones) # 5) Group textual results groups = { "Concentration_and_Decision_Making": ["YOWRCONC", "YOWRDCSN"], "Sleep_and_Energy_Levels": ["YOWRHRS", "YO_MDEA5", "YOWRELES", "YO_MDEA2"], "Mood_and_Emotional_State": ["YOWRCHR", "YOWRLSIN", "YOWRDEPR", "YODPDISC", "YOLOSEV", "YODPLSIN", "YODSCEV"], "Appetite_and_Weight_Changes": ["YO_MDEA3", "YOWRELES"], "Duration_and_Severity_of_Depression_Symptoms": ["YODPPROB", "YOWRPROB", "YODPR2WK", "YODSMMDE", "YOPB2WK"] } grouped_text = {k: [] for k in groups} for i, arr in enumerate(predictions): col_name = model_filenames[i].split('.')[0] pred_val = arr[0] if col_name in predictor.prediction_map and pred_val in [0,1]: text_val = predictor.prediction_map[col_name][pred_val] else: text_val = f"Prediction={pred_val}" found_group = False for gname, gcols in groups.items(): if col_name in gcols: grouped_text[gname].append(f"{col_name} => {text_val}") found_group = True break # If not found_group, we do nothing (skip or put in a "misc" group) final_str = [] for gname, items in grouped_text.items(): if items: final_str.append(f"**{gname.replace('_',' ')}**") final_str.append("\n".join(items)) final_str.append("\n") final_str = "\n".join(final_str).strip() if not final_str: final_str = "No predictions made. Please check inputs." # Additional info total_patients = len(df) total_patient_markdown = ( f"### Total Patient Count\nThere are **{total_patients}** patients in the dataset." ) # A) Bar chart for input features same_val_counts = {} for col, val_list in user_input_data.items(): val_ = val_list[0] same_val_counts[col] = len(df[df[col] == val_]) bar_input_df = pd.DataFrame({"Feature": list(same_val_counts.keys()), "Count": list(same_val_counts.values())}) fig_bar_input = px.bar( bar_input_df, x="Feature", y="Count", title="Number of Patients with Same Input Feature Values" ) fig_bar_input.update_layout(width=800, height=500) # B) Bar chart for predicted labels label_counts = {} for i, arr in enumerate(predictions): lbl_col = model_filenames[i].split('.')[0] pred_val = arr[0] if pred_val in [0,1]: label_counts[lbl_col] = len(df[df[lbl_col] == pred_val]) if label_counts: bar_label_df = pd.DataFrame({"Label": list(label_counts.keys()), "Count": list(label_counts.values())}) fig_bar_labels = px.bar(bar_label_df, x="Label", y="Count", title="Number of Patients with the Same Predicted Label") fig_bar_labels.update_layout(width=800, height=500) else: fig_bar_labels = px.bar(title="No valid predicted labels to display.") fig_bar_labels.update_layout(width=800, height=500) # C) Distribution Plot (small sample) subset_input_cols = list(user_input_data.keys())[:4] # first 4 input columns subset_labels = [fn.split('.')[0] for fn in model_filenames[:3]] # first 3 label columns dist_rows = [] for feat in subset_input_cols: if feat not in df.columns: continue for label_col in subset_labels: if label_col not in df.columns: continue tmp = df.groupby([feat, label_col]).size().reset_index(name="count") tmp["feature"] = feat tmp["label"] = label_col dist_rows.append(tmp) if dist_rows: big_dist_df = pd.concat(dist_rows, ignore_index=True) fig_dist = px.bar( big_dist_df, x=big_dist_df.columns[0], y="count", color=big_dist_df.columns[1], facet_row="feature", facet_col="label", title="Distribution of Sample Input Features vs. Sample Predicted Labels" ) fig_dist.update_layout(width=1000, height=700) else: fig_dist = px.bar(title="Distribution plot not generated.") # D) Nearest neighbors (placeholder or your own logic) nearest_neighbors_markdown = "Nearest neighbors omitted or placed here if needed..." # We won't produce a co-occurrence plot by default here, so set to None co_occurrence_placeholder = None # Return the 8 outputs return ( final_str, # 1) Prediction Results severity, # 2) Mental Health Severity total_patient_markdown, # 3) Total Patient Count fig_dist, # 4) Distribution Plot nearest_neighbors_markdown, # 5) Nearest Neighbors co_occurrence_placeholder, # 6) Co-occurrence Plot placeholder fig_bar_input, # 7) Bar Chart for input features fig_bar_labels # 8) Bar Chart for predicted labels ) ###################################### # 5) Input Mapping ###################################### input_mapping = { 'YNURSMDE': {"Yes": 1, "No": 0}, 'YMDEYR': {"Yes": 1, "No": 2}, 'YSOCMDE': {"Yes": 1, "No": 0}, 'YMDESUD5ANYO': {"SUD only, no MDE": 1, "MDE only, no SUD": 2, "SUD and MDE": 3, "Neither SUD or MDE": 4}, 'YMSUD5YANY': {"Yes": 1, "No": 0}, 'YUSUITHK': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4}, 'YMDETXRX': {"Yes": 1, "No": 0}, 'YUSUITHKYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4}, 'YMDERSUD5ANY': {"Yes": 1, "No": 0}, 'YUSUIPLNYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4}, 'YCOUNMDE': {"Yes": 1, "No": 0}, 'YPSY1MDE': {"Yes": 1, "No": 0}, 'YHLTMDE': {"Yes": 1, "No": 0}, 'YDOCMDE': {"Yes": 1, "No": 0}, 'YPSY2MDE': {"Yes": 1, "No": 0}, 'YMDEHARX': {"Yes": 1, "No": 0}, 'LVLDIFMEM2': {"No Difficulty": 1, "Some difficulty": 2, "A lot of difficulty or cannot do at all": 3}, 'MDEIMPY': {"Yes": 1, "No": 2}, 'YMDEHPO': {"Yes": 1, "No": 0}, 'YMIMS5YANY': {"Yes": 1, "No": 0}, 'YMDEIMAD5YR': {"Yes": 1, "No": 0}, 'YMIUD5YANY': {"Yes": 1, "No": 0}, 'YMDEHPRX': {"Yes": 1, "No": 0}, 'YMIMI5YANY': {"Yes": 1, "No": 0}, 'YUSUIPLN': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4}, 'YTXMDEYR': {"Yes": 1, "No": 0}, 'YMDEAUD5YR': {"Yes": 1, "No": 0}, 'YRXMDEYR': {"Yes": 1, "No": 0}, 'YMDELT': {"Yes": 1, "No": 2} } ###################################### # 6) Co-Occurrence Function ###################################### def co_occurrence_plot(feature1, feature2, label_col): """ Generate a single co-occurrence bar chart grouping by [feature1, feature2, label_col]. """ if not feature1 or not feature2 or not label_col: return px.bar(title="Please select all three fields.") if feature1 not in df.columns or feature2 not in df.columns or label_col not in df.columns: return px.bar(title="Selected columns not found in the dataset.") grouped_df = df.groupby([feature1, feature2, label_col]).size().reset_index(name="count") fig = px.bar( grouped_df, x=feature1, y="count", color=label_col, facet_col=feature2, title=f"Co-Occurrence Plot: {feature1} & {feature2} vs. {label_col}" ) fig.update_layout(width=1000, height=600) return fig ###################################### # 7) Gradio Interface with Tabs ###################################### with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo: with gr.Tab("Prediction"): # --------- INPUT FIELDS --------- # YMDEYR_dd = gr.Dropdown(list(input_mapping['YMDEYR'].keys()), label="YMDEYR") YMDERSUD5ANY_dd = gr.Dropdown(list(input_mapping['YMDERSUD5ANY'].keys()), label="YMDERSUD5ANY") YMDEIMAD5YR_dd = gr.Dropdown(list(input_mapping['YMDEIMAD5YR'].keys()), label="YMDEIMAD5YR") YMIMS5YANY_dd = gr.Dropdown(list(input_mapping['YMIMS5YANY'].keys()), label="YMIMS5YANY") YMDELT_dd = gr.Dropdown(list(input_mapping['YMDELT'].keys()), label="YMDELT") YMDEHARX_dd = gr.Dropdown(list(input_mapping['YMDEHARX'].keys()), label="YMDEHARX") YMDEHPRX_dd = gr.Dropdown(list(input_mapping['YMDEHPRX'].keys()), label="YMDEHPRX") YMDETXRX_dd = gr.Dropdown(list(input_mapping['YMDETXRX'].keys()), label="YMDETXRX") YMDEHPO_dd = gr.Dropdown(list(input_mapping['YMDEHPO'].keys()), label="YMDEHPO") YMDEAUD5YR_dd = gr.Dropdown(list(input_mapping['YMDEAUD5YR'].keys()), label="YMDEAUD5YR") YMIMI5YANY_dd = gr.Dropdown(list(input_mapping['YMIMI5YANY'].keys()), label="YMIMI5YANY") YMIUD5YANY_dd = gr.Dropdown(list(input_mapping['YMIUD5YANY'].keys()), label="YMIUD5YANY") YMDESUD5ANYO_dd = gr.Dropdown(list(input_mapping['YMDESUD5ANYO'].keys()), label="YMDESUD5ANYO") # Consultations YNURSMDE_dd = gr.Dropdown(list(input_mapping['YNURSMDE'].keys()), label="YNURSMDE") YSOCMDE_dd = gr.Dropdown(list(input_mapping['YSOCMDE'].keys()), label="YSOCMDE") YCOUNMDE_dd = gr.Dropdown(list(input_mapping['YCOUNMDE'].keys()), label="YCOUNMDE") YPSY1MDE_dd = gr.Dropdown(list(input_mapping['YPSY1MDE'].keys()), label="YPSY1MDE") YPSY2MDE_dd = gr.Dropdown(list(input_mapping['YPSY2MDE'].keys()), label="YPSY2MDE") YHLTMDE_dd = gr.Dropdown(list(input_mapping['YHLTMDE'].keys()), label="YHLTMDE") YDOCMDE_dd = gr.Dropdown(list(input_mapping['YDOCMDE'].keys()), label="YDOCMDE") YTXMDEYR_dd = gr.Dropdown(list(input_mapping['YTXMDEYR'].keys()), label="YTXMDEYR") # Suicidal thoughts/plans YUSUITHKYR_dd = gr.Dropdown(list(input_mapping['YUSUITHKYR'].keys()), label="YUSUITHKYR") YUSUIPLNYR_dd = gr.Dropdown(list(input_mapping['YUSUIPLNYR'].keys()), label="YUSUIPLNYR") YUSUITHK_dd = gr.Dropdown(list(input_mapping['YUSUITHK'].keys()), label="YUSUITHK") YUSUIPLN_dd = gr.Dropdown(list(input_mapping['YUSUIPLN'].keys()), label="YUSUIPLN") # Impairments MDEIMPY_dd = gr.Dropdown(list(input_mapping['MDEIMPY'].keys()), label="MDEIMPY") LVLDIFMEM2_dd = gr.Dropdown(list(input_mapping['LVLDIFMEM2'].keys()), label="LVLDIFMEM2") YMSUD5YANY_dd = gr.Dropdown(list(input_mapping['YMSUD5YANY'].keys()), label="YMSUD5YANY") YRXMDEYR_dd = gr.Dropdown(list(input_mapping['YRXMDEYR'].keys()), label="YRXMDEYR") # --------- PREDICT BUTTON (BEFORE OUTPUTS) --------- # predict_btn = gr.Button("Predict") # --------- OUTPUTS (IN THE SAME ORDER AS THE RETURN TUPLE) --------- # out_pred_res = gr.Textbox(label="Prediction Results", lines=8) out_sev = gr.Textbox(label="Mental Health Severity", lines=2) out_count = gr.Markdown(label="Total Patient Count") out_distplot = gr.Plot(label="Distribution Plot") out_nn = gr.Markdown(label="Nearest Neighbors Summary") out_cooc = gr.Plot(label="Co-occurrence Plot Placeholder") out_bar_input = gr.Plot(label="Input Feature Counts") out_bar_labels = gr.Plot(label="Predicted Label Counts") # Link button to the function predict_btn.click( fn=predict, inputs=[ YMDEYR_dd, YMDERSUD5ANY_dd, YMDEIMAD5YR_dd, YMIMS5YANY_dd, YMDELT_dd, YMDEHARX_dd, YMDEHPRX_dd, YMDETXRX_dd, YMDEHPO_dd, YMDEAUD5YR_dd, YMIMI5YANY_dd, YMIUD5YANY_dd, YMDESUD5ANYO_dd, YNURSMDE_dd, YSOCMDE_dd, YCOUNMDE_dd, YPSY1MDE_dd, YPSY2MDE_dd, YHLTMDE_dd, YDOCMDE_dd, YTXMDEYR_dd, YUSUITHKYR_dd, YUSUIPLNYR_dd, YUSUITHK_dd, YUSUIPLN_dd, MDEIMPY_dd, LVLDIFMEM2_dd, YMSUD5YANY_dd, YRXMDEYR_dd ], outputs=[ out_pred_res, out_sev, out_count, out_distplot, out_nn, out_cooc, out_bar_input, out_bar_labels ] ) # ------------- SECOND TAB (CO-OCCURRENCE) ------------- with gr.Tab("Co-occurrence"): gr.Markdown("## Generate a Co-Occurrence Plot on Demand\nSelect two features and one label:") with gr.Row(): feature1_dd = gr.Dropdown(sorted(df.columns), label="Feature 1") feature2_dd = gr.Dropdown(sorted(df.columns), label="Feature 2") label_dd = gr.Dropdown(sorted(df.columns), label="Label Column") out_co_occ_plot = gr.Plot(label="Co-occurrence Plot") co_occ_btn = gr.Button("Generate Plot") co_occ_btn.click( fn=co_occurrence_plot, inputs=[feature1_dd, feature2_dd, label_dd], outputs=out_co_occ_plot ) # Optionally, you can customize your CSS or server launch parameters demo.launch(server_name="0.0.0.0", server_port=7860)