Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import numpy as np | |
import pathlib | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
from sklearn.metrics import ( | |
classification_report, | |
precision_recall_curve, | |
accuracy_score, | |
f1_score, | |
confusion_matrix, | |
ConfusionMatrixDisplay, | |
) | |
from sklearn.preprocessing import label_binarize | |
from torchvision import transforms | |
from configs import * | |
# EfficientNet: 0.901978973407545 | |
# MobileNet: 0.8731189445475158 | |
# SquuezeNet: 0.8559218559218559 | |
# Constants | |
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
NUM_AUGMENTATIONS = 10 # Number of augmentations to perform | |
model2 = EfficientNetB2WithDropout(num_classes=NUM_CLASSES).to(DEVICE) | |
model2.load_state_dict(torch.load("output/checkpoints/EfficientNetB2WithDropout.pth")) | |
model1 = SqueezeNet1_0WithSE(num_classes=NUM_CLASSES).to(DEVICE) | |
model1.load_state_dict(torch.load("output/checkpoints/SqueezeNet1_0WithSE.pth")) | |
model3 = MobileNetV2WithDropout(num_classes=NUM_CLASSES).to(DEVICE) | |
model3.load_state_dict(torch.load("output\checkpoints\MobileNetV2WithDropout.pth")) | |
best_weights = [0.901978973407545, 0.8731189445475158, 0.8559218559218559] | |
# Load the model | |
model = WeightedVoteEnsemble([model1, model2, model3], best_weights) | |
# model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE)) | |
model.load_state_dict(torch.load('output/checkpoints/WeightedVoteEnsemble.pth', map_location=DEVICE)) | |
model.eval() | |
# define augmentations for TTA | |
tta_transforms = transforms.Compose( | |
[ | |
transforms.RandomHorizontalFlip(p=0.5), | |
transforms.RandomVerticalFlip(p=0.5), | |
] | |
) | |
def perform_tta(model, image, tta_transforms): | |
augmented_predictions = [] | |
augmented_scores = [] | |
for _ in range(NUM_AUGMENTATIONS): | |
augmented_image = tta_transforms(image) | |
output = model(augmented_image) | |
predicted_class = torch.argmax(output, dim=1).item() | |
augmented_predictions.append(predicted_class) | |
augmented_scores.append(output.softmax(dim=1).cpu().numpy()) | |
# max voting | |
final_predicted_class_max = max( | |
set(augmented_predictions), key=augmented_predictions.count | |
) | |
# average probabilities | |
final_predicted_scores_avg = np.mean(np.array(augmented_scores), axis=0) | |
# rotate and average probabilities | |
rotation_transforms = [ | |
transforms.RandomRotation(degrees=i) for i in range(0, 360, 30) | |
] | |
rotated_scores = [] | |
for rotation_transform in rotation_transforms: | |
augmented_image = rotation_transform(image) | |
output = model(augmented_image) | |
rotated_scores.append(output.softmax(dim=1).cpu().numpy()) | |
final_predicted_scores_rotation = np.mean(np.array(rotated_scores), axis=0) | |
return ( | |
final_predicted_class_max, | |
final_predicted_scores_avg, | |
final_predicted_scores_rotation, | |
) | |
def predict_image_with_tta(image_path, model, transform, tta_transforms): | |
model.eval() | |
correct_predictions = 0 | |
true_classes = [] | |
predicted_labels_max = [] | |
predicted_labels_avg = [] | |
predicted_labels_rotation = [] | |
with torch.no_grad(): | |
images = list(pathlib.Path(image_path).rglob("*.png")) | |
total_predictions = len(images) | |
for image_file in images: | |
true_class = CLASSES.index(image_file.parts[-2]) | |
original_image = Image.open(image_file).convert("RGB") | |
original_image = transform(original_image).unsqueeze(0) | |
original_image = original_image.to(DEVICE) | |
# Perform TTA with different strategies | |
final_predicted_class_max, _, _ = perform_tta( | |
model, original_image, tta_transforms | |
) | |
_, final_predicted_scores_avg, _ = perform_tta( | |
model, original_image, tta_transforms | |
) | |
_, _, final_predicted_scores_rotation = perform_tta( | |
model, original_image, tta_transforms | |
) | |
true_classes.append(true_class) | |
predicted_labels_max.append(final_predicted_class_max) | |
predicted_labels_avg.append(np.argmax(final_predicted_scores_avg)) | |
predicted_labels_rotation.append(np.argmax(final_predicted_scores_rotation)) | |
if final_predicted_class_max == true_class: | |
correct_predictions += 1 | |
# accuracy for each strategy | |
accuracy_max = accuracy_score(true_classes, predicted_labels_max) | |
accuracy_avg = accuracy_score(true_classes, predicted_labels_avg) | |
accuracy_rotation = accuracy_score(true_classes, predicted_labels_rotation) | |
print("Accuracy (Max Voting):", accuracy_max) | |
print("Accuracy (Average Probabilities):", accuracy_avg) | |
print("Accuracy (Rotation and Average):", accuracy_rotation) | |
# final prediction using ensemble (choose the strategy with the highest accuracy) | |
final_predicted_labels = [] | |
for i in range(len(true_classes)): | |
max_strategy_accuracy = max(accuracy_max, accuracy_avg, accuracy_rotation) | |
if accuracy_max == max_strategy_accuracy: | |
final_predicted_labels.append(predicted_labels_max[i]) | |
elif accuracy_avg == max_strategy_accuracy: | |
final_predicted_labels.append(predicted_labels_avg[i]) | |
else: | |
final_predicted_labels.append(predicted_labels_rotation[i]) | |
# calculate accuracy and f1 score(ensemble) | |
accuracy_ensemble = accuracy_score(true_classes, final_predicted_labels) | |
f1_ensemble = f1_score(true_classes, final_predicted_labels, average="weighted") | |
print("Ensemble Accuracy:", accuracy_ensemble) | |
print("Ensemble Weighted F1 Score:", f1_ensemble) | |
# Classification report | |
class_names = [str(cls) for cls in range(NUM_CLASSES)] | |
report = classification_report( | |
true_classes, final_predicted_labels, target_names=class_names | |
) | |
print("Classification Report of", MODEL.__class__.__name__, ":\n", report) | |
# confusion matrix and classification report for the ensemble | |
conf_matrix_ensemble = confusion_matrix(true_classes, final_predicted_labels) | |
ConfusionMatrixDisplay( | |
confusion_matrix=conf_matrix_ensemble, display_labels=range(NUM_CLASSES) | |
).plot(cmap=plt.cm.Blues) | |
plt.title("Confusion Matrix (Ensemble)") | |
plt.show() | |
class_names = [str(cls) for cls in range(NUM_CLASSES)] | |
report_ensemble = classification_report( | |
true_classes, final_predicted_labels, target_names=class_names | |
) | |
print("Classification Report (Ensemble):\n", report_ensemble) | |
# Calculate precision and recall for each class | |
true_classes_binary = label_binarize(true_classes, classes=range(NUM_CLASSES)) | |
precision, recall, _ = precision_recall_curve( | |
true_classes_binary.ravel(), np.array(final_predicted_scores_rotation).ravel() | |
) | |
# Plot precision-recall curve | |
plt.figure(figsize=(10, 6)) | |
plt.plot(recall, precision) | |
plt.title("Precision-Recall Curve") | |
plt.xlabel("Recall") | |
plt.ylabel("Precision") | |
plt.show() | |
predict_image_with_tta("data/test/Task 1/", model, preprocess, tta_transforms) | |