pantdipendra
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -389,36 +389,21 @@ def predict(
|
|
389 |
)
|
390 |
fig_in.update_layout(width=1200, height=400)
|
391 |
|
392 |
-
# 8) Bar chart for predicted labels
|
393 |
-
|
394 |
for lbl_col, (pred_val, _) in label_prediction_info.items():
|
395 |
if lbl_col in df.columns:
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
"Label Column": lbl_col,
|
403 |
-
"Label Value": str(val),
|
404 |
-
"Count": cnt
|
405 |
-
})
|
406 |
-
|
407 |
-
if label_counts_data:
|
408 |
-
bar_lbl_df = pd.DataFrame(label_counts_data)
|
409 |
fig_lbl = px.bar(
|
410 |
-
bar_lbl_df,
|
411 |
-
|
412 |
-
y="Count",
|
413 |
-
color="Label Value",
|
414 |
-
title="Stacked Bar of 0 and 1 (Exact Counts) per Predicted Label Column"
|
415 |
-
)
|
416 |
-
fig_lbl.update_layout(
|
417 |
-
width=1200,
|
418 |
-
height=400,
|
419 |
-
barmode="stack", # stacked bars
|
420 |
-
barnorm=None # no normalization => actual counts
|
421 |
)
|
|
|
422 |
else:
|
423 |
fig_lbl = px.bar(title="No valid predicted labels to display.")
|
424 |
fig_lbl.update_layout(width=1200, height=400)
|
@@ -429,7 +414,7 @@ def predict(
|
|
429 |
total_count_md, # 3) Total Patient Count
|
430 |
nn_md, # 4) Nearest Neighbors Summary
|
431 |
fig_in, # 5) Bar Chart (input features)
|
432 |
-
fig_lbl # 6) Bar Chart (labels
|
433 |
)
|
434 |
|
435 |
######################################
|
@@ -440,6 +425,7 @@ def combined_plot(feature_list, label_col):
|
|
440 |
If user picks 1 feature => distribution plot.
|
441 |
If user picks 2 features => co-occurrence plot.
|
442 |
Otherwise => show error or empty plot.
|
|
|
443 |
This function also maps numeric codes to text using 'input_mapping'
|
444 |
and 'predictor.prediction_map' so that the plots display more readable labels.
|
445 |
"""
|
@@ -452,10 +438,11 @@ def combined_plot(feature_list, label_col):
|
|
452 |
# A) Convert numeric codes -> text for each feature in `input_mapping`
|
453 |
for col, text_to_num_dict in input_mapping.items():
|
454 |
if col in df_copy.columns:
|
|
|
455 |
num_to_text = {v: k for k, v in text_to_num_dict.items()}
|
456 |
df_copy[col] = df_copy[col].map(num_to_text).fillna(df_copy[col])
|
457 |
|
458 |
-
# B) Convert label 0/1 to text if label_col is in predictor.prediction_map
|
459 |
if label_col in predictor.prediction_map and label_col in df_copy.columns:
|
460 |
zero_text, one_text = predictor.prediction_map[label_col]
|
461 |
label_map = {0: zero_text, 1: one_text}
|
@@ -514,8 +501,8 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
|
|
514 |
("YMDERSUD5ANY", "YMDERSUD5ANY: MDE or SUD in past year?"),
|
515 |
("YMSUD5YANY", "YMSUD5YANY: Past-year MDE & substance use disorder"),
|
516 |
("YMIUD5YANY", "YMIUD5YANY: Past-year MDE & illicit drug use disorder"),
|
517 |
-
("YMIMS5YANY", "
|
518 |
-
("YMIMI5YANY", "
|
519 |
]
|
520 |
cat1_inputs = []
|
521 |
for col, label_text in cat1_col_labels:
|
@@ -593,7 +580,7 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
|
|
593 |
out_count = gr.Markdown(label="Total Patient Count")
|
594 |
out_nn = gr.Markdown(label="Nearest Neighbors Summary")
|
595 |
out_bar_input= gr.Plot(label="Input Feature Counts")
|
596 |
-
out_bar_label= gr.Plot(label="Predicted Label Counts
|
597 |
|
598 |
# Connect the predict button to the predict function
|
599 |
predict_btn.click(
|
@@ -613,7 +600,7 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
|
|
613 |
with gr.Tab("Distribution/Co-occurrence"):
|
614 |
gr.Markdown("### Select 1 or 2 features + 1 label to see a bar chart.")
|
615 |
|
616 |
-
# Show only your input features
|
617 |
list_of_features = sorted(input_mapping.keys())
|
618 |
# Show all label columns from the predictor map
|
619 |
list_of_labels = sorted(predictor.prediction_map.keys())
|
@@ -637,4 +624,4 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
|
|
637 |
)
|
638 |
|
639 |
# Finally, launch the Gradio app
|
640 |
-
demo.launch()
|
|
|
389 |
)
|
390 |
fig_in.update_layout(width=1200, height=400)
|
391 |
|
392 |
+
# 8) Bar chart for predicted labels
|
393 |
+
label_counts = {}
|
394 |
for lbl_col, (pred_val, _) in label_prediction_info.items():
|
395 |
if lbl_col in df.columns:
|
396 |
+
label_counts[lbl_col] = len(df[df[lbl_col] == pred_val])
|
397 |
+
if label_counts:
|
398 |
+
bar_lbl_df = pd.DataFrame({
|
399 |
+
"Label": list(label_counts.keys()),
|
400 |
+
"Count": list(label_counts.values())
|
401 |
+
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
402 |
fig_lbl = px.bar(
|
403 |
+
bar_lbl_df, x="Label", y="Count",
|
404 |
+
title="Number of Patients with the Same Predicted Label"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
405 |
)
|
406 |
+
fig_lbl.update_layout(width=1200, height=400)
|
407 |
else:
|
408 |
fig_lbl = px.bar(title="No valid predicted labels to display.")
|
409 |
fig_lbl.update_layout(width=1200, height=400)
|
|
|
414 |
total_count_md, # 3) Total Patient Count
|
415 |
nn_md, # 4) Nearest Neighbors Summary
|
416 |
fig_in, # 5) Bar Chart (input features)
|
417 |
+
fig_lbl # 6) Bar Chart (labels)
|
418 |
)
|
419 |
|
420 |
######################################
|
|
|
425 |
If user picks 1 feature => distribution plot.
|
426 |
If user picks 2 features => co-occurrence plot.
|
427 |
Otherwise => show error or empty plot.
|
428 |
+
|
429 |
This function also maps numeric codes to text using 'input_mapping'
|
430 |
and 'predictor.prediction_map' so that the plots display more readable labels.
|
431 |
"""
|
|
|
438 |
# A) Convert numeric codes -> text for each feature in `input_mapping`
|
439 |
for col, text_to_num_dict in input_mapping.items():
|
440 |
if col in df_copy.columns:
|
441 |
+
# Reverse mapping: "Yes"->1 becomes 1->"Yes"
|
442 |
num_to_text = {v: k for k, v in text_to_num_dict.items()}
|
443 |
df_copy[col] = df_copy[col].map(num_to_text).fillna(df_copy[col])
|
444 |
|
445 |
+
# B) Convert label 0/1 to text in df_copy if label_col is in predictor.prediction_map
|
446 |
if label_col in predictor.prediction_map and label_col in df_copy.columns:
|
447 |
zero_text, one_text = predictor.prediction_map[label_col]
|
448 |
label_map = {0: zero_text, 1: one_text}
|
|
|
501 |
("YMDERSUD5ANY", "YMDERSUD5ANY: MDE or SUD in past year?"),
|
502 |
("YMSUD5YANY", "YMSUD5YANY: Past-year MDE & substance use disorder"),
|
503 |
("YMIUD5YANY", "YMIUD5YANY: Past-year MDE & illicit drug use disorder"),
|
504 |
+
("YMIMS5YANY", "YMIMS5YANY: Past-year MDE + severe impairment + substance use"),
|
505 |
+
("YMIMI5YANY", "YMIMI5YANY: Past-year MDE w/ severe impairment & illicit drug use")
|
506 |
]
|
507 |
cat1_inputs = []
|
508 |
for col, label_text in cat1_col_labels:
|
|
|
580 |
out_count = gr.Markdown(label="Total Patient Count")
|
581 |
out_nn = gr.Markdown(label="Nearest Neighbors Summary")
|
582 |
out_bar_input= gr.Plot(label="Input Feature Counts")
|
583 |
+
out_bar_label= gr.Plot(label="Predicted Label Counts")
|
584 |
|
585 |
# Connect the predict button to the predict function
|
586 |
predict_btn.click(
|
|
|
600 |
with gr.Tab("Distribution/Co-occurrence"):
|
601 |
gr.Markdown("### Select 1 or 2 features + 1 label to see a bar chart.")
|
602 |
|
603 |
+
# Show only your 25 input features
|
604 |
list_of_features = sorted(input_mapping.keys())
|
605 |
# Show all label columns from the predictor map
|
606 |
list_of_labels = sorted(predictor.prediction_map.keys())
|
|
|
624 |
)
|
625 |
|
626 |
# Finally, launch the Gradio app
|
627 |
+
demo.launch()
|