color , status, probablity sum added
Browse files
app.py
CHANGED
@@ -83,9 +83,6 @@ class ModelPredictor:
|
|
83 |
|
84 |
# If model can do predict_proba, we interpret the "2" class as the second column
|
85 |
if hasattr(model, "predict_proba"):
|
86 |
-
# Usually classes_ might be something like [0,1], but we want [1,2].
|
87 |
-
# You may need to check model.classes_ or do a small transformation.
|
88 |
-
# For demonstration, we'll assume the second column is the "2" probability:
|
89 |
y_prob_2 = model.predict_proba(user_input)[:, 1]
|
90 |
probs.append(y_prob_2)
|
91 |
else:
|
@@ -129,8 +126,7 @@ categories_dict = {
|
|
129 |
]
|
130 |
}
|
131 |
|
132 |
-
# NOTE: input_mapping below for capturing user choices => numeric codes.
|
133 |
-
# If the models expect [1,2] for "Yes"/"No", ensure those are correct here too.
|
134 |
input_mapping = {
|
135 |
'YMDESUD5ANYO': {
|
136 |
"SUD only, no MDE": 1,
|
@@ -330,21 +326,24 @@ def predict(
|
|
330 |
count_ones = np.sum(all_preds == 1)
|
331 |
|
332 |
# Evaluate severity using count_ones
|
333 |
-
|
334 |
-
|
335 |
-
#
|
336 |
-
#
|
|
|
|
|
|
|
|
|
337 |
sum_prob_2 = sum(prob[0] for prob in probs if not np.isnan(prob[0]))
|
338 |
-
|
339 |
-
severity_msg = f"{
|
340 |
|
341 |
# 4) Summarize predictions (with probabilities)
|
342 |
-
# Build label -> (pred_value, prob_value)
|
343 |
label_prediction_info = {}
|
344 |
for i, fname in enumerate(model_filenames):
|
345 |
lbl_col = fname.split('.')[0]
|
346 |
pred_val = preds[i][0] # e.g. 1 or 2
|
347 |
-
prob_val = probs[i][0] # probability for class=2
|
348 |
label_prediction_info[lbl_col] = (pred_val, prob_val)
|
349 |
|
350 |
# Group them by domain
|
@@ -379,7 +378,7 @@ def predict(
|
|
379 |
icon = "❌" # red cross
|
380 |
|
381 |
if not np.isnan(prob_val):
|
382 |
-
text_prob = f"(Prob= {prob_val:.2f})"
|
383 |
else:
|
384 |
text_prob = "(No probability available)"
|
385 |
|
@@ -426,14 +425,22 @@ def predict(
|
|
426 |
"Count": list(label_counts.values()),
|
427 |
"Pred_Val": [label_prediction_info[lbl_col][0] for lbl_col in label_counts.keys()]
|
428 |
})
|
429 |
-
# Assign
|
430 |
-
|
|
|
|
|
|
|
|
|
|
|
431 |
fig_lbl = px.bar(
|
432 |
bar_lbl_df,
|
433 |
x="Label",
|
434 |
y="Count",
|
435 |
-
color="
|
436 |
-
color_discrete_map={
|
|
|
|
|
|
|
437 |
title="Number of Patients with the Same Predicted Label"
|
438 |
)
|
439 |
fig_lbl.update_layout(width=1200, height=400)
|
@@ -443,7 +450,7 @@ def predict(
|
|
443 |
|
444 |
return (
|
445 |
final_str, # 1) Prediction Results
|
446 |
-
severity_msg, # 2) Mental Health Severity
|
447 |
total_count_md, # 3) Total Patient Count
|
448 |
nn_md, # 4) Nearest Neighbors Summary
|
449 |
fig_in, # 5) Bar Chart (input features)
|
|
|
83 |
|
84 |
# If model can do predict_proba, we interpret the "2" class as the second column
|
85 |
if hasattr(model, "predict_proba"):
|
|
|
|
|
|
|
86 |
y_prob_2 = model.predict_proba(user_input)[:, 1]
|
87 |
probs.append(y_prob_2)
|
88 |
else:
|
|
|
126 |
]
|
127 |
}
|
128 |
|
129 |
+
# NOTE: input_mapping below for capturing user choices => numeric codes.
|
|
|
130 |
input_mapping = {
|
131 |
'YMDESUD5ANYO': {
|
132 |
"SUD only, no MDE": 1,
|
|
|
326 |
count_ones = np.sum(all_preds == 1)
|
327 |
|
328 |
# Evaluate severity using count_ones
|
329 |
+
severity_base = predictor.evaluate_severity(count_ones)
|
330 |
+
|
331 |
+
# -------------------------------
|
332 |
+
# Sum of predicted probabilities
|
333 |
+
# -------------------------------
|
334 |
+
# 'probs' is a list of arrays; each array is the prob for class=2 from each model.
|
335 |
+
# If classes_=[1,2], then prob[0] is P(class=2). Probability of class=1 is (1 - P(class=2)).
|
336 |
+
# We'll sum them:
|
337 |
sum_prob_2 = sum(prob[0] for prob in probs if not np.isnan(prob[0]))
|
338 |
+
sum_prob_1 = sum((1 - prob[0]) for prob in probs if not np.isnan(prob[0]))
|
339 |
+
severity_msg = f"{severity_base} (Sum of Prob=1={sum_prob_1:.2f}, Prob=2={sum_prob_2:.2f})"
|
340 |
|
341 |
# 4) Summarize predictions (with probabilities)
|
|
|
342 |
label_prediction_info = {}
|
343 |
for i, fname in enumerate(model_filenames):
|
344 |
lbl_col = fname.split('.')[0]
|
345 |
pred_val = preds[i][0] # e.g. 1 or 2
|
346 |
+
prob_val = probs[i][0] # probability for class=2
|
347 |
label_prediction_info[lbl_col] = (pred_val, prob_val)
|
348 |
|
349 |
# Group them by domain
|
|
|
378 |
icon = "❌" # red cross
|
379 |
|
380 |
if not np.isnan(prob_val):
|
381 |
+
text_prob = f"(Prob= {prob_val:.2f} for class=2)"
|
382 |
else:
|
383 |
text_prob = "(No probability available)"
|
384 |
|
|
|
425 |
"Count": list(label_counts.values()),
|
426 |
"Pred_Val": [label_prediction_info[lbl_col][0] for lbl_col in label_counts.keys()]
|
427 |
})
|
428 |
+
# Assign legend text & color based on predicted value
|
429 |
+
# - 2 => "Ok Mental Status" (green)
|
430 |
+
# - 1 => "Bad Mental Status" (red)
|
431 |
+
bar_lbl_df["Mental Status"] = bar_lbl_df["Pred_Val"].apply(
|
432 |
+
lambda x: "Ok Mental Status" if x == 2 else "Bad Mental Status"
|
433 |
+
)
|
434 |
+
|
435 |
fig_lbl = px.bar(
|
436 |
bar_lbl_df,
|
437 |
x="Label",
|
438 |
y="Count",
|
439 |
+
color="Mental Status",
|
440 |
+
color_discrete_map={
|
441 |
+
"Ok Mental Status": "green",
|
442 |
+
"Bad Mental Status": "red"
|
443 |
+
},
|
444 |
title="Number of Patients with the Same Predicted Label"
|
445 |
)
|
446 |
fig_lbl.update_layout(width=1200, height=400)
|
|
|
450 |
|
451 |
return (
|
452 |
final_str, # 1) Prediction Results
|
453 |
+
severity_msg, # 2) Mental Health Severity (with sums of prob=1 & prob=2)
|
454 |
total_count_md, # 3) Total Patient Count
|
455 |
nn_md, # 4) Nearest Neighbors Summary
|
456 |
fig_in, # 5) Bar Chart (input features)
|