|
import pickle |
|
|
|
import gradio as gr |
|
import numpy as np |
|
import pandas as pd |
|
import plotly.express as px |
|
|
|
|
|
df = pd.read_csv("X_train_Y_Train_merged_train.csv") |
|
|
|
|
|
|
|
|
|
class ModelPredictor: |
|
def __init__(self, model_path, model_filenames): |
|
self.model_path = model_path |
|
self.model_filenames = model_filenames |
|
self.models = self.load_models() |
|
|
|
|
|
|
|
self.prediction_map = { |
|
"YOWRCONC": ["Did not have difficulty concentrating", "Had difficulty concentrating"], |
|
"YOSEEDOC": ["Did not feel the need to see a doctor", "Felt the need to see a doctor"], |
|
"YOWRHRS": ["Did not have trouble sleeping", "Had trouble sleeping"], |
|
"YO_MDEA5": ["Others did not notice restlessness/lethargy", "Others noticed restlessness/lethargy"], |
|
"YOWRCHR": ["Did not feel so sad that nothing could cheer up", "Felt so sad that nothing could cheer up"], |
|
"YOWRLSIN": ["Did not feel bored and lose interest in all enjoyable things", |
|
"Felt bored and lost interest in all enjoyable things"], |
|
"YODPPROB": ["Did not have other problems for 2+ weeks", "Had other problems for 2+ weeks"], |
|
"YOWRPROB": ["Did not have the worst time ever feeling", "Had the worst time ever feeling"], |
|
"YODPR2WK": ["Did not have periods where feelings lasted 2+ weeks", |
|
"Had periods where feelings lasted 2+ weeks"], |
|
"YOWRDEPR": ["Did not feel sad/depressed mostly everyday", "Felt sad/depressed mostly everyday"], |
|
"YODPDISC": ["Overall mood duration was not sad/depressed", |
|
"Overall mood duration was sad/depressed (discrepancy)"], |
|
"YOLOSEV": ["Did not lose interest in enjoyable things and activities", |
|
"Lost interest in enjoyable things and activities"], |
|
"YOWRDCSN": ["Was able to make decisions", "Was unable to make decisions"], |
|
"YODSMMDE": ["Never had depression symptoms lasting 2 weeks or longer", |
|
"Had depression symptoms lasting 2 weeks or longer"], |
|
"YO_MDEA3": ["Did not experience changes in appetite or weight", |
|
"Experienced changes in appetite or weight"], |
|
"YODPLSIN": ["Never lost interest and felt bored", "Lost interest and felt bored"], |
|
"YOWRELES": ["Did not eat less than usual", "Ate less than usual"], |
|
"YODSCEV": ["Had fewer severe symptoms of depression", "Had more severe symptoms of depression"], |
|
"YOPB2WK": ["Did not experience uneasy feelings lasting every day for 2+ weeks or longer", |
|
"Experienced uneasy feelings lasting every day for 2+ weeks or longer"], |
|
"YO_MDEA2": ["Did not have issues with physical and mental well-being every day for 2 weeks or longer", |
|
"Had issues with physical and mental well-being every day for 2 weeks or longer"] |
|
} |
|
|
|
def load_models(self): |
|
models = [] |
|
for filename in self.model_filenames: |
|
filepath = self.model_path + filename |
|
with open(filepath, 'rb') as file: |
|
model = pickle.load(file) |
|
models.append(model) |
|
return models |
|
|
|
def make_predictions(self, user_input): |
|
""" |
|
Returns a list of numpy arrays, each array is [0] or [1]. |
|
The i-th array corresponds to the i-th model in self.models. |
|
""" |
|
predictions = [] |
|
for model in self.models: |
|
pred = model.predict(user_input) |
|
pred = np.array(pred).flatten() |
|
predictions.append(pred) |
|
return predictions |
|
|
|
def get_majority_vote(self, predictions): |
|
""" |
|
Flatten all predictions from all models, combine them into a single array, |
|
then find the majority class (0 or 1) across all of them. |
|
""" |
|
combined_predictions = np.concatenate(predictions) |
|
majority_vote = np.bincount(combined_predictions).argmax() |
|
return majority_vote |
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_severity(self, majority_vote_count): |
|
if majority_vote_count >= 13: |
|
return "Mental health severity: Severe" |
|
elif majority_vote_count >= 9: |
|
return "Mental health severity: Moderate" |
|
elif majority_vote_count >= 5: |
|
return "Mental health severity: Low" |
|
else: |
|
return "Mental health severity: Very Low" |
|
|
|
|
|
|
|
|
|
model_filenames = [ |
|
"YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl", |
|
"YODPPROB.pkl", "YOWRPROB.pkl", "YODPR2WK.pkl", "YOWRDEPR.pkl", |
|
"YODPDISC.pkl", "YOLOSEV.pkl", "YOWRDCSN.pkl", "YODSMMDE.pkl", |
|
"YO_MDEA3.pkl", "YODPLSIN.pkl", "YOWRELES.pkl", "YOPB2WK.pkl" |
|
] |
|
model_path = "models/" |
|
predictor = ModelPredictor(model_path, model_filenames) |
|
|
|
|
|
|
|
|
|
def validate_inputs(*args): |
|
for arg in args: |
|
if arg == '' or arg is None: |
|
return False |
|
return True |
|
|
|
|
|
|
|
|
|
|
|
input_mapping = { |
|
'YNURSMDE': {"Yes": 1, "No": 0}, |
|
'YMDEYR': {"Yes": 1, "No": 2}, |
|
'YSOCMDE': {"Yes": 1, "No": 0}, |
|
'YMDESUD5ANYO': {"SUD only, no MDE": 1, "MDE only, no SUD": 2, "SUD and MDE": 3, "Neither SUD or MDE": 4}, |
|
'YMSUD5YANY': {"Yes": 1, "No": 0}, |
|
'YUSUITHK': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4}, |
|
'YMDETXRX': {"Yes": 1, "No": 0}, |
|
'YUSUITHKYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4}, |
|
'YMDERSUD5ANY': {"Yes": 1, "No": 0}, |
|
'YUSUIPLNYR': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4}, |
|
'YCOUNMDE': {"Yes": 1, "No": 0}, |
|
'YPSY1MDE': {"Yes": 1, "No": 0}, |
|
'YHLTMDE': {"Yes": 1, "No": 0}, |
|
'YDOCMDE': {"Yes": 1, "No": 0}, |
|
'YPSY2MDE': {"Yes": 1, "No": 0}, |
|
'YMDEHARX': {"Yes": 1, "No": 0}, |
|
'LVLDIFMEM2': {"No Difficulty": 1, "Some difficulty": 2, "A lot of difficulty or cannot do at all": 3}, |
|
'MDEIMPY': {"Yes": 1, "No": 2}, |
|
'YMDEHPO': {"Yes": 1, "No": 0}, |
|
'YMIMS5YANY': {"Yes": 1, "No": 0}, |
|
'YMDEIMAD5YR': {"Yes": 1, "No": 0}, |
|
'YMIUD5YANY': {"Yes": 1, "No": 0}, |
|
'YMDEHPRX': {"Yes": 1, "No": 0}, |
|
'YMIMI5YANY': {"Yes": 1, "No": 0}, |
|
'YUSUIPLN': {"Yes": 1, "No": 2, "I'm not sure": 3, "I don't want to answer": 4}, |
|
'YTXMDEYR': {"Yes": 1, "No": 0}, |
|
'YMDEAUD5YR': {"Yes": 1, "No": 0}, |
|
'YRXMDEYR': {"Yes": 1, "No": 0}, |
|
'YMDELT': {"Yes": 1, "No": 2} |
|
} |
|
|
|
|
|
reverse_mapping = {} |
|
for col, mapping_dict in input_mapping.items(): |
|
rev = {v: k for k, v in mapping_dict.items()} |
|
reverse_mapping[col] = rev |
|
|
|
|
|
|
|
|
|
def predict( |
|
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX, |
|
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY, |
|
YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE, |
|
YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK, |
|
YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR |
|
): |
|
""" |
|
Core prediction function that: |
|
1) Predicts with each model |
|
2) Aggregates results |
|
3) Produces an overall 'severity' |
|
4) Returns detailed per-model predictions |
|
5) Creates a distribution plot for ALL input features vs. a chosen label |
|
6) Nearest neighbor logic (with disclaimers), mapping numeric -> user text |
|
""" |
|
|
|
|
|
user_input_data = { |
|
'YNURSMDE': [int(YNURSMDE)], |
|
'YMDEYR': [int(YMDEYR)], |
|
'YSOCMDE': [int(YSOCMDE)], |
|
'YMDESUD5ANYO': [int(YMDESUD5ANYO)], |
|
'YMSUD5YANY': [int(YMSUD5YANY)], |
|
'YUSUITHK': [int(YUSUITHK)], |
|
'YMDETXRX': [int(YMDETXRX)], |
|
'YUSUITHKYR': [int(YUSUITHKYR)], |
|
'YMDERSUD5ANY': [int(YMDERSUD5ANY)], |
|
'YUSUIPLNYR': [int(YUSUIPLNYR)], |
|
'YCOUNMDE': [int(YCOUNMDE)], |
|
'YPSY1MDE': [int(YPSY1MDE)], |
|
'YHLTMDE': [int(YHLTMDE)], |
|
'YDOCMDE': [int(YDOCMDE)], |
|
'YPSY2MDE': [int(YPSY2MDE)], |
|
'YMDEHARX': [int(YMDEHARX)], |
|
'LVLDIFMEM2': [int(LVLDIFMEM2)], |
|
'MDEIMPY': [int(MDEIMPY)], |
|
'YMDEHPO': [int(YMDEHPO)], |
|
'YMIMS5YANY': [int(YMIMS5YANY)], |
|
'YMDEIMAD5YR': [int(YMDEIMAD5YR)], |
|
'YMIUD5YANY': [int(YMIUD5YANY)], |
|
'YMDEHPRX': [int(YMDEHPRX)], |
|
'YMIMI5YANY': [int(YMIMI5YANY)], |
|
'YUSUIPLN': [int(YUSUIPLN)], |
|
'YTXMDEYR': [int(YTXMDEYR)], |
|
'YMDEAUD5YR': [int(YMDEAUD5YR)], |
|
'YRXMDEYR': [int(YRXMDEYR)], |
|
'YMDELT': [int(YMDELT)] |
|
} |
|
user_input = pd.DataFrame(user_input_data) |
|
|
|
|
|
predictions = predictor.make_predictions(user_input) |
|
|
|
|
|
majority_vote = predictor.get_majority_vote(predictions) |
|
|
|
|
|
majority_vote_count = sum([1 for pred in np.concatenate(predictions) if pred == 1]) |
|
|
|
|
|
severity = predictor.evaluate_severity(majority_vote_count) |
|
|
|
|
|
|
|
results = { |
|
"Concentration_and_Decision_Making": [], |
|
"Sleep_and_Energy_Levels": [], |
|
"Mood_and_Emotional_State": [], |
|
"Appetite_and_Weight_Changes": [], |
|
"Duration_and_Severity_of_Depression_Symptoms": [] |
|
} |
|
|
|
prediction_groups = { |
|
"Concentration_and_Decision_Making": ["YOWRCONC", "YOWRDCSN"], |
|
"Sleep_and_Energy_Levels": ["YOWRHRS", "YO_MDEA5", "YOWRELES", "YO_MDEA2"], |
|
"Mood_and_Emotional_State": ["YOWRCHR", "YOWRLSIN", "YOWRDEPR", "YODPDISC", |
|
"YOLOSEV", "YODPLSIN", "YODSCEV"], |
|
"Appetite_and_Weight_Changes": ["YO_MDEA3", "YOWRELES"], |
|
"Duration_and_Severity_of_Depression_Symptoms": ["YODPPROB", "YOWRPROB", |
|
"YODPR2WK", "YODSMMDE", |
|
"YOPB2WK"] |
|
} |
|
|
|
|
|
for i, pred in enumerate(predictions): |
|
model_name = predictor.model_filenames[i].split('.')[0] |
|
pred_value = pred[0] |
|
|
|
if model_name in predictor.prediction_map and pred_value in [0, 1]: |
|
result_text = f"Model {model_name}: {predictor.prediction_map[model_name][pred_value]}" |
|
else: |
|
result_text = f"Model {model_name}: Unknown or out-of-range" |
|
|
|
|
|
found_group = False |
|
for group_name, group_models in prediction_groups.items(): |
|
if model_name in group_models: |
|
results[group_name].append(result_text) |
|
found_group = True |
|
break |
|
if not found_group: |
|
|
|
pass |
|
|
|
|
|
formatted_results = [] |
|
for group, preds in results.items(): |
|
if preds: |
|
formatted_results.append(f"Group {group.replace('_', ' ')}:") |
|
formatted_results.append("\n".join(preds)) |
|
formatted_results.append("\n") |
|
formatted_results = "\n".join(formatted_results).strip() |
|
if len(formatted_results) == 0: |
|
formatted_results = "No predictions made. Please check your inputs." |
|
|
|
|
|
num_unknown = sum(1 for group, preds in results.items() if any("Unknown or out-of-range" in p for p in preds)) |
|
if num_unknown > len(predictor.model_filenames) / 2: |
|
severity += " (Unknown prediction count is high. Please consult with a human.)" |
|
|
|
|
|
|
|
|
|
total_patients = len(df) |
|
total_patient_count_markdown = ( |
|
"### Total Patient Count\n" |
|
f"There are **{total_patients}** total patients in the dataset.\n\n" |
|
"This number helps you understand the size of the dataset used." |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chosen_label = "YOWRCONC" |
|
if chosen_label in df.columns: |
|
|
|
|
|
input_cols_in_df = [c for c in user_input_data.keys() if c in df.columns] |
|
|
|
|
|
sub_df = df[input_cols_in_df + [chosen_label]].copy() |
|
|
|
melted = sub_df.melt( |
|
id_vars=[chosen_label], |
|
var_name="FeatureName", |
|
value_name="FeatureValue" |
|
) |
|
|
|
dist_data = melted.groupby(["FeatureName", "FeatureValue", chosen_label]).size().reset_index(name="count") |
|
|
|
|
|
def map_value(row): |
|
fn = row["FeatureName"] |
|
fv = row["FeatureValue"] |
|
if fn in reverse_mapping: |
|
if fv in reverse_mapping[fn]: |
|
return reverse_mapping[fn][fv] |
|
return fv |
|
dist_data["FeatureValueText"] = dist_data.apply(map_value, axis=1) |
|
|
|
if chosen_label in predictor.prediction_map: |
|
def map_label(val): |
|
if val in [0, 1]: |
|
return predictor.prediction_map[chosen_label][val] |
|
return f"Unknown label {val}" |
|
dist_data["LabelText"] = dist_data[chosen_label].apply(map_label) |
|
else: |
|
dist_data["LabelText"] = dist_data[chosen_label].astype(str) |
|
|
|
|
|
fig_distribution = px.bar( |
|
dist_data, |
|
x="FeatureValueText", |
|
y="count", |
|
color="LabelText", |
|
facet_col="FeatureName", |
|
facet_col_wrap=4, |
|
title=f"Distribution of All Input Features vs. {chosen_label}", |
|
height=800 |
|
) |
|
fig_distribution.update_layout(legend=dict(title=chosen_label)) |
|
|
|
fig_distribution.update_xaxes(tickangle=45) |
|
else: |
|
|
|
fig_distribution = px.bar(title=f"Label {chosen_label} not found in dataset. Distribution not available.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
features_to_compare = [col for col in user_input_data if col in df.columns] |
|
user_series = user_input.iloc[0] |
|
|
|
|
|
distances = [] |
|
for idx, row in df[features_to_compare].iterrows(): |
|
d = 0 |
|
for col in features_to_compare: |
|
if row[col] != user_series[col]: |
|
d += 1 |
|
distances.append(d) |
|
|
|
df_with_dist = df.copy() |
|
df_with_dist["distance"] = distances |
|
|
|
|
|
K = 5 |
|
nearest_neighbors = df_with_dist.sort_values("distance", ascending=True).head(K) |
|
|
|
|
|
|
|
|
|
if chosen_label in nearest_neighbors.columns: |
|
nn_label_0 = len(nearest_neighbors[nearest_neighbors[chosen_label] == 0]) |
|
nn_label_1 = len(nearest_neighbors[nearest_neighbors[chosen_label] == 1]) |
|
if chosen_label in predictor.prediction_map: |
|
label0_text = predictor.prediction_map[chosen_label][0] |
|
label1_text = predictor.prediction_map[chosen_label][1] |
|
else: |
|
label0_text = "Label=0" |
|
label1_text = "Label=1" |
|
else: |
|
nn_label_0 = nn_label_1 = 0 |
|
label0_text = "Label=0" |
|
label1_text = "Label=1" |
|
|
|
|
|
neighbor_text_rows = [] |
|
for idx, nn_row in nearest_neighbors.iterrows(): |
|
|
|
row_str_parts = [] |
|
row_str_parts.append(f"distance={nn_row['distance']}") |
|
for fcol in features_to_compare: |
|
val = nn_row[fcol] |
|
|
|
if fcol in reverse_mapping and val in reverse_mapping[fcol]: |
|
val_str = reverse_mapping[fcol][val] |
|
else: |
|
val_str = str(val) |
|
row_str_parts.append(f"{fcol}={val_str}") |
|
|
|
if chosen_label in nn_row: |
|
lbl_val = nn_row[chosen_label] |
|
if chosen_label in predictor.prediction_map and lbl_val in [0, 1]: |
|
lbl_str = predictor.prediction_map[chosen_label][lbl_val] |
|
else: |
|
lbl_str = str(lbl_val) |
|
row_str_parts.append(f"{chosen_label}={lbl_str}") |
|
neighbor_text_rows.append(" | ".join(row_str_parts)) |
|
|
|
neighbor_text_block = "\n".join(neighbor_text_rows) |
|
|
|
similar_patient_markdown = ( |
|
"### Nearest Neighbors (Simple Hamming Distance)\n" |
|
"“Nearest neighbor” methods for high-dimensional or purely categorical data can be non-trivial. " |
|
"This demo simply uses a Hamming distance over all input features and picks **K=5** neighbors.\n\n" |
|
"In a real application, you would refine which features are most relevant, how to encode them, " |
|
"and how many neighbors to select.\n\n" |
|
f"Among these **{K}** nearest neighbors:\n" |
|
f"- **{nn_label_0}** had {label0_text}\n" |
|
f"- **{nn_label_1}** had {label1_text}\n\n" |
|
"Below is a breakdown of each neighbor's key features in user-friendly text:\n\n" |
|
f"```\n{neighbor_text_block}\n```" |
|
) |
|
|
|
|
|
|
|
|
|
return ( |
|
formatted_results, |
|
severity, |
|
total_patient_count_markdown, |
|
fig_distribution, |
|
similar_patient_markdown, |
|
None, |
|
None, |
|
None |
|
) |
|
|
|
|
|
|
|
|
|
def predict_with_text( |
|
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX, |
|
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY, |
|
YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE, |
|
YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK, |
|
YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR |
|
): |
|
|
|
if not validate_inputs( |
|
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX, |
|
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY, |
|
YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE, |
|
YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK, |
|
YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR |
|
): |
|
return ( |
|
"Please select all required fields.", |
|
"Validation Error", |
|
"No data", |
|
None, |
|
"No data", |
|
None, None, None |
|
) |
|
|
|
|
|
user_inputs = { |
|
'YNURSMDE': input_mapping['YNURSMDE'][YNURSMDE], |
|
'YMDEYR': input_mapping['YMDEYR'][YMDEYR], |
|
'YSOCMDE': input_mapping['YSOCMDE'][YSOCMDE], |
|
'YMDESUD5ANYO': input_mapping['YMDESUD5ANYO'][YMDESUD5ANYO], |
|
'YMSUD5YANY': input_mapping['YMSUD5YANY'][YMSUD5YANY], |
|
'YUSUITHK': input_mapping['YUSUITHK'][YUSUITHK], |
|
'YMDETXRX': input_mapping['YMDETXRX'][YMDETXRX], |
|
'YUSUITHKYR': input_mapping['YUSUITHKYR'][YUSUITHKYR], |
|
'YMDERSUD5ANY': input_mapping['YMDERSUD5ANY'][YMDERSUD5ANY], |
|
'YUSUIPLNYR': input_mapping['YUSUIPLNYR'][YUSUIPLNYR], |
|
'YCOUNMDE': input_mapping['YCOUNMDE'][YCOUNMDE], |
|
'YPSY1MDE': input_mapping['YPSY1MDE'][YPSY1MDE], |
|
'YHLTMDE': input_mapping['YHLTMDE'][YHLTMDE], |
|
'YDOCMDE': input_mapping['YDOCMDE'][YDOCMDE], |
|
'YPSY2MDE': input_mapping['YPSY2MDE'][YPSY2MDE], |
|
'YMDEHARX': input_mapping['YMDEHARX'][YMDEHARX], |
|
'LVLDIFMEM2': input_mapping['LVLDIFMEM2'][LVLDIFMEM2], |
|
'MDEIMPY': input_mapping['MDEIMPY'][MDEIMPY], |
|
'YMDEHPO': input_mapping['YMDEHPO'][YMDEHPO], |
|
'YMIMS5YANY': input_mapping['YMIMS5YANY'][YMIMS5YANY], |
|
'YMDEIMAD5YR': input_mapping['YMDEIMAD5YR'][YMDEIMAD5YR], |
|
'YMIUD5YANY': input_mapping['YMIUD5YANY'][YMIUD5YANY], |
|
'YMDEHPRX': input_mapping['YMDEHPRX'][YMDEHPRX], |
|
'YMIMI5YANY': input_mapping['YMIMI5YANY'][YMIMI5YANY], |
|
'YUSUIPLN': input_mapping['YUSUIPLN'][YUSUIPLN], |
|
'YTXMDEYR': input_mapping['YTXMDEYR'][YTXMDEYR], |
|
'YMDEAUD5YR': input_mapping['YMDEAUD5YR'][YMDEAUD5YR], |
|
'YRXMDEYR': input_mapping['YRXMDEYR'][YRXMDEYR], |
|
'YMDELT': input_mapping['YMDELT'][YMDELT] |
|
} |
|
|
|
|
|
return predict(**user_inputs) |
|
|
|
|
|
|
|
|
|
import sys |
|
|
|
|
|
outputs = [ |
|
gr.Textbox(label="Prediction Results", lines=30), |
|
gr.Textbox(label="Mental Health Severity", lines=4), |
|
gr.Markdown(label="Total Patient Count"), |
|
gr.Plot(label="Distribution of All Input Features vs. One Label"), |
|
gr.Markdown(label="Nearest Neighbors Summary"), |
|
gr.Plot(label="Placeholder Plot"), |
|
gr.Plot(label="Placeholder Plot"), |
|
gr.Plot(label="Placeholder Plot") |
|
] |
|
|
|
|
|
inputs = [ |
|
|
|
gr.Dropdown(list(input_mapping['YMDEYR'].keys()), label="YMDEYR: PAST YEAR MDE?"), |
|
gr.Dropdown(list(input_mapping['YMDERSUD5ANY'].keys()), label="YMDERSUD5ANY: MDE OR SUBSTANCE USE DISORDER - ANY"), |
|
gr.Dropdown(list(input_mapping['YMDEIMAD5YR'].keys()), label="YMDEIMAD5YR: MDE + ALCOHOL USE DISORDER?"), |
|
gr.Dropdown(list(input_mapping['YMIMS5YANY'].keys()), label="YMIMS5YANY: MDE + SUBSTANCE USE DISORDER?"), |
|
gr.Dropdown(list(input_mapping['YMDELT'].keys()), label="YMDELT: EVER HAD MDE LIFETIME?"), |
|
gr.Dropdown(list(input_mapping['YMDEHARX'].keys()), label="YMDEHARX: SAW HEALTH PROF + MEDS FOR MDE"), |
|
gr.Dropdown(list(input_mapping['YMDEHPRX'].keys()), label="YMDEHPRX: SAW HEALTH PROF OR MEDS FOR MDE"), |
|
gr.Dropdown(list(input_mapping['YMDETXRX'].keys()), label="YMDETXRX: TREATMENT/COUNSELING FOR MDE"), |
|
gr.Dropdown(list(input_mapping['YMDEHPO'].keys()), label="YMDEHPO: HEALTH PROF ONLY FOR MDE"), |
|
gr.Dropdown(list(input_mapping['YMDEAUD5YR'].keys()), label="YMDEAUD5YR: MDE + ALCOHOL USE DISORDER"), |
|
gr.Dropdown(list(input_mapping['YMIMI5YANY'].keys()), label="YMIMI5YANY: MDE + ILL DRUG USE DISORDER"), |
|
gr.Dropdown(list(input_mapping['YMIUD5YANY'].keys()), label="YMIUD5YANY: MDE + ILL DRUG USE DISORDER"), |
|
gr.Dropdown(list(input_mapping['YMDESUD5ANYO'].keys()), label="YMDESUD5ANYO: MDE vs. SUD vs. BOTH vs. NEITHER"), |
|
|
|
|
|
gr.Dropdown(list(input_mapping['YNURSMDE'].keys()), label="YNURSMDE: NURSE / OT FOR MDE"), |
|
gr.Dropdown(list(input_mapping['YSOCMDE'].keys()), label="YSOCMDE: SOCIAL WORKER FOR MDE"), |
|
gr.Dropdown(list(input_mapping['YCOUNMDE'].keys()), label="YCOUNMDE: COUNSELOR FOR MDE"), |
|
gr.Dropdown(list(input_mapping['YPSY1MDE'].keys()), label="YPSY1MDE: PSYCHOLOGIST FOR MDE"), |
|
gr.Dropdown(list(input_mapping['YPSY2MDE'].keys()), label="YPSY2MDE: PSYCHIATRIST FOR MDE"), |
|
gr.Dropdown(list(input_mapping['YHLTMDE'].keys()), label="YHLTMDE: HEALTH PROF FOR MDE"), |
|
gr.Dropdown(list(input_mapping['YDOCMDE'].keys()), label="YDOCMDE: GP/FAMILY MD FOR MDE"), |
|
gr.Dropdown(list(input_mapping['YTXMDEYR'].keys()), label="YTXMDEYR: DOCTOR/HEALTH PROF FOR MDE THIS YEAR"), |
|
|
|
|
|
gr.Dropdown(list(input_mapping['YUSUITHKYR'].keys()), label="YUSUITHKYR: SERIOUSLY THOUGHT ABOUT KILLING SELF"), |
|
gr.Dropdown(list(input_mapping['YUSUIPLNYR'].keys()), label="YUSUIPLNYR: MADE PLANS TO KILL SELF"), |
|
gr.Dropdown(list(input_mapping['YUSUITHK'].keys()), label="YUSUITHK: THINK ABOUT KILLING SELF (12 MONTHS)"), |
|
gr.Dropdown(list(input_mapping['YUSUIPLN'].keys()), label="YUSUIPLN: MADE PLANS TO KILL SELF (12 MONTHS)"), |
|
|
|
|
|
gr.Dropdown(list(input_mapping['MDEIMPY'].keys()), label="MDEIMPY: MDE WITH SEVERE ROLE IMPAIRMENT?"), |
|
gr.Dropdown(list(input_mapping['LVLDIFMEM2'].keys()), label="LVLDIFMEM2: DIFFICULTY REMEMBERING/CONCENTRATING"), |
|
gr.Dropdown(list(input_mapping['YMSUD5YANY'].keys()), label="YMSUD5YANY: MDE + SUBSTANCE USE DISORDER?"), |
|
gr.Dropdown(list(input_mapping['YRXMDEYR'].keys()), label="YRXMDEYR: USED MEDS FOR MDE IN PAST YEAR?") |
|
] |
|
|
|
|
|
custom_css = """ |
|
.gradio-container * { |
|
color: #1B1212 !important; |
|
} |
|
.gradio-container .form .form-group label { |
|
color: #1B1212 !important; |
|
} |
|
.gradio-container .output-textbox, |
|
.gradio-container .output-textbox textarea { |
|
color: #1B1212 !important; |
|
} |
|
.gradio-container .label, |
|
.gradio-container .input-label { |
|
color: #1B1212 !important; |
|
} |
|
""" |
|
|
|
|
|
interface = gr.Interface( |
|
fn=predict_with_text, |
|
inputs=inputs, |
|
outputs=outputs, |
|
title="Adolescents with Substance Use Mental Health Screening (NSDUH Data)", |
|
css=custom_css, |
|
) |
|
|
|
if __name__ == "__main__": |
|
interface.launch() |
|
|