File size: 5,886 Bytes
97dcf92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import os
import torchvision
import shap
import torch
import numpy as np
import pathlib
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib import rcParams
from sklearn.metrics import (
    classification_report,
    precision_recall_curve,
    accuracy_score,
    f1_score,
    confusion_matrix,
    ConfusionMatrixDisplay,
    roc_curve,
    auc,
    average_precision_score,
)
from sklearn.preprocessing import label_binarize
from configs import *
from data_loader import load_data  # Import the load_data function

# MobileNet: 0.8731189445475158
# EfficientNet: 0.873118944547516
# SquuezeNet: 0.8865856365856365


rcParams['font.family'] = 'Times New Roman'

# Load the model
model = MODEL.to(DEVICE)
model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
model.eval()

# model2 = EfficientNetB3WithDropout(num_classes=NUM_CLASSES).to(DEVICE)
# model2.load_state_dict(torch.load("output/checkpoints/EfficientNetB3WithDropout.pth"))
# model1 = SqueezeNet1_0WithSE(num_classes=NUM_CLASSES).to(DEVICE)
# model1.load_state_dict(torch.load("output/checkpoints/SqueezeNet1_0WithSE.pth"))
# model3 = MobileNetV2WithDropout(num_classes=NUM_CLASSES).to(DEVICE)
# model3.load_state_dict(torch.load("output\checkpoints\MobileNetV2WithDropout.pth"))

# model1.eval()
# model2.eval()
# model3.eval()

# # Load the model
# model = WeightedVoteEnsemble([model1, model2, model3], [0.38, 0.34, 0.28])
# # model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
# model.load_state_dict(
#     torch.load("output/checkpoints/WeightedVoteEnsemble.pth", map_location=DEVICE)
# )
# model.eval()


def predict_image(image_path, model, transform):
    model.eval()
    correct_predictions = 0

    # Get a list of image files
    images = list(pathlib.Path(image_path).rglob("*.png"))

    total_predictions = len(images)

    true_classes = []
    predicted_labels = []
    predicted_scores = []  # To store predicted class probabilities

    with torch.no_grad():
        for image_file in images:
            print("---------------------------")
            # Check the true label of the image by checking the sequence of the folder in Task 1
            true_class = CLASSES.index(image_file.parts[-2])
            print("Image path:", image_file)
            print("True class:", true_class)
            image = Image.open(image_file).convert("RGB")
            image = transform(image).unsqueeze(0)
            image = image.to(DEVICE)
            output = model(image)
            predicted_class = torch.argmax(output, dim=1).item()
            # Print the predicted class
            print("Predicted class:", predicted_class)
            # Append true and predicted labels to their respective lists
            true_classes.append(true_class)
            predicted_labels.append(predicted_class)
            predicted_scores.append(
                output.softmax(dim=1).cpu().numpy()
            )  # Store predicted class probabilities

            # Check if the prediction is correct
            if predicted_class == true_class:
                correct_predictions += 1

    # Calculate accuracy and f1 score
    accuracy = accuracy_score(true_classes, predicted_labels)
    print("Accuracy:", accuracy)
    f1 = f1_score(true_classes, predicted_labels, average="weighted")
    print("Weighted F1 Score:", f1)

    # Convert the lists to tensors
    predicted_labels_tensor = torch.tensor(predicted_labels)
    true_classes_tensor = torch.tensor(true_classes)

    # Calculate the confusion matrix
    conf_matrix = confusion_matrix(true_classes, predicted_labels)

    # Plot the confusion matrix
    ConfusionMatrixDisplay(
        confusion_matrix=conf_matrix, display_labels=range(NUM_CLASSES)
    ).plot(cmap=plt.cm.Blues)
    plt.title("Confusion Matrix")
    plt.show()

    # Classification report
    class_names = [str(cls) for cls in range(NUM_CLASSES)]
    report = classification_report(
        true_classes, predicted_labels, target_names=class_names
    )
    print("Classification Report:\n", report)

    # Calculate precision and recall for each class
    true_classes_binary = label_binarize(true_classes, classes=range(NUM_CLASSES))
    precision, recall, _ = precision_recall_curve(
        true_classes_binary.ravel(), np.array(predicted_scores).ravel()
    )

    fpr, tpr, _ = roc_curve(
        true_classes_binary.ravel(), np.array(predicted_scores).ravel()
    )
    auc_roc = auc(fpr, tpr)
    print("AUC-ROC:", auc_roc)

    # Calculate PRC AUC
    precision, recall, _ = precision_recall_curve(
        true_classes_binary.ravel(), np.array(predicted_scores).ravel()
    )
    auc_prc = average_precision_score(
        true_classes_binary.ravel(), np.array(predicted_scores).ravel()
    )
    print("AUC PRC:", auc_prc)

    # Plot precision-recall curve
    plt.figure(figsize=(10, 6))
    plt.plot(recall, precision)
    plt.title("Precision-Recall Curve")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    # Show the AUC value on the plot
    plt.text(
        0.6,
        0.2,
        "AUC-PRC = {:.3f}".format(auc_prc),
        bbox=dict(boxstyle="round", facecolor="white", alpha=0.8),
    )
    plt.savefig("docs/efficientnet/prc.png")
    plt.show()
    
    # Plot ROC curve
    plt.figure(figsize=(10, 6))
    plt.plot(fpr, tpr)
    plt.title("ROC Curve")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    # Show the AUC value on the plot
    plt.text(
        0.6,
        0.2,
        "AUC-ROC = {:.3f}".format(auc_roc),
        bbox=dict(boxstyle="round", facecolor="white", alpha=0.8),
    )
    plt.savefig("docs/efficientnet/roc.png")
    plt.show()


predict_image("data/test/Task 1/", model, preprocess)


# 89 EfficientNetB2WithDropout / 0.873118944547516
# 89 MobileNetV2WithDropout / 0.8731189445475158
# 89 SqueezeNet1_0WithSE / .8865856365856365