pantdipendra commited on
Commit
685722d
·
verified ·
1 Parent(s): ea4c0c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -33
app.py CHANGED
@@ -389,36 +389,21 @@ def predict(
389
  )
390
  fig_in.update_layout(width=1200, height=400)
391
 
392
- # 8) Bar chart for predicted labels (0 and 1) - vertical stacked, actual counts
393
- label_counts_data = []
394
  for lbl_col, (pred_val, _) in label_prediction_info.items():
395
  if lbl_col in df.columns:
396
- # Count how many in df have label_col == 0 vs label_col == 1
397
- val_counts = df[lbl_col].value_counts(dropna=False)
398
- # Only gather data for 0 or 1
399
- for val, cnt in val_counts.items():
400
- if val in [0, 1]: # filter out anything else
401
- label_counts_data.append({
402
- "Label Column": lbl_col,
403
- "Label Value": str(val),
404
- "Count": cnt
405
- })
406
-
407
- if label_counts_data:
408
- bar_lbl_df = pd.DataFrame(label_counts_data)
409
  fig_lbl = px.bar(
410
- bar_lbl_df,
411
- x="Label Column",
412
- y="Count",
413
- color="Label Value",
414
- title="Stacked Bar of 0 and 1 (Exact Counts) per Predicted Label Column"
415
- )
416
- fig_lbl.update_layout(
417
- width=1200,
418
- height=400,
419
- barmode="stack", # stacked bars
420
- barnorm=None # no normalization => actual counts
421
  )
 
422
  else:
423
  fig_lbl = px.bar(title="No valid predicted labels to display.")
424
  fig_lbl.update_layout(width=1200, height=400)
@@ -429,7 +414,7 @@ def predict(
429
  total_count_md, # 3) Total Patient Count
430
  nn_md, # 4) Nearest Neighbors Summary
431
  fig_in, # 5) Bar Chart (input features)
432
- fig_lbl # 6) Bar Chart (labels: vertical stacked, counts for 0/1 only)
433
  )
434
 
435
  ######################################
@@ -440,6 +425,7 @@ def combined_plot(feature_list, label_col):
440
  If user picks 1 feature => distribution plot.
441
  If user picks 2 features => co-occurrence plot.
442
  Otherwise => show error or empty plot.
 
443
  This function also maps numeric codes to text using 'input_mapping'
444
  and 'predictor.prediction_map' so that the plots display more readable labels.
445
  """
@@ -452,10 +438,11 @@ def combined_plot(feature_list, label_col):
452
  # A) Convert numeric codes -> text for each feature in `input_mapping`
453
  for col, text_to_num_dict in input_mapping.items():
454
  if col in df_copy.columns:
 
455
  num_to_text = {v: k for k, v in text_to_num_dict.items()}
456
  df_copy[col] = df_copy[col].map(num_to_text).fillna(df_copy[col])
457
 
458
- # B) Convert label 0/1 to text if label_col is in predictor.prediction_map
459
  if label_col in predictor.prediction_map and label_col in df_copy.columns:
460
  zero_text, one_text = predictor.prediction_map[label_col]
461
  label_map = {0: zero_text, 1: one_text}
@@ -514,8 +501,8 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
514
  ("YMDERSUD5ANY", "YMDERSUD5ANY: MDE or SUD in past year?"),
515
  ("YMSUD5YANY", "YMSUD5YANY: Past-year MDE & substance use disorder"),
516
  ("YMIUD5YANY", "YMIUD5YANY: Past-year MDE & illicit drug use disorder"),
517
- ("YMIMS5YANY", "YMIUD5YANY: Past-year MDE + severe impairment + substance use"),
518
- ("YMIMI5YANY", "YIMI5YANY: Past-year MDE w/ severe impairment & illicit drug use")
519
  ]
520
  cat1_inputs = []
521
  for col, label_text in cat1_col_labels:
@@ -593,7 +580,7 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
593
  out_count = gr.Markdown(label="Total Patient Count")
594
  out_nn = gr.Markdown(label="Nearest Neighbors Summary")
595
  out_bar_input= gr.Plot(label="Input Feature Counts")
596
- out_bar_label= gr.Plot(label="Predicted Label Counts (Stacked Bar: 0 & 1)")
597
 
598
  # Connect the predict button to the predict function
599
  predict_btn.click(
@@ -613,7 +600,7 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
613
  with gr.Tab("Distribution/Co-occurrence"):
614
  gr.Markdown("### Select 1 or 2 features + 1 label to see a bar chart.")
615
 
616
- # Show only your input features
617
  list_of_features = sorted(input_mapping.keys())
618
  # Show all label columns from the predictor map
619
  list_of_labels = sorted(predictor.prediction_map.keys())
@@ -637,4 +624,4 @@ with gr.Blocks(css=".gradio-container {max-width: 1200px;}") as demo:
637
  )
638
 
639
  # Finally, launch the Gradio app
640
- demo.launch()
 
389
  )
390
  fig_in.update_layout(width=1200, height=400)
391
 
392
+ # 8) Bar chart for predicted labels
393
+ label_counts = {}
394
  for lbl_col, (pred_val, _) in label_prediction_info.items():
395
  if lbl_col in df.columns:
396
+ label_counts[lbl_col] = len(df[df[lbl_col] == pred_val])
397
+ if label_counts:
398
+ bar_lbl_df = pd.DataFrame({
399
+ "Label": list(label_counts.keys()),
400
+ "Count": list(label_counts.values())
401
+ })
 
 
 
 
 
 
 
402
  fig_lbl = px.bar(
403
+ bar_lbl_df, x="Label", y="Count",
404
+ title="Number of Patients with the Same Predicted Label"
 
 
 
 
 
 
 
 
 
405
  )
406
+ fig_lbl.update_layout(width=1200, height=400)
407
  else:
408
  fig_lbl = px.bar(title="No valid predicted labels to display.")
409
  fig_lbl.update_layout(width=1200, height=400)
 
414
  total_count_md, # 3) Total Patient Count
415
  nn_md, # 4) Nearest Neighbors Summary
416
  fig_in, # 5) Bar Chart (input features)
417
+ fig_lbl # 6) Bar Chart (labels)
418
  )
419
 
420
  ######################################
 
425
  If user picks 1 feature => distribution plot.
426
  If user picks 2 features => co-occurrence plot.
427
  Otherwise => show error or empty plot.
428
+
429
  This function also maps numeric codes to text using 'input_mapping'
430
  and 'predictor.prediction_map' so that the plots display more readable labels.
431
  """
 
438
  # A) Convert numeric codes -> text for each feature in `input_mapping`
439
  for col, text_to_num_dict in input_mapping.items():
440
  if col in df_copy.columns:
441
+ # Reverse mapping: "Yes"->1 becomes 1->"Yes"
442
  num_to_text = {v: k for k, v in text_to_num_dict.items()}
443
  df_copy[col] = df_copy[col].map(num_to_text).fillna(df_copy[col])
444
 
445
+ # B) Convert label 0/1 to text in df_copy if label_col is in predictor.prediction_map
446
  if label_col in predictor.prediction_map and label_col in df_copy.columns:
447
  zero_text, one_text = predictor.prediction_map[label_col]
448
  label_map = {0: zero_text, 1: one_text}
 
501
  ("YMDERSUD5ANY", "YMDERSUD5ANY: MDE or SUD in past year?"),
502
  ("YMSUD5YANY", "YMSUD5YANY: Past-year MDE & substance use disorder"),
503
  ("YMIUD5YANY", "YMIUD5YANY: Past-year MDE & illicit drug use disorder"),
504
+ ("YMIMS5YANY", "YMIMS5YANY: Past-year MDE + severe impairment + substance use"),
505
+ ("YMIMI5YANY", "YMIMI5YANY: Past-year MDE w/ severe impairment & illicit drug use")
506
  ]
507
  cat1_inputs = []
508
  for col, label_text in cat1_col_labels:
 
580
  out_count = gr.Markdown(label="Total Patient Count")
581
  out_nn = gr.Markdown(label="Nearest Neighbors Summary")
582
  out_bar_input= gr.Plot(label="Input Feature Counts")
583
+ out_bar_label= gr.Plot(label="Predicted Label Counts")
584
 
585
  # Connect the predict button to the predict function
586
  predict_btn.click(
 
600
  with gr.Tab("Distribution/Co-occurrence"):
601
  gr.Markdown("### Select 1 or 2 features + 1 label to see a bar chart.")
602
 
603
+ # Show only your 25 input features
604
  list_of_features = sorted(input_mapping.keys())
605
  # Show all label columns from the predictor map
606
  list_of_labels = sorted(predictor.prediction_map.keys())
 
624
  )
625
 
626
  # Finally, launch the Gradio app
627
+ demo.launch()