import pickle import gradio as gr import numpy as np import pandas as pd import plotly.express as px ###################################### # 1) LOAD DATA & MODELS ###################################### # Load your dataset df = pd.read_csv("X_train_test_combined_dataset_Filtered_dataset.csv") # Ensure 'YMDESUD5ANYO' exists in your DataFrame if 'YMDESUD5ANYO' not in df.columns: raise ValueError("The column 'YMDESUD5ANYO' is missing from the dataset. Please check your CSV file.") # List of model filenames model_filenames = [ "YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl", "YODPPROB.pkl", "YOWRPROB.pkl", "YODPR2WK.pkl", "YOWRDEPR.pkl", "YODPDISC.pkl", "YOLOSEV.pkl", "YOWRDCSN.pkl", "YODSMMDE.pkl", "YO_MDEA3.pkl", "YODPLSIN.pkl", "YOWRELES.pkl", "YOPB2WK.pkl" ] model_path = "models/" ###################################### # 2) MODEL PREDICTOR ###################################### class ModelPredictor: def __init__(self, model_path, model_filenames): self.model_path = model_path self.model_filenames = model_filenames self.models = self.load_models() # Mapping each label (column) so that # - "1" = first item in list # - "2" = second item self.prediction_map = { "YOWRCONC": {2: "Did NOT have difficulty concentrating", 1: "Had difficulty concentrating"}, "YOSEEDOC": {2: "Did NOT feel the need to see a doctor", 1: "Felt the need to see a doctor"}, "YO_MDEA5": {2: "No restlessness/lethargy noticed", 1: "Others noticed restlessness/lethargy"}, "YOWRLSIN": {2: "Did NOT feel bored/lose interest", 1: "Felt bored/lost interest"}, "YODPPROB": {2: "No other problems for 2+ weeks", 1: "Had other problems for 2+ weeks"}, "YOWRPROB": {2: "No 'worst time ever' feeling", 1: "Had 'worst time ever' feeling"}, "YODPR2WK": {2: "No depressed feelings for 2+ wks", 1: "Had depressed feelings for 2+ wks"}, "YOWRDEPR": {2: "Did NOT feel sad/depressed daily", 1: "Felt sad/depressed mostly everyday"}, "YODPDISC": {2: "Overall mood not sad/depressed", 1: "Overall mood was sad/depressed"}, "YOLOSEV": {2: "Did NOT lose interest in things", 1: "Lost interest in enjoyable things"}, "YOWRDCSN": {2: "Was able to make decisions", 1: "Was unable to make decisions"}, "YODSMMDE": {2: "No 2+ wks depression symptoms", 1: "Had 2+ wks depression symptoms"}, "YO_MDEA3": {2: "No appetite/weight changes", 1: "Had changes in appetite/weight"}, "YODPLSIN": {2: "Never lost interest/felt bored", 1: "Lost interest/felt bored"}, "YOWRELES": {2: "Did NOT eat less than usual", 1: "Ate less than usual"}, "YOPB2WK": {2: "No uneasy feelings 2+ weeks", 1: "Uneasy feelings 2+ weeks"} } def load_models(self): loaded = [] for fname in self.model_filenames: try: with open(self.model_path + fname, "rb") as f: model = pickle.load(f) loaded.append(model) except FileNotFoundError: raise FileNotFoundError(f"Model file '{fname}' not found in path '{self.model_path}'.") except Exception as e: raise Exception(f"Error loading model '{fname}': {e}") return loaded def make_predictions(self, user_input: pd.DataFrame): """ Return: - A list of np.array [1/2], one for each model - A list of np.array [prob_of_2], if predict_proba is available, else np.nan IMPORTANT: This code assumes your model returns [1, 2]. If your model is returning [0, 1], you'll need a transform or re-train it to return [1, 2]. """ preds = [] probs = [] for model in self.models: y_pred = model.predict(user_input) # Suppose this returns [1 or 2]. preds.append(y_pred.flatten()) # If model can do predict_proba, we interpret the "2" class as the second column if hasattr(model, "predict_proba"): # Usually classes_ might be something like [0,1], but we want [1,2]. # You may need to check model.classes_ or do a small transformation. # For demonstration, we'll assume the second column is the "2" probability: # Example: y_prob_2 = model.predict_proba(user_input)[:, 1] # This might be inaccurate if the model classes_ = [1,2]. # Check model.classes_ to see if index=0 => class=1, index=1 => class=2 # If so: y_prob_2 = model.predict_proba(user_input)[:, 1] probs.append(y_prob_2) else: probs.append(np.full(len(user_input), np.nan)) return preds, probs def evaluate_severity(self, count_ones: int) -> str: """ Evaluate severity based on how many labels predicted = 1. The bigger the number of 1’s, the more severe the condition. """ if count_ones >= 13: return "Mental Health Severity: Severe" elif count_ones >= 9: return "Mental Health Severity: Moderate" elif count_ones >= 5: return "Mental Health Severity: Low" else: return "Mental Health Severity: Very Low" predictor = ModelPredictor(model_path, model_filenames) ###################################### # 3) FEATURE CATEGORIES + MAPPING ###################################### categories_dict = { "1. Depression & Substance Use Diagnosis": [ "YMDESUD5ANYO", "YMDELT", "YMDEYR", "YMDERSUD5ANY", "YMSUD5YANY", "YMIUD5YANY", "YMIMS5YANY", "YMIMI5YANY" ], "2. Mental Health Treatment & Prof Consultation": [ "YMDEHPO", "YMDETXRX", "YMDEHARX", "YMDEHPRX", "YRXMDEYR", "YHLTMDE", "YTXMDEYR", "YDOCMDE", "YPSY2MDE", "YPSY1MDE", "YCOUNMDE" ], "3. Functional & Cognitive Impairment": [ "MDEIMPY", "LVLDIFMEM2" ], "4. Suicidal Thoughts & Behaviors": [ "YUSUITHK", "YUSUITHKYR", "YUSUIPLNYR", "YUSUIPLN" ] } # NOTE: input_mapping below for capturing user choices => numeric codes. # If the models expect [1,2] for "Yes"/"No", ensure those are correct here too. input_mapping = { 'YMDESUD5ANYO': { "SUD only, no MDE": 1, "MDE only, no SUD": 2, "SUD and MDE": 3, "Neither SUD or MDE": 4 }, 'YMDELT': {"Yes": 1, "No": 2}, 'YMDEYR': {"Yes": 1, "No": 2}, 'YMDERSUD5ANY': {"Yes": 1, "No": 0}, 'YMSUD5YANY': {"Yes": 1, "No": 0}, 'YMIUD5YANY': {"Yes": 1, "No": 0}, 'YMIMS5YANY': {"Yes": 1, "No": 0}, 'YMIMI5YANY': {"Yes": 1, "No": 0}, 'YMDEHPO': {"Yes": 1, "No": 0}, 'YMDETXRX': {"Yes": 1, "No": 0}, 'YMDEHARX': {"Yes": 1, "No": 0}, 'YMDEHPRX': {"Yes": 1, "No": 0}, 'YRXMDEYR': {"Yes": 1, "No": 0}, 'YHLTMDE': {"Yes": 1, "No": 0}, 'YTXMDEYR': {"Yes": 1, "No": 0}, 'YDOCMDE': {"Yes": 1, "No": 0}, 'YPSY2MDE': {"Yes": 1, "No": 0}, 'YPSY1MDE': {"Yes": 1, "No": 0}, 'YCOUNMDE': {"Yes": 1, "No": 0}, 'MDEIMPY': {"Yes": 1, "No": 2}, 'LVLDIFMEM2': { "No Difficulty": 1, "Some difficulty": 2, "A lot of difficulty or cannot do at all": 3 }, 'YUSUITHK': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4}, 'YUSUITHKYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4}, 'YUSUIPLNYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4}, 'YUSUIPLN': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4} } def validate_inputs(*args): for arg in args: if arg is None or arg == "": return False return True ###################################### # 4) NEAREST NEIGHBORS ###################################### def get_nearest_neighbors_info(user_input_df: pd.DataFrame, k=5): user_cols = user_input_df.columns if not all(col in df.columns for col in user_cols): return "Cannot compute nearest neighbors. Some columns not found in df." sub_df = df[user_cols].copy() diffs = sub_df - user_input_df.iloc[0] dists = (diffs**2).sum(axis=1)**0.5 nn_indices = dists.nsmallest(k).index neighbors = df.loc[nn_indices] lines = [ f"**Nearest Neighbors (k={k})**", f"Distances range: {dists[nn_indices].min():.2f} to {dists[nn_indices].max():.2f}", "" ] # Show user input in numeric->text form lines.append("**User Input (numeric -> text)**") for col in user_cols: val_numeric = user_input_df.iloc[0][col] text_val = None if col in input_mapping: # Reverse-lookup to find textual label for txt_key, num_val in input_mapping[col].items(): if val_numeric == num_val: text_val = txt_key break if not text_val: text_val = f"{val_numeric} (no mapping found)" lines.append(f"- {col} = {val_numeric} => '{text_val}'") lines.append("") # Show label columns among neighbors label_cols = list(predictor.prediction_map.keys()) lines.append("**Label Distribution Among Neighbors**") for lbl in label_cols: if lbl not in neighbors.columns: continue val_counts = neighbors[lbl].value_counts().to_dict() parts = [] for val_, count_ in val_counts.items(): # If we only mapped [1,2], we check if val_ in [1,2] if lbl in predictor.prediction_map and val_ in [1,2]: label_text = predictor.prediction_map[lbl][val_] parts.append(f"{count_} had '{label_text}' (value={val_})") else: parts.append(f"{count_} had numeric={val_}") lines.append(f"- {lbl}: " + "; ".join(parts)) lines.append("") return "\n".join(lines) ###################################### # 5) PREDICT FUNCTION ###################################### def predict( # Category 1 (8): YMDESUD5ANYO, YMDELT, YMDEYR, YMDERSUD5ANY, YMSUD5YANY, YMIUD5YANY, YMIMS5YANY, YMIMI5YANY, # Category 2 (11): YMDEHPO, YMDETXRX, YMDEHARX, YMDEHPRX, YRXMDEYR, YHLTMDE, YTXMDEYR, YDOCMDE, YPSY2MDE, YPSY1MDE, YCOUNMDE, # Category 3 (2): MDEIMPY, LVLDIFMEM2, # Category 4 (4): YUSUITHK, YUSUITHKYR, YUSUIPLNYR, YUSUIPLN ): # 1) Validate if not validate_inputs( YMDESUD5ANYO, YMDELT, YMDEYR, YMDERSUD5ANY, YMSUD5YANY, YMIUD5YANY, YMIMS5YANY, YMIMI5YANY, YMDEHPO, YMDETXRX, YMDEHARX, YMDEHPRX, YRXMDEYR, YHLTMDE, YTXMDEYR, YDOCMDE, YPSY2MDE, YPSY1MDE, YCOUNMDE, MDEIMPY, LVLDIFMEM2, YUSUITHK, YUSUITHKYR, YUSUIPLNYR, YUSUIPLN ): return ( "Please select all required fields.", # 1) Prediction Results "Validation Error", # 2) Severity "No data", # 3) Total Count "No nearest neighbors info", # 4) NN Summary None, # 5) Bar chart (Input) None # 6) Bar chart (Labels) ) # 2) Convert text -> numeric try: user_input_dict = { 'YMDESUD5ANYO': input_mapping['YMDESUD5ANYO'][YMDESUD5ANYO], 'YMDELT': input_mapping['YMDELT'][YMDELT], 'YMDEYR': input_mapping['YMDEYR'][YMDEYR], 'YMDERSUD5ANY': input_mapping['YMDERSUD5ANY'][YMDERSUD5ANY], 'YMSUD5YANY': input_mapping['YMSUD5YANY'][YMSUD5YANY], 'YMIUD5YANY': input_mapping['YMIUD5YANY'][YMIUD5YANY], 'YMIMS5YANY': input_mapping['YMIMS5YANY'][YMIMS5YANY], 'YMIMI5YANY': input_mapping['YMIMI5YANY'][YMIMI5YANY], 'YMDEHPO': input_mapping['YMDEHPO'][YMDEHPO], 'YMDETXRX': input_mapping['YMDETXRX'][YMDETXRX], 'YMDEHARX': input_mapping['YMDEHARX'][YMDEHARX], 'YMDEHPRX': input_mapping['YMDEHPRX'][YMDEHPRX], 'YRXMDEYR': input_mapping['YRXMDEYR'][YRXMDEYR], 'YHLTMDE': input_mapping['YHLTMDE'][YHLTMDE], 'YTXMDEYR': input_mapping['YTXMDEYR'][YTXMDEYR], 'YDOCMDE': input_mapping['YDOCMDE'][YDOCMDE], 'YPSY2MDE': input_mapping['YPSY2MDE'][YPSY2MDE], 'YPSY1MDE': input_mapping['YPSY1MDE'][YPSY1MDE], 'YCOUNMDE': input_mapping['YCOUNMDE'][YCOUNMDE], 'MDEIMPY': input_mapping['MDEIMPY'][MDEIMPY], 'LVLDIFMEM2': input_mapping['LVLDIFMEM2'][LVLDIFMEM2], 'YUSUITHK': input_mapping['YUSUITHK'][YUSUITHK], 'YUSUITHKYR': input_mapping['YUSUITHKYR'][YUSUITHKYR], 'YUSUIPLNYR': input_mapping['YUSUIPLNYR'][YUSUIPLNYR], 'YUSUIPLN': input_mapping['YUSUIPLN'][YUSUIPLN] } except KeyError as e: missing_key = e.args[0] return ( f"Input mapping missing for key: {missing_key}. Please check your `input_mapping` dictionary.", "Mapping Error", "No data", "No nearest neighbors info", None, None ) user_df = pd.DataFrame(user_input_dict, index=[0]) # 3) Make predictions try: preds, probs = predictor.make_predictions(user_df) except Exception as e: return ( f"Error during prediction: {e}", "Prediction Error", "No data", "No nearest neighbors info", None, None ) # Flatten predictions into a single array all_preds = np.concatenate(preds) # ===================================== # Count how many are "1" (the 'Yes' or # more severe category in your new mapping) # ===================================== count_ones = np.sum(all_preds == 1) # Evaluate severity using count_ones severity_msg = predictor.evaluate_severity(count_ones) # 4) Summarize predictions (with probabilities) # Build label -> (pred_value, prob_value) label_prediction_info = {} for i, fname in enumerate(model_filenames): lbl_col = fname.split('.')[0] pred_val = preds[i][0] # e.g. 1 or 2 prob_val = probs[i][0] # probability for class=2 (?) label_prediction_info[lbl_col] = (pred_val, prob_val) # Group them by domain domain_groups = { "Concentration and Decision Making": ["YOWRCONC", "YOWRDCSN"], "Sleep and Energy Levels": ["YO_MDEA5", "YOSEEDOC"], "Mood and Emotional State": [ "YOWRLSIN", "YOWRDEPR", "YODPDISC", "YOLOSEV", "YODPLSIN" ], "Appetite and Weight Changes": ["YO_MDEA3", "YOWRELES"], "Duration and Severity of Depression Symptoms": [ "YODPPROB", "YOWRPROB", "YODPR2WK", "YODSMMDE", "YOPB2WK" ] } final_str_parts = [] for gname, lbls in domain_groups.items(): group_lines = [] for lbl in lbls: if lbl in label_prediction_info: pred_val, prob_val = label_prediction_info[lbl] # If pred_val is 1 or 2, we have a mapping if lbl in predictor.prediction_map and pred_val in [1, 2]: text_pred = predictor.prediction_map[lbl][pred_val] else: text_pred = f"Prediction={pred_val}" if not np.isnan(prob_val): text_prob = f"(Prob= {prob_val:.2f})" else: text_prob = "(No probability available)" group_lines.append(f"{lbl} => {text_pred} {text_prob}") if group_lines: final_str_parts.append(f"**{gname}**") final_str_parts.append("\n".join(group_lines)) final_str_parts.append("") # spacing if final_str_parts: final_str = "\n".join(final_str_parts) else: final_str = "No predictions made or no matching group columns." # 5) Additional info total_count_md = f"We have **{len(df)}** patients in the dataset." # 6) Nearest Neighbors nn_md = get_nearest_neighbors_info(user_df, k=5) # 7) Bar chart for input features input_counts = {} for col, val_ in user_input_dict.items(): matched = len(df[df[col] == val_]) input_counts[col] = matched bar_in_df = pd.DataFrame({ "Feature": list(input_counts.keys()), "Count": list(input_counts.values()) }) fig_in = px.bar( bar_in_df, x="Feature", y="Count", title="Number of Patients with the Same Input Feature Values" ) fig_in.update_layout(width=1200, height=400) # 8) Bar chart for predicted labels label_counts = {} for lbl_col, (pred_val, _) in label_prediction_info.items(): if lbl_col in df.columns: label_counts[lbl_col] = len(df[df[lbl_col] == pred_val]) if label_counts: bar_lbl_df = pd.DataFrame({ "Label": list(label_counts.keys()), "Count": list(label_counts.values()) }) fig_lbl = px.bar( bar_lbl_df, x="Label", y="Count", title="Number of Patients with the Same Predicted Label" ) fig_lbl.update_layout(width=1200, height=400) else: fig_lbl = px.bar(title="No valid predicted labels to display.") fig_lbl.update_layout(width=1200, height=400) return ( final_str, # 1) Prediction Results severity_msg, # 2) Mental Health Severity total_count_md, # 3) Total Patient Count nn_md, # 4) Nearest Neighbors Summary fig_in, # 5) Bar Chart (input features) fig_lbl # 6) Bar Chart (labels) ) ###################################### # 6) UNIFIED DISTRIBUTION/CO-OCCURRENCE ###################################### def combined_plot(feature_list, label_col): """ If user picks 1 feature => distribution plot. If user picks 2 features => co-occurrence plot. Otherwise => show error or empty plot. This function also maps numeric codes to text using 'input_mapping' and 'predictor.prediction_map' so that the plots display more readable labels. """ if not label_col: return px.bar(title="Please select a label column.") df_copy = df.copy() # Convert numeric codes -> text for features for col, text_to_num_dict in input_mapping.items(): if col in df_copy.columns: num_to_text = {v: k for k, v in text_to_num_dict.items()} df_copy[col] = df_copy[col].map(num_to_text).fillna(df_copy[col]) # Convert label 1/2 -> text if label_col is in predictor.prediction_map if label_col in predictor.prediction_map and label_col in df_copy.columns: # Example: {1: "First meaning", 2: "Second meaning"} map_12 = predictor.prediction_map[label_col] df_copy[label_col] = df_copy[label_col].map(map_12).fillna(df_copy[label_col]) if len(feature_list) == 1: f_ = feature_list[0] if f_ not in df_copy.columns or label_col not in df_copy.columns: return px.bar(title="Selected columns not found in dataset.") grouped = df_copy.groupby([f_, label_col]).size().reset_index(name="count") fig = px.bar( grouped, x=f_, y="count", color=label_col, title=f"Distribution of {f_} vs {label_col} (Mapped)" ) fig.update_layout(width=1200, height=600) return fig elif len(feature_list) == 2: f1, f2 = feature_list if (f1 not in df_copy.columns) or (f2 not in df_copy.columns) or (label_col not in df_copy.columns): return px.bar(title="Selected columns not found in dataset.") grouped = df_copy.groupby([f1, f2, label_col]).size().reset_index(name="count") fig = px.bar( grouped, x=f1, y="count", color=label_col, facet_col=f2, title=f"Co-occurrence: {f1}, {f2} vs {label_col} (Mapped)" ) fig.update_layout(width=1200, height=600) return fig else: return px.bar(title="Please select exactly 1 or 2 features.") ###################################### # 7) BUILD GRADIO UI ###################################### with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo: # ======== TAB 1: Prediction ======== with gr.Tab("Prediction"): gr.Markdown("### Please provide inputs in each category below. All fields are required.") # Category 1: Depression & Substance Use Diagnosis (8 features) gr.Markdown("#### 1. Depression & Substance Use Diagnosis") cat1_col_labels = [ ("YMDESUD5ANYO", "YMDESUD5ANYO: ONLY MDE, ONLY SUD, BOTH, OR NEITHER"), ("YMDELT", "YMDELT: Had major depressive episode in lifetime"), ("YMDEYR", "YMDEYR: Past-year major depressive episode"), ("YMDERSUD5ANY", "YMDERSUD5ANY: MDE or SUD in past year?"), ("YMSUD5YANY", "YMSUD5YANY: Past-year MDE & substance use disorder"), ("YMIUD5YANY", "YMIUD5YANY: Past-year MDE & illicit drug use disorder"), ("YMIMS5YANY", "YMIMS5YANY: Past-year MDE + severe impairment + substance use"), ("YMIMI5YANY", "YMIMI5YANY: Past-year MDE w/ severe impairment & illicit drug use") ] cat1_inputs = [] for col, label_text in cat1_col_labels: cat1_inputs.append( gr.Dropdown( choices=list(input_mapping[col].keys()), label=label_text ) ) # Category 2: Mental Health Treatment & Professional Consultation (11 features) gr.Markdown("#### 2. Mental Health Treatment & Professional Consultation") cat2_col_labels = [ ("YMDEHPO", "YMDEHPO: Saw health prof only for MDE"), ("YMDETXRX", "YMDETXRX: Received treatment/counseling if saw doc/prof for MDE"), ("YMDEHARX", "YMDEHARX: Saw health prof & medication for MDE"), ("YMDEHPRX", "YMDEHPRX: Saw health prof or med for MDE in past year?"), ("YRXMDEYR", "YRXMDEYR: Used medication for MDE in past years"), ("YHLTMDE", "YHLTMDE: Saw/talked to health prof about MDE"), ("YTXMDEYR", "YTXMDEYR: Saw/talked to doc/prof for MDE in past year"), ("YDOCMDE", "YDOCMDE: Saw/talked to general practitioner/family MD"), ("YPSY2MDE", "YPSY2MDE: Saw/talked to psychiatrist"), ("YPSY1MDE", "YPSY1MDE: Saw/talked to psychologist"), ("YCOUNMDE", "YCOUNMDE: Saw/talked to counselor") ] cat2_inputs = [] for col, label_text in cat2_col_labels: cat2_inputs.append( gr.Dropdown( choices=list(input_mapping[col].keys()), label=label_text ) ) # Category 3: Functional & Cognitive Impairment (2 features) gr.Markdown("#### 3. Functional & Cognitive Impairment") cat3_col_labels = [ ("MDEIMPY", "MDEIMPY: MDE with severe role impairment?"), ("LVLDIFMEM2", "LVLDIFMEM2: Difficulty remembering/concentrating") ] cat3_inputs = [] for col, label_text in cat3_col_labels: cat3_inputs.append( gr.Dropdown( choices=list(input_mapping[col].keys()), label=label_text ) ) # Category 4: Suicidal Thoughts & Behaviors (4 features) gr.Markdown("#### 4. Suicidal Thoughts & Behaviors") cat4_col_labels = [ ("YUSUITHK", "YUSUITHK: Thought of killing self (past 12 months)?"), ("YUSUITHKYR", "YUSUITHKYR: Seriously thought about killing self?"), ("YUSUIPLNYR", "YUSUIPLNYR: Made plans to kill self in past years?"), ("YUSUIPLN", "YUSUIPLN: Made plans to kill yourself in past 12 months?") ] cat4_inputs = [] for col, label_text in cat4_col_labels: cat4_inputs.append( gr.Dropdown( choices=list(input_mapping[col].keys()), label=label_text ) ) # Combine all all_inputs = cat1_inputs + cat2_inputs + cat3_inputs + cat4_inputs # Outputs predict_btn = gr.Button("Predict") out_pred_res = gr.Textbox(label="Prediction Results (with Probability)", lines=8) out_sev = gr.Textbox(label="Mental Health Severity", lines=2) out_count = gr.Markdown(label="Total Patient Count") out_nn = gr.Markdown(label="Nearest Neighbors Summary") out_bar_input = gr.Plot(label="Input Feature Counts") out_bar_label = gr.Plot(label="Predicted Label Counts") # Connect predict button predict_btn.click( fn=predict, inputs=all_inputs, outputs=[ out_pred_res, out_sev, out_count, out_nn, out_bar_input, out_bar_label ] ) # ======== TAB 2: Unified Distribution/Co-occurrence ======== with gr.Tab("Distribution/Co-occurrence"): gr.Markdown("### Select 1 or 2 features + 1 label to see a bar chart.") list_of_features = sorted(input_mapping.keys()) list_of_labels = sorted(predictor.prediction_map.keys()) selected_features = gr.CheckboxGroup( choices=list_of_features, label="Select 1 or 2 features" ) label_dd = gr.Dropdown( choices=list_of_labels, label="Label Column (e.g. YOWRCONC, YOSEEDOC, etc.)" ) generate_combined_btn = gr.Button("Generate Plot") combined_output = gr.Plot() generate_combined_btn.click( fn=combined_plot, inputs=[selected_features, label_dd], outputs=combined_output ) # Launch the Gradio app demo.launch()