import torch import numpy as np import pathlib from PIL import Image import matplotlib.pyplot as plt from matplotlib import rcParams from sklearn.metrics import ( classification_report, precision_recall_curve, accuracy_score, f1_score, confusion_matrix, matthews_corrcoef, ConfusionMatrixDisplay, roc_curve, auc, average_precision_score, cohen_kappa_score, ) from sklearn.preprocessing import label_binarize from configs import * rcParams["font.family"] = "Times New Roman" # Load the model model = MODEL.to(DEVICE) model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE)) model.eval() # model2 = EfficientNetB3WithDropout(num_classes=NUM_CLASSES).to(DEVICE) # model2.load_state_dict(torch.load("output/checkpoints/EfficientNetB3WithDropout.pth")) # model1 = SqueezeNet1_0WithSE(num_classes=NUM_CLASSES).to(DEVICE) # model1.load_state_dict(torch.load("output/checkpoints/SqueezeNet1_0WithSE.pth")) # model3 = MobileNetV2WithDropout(num_classes=NUM_CLASSES).to(DEVICE) # model3.load_state_dict(torch.load("output\checkpoints\MobileNetV2WithDropout.pth")) # model1.eval() # model2.eval() # model3.eval() # # Load the model # model = WeightedVoteEnsemble([model1, model2, model3], [0.38, 0.34, 0.28]) # # model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE)) # model.load_state_dict( # torch.load("output/checkpoints/WeightedVoteEnsemble.pth", map_location=DEVICE) # ) # model.eval() def predict_image(image_path, model, transform): model.eval() correct_predictions = 0 # Get a list of image files images = list(pathlib.Path(image_path).rglob("*.png")) total_predictions = len(images) true_classes = [] predicted_labels = [] predicted_scores = [] # To store predicted class probabilities with torch.no_grad(): for image_file in images: print("---------------------------") # Check the true label of the image by checking the sequence of the folder in Task 1 true_class = CLASSES.index(image_file.parts[-2]) print("Image path:", image_file) print("True class:", true_class) image = Image.open(image_file).convert("RGB") image = transform(image).unsqueeze(0) image = image.to(DEVICE) output = model(image) predicted_class = torch.argmax(output, dim=1).item() # Print the predicted class print("Predicted class:", predicted_class) # Append true and predicted labels to their respective lists true_classes.append(true_class) predicted_labels.append(predicted_class) predicted_scores.append( output.softmax(dim=1).cpu().numpy() ) # Store predicted class probabilities # Check if the prediction is correct if predicted_class == true_class: correct_predictions += 1 # Calculate accuracy and f1 score accuracy = accuracy_score(true_classes, predicted_labels) print("Accuracy:", accuracy) f1 = f1_score(true_classes, predicted_labels, average="weighted") print("Weighted F1 Score:", f1) # Convert the lists to tensors predicted_labels_tensor = torch.tensor(predicted_labels) true_classes_tensor = torch.tensor(true_classes) # Calculate the confusion matrix conf_matrix = confusion_matrix( true_classes, predicted_labels, ) # Plot the confusion matrix ConfusionMatrixDisplay(confusion_matrix=conf_matrix, display_labels=CLASSES).plot( cmap=plt.cm.Blues, xticks_rotation=25 ) # Use the exported value of margin_left to adjust the space between the yticklabels and the yticks plt.subplots_adjust( top=0.935, bottom=0.155, left=0.125, right=0.905, hspace=0.2, wspace=0.2, ) plt.title("Confusion Matrix") manager = plt.get_current_fig_manager() manager.full_screen_toggle() plt.savefig("docs/evaluation/confusion_matrix.png") plt.show() # Classification report class_names = CLASSES report = classification_report( true_classes, predicted_labels, target_names=class_names ) print("Classification Report:\n", report) # Calculate precision and recall for each class true_classes_binary = label_binarize(true_classes, classes=range(NUM_CLASSES)) precision, recall, _ = precision_recall_curve( true_classes_binary.ravel(), np.array(predicted_scores).ravel() ) fpr, tpr, _ = roc_curve( true_classes_binary.ravel(), np.array(predicted_scores).ravel() ) auc_roc = auc(fpr, tpr) print("AUC-ROC:", auc_roc) # Calculate PRC AUC precision, recall, _ = precision_recall_curve( true_classes_binary.ravel(), np.array(predicted_scores).ravel() ) auc_prc = average_precision_score( true_classes_binary.ravel(), np.array(predicted_scores).ravel() ) print("AUC PRC:", auc_prc) # Plot precision-recall curve plt.figure(figsize=(10, 6)) plt.plot(recall, precision) plt.title("Precision-Recall Curve") plt.xlabel("Recall") plt.ylabel("Precision") # Show the AUC value on the plot plt.text( 0.6, 0.2, "AUC-PRC = {:.3f}".format(auc_prc), bbox=dict(boxstyle="round", facecolor="white", alpha=0.8), ) plt.savefig("docs/evaluation/prc.png") plt.show() # Plot ROC curve plt.figure(figsize=(10, 6)) plt.plot(fpr, tpr) plt.title("ROC Curve") plt.xlabel("False Positive Rate") plt.ylabel("True Positive Rate") # Show the AUC value on the plot plt.text( 0.6, 0.2, "AUC-ROC = {:.3f}".format(auc_roc), bbox=dict(boxstyle="round", facecolor="white", alpha=0.8), ) plt.savefig("docs/evaluation/roc.png") plt.show() # Matthew's correlation coefficient print("Matthew's correlation coefficient:", matthews_corrcoef(true_classes, predicted_labels)) # Cohen's kappa print("Cohen's kappa:", cohen_kappa_score(true_classes, predicted_labels)) predict_image("data/test/Task 1/", model, preprocess) # 89 EfficientNetB2WithDropout / 0.873118944547516 # 89 MobileNetV2WithDropout / 0.8731189445475158 # 89 SqueezeNet1_0WithSE / .8865856365856365