SpiralSense / eval.py
cycool29's picture
Update
97dcf92
raw
history blame
7.1 kB
import os
import torch
import numpy as np
import pathlib
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.metrics import (
classification_report,
precision_recall_curve,
accuracy_score,
f1_score,
confusion_matrix,
ConfusionMatrixDisplay,
)
from sklearn.preprocessing import label_binarize
from torchvision import transforms
from configs import *
# EfficientNet: 0.901978973407545
# MobileNet: 0.8731189445475158
# SquuezeNet: 0.8559218559218559
# Constants
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
NUM_AUGMENTATIONS = 10 # Number of augmentations to perform
model2 = EfficientNetB2WithDropout(num_classes=NUM_CLASSES).to(DEVICE)
model2.load_state_dict(torch.load("output/checkpoints/EfficientNetB2WithDropout.pth"))
model1 = SqueezeNet1_0WithSE(num_classes=NUM_CLASSES).to(DEVICE)
model1.load_state_dict(torch.load("output/checkpoints/SqueezeNet1_0WithSE.pth"))
model3 = MobileNetV2WithDropout(num_classes=NUM_CLASSES).to(DEVICE)
model3.load_state_dict(torch.load("output\checkpoints\MobileNetV2WithDropout.pth"))
best_weights = [0.901978973407545, 0.8731189445475158, 0.8559218559218559]
# Load the model
model = WeightedVoteEnsemble([model1, model2, model3], best_weights)
# model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
model.load_state_dict(torch.load('output/checkpoints/WeightedVoteEnsemble.pth', map_location=DEVICE))
model.eval()
# define augmentations for TTA
tta_transforms = transforms.Compose(
[
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.5),
]
)
def perform_tta(model, image, tta_transforms):
augmented_predictions = []
augmented_scores = []
for _ in range(NUM_AUGMENTATIONS):
augmented_image = tta_transforms(image)
output = model(augmented_image)
predicted_class = torch.argmax(output, dim=1).item()
augmented_predictions.append(predicted_class)
augmented_scores.append(output.softmax(dim=1).cpu().numpy())
# max voting
final_predicted_class_max = max(
set(augmented_predictions), key=augmented_predictions.count
)
# average probabilities
final_predicted_scores_avg = np.mean(np.array(augmented_scores), axis=0)
# rotate and average probabilities
rotation_transforms = [
transforms.RandomRotation(degrees=i) for i in range(0, 360, 30)
]
rotated_scores = []
for rotation_transform in rotation_transforms:
augmented_image = rotation_transform(image)
output = model(augmented_image)
rotated_scores.append(output.softmax(dim=1).cpu().numpy())
final_predicted_scores_rotation = np.mean(np.array(rotated_scores), axis=0)
return (
final_predicted_class_max,
final_predicted_scores_avg,
final_predicted_scores_rotation,
)
def predict_image_with_tta(image_path, model, transform, tta_transforms):
model.eval()
correct_predictions = 0
true_classes = []
predicted_labels_max = []
predicted_labels_avg = []
predicted_labels_rotation = []
with torch.no_grad():
images = list(pathlib.Path(image_path).rglob("*.png"))
total_predictions = len(images)
for image_file in images:
true_class = CLASSES.index(image_file.parts[-2])
original_image = Image.open(image_file).convert("RGB")
original_image = transform(original_image).unsqueeze(0)
original_image = original_image.to(DEVICE)
# Perform TTA with different strategies
final_predicted_class_max, _, _ = perform_tta(
model, original_image, tta_transforms
)
_, final_predicted_scores_avg, _ = perform_tta(
model, original_image, tta_transforms
)
_, _, final_predicted_scores_rotation = perform_tta(
model, original_image, tta_transforms
)
true_classes.append(true_class)
predicted_labels_max.append(final_predicted_class_max)
predicted_labels_avg.append(np.argmax(final_predicted_scores_avg))
predicted_labels_rotation.append(np.argmax(final_predicted_scores_rotation))
if final_predicted_class_max == true_class:
correct_predictions += 1
# accuracy for each strategy
accuracy_max = accuracy_score(true_classes, predicted_labels_max)
accuracy_avg = accuracy_score(true_classes, predicted_labels_avg)
accuracy_rotation = accuracy_score(true_classes, predicted_labels_rotation)
print("Accuracy (Max Voting):", accuracy_max)
print("Accuracy (Average Probabilities):", accuracy_avg)
print("Accuracy (Rotation and Average):", accuracy_rotation)
# final prediction using ensemble (choose the strategy with the highest accuracy)
final_predicted_labels = []
for i in range(len(true_classes)):
max_strategy_accuracy = max(accuracy_max, accuracy_avg, accuracy_rotation)
if accuracy_max == max_strategy_accuracy:
final_predicted_labels.append(predicted_labels_max[i])
elif accuracy_avg == max_strategy_accuracy:
final_predicted_labels.append(predicted_labels_avg[i])
else:
final_predicted_labels.append(predicted_labels_rotation[i])
# calculate accuracy and f1 score(ensemble)
accuracy_ensemble = accuracy_score(true_classes, final_predicted_labels)
f1_ensemble = f1_score(true_classes, final_predicted_labels, average="weighted")
print("Ensemble Accuracy:", accuracy_ensemble)
print("Ensemble Weighted F1 Score:", f1_ensemble)
# Classification report
class_names = [str(cls) for cls in range(NUM_CLASSES)]
report = classification_report(
true_classes, final_predicted_labels, target_names=class_names
)
print("Classification Report of", MODEL.__class__.__name__, ":\n", report)
# confusion matrix and classification report for the ensemble
conf_matrix_ensemble = confusion_matrix(true_classes, final_predicted_labels)
ConfusionMatrixDisplay(
confusion_matrix=conf_matrix_ensemble, display_labels=range(NUM_CLASSES)
).plot(cmap=plt.cm.Blues)
plt.title("Confusion Matrix (Ensemble)")
plt.show()
class_names = [str(cls) for cls in range(NUM_CLASSES)]
report_ensemble = classification_report(
true_classes, final_predicted_labels, target_names=class_names
)
print("Classification Report (Ensemble):\n", report_ensemble)
# Calculate precision and recall for each class
true_classes_binary = label_binarize(true_classes, classes=range(NUM_CLASSES))
precision, recall, _ = precision_recall_curve(
true_classes_binary.ravel(), np.array(final_predicted_scores_rotation).ravel()
)
# Plot precision-recall curve
plt.figure(figsize=(10, 6))
plt.plot(recall, precision)
plt.title("Precision-Recall Curve")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.show()
predict_image_with_tta("data/test/Task 1/", model, preprocess, tta_transforms)