Spaces:
Runtime error
Runtime error
File size: 5,886 Bytes
97dcf92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
import os
import torchvision
import shap
import torch
import numpy as np
import pathlib
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib import rcParams
from sklearn.metrics import (
classification_report,
precision_recall_curve,
accuracy_score,
f1_score,
confusion_matrix,
ConfusionMatrixDisplay,
roc_curve,
auc,
average_precision_score,
)
from sklearn.preprocessing import label_binarize
from configs import *
from data_loader import load_data # Import the load_data function
# MobileNet: 0.8731189445475158
# EfficientNet: 0.873118944547516
# SquuezeNet: 0.8865856365856365
rcParams['font.family'] = 'Times New Roman'
# Load the model
model = MODEL.to(DEVICE)
model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
model.eval()
# model2 = EfficientNetB3WithDropout(num_classes=NUM_CLASSES).to(DEVICE)
# model2.load_state_dict(torch.load("output/checkpoints/EfficientNetB3WithDropout.pth"))
# model1 = SqueezeNet1_0WithSE(num_classes=NUM_CLASSES).to(DEVICE)
# model1.load_state_dict(torch.load("output/checkpoints/SqueezeNet1_0WithSE.pth"))
# model3 = MobileNetV2WithDropout(num_classes=NUM_CLASSES).to(DEVICE)
# model3.load_state_dict(torch.load("output\checkpoints\MobileNetV2WithDropout.pth"))
# model1.eval()
# model2.eval()
# model3.eval()
# # Load the model
# model = WeightedVoteEnsemble([model1, model2, model3], [0.38, 0.34, 0.28])
# # model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
# model.load_state_dict(
# torch.load("output/checkpoints/WeightedVoteEnsemble.pth", map_location=DEVICE)
# )
# model.eval()
def predict_image(image_path, model, transform):
model.eval()
correct_predictions = 0
# Get a list of image files
images = list(pathlib.Path(image_path).rglob("*.png"))
total_predictions = len(images)
true_classes = []
predicted_labels = []
predicted_scores = [] # To store predicted class probabilities
with torch.no_grad():
for image_file in images:
print("---------------------------")
# Check the true label of the image by checking the sequence of the folder in Task 1
true_class = CLASSES.index(image_file.parts[-2])
print("Image path:", image_file)
print("True class:", true_class)
image = Image.open(image_file).convert("RGB")
image = transform(image).unsqueeze(0)
image = image.to(DEVICE)
output = model(image)
predicted_class = torch.argmax(output, dim=1).item()
# Print the predicted class
print("Predicted class:", predicted_class)
# Append true and predicted labels to their respective lists
true_classes.append(true_class)
predicted_labels.append(predicted_class)
predicted_scores.append(
output.softmax(dim=1).cpu().numpy()
) # Store predicted class probabilities
# Check if the prediction is correct
if predicted_class == true_class:
correct_predictions += 1
# Calculate accuracy and f1 score
accuracy = accuracy_score(true_classes, predicted_labels)
print("Accuracy:", accuracy)
f1 = f1_score(true_classes, predicted_labels, average="weighted")
print("Weighted F1 Score:", f1)
# Convert the lists to tensors
predicted_labels_tensor = torch.tensor(predicted_labels)
true_classes_tensor = torch.tensor(true_classes)
# Calculate the confusion matrix
conf_matrix = confusion_matrix(true_classes, predicted_labels)
# Plot the confusion matrix
ConfusionMatrixDisplay(
confusion_matrix=conf_matrix, display_labels=range(NUM_CLASSES)
).plot(cmap=plt.cm.Blues)
plt.title("Confusion Matrix")
plt.show()
# Classification report
class_names = [str(cls) for cls in range(NUM_CLASSES)]
report = classification_report(
true_classes, predicted_labels, target_names=class_names
)
print("Classification Report:\n", report)
# Calculate precision and recall for each class
true_classes_binary = label_binarize(true_classes, classes=range(NUM_CLASSES))
precision, recall, _ = precision_recall_curve(
true_classes_binary.ravel(), np.array(predicted_scores).ravel()
)
fpr, tpr, _ = roc_curve(
true_classes_binary.ravel(), np.array(predicted_scores).ravel()
)
auc_roc = auc(fpr, tpr)
print("AUC-ROC:", auc_roc)
# Calculate PRC AUC
precision, recall, _ = precision_recall_curve(
true_classes_binary.ravel(), np.array(predicted_scores).ravel()
)
auc_prc = average_precision_score(
true_classes_binary.ravel(), np.array(predicted_scores).ravel()
)
print("AUC PRC:", auc_prc)
# Plot precision-recall curve
plt.figure(figsize=(10, 6))
plt.plot(recall, precision)
plt.title("Precision-Recall Curve")
plt.xlabel("Recall")
plt.ylabel("Precision")
# Show the AUC value on the plot
plt.text(
0.6,
0.2,
"AUC-PRC = {:.3f}".format(auc_prc),
bbox=dict(boxstyle="round", facecolor="white", alpha=0.8),
)
plt.savefig("docs/efficientnet/prc.png")
plt.show()
# Plot ROC curve
plt.figure(figsize=(10, 6))
plt.plot(fpr, tpr)
plt.title("ROC Curve")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
# Show the AUC value on the plot
plt.text(
0.6,
0.2,
"AUC-ROC = {:.3f}".format(auc_roc),
bbox=dict(boxstyle="round", facecolor="white", alpha=0.8),
)
plt.savefig("docs/efficientnet/roc.png")
plt.show()
predict_image("data/test/Task 1/", model, preprocess)
# 89 EfficientNetB2WithDropout / 0.873118944547516
# 89 MobileNetV2WithDropout / 0.8731189445475158
# 89 SqueezeNet1_0WithSE / .8865856365856365
|