|
import pickle |
|
import gradio as gr |
|
import numpy as np |
|
import pandas as pd |
|
import plotly.express as px |
|
|
|
|
|
|
|
|
|
df = pd.read_csv("X_train_test_combined_dataset_Filtered_dataset.csv") |
|
|
|
model_filenames = [ |
|
"YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl", |
|
"YODPPROB.pkl", "YOWRPROB.pkl", "YODPR2WK.pkl", "YOWRDEPR.pkl", |
|
"YODPDISC.pkl", "YOLOSEV.pkl", "YOWRDCSN.pkl", "YODSMMDE.pkl", |
|
"YO_MDEA3.pkl", "YODPLSIN.pkl", "YOWRELES.pkl", "YOPB2WK.pkl" |
|
] |
|
model_path = "models/" |
|
|
|
|
|
|
|
|
|
|
|
class ModelPredictor: |
|
def __init__(self, model_path, model_filenames): |
|
self.model_path = model_path |
|
self.model_filenames = model_filenames |
|
self.models = self.load_models() |
|
|
|
self.prediction_map = { |
|
"YOWRCONC": ["Did not have difficulty concentrating", "Had difficulty concentrating"], |
|
"YOSEEDOC": ["Did not feel the need to see a doctor", "Felt the need to see a doctor"], |
|
"YOWRHRS": ["Did not have trouble sleeping", "Had trouble sleeping"], |
|
"YO_MDEA5": ["Others did not notice restlessness/lethargy", "Others noticed restlessness/lethargy"], |
|
"YOWRCHR": ["Did not feel so sad nothing could cheer up", "Felt so sad that nothing could cheer up"], |
|
"YOWRLSIN": ["Did not feel bored / lose interest", "Felt bored / lost interest"], |
|
"YODPPROB": ["No other problems for 2+ weeks", "Had other problems for 2+ weeks"], |
|
"YOWRPROB": ["No 'worst time ever' feeling", "Had 'worst time ever' feeling"], |
|
"YODPR2WK": ["No periods with depressed feelings lasting 2+ weeks", "Had depressed feelings 2+ weeks"], |
|
"YOWRDEPR": ["Did not feel sad/depressed mostly everyday", "Felt sad/depressed mostly everyday"], |
|
"YODPDISC": ["Overall mood not sad/depressed", "Overall mood was sad/depressed"], |
|
"YOLOSEV": ["Did not lose interest", "Lost interest in enjoyable things"], |
|
"YOWRDCSN": ["Was able to make decisions", "Was unable to make decisions"], |
|
"YODSMMDE": ["Never had 2 weeks depression symptoms", "Had 2+ weeks of depression symptoms"], |
|
"YO_MDEA3": ["No changes in appetite/weight", "Had changes in appetite/weight"], |
|
"YODPLSIN": ["Never lost interest / felt bored", "Lost interest/felt bored"], |
|
"YOWRELES": ["Did not eat less than usual", "Ate less than usual"], |
|
"YODSCEV": ["Fewer severe depression symptoms", "More severe depression symptoms"], |
|
"YOPB2WK": ["No uneasy feelings lasting 2+ weeks", "Uneasy feelings lasting 2+ weeks"], |
|
"YO_MDEA2": ["No physical/mental issues (2+ weeks)", "Had physical/mental issues (2+ weeks)"] |
|
} |
|
|
|
def load_models(self): |
|
loaded = [] |
|
for fname in self.model_filenames: |
|
with open(self.model_path + fname, "rb") as f: |
|
model = pickle.load(f) |
|
loaded.append(model) |
|
return loaded |
|
|
|
def make_predictions(self, user_input: pd.DataFrame): |
|
predictions = [] |
|
for model in self.models: |
|
out = model.predict(user_input) |
|
predictions.append(out.flatten()) |
|
return predictions |
|
|
|
def get_majority_vote(self, predictions): |
|
combined = np.concatenate(predictions) |
|
return np.bincount(combined).argmax() |
|
|
|
def evaluate_severity(self, count_ones: int) -> str: |
|
if count_ones >= 13: |
|
return "Mental Health Severity: Severe" |
|
elif count_ones >= 9: |
|
return "Mental Health Severity: Moderate" |
|
elif count_ones >= 5: |
|
return "Mental Health Severity: Low" |
|
else: |
|
return "Mental Health Severity: Very Low" |
|
|
|
|
|
predictor = ModelPredictor(model_path, model_filenames) |
|
|
|
|
|
|
|
|
|
|
|
categories_dict = { |
|
"1. Depression & Substance Use Diagnosis": [ |
|
"YMDESUD5ANYO", "YMDELT", "YMDEYR", "YMDERSUD5ANY", |
|
"YMSUD5YANY", "YMIUD5YANY", "YMIMS5YANY", "YMIMI5YANY" |
|
], |
|
"2. Mental Health Treatment & Prof Consultation": [ |
|
"YMDEHPO", "YMDETXRX", "YMDEHARX", "YRXMDEYR", "YHLTMDE", |
|
"YTXMDEYR", "YDOCMDE", "YPSY2MDE", "YPSY1MDE", "YCOUNMDE" |
|
], |
|
"3. Functional & Cognitive Impairment": [ |
|
"MDEIMPY", "LVLDIFMEM2" |
|
], |
|
"4. Suicidal Thoughts & Behaviors": [ |
|
"YUSUITHK", "YUSUITHKYR", "YUSUIPLNYR", "YUSUIPLN" |
|
] |
|
} |
|
|
|
|
|
input_mapping = { |
|
'YMDESUD5ANYO': {"SUD only, no MDE": 1, "MDE only, no SUD": 2, "SUD and MDE": 3, "Neither SUD or MDE": 4}, |
|
'YMDELT': {"Yes": 1, "No": 2}, |
|
'YMDEYR': {"Yes": 1, "No": 2}, |
|
'YMDERSUD5ANY': {"Yes": 1, "No": 0}, |
|
'YMSUD5YANY': {"Yes": 1, "No": 0}, |
|
'YMIUD5YANY': {"Yes": 1, "No": 0}, |
|
'YMIMS5YANY': {"Yes": 1, "No": 0}, |
|
'YMIMI5YANY': {"Yes": 1, "No": 0}, |
|
|
|
'YMDEHPO': {"Yes": 1, "No": 0}, |
|
'YMDETXRX': {"Yes": 1, "No": 0}, |
|
'YMDEHARX': {"Yes": 1, "No": 0}, |
|
'YRXMDEYR': {"Yes": 1, "No": 0}, |
|
'YHLTMDE': {"Yes": 1, "No": 0}, |
|
'YTXMDEYR': {"Yes": 1, "No": 0}, |
|
'YDOCMDE': {"Yes": 1, "No": 0}, |
|
'YPSY2MDE': {"Yes": 1, "No": 0}, |
|
'YPSY1MDE': {"Yes": 1, "No": 0}, |
|
'YCOUNMDE': {"Yes": 1, "No": 0}, |
|
|
|
'MDEIMPY': {"Yes": 1, "No": 2}, |
|
'LVLDIFMEM2': {"No Difficulty": 1, "Some difficulty": 2, "A lot of difficulty or cannot do at all": 3}, |
|
|
|
'YUSUITHK': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4}, |
|
'YUSUITHKYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4}, |
|
'YUSUIPLNYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4}, |
|
'YUSUIPLN': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4}, |
|
} |
|
|
|
|
|
def validate_inputs(*args): |
|
for arg in args: |
|
if not arg: |
|
return False |
|
return True |
|
|
|
|
|
|
|
|
|
|
|
def get_nearest_neighbors_info(user_input_df: pd.DataFrame, k=5): |
|
|
|
user_cols = user_input_df.columns |
|
if not all(col in df.columns for col in user_cols): |
|
return "Cannot compute nearest neighbors. Some columns not found in df." |
|
|
|
|
|
sub_df = df[list(user_cols)].copy() |
|
diffs = sub_df - user_input_df.iloc[0] |
|
dists = (diffs**2).sum(axis=1)**0.5 |
|
nn_indices = dists.nsmallest(k).index |
|
neighbors = df.loc[nn_indices] |
|
|
|
lines = [f"**Nearest Neighbors (k={k})**", |
|
f"Distances Range: {dists[nn_indices].min():.2f} to {dists[nn_indices].max():.2f}", |
|
""] |
|
|
|
|
|
for cat_name, cat_feats in categories_dict.items(): |
|
lines.append(f"### {cat_name}") |
|
for feat in cat_feats: |
|
if feat not in neighbors.columns: |
|
continue |
|
|
|
val_counts = neighbors[feat].value_counts().to_dict() |
|
|
|
parts = [] |
|
for val_, count_ in val_counts.items(): |
|
parts.append(f"{count_} had '{val_}'") |
|
joined = "; ".join(parts) |
|
lines.append(f"**{feat}** => {joined}") |
|
lines.append("") |
|
|
|
return "\n".join(lines) |
|
|
|
|
|
|
|
|
|
|
|
def predict( |
|
|
|
|
|
YMDESUD5ANYO, YMDELT, YMDEYR, YMDERSUD5ANY, |
|
YMSUD5YANY, YMIUD5YANY, YMIMS5YANY, YMIMI5YANY, |
|
|
|
YMDEHPO, YMDETXRX, YMDEHARX, YRXMDEYR, YHLTMDE, |
|
YTXMDEYR, YDOCMDE, YPSY2MDE, YPSY1MDE, YCOUNMDE, |
|
|
|
MDEIMPY, LVLDIFMEM2, |
|
|
|
YUSUITHK, YUSUITHKYR, YUSUIPLNYR, YUSUIPLN |
|
): |
|
if not validate_inputs( |
|
YMDESUD5ANYO, YMDELT, YMDEYR, YMDERSUD5ANY, |
|
YMSUD5YANY, YMIUD5YANY, YMIMS5YANY, YMIMI5YANY, |
|
YMDEHPO, YMDETXRX, YMDEHARX, YRXMDEYR, YHLTMDE, |
|
YTXMDEYR, YDOCMDE, YPSY2MDE, YPSY1MDE, YCOUNMDE, |
|
MDEIMPY, LVLDIFMEM2, |
|
YUSUITHK, YUSUITHKYR, YUSUIPLNYR, YUSUIPLN |
|
): |
|
return ( |
|
"Please select all required fields.", |
|
"Validation Error", |
|
"No data", |
|
"No nearest neighbors info", |
|
None, |
|
None |
|
) |
|
|
|
|
|
user_input_dict = { |
|
'YMDESUD5ANYO': input_mapping['YMDESUD5ANYO'][YMDESUD5ANYO], |
|
'YMDELT': input_mapping['YMDELT'][YMDELT], |
|
'YMDEYR': input_mapping['YMDEYR'][YMDEYR], |
|
'YMDERSUD5ANY': input_mapping['YMDERSUD5ANY'][YMDERSUD5ANY], |
|
'YMSUD5YANY': input_mapping['YMSUD5YANY'][YMSUD5YANY], |
|
'YMIUD5YANY': input_mapping['YMIUD5YANY'][YMIUD5YANY], |
|
'YMIMS5YANY': input_mapping['YMIMS5YANY'][YMIMS5YANY], |
|
'YMIMI5YANY': input_mapping['YMIMI5YANY'][YMIMI5YANY], |
|
|
|
'YMDEHPO': input_mapping['YMDEHPO'][YMDEHPO], |
|
'YMDETXRX': input_mapping['YMDETXRX'][YMDETXRX], |
|
'YMDEHARX': input_mapping['YMDEHARX'][YMDEHARX], |
|
'YRXMDEYR': input_mapping['YRXMDEYR'][YRXMDEYR], |
|
'YHLTMDE': input_mapping['YHLTMDE'][YHLTMDE], |
|
'YTXMDEYR': input_mapping['YTXMDEYR'][YTXMDEYR], |
|
'YDOCMDE': input_mapping['YDOCMDE'][YDOCMDE], |
|
'YPSY2MDE': input_mapping['YPSY2MDE'][YPSY2MDE], |
|
'YPSY1MDE': input_mapping['YPSY1MDE'][YPSY1MDE], |
|
'YCOUNMDE': input_mapping['YCOUNMDE'][YCOUNMDE], |
|
|
|
'MDEIMPY': input_mapping['MDEIMPY'][MDEIMPY], |
|
'LVLDIFMEM2': input_mapping['LVLDIFMEM2'][LVLDIFMEM2], |
|
|
|
'YUSUITHK': input_mapping['YUSUITHK'][YUSUITHK], |
|
'YUSUITHKYR': input_mapping['YUSUITHKYR'][YUSUITHKYR], |
|
'YUSUIPLNYR': input_mapping['YUSUIPLNYR'][YUSUIPLNYR], |
|
'YUSUIPLN': input_mapping['YUSUIPLN'][YUSUIPLN] |
|
} |
|
user_df = pd.DataFrame(user_input_dict, index=[0]) |
|
|
|
|
|
predictions = predictor.make_predictions(user_df) |
|
all_preds = np.concatenate(predictions) |
|
count_ones = sum(all_preds == 1) |
|
severity_msg = predictor.evaluate_severity(count_ones) |
|
|
|
|
|
groups = { |
|
"Concentration_and_Decision_Making": ["YOWRCONC", "YOWRDCSN"], |
|
"Sleep_and_Energy_Levels": ["YOWRHRS", "YO_MDEA5", "YOWRELES", "YO_MDEA2"], |
|
"Mood_and_Emotional_State": [ |
|
"YOWRCHR", "YOWRLSIN", "YOWRDEPR", "YODPDISC", "YOLOSEV", "YODPLSIN", "YODSCEV" |
|
], |
|
"Appetite_and_Weight_Changes": ["YO_MDEA3", "YOWRELES"], |
|
"Duration_and_Severity_of_Depression_Symptoms": [ |
|
"YODPPROB", "YOWRPROB", "YODPR2WK", "YODSMMDE", "YOPB2WK" |
|
] |
|
} |
|
group_text = {g: [] for g in groups} |
|
|
|
for i, arr in enumerate(predictions): |
|
label_col = model_filenames[i].split('.')[0] |
|
val = arr[0] |
|
|
|
if label_col in predictor.prediction_map and val in range(len(predictor.prediction_map[label_col])): |
|
text_label = predictor.prediction_map[label_col][val] |
|
else: |
|
text_label = f"Prediction={val}" |
|
|
|
|
|
for group_name, cols_ in groups.items(): |
|
if label_col in cols_: |
|
group_text[group_name].append(f"{label_col} => {text_label}") |
|
break |
|
|
|
final_str_parts = [] |
|
for gname, lines in group_text.items(): |
|
if lines: |
|
gtitle = gname.replace("_", " ") |
|
final_str_parts.append(f"**{gtitle}**") |
|
final_str_parts.append("\n".join(lines)) |
|
final_str_parts.append("") |
|
if not final_str_parts: |
|
final_str = "No predictions made or no matching group columns." |
|
else: |
|
final_str = "\n".join(final_str_parts) |
|
|
|
|
|
total_count = len(df) |
|
total_count_md = f"We have **{total_count}** patients in the dataset." |
|
|
|
|
|
nn_md = get_nearest_neighbors_info(user_df, k=5) |
|
|
|
|
|
input_counts = {} |
|
for col, val_ in user_input_dict.items(): |
|
matched = len(df[df[col] == val_]) |
|
input_counts[col] = matched |
|
bar_in_df = pd.DataFrame({"Feature": list(input_counts.keys()), |
|
"Count": list(input_counts.values())}) |
|
fig_in = px.bar( |
|
bar_in_df, x="Feature", y="Count", |
|
title="Number of Patients with the Same Input Feature Values" |
|
) |
|
fig_in.update_layout(width=1200, height=400) |
|
|
|
|
|
label_counts = {} |
|
for i, arr in enumerate(predictions): |
|
lbl = model_filenames[i].split('.')[0] |
|
pred_val = arr[0] |
|
if lbl in df.columns: |
|
label_counts[lbl] = len(df[df[lbl] == pred_val]) |
|
if label_counts: |
|
bar_lbl_df = pd.DataFrame({ |
|
"Label": list(label_counts.keys()), |
|
"Count": list(label_counts.values()) |
|
}) |
|
fig_lbl = px.bar( |
|
bar_lbl_df, x="Label", y="Count", |
|
title="Number of Patients with the Same Predicted Label" |
|
) |
|
fig_lbl.update_layout(width=1200, height=400) |
|
else: |
|
fig_lbl = px.bar(title="No valid predicted labels to display.") |
|
fig_lbl.update_layout(width=1200, height=400) |
|
|
|
return ( |
|
final_str, |
|
severity_msg, |
|
total_count_md, |
|
nn_md, |
|
fig_in, |
|
fig_lbl |
|
) |
|
|
|
|
|
|
|
|
|
|
|
def distribution_plot(feature_col, label_col): |
|
if not feature_col or not label_col: |
|
return px.bar(title="Please select both Feature and Label.") |
|
if (feature_col not in df.columns) or (label_col not in df.columns): |
|
return px.bar(title="Selected columns not found in the dataset.") |
|
|
|
grouped = df.groupby([feature_col, label_col]).size().reset_index(name="count") |
|
fig = px.bar( |
|
grouped, |
|
x=feature_col, |
|
y="count", |
|
color=label_col, |
|
title=f"Distribution of {feature_col} vs {label_col}" |
|
) |
|
fig.update_layout(width=1200, height=600) |
|
return fig |
|
|
|
|
|
def co_occurrence_plot(feature1, feature2, label_col): |
|
if not feature1 or not feature2 or not label_col: |
|
return px.bar(title="Please select all three fields.") |
|
if feature1 not in df.columns or feature2 not in df.columns or label_col not in df.columns: |
|
return px.bar(title="Selected columns not found in the dataset.") |
|
|
|
grouped = df.groupby([feature1, feature2, label_col]).size().reset_index(name="count") |
|
fig = px.bar( |
|
grouped, |
|
x=feature1, |
|
y="count", |
|
color=label_col, |
|
facet_col=feature2, |
|
title=f"Co-occurrence: {feature1}, {feature2} vs {label_col}" |
|
) |
|
fig.update_layout(width=1200, height=600) |
|
return fig |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo: |
|
|
|
|
|
with gr.Tab("Prediction"): |
|
gr.Markdown( |
|
""" |
|
### Please provide inputs in each of the four categories below. |
|
*All fields are required.* |
|
""" |
|
) |
|
|
|
|
|
|
|
cat1_col_labels = [ |
|
("YMDESUD5ANYO", "YMDESUD5ANYO: ONLY MDE, ONLY SUD, BOTH, OR NEITHER"), |
|
("YMDELT", "YMDELT: Had major depressive episode in lifetime"), |
|
("YMDEYR", "YMDEYR: Past-year major depressive episode"), |
|
("YMDERSUD5ANY", "YMDERSUD5ANY: MDE or substance use disorder - past year"), |
|
("YMSUD5YANY", "YMSUD5YANY: Past-year MDE & substance use disorder"), |
|
("YMIUD5YANY", "YMIUD5YANY: Past-year MDE & illicit drug use disorder"), |
|
("YMIMS5YANY", "YMIMS5YANY: Past-year MDE + severe impairment + substance use"), |
|
("YMIMI5YANY", "YMIMI5YANY: Past-year MDE with severe impairment & illicit drug use") |
|
] |
|
cat2_col_labels = [ |
|
("YMDEHPO", "YMDEHPO: Saw health prof only for MDE in past year"), |
|
("YMDETXRX", "YMDETXRX: Received treatment/counseling if saw doc/prof for MDE"), |
|
("YMDEHARX", "YMDEHARX: Saw health professional & received medication for MDE"), |
|
("YRXMDEYR", "YRXMDEYR: Used received medication for MDE in past years"), |
|
("YHLTMDE", "YHLTMDE: Saw/talked to health professional about MDE in past year"), |
|
("YTXMDEYR", "YTXMDEYR: Saw or talked to doc/health prof for MDE in past year"), |
|
("YDOCMDE", "YDOCMDE: Saw/talked to general practitioner/family MD about MDE"), |
|
("YPSY2MDE", "YPSY2MDE: Saw/talked to psychiatrist about MDE"), |
|
("YPSY1MDE", "YPSY1MDE: Saw/talked to psychologist about MDE"), |
|
("YCOUNMDE", "YCOUNMDE: Saw/talked to counselor about MDE") |
|
] |
|
cat3_col_labels = [ |
|
("MDEIMPY", "MDEIMPY: MDE with severe role impairment"), |
|
("LVLDIFMEM2", "LVLDIFMEM2: Difficulty remembering/concentrating") |
|
] |
|
cat4_col_labels = [ |
|
("YUSUITHK", "YUSUITHK: Youth seriously think about killing self in past 12 months"), |
|
("YUSUITHKYR", "YUSUITHKYR: Seriously thought about killing self"), |
|
("YUSUIPLNYR", "YUSUIPLNYR: Made plans to kill self in past year"), |
|
("YUSUIPLN", "YUSUIPLN: Made plans to kill yourself in past 12 months") |
|
] |
|
|
|
|
|
gr.Markdown("#### 1. Depression & Substance Use Diagnosis") |
|
cat1_inputs = [] |
|
for col, label_text in cat1_col_labels: |
|
dd = gr.Dropdown( |
|
choices=list(input_mapping[col].keys()), |
|
label=label_text |
|
) |
|
cat1_inputs.append(dd) |
|
|
|
|
|
gr.Markdown("#### 2. Mental Health Treatment & Professional Consultation") |
|
cat2_inputs = [] |
|
for col, label_text in cat2_col_labels: |
|
dd = gr.Dropdown( |
|
choices=list(input_mapping[col].keys()), |
|
label=label_text |
|
) |
|
cat2_inputs.append(dd) |
|
|
|
|
|
gr.Markdown("#### 3. Functional & Cognitive Impairment") |
|
cat3_inputs = [] |
|
for col, label_text in cat3_col_labels: |
|
dd = gr.Dropdown( |
|
choices=list(input_mapping[col].keys()), |
|
label=label_text |
|
) |
|
cat3_inputs.append(dd) |
|
|
|
|
|
gr.Markdown("#### 4. Suicidal Thoughts & Behaviors") |
|
cat4_inputs = [] |
|
for col, label_text in cat4_col_labels: |
|
dd = gr.Dropdown( |
|
choices=list(input_mapping[col].keys()), |
|
label=label_text |
|
) |
|
cat4_inputs.append(dd) |
|
|
|
|
|
all_inputs = cat1_inputs + cat2_inputs + cat3_inputs + cat4_inputs |
|
|
|
predict_btn = gr.Button("Predict") |
|
|
|
|
|
out_pred_res = gr.Textbox(label="Prediction Results", lines=8) |
|
out_sev = gr.Textbox(label="Mental Health Severity", lines=2) |
|
out_count = gr.Markdown(label="Total Patient Count") |
|
out_nn = gr.Markdown(label="Nearest Neighbors Summary (Grouped by Category)") |
|
out_bar_input= gr.Plot(label="Input Feature Counts") |
|
out_bar_label= gr.Plot(label="Predicted Label Counts") |
|
|
|
predict_btn.click( |
|
fn=predict, |
|
inputs=all_inputs, |
|
outputs=[ |
|
out_pred_res, |
|
out_sev, |
|
out_count, |
|
out_nn, |
|
out_bar_input, |
|
out_bar_label |
|
] |
|
) |
|
|
|
|
|
with gr.Tab("Distribution Analysis"): |
|
gr.Markdown("## Distribution Plot\nSelect one feature and one label column to see bar counts.") |
|
list_of_features = sorted(input_mapping.keys()) |
|
list_of_labels = sorted(predictor.prediction_map.keys()) |
|
|
|
feat_dd = gr.Dropdown(choices=list_of_features, label="Feature Column") |
|
lbl_dd = gr.Dropdown(choices=list_of_labels, label="Label Column") |
|
|
|
generate_dist_btn = gr.Button("Generate Distribution Plot") |
|
dist_output = gr.Plot() |
|
|
|
generate_dist_btn.click( |
|
fn=distribution_plot, |
|
inputs=[feat_dd, lbl_dd], |
|
outputs=dist_output |
|
) |
|
|
|
|
|
with gr.Tab("Co-occurrence"): |
|
gr.Markdown("## Co-Occurrence Plot\nSelect two features + one label to see a 3-way distribution.") |
|
|
|
feat1_dd = gr.Dropdown(choices=list_of_features, label="Feature 1") |
|
feat2_dd = gr.Dropdown(choices=list_of_features, label="Feature 2") |
|
label_dd = gr.Dropdown(choices=list_of_labels, label="Label Column") |
|
|
|
generate_btn = gr.Button("Generate Co-occurrence Plot") |
|
co_occ_output = gr.Plot() |
|
|
|
generate_btn.click( |
|
fn=co_occurrence_plot, |
|
inputs=[feat1_dd, feat2_dd, label_dd], |
|
outputs=co_occ_output |
|
) |
|
|
|
|
|
demo.launch() |
|
|