pantdipendra
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -9,6 +9,7 @@ import gradio as gr
|
|
9 |
######################################
|
10 |
df = pd.read_csv("X_train_Y_Train_merged_train.csv")
|
11 |
|
|
|
12 |
model_filenames = [
|
13 |
"YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl",
|
14 |
"YODPPROB.pkl", "YOWRPROB.pkl", "YODPR2WK.pkl", "YOWRDEPR.pkl",
|
@@ -17,6 +18,7 @@ model_filenames = [
|
|
17 |
]
|
18 |
model_path = "models/"
|
19 |
|
|
|
20 |
######################################
|
21 |
# 2) Model Predictor
|
22 |
######################################
|
@@ -59,7 +61,10 @@ class ModelPredictor:
|
|
59 |
return models
|
60 |
|
61 |
def make_predictions(self, user_input):
|
62 |
-
|
|
|
|
|
|
|
63 |
predictions = []
|
64 |
for model in self.models:
|
65 |
pred = model.predict(user_input)
|
@@ -67,13 +72,16 @@ class ModelPredictor:
|
|
67 |
return predictions
|
68 |
|
69 |
def get_majority_vote(self, predictions):
|
|
|
|
|
|
|
|
|
70 |
combined = np.concatenate(predictions)
|
71 |
-
|
72 |
-
|
73 |
-
return majority_vote
|
74 |
|
|
|
75 |
def evaluate_severity(self, majority_vote_count):
|
76 |
-
# Simple threshold approach
|
77 |
if majority_vote_count >= 13:
|
78 |
return "Mental Health Severity: Severe"
|
79 |
elif majority_vote_count >= 9:
|
@@ -83,6 +91,7 @@ class ModelPredictor:
|
|
83 |
else:
|
84 |
return "Mental Health Severity: Very Low"
|
85 |
|
|
|
86 |
######################################
|
87 |
# 3) Validate Inputs
|
88 |
######################################
|
@@ -92,6 +101,7 @@ def validate_inputs(*args):
|
|
92 |
return False
|
93 |
return True
|
94 |
|
|
|
95 |
######################################
|
96 |
# 4) Core Prediction
|
97 |
######################################
|
@@ -104,6 +114,26 @@ def predict(
|
|
104 |
YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
|
105 |
YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
|
106 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
user_input_data = {
|
108 |
'YNURSMDE': [int(YNURSMDE)],
|
109 |
'YMDEYR': [int(YMDEYR)],
|
@@ -137,48 +167,46 @@ def predict(
|
|
137 |
}
|
138 |
user_input = pd.DataFrame(user_input_data)
|
139 |
|
140 |
-
# 1)
|
141 |
predictions = predictor.make_predictions(user_input)
|
142 |
|
143 |
# 2) Majority vote
|
144 |
majority_vote = predictor.get_majority_vote(predictions)
|
145 |
|
146 |
-
# 3) Count
|
147 |
num_ones = sum(np.concatenate(predictions) == 1)
|
148 |
|
149 |
# 4) Severity
|
150 |
severity = predictor.evaluate_severity(num_ones)
|
151 |
|
152 |
-
# 5)
|
153 |
-
# [Same grouping logic as before, or adapt as needed]
|
154 |
groups = {
|
155 |
"Concentration_and_Decision_Making": ["YOWRCONC", "YOWRDCSN"],
|
156 |
"Sleep_and_Energy_Levels": ["YOWRHRS", "YO_MDEA5", "YOWRELES", "YO_MDEA2"],
|
157 |
-
"Mood_and_Emotional_State": ["YOWRCHR", "YOWRLSIN", "YOWRDEPR", "YODPDISC",
|
158 |
"YOLOSEV", "YODPLSIN", "YODSCEV"],
|
159 |
"Appetite_and_Weight_Changes": ["YO_MDEA3", "YOWRELES"],
|
160 |
-
"Duration_and_Severity_of_Depression_Symptoms": ["YODPPROB", "YOWRPROB",
|
161 |
-
"YODPR2WK", "YODSMMDE",
|
162 |
"YOPB2WK"]
|
163 |
}
|
|
|
164 |
grouped_text = {k: [] for k in groups}
|
165 |
-
for i,
|
166 |
col_name = model_filenames[i].split('.')[0]
|
167 |
-
pred_val =
|
168 |
if col_name in predictor.prediction_map and pred_val in [0,1]:
|
169 |
text_val = predictor.prediction_map[col_name][pred_val]
|
170 |
else:
|
171 |
text_val = f"Prediction={pred_val}"
|
172 |
-
|
173 |
-
|
174 |
for gname, gcols in groups.items():
|
175 |
if col_name in gcols:
|
176 |
grouped_text[gname].append(f"{col_name} => {text_val}")
|
177 |
-
|
178 |
break
|
179 |
-
|
180 |
-
# Or skip
|
181 |
-
pass
|
182 |
|
183 |
final_str = []
|
184 |
for gname, items in grouped_text.items():
|
@@ -190,7 +218,7 @@ def predict(
|
|
190 |
if not final_str:
|
191 |
final_str = "No predictions made. Please check inputs."
|
192 |
|
193 |
-
#
|
194 |
total_patients = len(df)
|
195 |
total_patient_markdown = (
|
196 |
f"### Total Patient Count\nThere are **{total_patients}** patients in the dataset."
|
@@ -203,18 +231,20 @@ def predict(
|
|
203 |
same_val_counts[col] = len(df[df[col] == val_])
|
204 |
bar_input_df = pd.DataFrame({"Feature": list(same_val_counts.keys()),
|
205 |
"Count": list(same_val_counts.values())})
|
206 |
-
fig_bar_input = px.bar(
|
207 |
-
|
|
|
|
|
208 |
fig_bar_input.update_layout(width=800, height=500)
|
209 |
|
210 |
# B) Bar chart for predicted labels
|
211 |
label_counts = {}
|
212 |
-
all_preds_flat = np.concatenate(predictions)
|
213 |
for i, arr in enumerate(predictions):
|
214 |
lbl_col = model_filenames[i].split('.')[0]
|
215 |
pred_val = arr[0]
|
216 |
if pred_val in [0,1]:
|
217 |
label_counts[lbl_col] = len(df[df[lbl_col] == pred_val])
|
|
|
218 |
if label_counts:
|
219 |
bar_label_df = pd.DataFrame({"Label": list(label_counts.keys()),
|
220 |
"Count": list(label_counts.values())})
|
@@ -226,9 +256,8 @@ def predict(
|
|
226 |
fig_bar_labels.update_layout(width=800, height=500)
|
227 |
|
228 |
# C) Distribution Plot (small sample)
|
229 |
-
|
230 |
-
|
231 |
-
subset_labels = [fn.split('.')[0] for fn in model_filenames[:3]]
|
232 |
dist_rows = []
|
233 |
for feat in subset_input_cols:
|
234 |
if feat not in df.columns:
|
@@ -242,37 +271,38 @@ def predict(
|
|
242 |
dist_rows.append(tmp)
|
243 |
if dist_rows:
|
244 |
big_dist_df = pd.concat(dist_rows, ignore_index=True)
|
245 |
-
fig_dist = px.bar(
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
|
|
|
|
252 |
fig_dist.update_layout(width=1000, height=700)
|
253 |
else:
|
254 |
fig_dist = px.bar(title="Distribution plot not generated.")
|
255 |
|
256 |
-
# D) Nearest
|
257 |
-
|
258 |
-
# or keep it.
|
259 |
-
# For now, let's produce an empty markdown
|
260 |
-
nearest_neighbors_markdown = "Nearest neighbors omitted here for brevity..."
|
261 |
|
262 |
-
# We won't produce a
|
|
|
263 |
|
264 |
-
# Return 8
|
265 |
return (
|
266 |
-
final_str,
|
267 |
-
severity,
|
268 |
-
total_patient_markdown,
|
269 |
-
fig_dist,
|
270 |
-
nearest_neighbors_markdown,
|
271 |
-
|
272 |
-
fig_bar_input,
|
273 |
-
fig_bar_labels
|
274 |
)
|
275 |
|
|
|
276 |
######################################
|
277 |
# 5) Input Mapping
|
278 |
######################################
|
@@ -308,20 +338,19 @@ input_mapping = {
|
|
308 |
'YMDELT': {"Yes": 1, "No": 2}
|
309 |
}
|
310 |
|
|
|
311 |
######################################
|
312 |
-
# 6) Co-Occurrence Function
|
313 |
######################################
|
314 |
def co_occurrence_plot(feature1, feature2, label_col):
|
315 |
"""
|
316 |
Generate a single co-occurrence bar chart grouping by [feature1, feature2, label_col].
|
317 |
-
We set a custom width/height so it's clearly visible.
|
318 |
"""
|
319 |
if not feature1 or not feature2 or not label_col:
|
320 |
return px.bar(title="Please select all three fields.")
|
321 |
if feature1 not in df.columns or feature2 not in df.columns or label_col not in df.columns:
|
322 |
return px.bar(title="Selected columns not found in the dataset.")
|
323 |
|
324 |
-
# Group
|
325 |
grouped_df = df.groupby([feature1, feature2, label_col]).size().reset_index(name="count")
|
326 |
fig = px.bar(
|
327 |
grouped_df,
|
@@ -334,13 +363,14 @@ def co_occurrence_plot(feature1, feature2, label_col):
|
|
334 |
fig.update_layout(width=1000, height=600)
|
335 |
return fig
|
336 |
|
|
|
337 |
######################################
|
338 |
-
# 7) Gradio with Tabs
|
339 |
######################################
|
340 |
with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
|
341 |
|
342 |
with gr.Tab("Prediction"):
|
343 |
-
#
|
344 |
YMDEYR_dd = gr.Dropdown(list(input_mapping['YMDEYR'].keys()), label="YMDEYR")
|
345 |
YMDERSUD5ANY_dd = gr.Dropdown(list(input_mapping['YMDERSUD5ANY'].keys()), label="YMDERSUD5ANY")
|
346 |
YMDEIMAD5YR_dd = gr.Dropdown(list(input_mapping['YMDEIMAD5YR'].keys()), label="YMDEIMAD5YR")
|
@@ -377,7 +407,10 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
|
|
377 |
YMSUD5YANY_dd = gr.Dropdown(list(input_mapping['YMSUD5YANY'].keys()), label="YMSUD5YANY")
|
378 |
YRXMDEYR_dd = gr.Dropdown(list(input_mapping['YRXMDEYR'].keys()), label="YRXMDEYR")
|
379 |
|
380 |
-
#
|
|
|
|
|
|
|
381 |
out_pred_res = gr.Textbox(label="Prediction Results", lines=8)
|
382 |
out_sev = gr.Textbox(label="Mental Health Severity", lines=2)
|
383 |
out_count = gr.Markdown(label="Total Patient Count")
|
@@ -387,9 +420,6 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
|
|
387 |
out_bar_input = gr.Plot(label="Input Feature Counts")
|
388 |
out_bar_labels = gr.Plot(label="Predicted Label Counts")
|
389 |
|
390 |
-
# Button
|
391 |
-
predict_btn = gr.Button("Predict")
|
392 |
-
|
393 |
# Link button to the function
|
394 |
predict_btn.click(
|
395 |
fn=predict,
|
@@ -401,10 +431,12 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
|
|
401 |
YUSUIPLN_dd, MDEIMPY_dd, LVLDIFMEM2_dd, YMSUD5YANY_dd, YRXMDEYR_dd
|
402 |
],
|
403 |
outputs=[
|
404 |
-
out_pred_res, out_sev, out_count, out_distplot,
|
|
|
405 |
]
|
406 |
)
|
407 |
|
|
|
408 |
with gr.Tab("Co-occurrence"):
|
409 |
gr.Markdown("## Generate a Co-Occurrence Plot on Demand\nSelect two features and one label:")
|
410 |
with gr.Row():
|
@@ -414,22 +446,11 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
|
|
414 |
out_co_occ_plot = gr.Plot(label="Co-occurrence Plot")
|
415 |
|
416 |
co_occ_btn = gr.Button("Generate Plot")
|
417 |
-
|
418 |
-
# Link to co_occurrence_plot function
|
419 |
co_occ_btn.click(
|
420 |
fn=co_occurrence_plot,
|
421 |
inputs=[feature1_dd, feature2_dd, label_dd],
|
422 |
outputs=out_co_occ_plot
|
423 |
)
|
424 |
|
425 |
-
#
|
426 |
-
custom_css = """
|
427 |
-
.gradio-container {
|
428 |
-
max-width: 1200px;
|
429 |
-
margin-left: auto;
|
430 |
-
margin-right: auto;
|
431 |
-
}
|
432 |
-
"""
|
433 |
-
|
434 |
-
# Launch
|
435 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
9 |
######################################
|
10 |
df = pd.read_csv("X_train_Y_Train_merged_train.csv")
|
11 |
|
12 |
+
# List of model filenames (adjust if needed)
|
13 |
model_filenames = [
|
14 |
"YOWRCONC.pkl", "YOSEEDOC.pkl", "YO_MDEA5.pkl", "YOWRLSIN.pkl",
|
15 |
"YODPPROB.pkl", "YOWRPROB.pkl", "YODPR2WK.pkl", "YOWRDEPR.pkl",
|
|
|
18 |
]
|
19 |
model_path = "models/"
|
20 |
|
21 |
+
|
22 |
######################################
|
23 |
# 2) Model Predictor
|
24 |
######################################
|
|
|
61 |
return models
|
62 |
|
63 |
def make_predictions(self, user_input):
|
64 |
+
"""
|
65 |
+
Returns a list of numpy arrays, each array is [0] or [1].
|
66 |
+
The i-th array corresponds to the i-th model in self.models.
|
67 |
+
"""
|
68 |
predictions = []
|
69 |
for model in self.models:
|
70 |
pred = model.predict(user_input)
|
|
|
72 |
return predictions
|
73 |
|
74 |
def get_majority_vote(self, predictions):
|
75 |
+
"""
|
76 |
+
Flatten all predictions from all models, combine them,
|
77 |
+
then find the majority class (0 or 1).
|
78 |
+
"""
|
79 |
combined = np.concatenate(predictions)
|
80 |
+
majority = np.bincount(combined).argmax()
|
81 |
+
return majority
|
|
|
82 |
|
83 |
+
# Simple threshold approach (0-4 => Very Low, 5-8 => Low, etc.)
|
84 |
def evaluate_severity(self, majority_vote_count):
|
|
|
85 |
if majority_vote_count >= 13:
|
86 |
return "Mental Health Severity: Severe"
|
87 |
elif majority_vote_count >= 9:
|
|
|
91 |
else:
|
92 |
return "Mental Health Severity: Very Low"
|
93 |
|
94 |
+
|
95 |
######################################
|
96 |
# 3) Validate Inputs
|
97 |
######################################
|
|
|
101 |
return False
|
102 |
return True
|
103 |
|
104 |
+
|
105 |
######################################
|
106 |
# 4) Core Prediction
|
107 |
######################################
|
|
|
114 |
YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
|
115 |
YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
|
116 |
):
|
117 |
+
# Validate
|
118 |
+
if not validate_inputs(
|
119 |
+
YMDEYR, YMDERSUD5ANY, YMDEIMAD5YR, YMIMS5YANY, YMDELT, YMDEHARX,
|
120 |
+
YMDEHPRX, YMDETXRX, YMDEHPO, YMDEAUD5YR, YMIMI5YANY, YMIUD5YANY,
|
121 |
+
YMDESUD5ANYO, YNURSMDE, YSOCMDE, YCOUNMDE, YPSY1MDE, YPSY2MDE,
|
122 |
+
YHLTMDE, YDOCMDE, YTXMDEYR, YUSUITHKYR, YUSUIPLNYR, YUSUITHK,
|
123 |
+
YUSUIPLN, MDEIMPY, LVLDIFMEM2, YMSUD5YANY, YRXMDEYR
|
124 |
+
):
|
125 |
+
return (
|
126 |
+
"Please select all required fields.",
|
127 |
+
"Validation Error",
|
128 |
+
"No data",
|
129 |
+
None,
|
130 |
+
"No data",
|
131 |
+
None,
|
132 |
+
None,
|
133 |
+
None
|
134 |
+
)
|
135 |
+
|
136 |
+
# Build dataframe from user inputs
|
137 |
user_input_data = {
|
138 |
'YNURSMDE': [int(YNURSMDE)],
|
139 |
'YMDEYR': [int(YMDEYR)],
|
|
|
167 |
}
|
168 |
user_input = pd.DataFrame(user_input_data)
|
169 |
|
170 |
+
# 1) Predictions
|
171 |
predictions = predictor.make_predictions(user_input)
|
172 |
|
173 |
# 2) Majority vote
|
174 |
majority_vote = predictor.get_majority_vote(predictions)
|
175 |
|
176 |
+
# 3) Count of '1's
|
177 |
num_ones = sum(np.concatenate(predictions) == 1)
|
178 |
|
179 |
# 4) Severity
|
180 |
severity = predictor.evaluate_severity(num_ones)
|
181 |
|
182 |
+
# 5) Group textual results
|
|
|
183 |
groups = {
|
184 |
"Concentration_and_Decision_Making": ["YOWRCONC", "YOWRDCSN"],
|
185 |
"Sleep_and_Energy_Levels": ["YOWRHRS", "YO_MDEA5", "YOWRELES", "YO_MDEA2"],
|
186 |
+
"Mood_and_Emotional_State": ["YOWRCHR", "YOWRLSIN", "YOWRDEPR", "YODPDISC",
|
187 |
"YOLOSEV", "YODPLSIN", "YODSCEV"],
|
188 |
"Appetite_and_Weight_Changes": ["YO_MDEA3", "YOWRELES"],
|
189 |
+
"Duration_and_Severity_of_Depression_Symptoms": ["YODPPROB", "YOWRPROB",
|
190 |
+
"YODPR2WK", "YODSMMDE",
|
191 |
"YOPB2WK"]
|
192 |
}
|
193 |
+
|
194 |
grouped_text = {k: [] for k in groups}
|
195 |
+
for i, arr in enumerate(predictions):
|
196 |
col_name = model_filenames[i].split('.')[0]
|
197 |
+
pred_val = arr[0]
|
198 |
if col_name in predictor.prediction_map and pred_val in [0,1]:
|
199 |
text_val = predictor.prediction_map[col_name][pred_val]
|
200 |
else:
|
201 |
text_val = f"Prediction={pred_val}"
|
202 |
+
|
203 |
+
found_group = False
|
204 |
for gname, gcols in groups.items():
|
205 |
if col_name in gcols:
|
206 |
grouped_text[gname].append(f"{col_name} => {text_val}")
|
207 |
+
found_group = True
|
208 |
break
|
209 |
+
# If not found_group, we do nothing (skip or put in a "misc" group)
|
|
|
|
|
210 |
|
211 |
final_str = []
|
212 |
for gname, items in grouped_text.items():
|
|
|
218 |
if not final_str:
|
219 |
final_str = "No predictions made. Please check inputs."
|
220 |
|
221 |
+
# Additional info
|
222 |
total_patients = len(df)
|
223 |
total_patient_markdown = (
|
224 |
f"### Total Patient Count\nThere are **{total_patients}** patients in the dataset."
|
|
|
231 |
same_val_counts[col] = len(df[df[col] == val_])
|
232 |
bar_input_df = pd.DataFrame({"Feature": list(same_val_counts.keys()),
|
233 |
"Count": list(same_val_counts.values())})
|
234 |
+
fig_bar_input = px.bar(
|
235 |
+
bar_input_df, x="Feature", y="Count",
|
236 |
+
title="Number of Patients with Same Input Feature Values"
|
237 |
+
)
|
238 |
fig_bar_input.update_layout(width=800, height=500)
|
239 |
|
240 |
# B) Bar chart for predicted labels
|
241 |
label_counts = {}
|
|
|
242 |
for i, arr in enumerate(predictions):
|
243 |
lbl_col = model_filenames[i].split('.')[0]
|
244 |
pred_val = arr[0]
|
245 |
if pred_val in [0,1]:
|
246 |
label_counts[lbl_col] = len(df[df[lbl_col] == pred_val])
|
247 |
+
|
248 |
if label_counts:
|
249 |
bar_label_df = pd.DataFrame({"Label": list(label_counts.keys()),
|
250 |
"Count": list(label_counts.values())})
|
|
|
256 |
fig_bar_labels.update_layout(width=800, height=500)
|
257 |
|
258 |
# C) Distribution Plot (small sample)
|
259 |
+
subset_input_cols = list(user_input_data.keys())[:4] # first 4 input columns
|
260 |
+
subset_labels = [fn.split('.')[0] for fn in model_filenames[:3]] # first 3 label columns
|
|
|
261 |
dist_rows = []
|
262 |
for feat in subset_input_cols:
|
263 |
if feat not in df.columns:
|
|
|
271 |
dist_rows.append(tmp)
|
272 |
if dist_rows:
|
273 |
big_dist_df = pd.concat(dist_rows, ignore_index=True)
|
274 |
+
fig_dist = px.bar(
|
275 |
+
big_dist_df,
|
276 |
+
x=big_dist_df.columns[0],
|
277 |
+
y="count",
|
278 |
+
color=big_dist_df.columns[1],
|
279 |
+
facet_row="feature",
|
280 |
+
facet_col="label",
|
281 |
+
title="Distribution of Sample Input Features vs. Sample Predicted Labels"
|
282 |
+
)
|
283 |
fig_dist.update_layout(width=1000, height=700)
|
284 |
else:
|
285 |
fig_dist = px.bar(title="Distribution plot not generated.")
|
286 |
|
287 |
+
# D) Nearest neighbors (placeholder or your own logic)
|
288 |
+
nearest_neighbors_markdown = "Nearest neighbors omitted or placed here if needed..."
|
|
|
|
|
|
|
289 |
|
290 |
+
# We won't produce a co-occurrence plot by default here, so set to None
|
291 |
+
co_occurrence_placeholder = None
|
292 |
|
293 |
+
# Return the 8 outputs
|
294 |
return (
|
295 |
+
final_str, # 1) Prediction Results
|
296 |
+
severity, # 2) Mental Health Severity
|
297 |
+
total_patient_markdown, # 3) Total Patient Count
|
298 |
+
fig_dist, # 4) Distribution Plot
|
299 |
+
nearest_neighbors_markdown, # 5) Nearest Neighbors
|
300 |
+
co_occurrence_placeholder, # 6) Co-occurrence Plot placeholder
|
301 |
+
fig_bar_input, # 7) Bar Chart for input features
|
302 |
+
fig_bar_labels # 8) Bar Chart for predicted labels
|
303 |
)
|
304 |
|
305 |
+
|
306 |
######################################
|
307 |
# 5) Input Mapping
|
308 |
######################################
|
|
|
338 |
'YMDELT': {"Yes": 1, "No": 2}
|
339 |
}
|
340 |
|
341 |
+
|
342 |
######################################
|
343 |
+
# 6) Co-Occurrence Function
|
344 |
######################################
|
345 |
def co_occurrence_plot(feature1, feature2, label_col):
|
346 |
"""
|
347 |
Generate a single co-occurrence bar chart grouping by [feature1, feature2, label_col].
|
|
|
348 |
"""
|
349 |
if not feature1 or not feature2 or not label_col:
|
350 |
return px.bar(title="Please select all three fields.")
|
351 |
if feature1 not in df.columns or feature2 not in df.columns or label_col not in df.columns:
|
352 |
return px.bar(title="Selected columns not found in the dataset.")
|
353 |
|
|
|
354 |
grouped_df = df.groupby([feature1, feature2, label_col]).size().reset_index(name="count")
|
355 |
fig = px.bar(
|
356 |
grouped_df,
|
|
|
363 |
fig.update_layout(width=1000, height=600)
|
364 |
return fig
|
365 |
|
366 |
+
|
367 |
######################################
|
368 |
+
# 7) Gradio Interface with Tabs
|
369 |
######################################
|
370 |
with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
|
371 |
|
372 |
with gr.Tab("Prediction"):
|
373 |
+
# --------- INPUT FIELDS --------- #
|
374 |
YMDEYR_dd = gr.Dropdown(list(input_mapping['YMDEYR'].keys()), label="YMDEYR")
|
375 |
YMDERSUD5ANY_dd = gr.Dropdown(list(input_mapping['YMDERSUD5ANY'].keys()), label="YMDERSUD5ANY")
|
376 |
YMDEIMAD5YR_dd = gr.Dropdown(list(input_mapping['YMDEIMAD5YR'].keys()), label="YMDEIMAD5YR")
|
|
|
407 |
YMSUD5YANY_dd = gr.Dropdown(list(input_mapping['YMSUD5YANY'].keys()), label="YMSUD5YANY")
|
408 |
YRXMDEYR_dd = gr.Dropdown(list(input_mapping['YRXMDEYR'].keys()), label="YRXMDEYR")
|
409 |
|
410 |
+
# --------- PREDICT BUTTON (BEFORE OUTPUTS) --------- #
|
411 |
+
predict_btn = gr.Button("Predict")
|
412 |
+
|
413 |
+
# --------- OUTPUTS (IN THE SAME ORDER AS THE RETURN TUPLE) --------- #
|
414 |
out_pred_res = gr.Textbox(label="Prediction Results", lines=8)
|
415 |
out_sev = gr.Textbox(label="Mental Health Severity", lines=2)
|
416 |
out_count = gr.Markdown(label="Total Patient Count")
|
|
|
420 |
out_bar_input = gr.Plot(label="Input Feature Counts")
|
421 |
out_bar_labels = gr.Plot(label="Predicted Label Counts")
|
422 |
|
|
|
|
|
|
|
423 |
# Link button to the function
|
424 |
predict_btn.click(
|
425 |
fn=predict,
|
|
|
431 |
YUSUIPLN_dd, MDEIMPY_dd, LVLDIFMEM2_dd, YMSUD5YANY_dd, YRXMDEYR_dd
|
432 |
],
|
433 |
outputs=[
|
434 |
+
out_pred_res, out_sev, out_count, out_distplot,
|
435 |
+
out_nn, out_cooc, out_bar_input, out_bar_labels
|
436 |
]
|
437 |
)
|
438 |
|
439 |
+
# ------------- SECOND TAB (CO-OCCURRENCE) -------------
|
440 |
with gr.Tab("Co-occurrence"):
|
441 |
gr.Markdown("## Generate a Co-Occurrence Plot on Demand\nSelect two features and one label:")
|
442 |
with gr.Row():
|
|
|
446 |
out_co_occ_plot = gr.Plot(label="Co-occurrence Plot")
|
447 |
|
448 |
co_occ_btn = gr.Button("Generate Plot")
|
|
|
|
|
449 |
co_occ_btn.click(
|
450 |
fn=co_occurrence_plot,
|
451 |
inputs=[feature1_dd, feature2_dd, label_dd],
|
452 |
outputs=out_co_occ_plot
|
453 |
)
|
454 |
|
455 |
+
# Optionally, you can customize your CSS or server launch parameters
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
456 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|