Spaces:
Runtime error
Runtime error
import os | |
import torch | |
from torchvision.transforms import transforms | |
import pathlib | |
from PIL import Image | |
from torchmetrics import ConfusionMatrix, Accuracy, F1Score | |
import matplotlib.pyplot as plt | |
from configs import * | |
from data_loader import load_data # Import the load_data function | |
image_path = "data/test/Task 1/" | |
# Constants | |
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
# Load the model | |
MODEL = MODEL.to(DEVICE) | |
MODEL.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE)) | |
MODEL.eval() | |
def predict_image(image_path, model, transform): | |
model.eval() | |
correct_predictions = 0 | |
# Get a list of image files | |
images = list(pathlib.Path(image_path).rglob("*.png")) | |
total_predictions = len(images) | |
true_classes = [] | |
predicted_labels = [] | |
accuracy_metric = Accuracy(num_classes=NUM_CLASSES, task="multiclass") | |
f1_metric = F1Score(num_classes=NUM_CLASSES, task="multiclass") | |
with torch.no_grad(): | |
for image_file in images: | |
print("---------------------------") | |
# Check the true label of the image by checking the sequence of the folder in Task 1 | |
true_class = CLASSES.index(image_file.parts[-2]) | |
print("Image path:", image_file) | |
print("True class:", true_class) | |
image = Image.open(image_file).convert("RGB") | |
image = transform(image).unsqueeze(0) | |
image = image.to(DEVICE) | |
output = model(image) | |
predicted_class = torch.argmax(output, dim=1).item() | |
# Print the predicted class | |
print("Predicted class:", predicted_class) | |
# Append true and predicted labels to their respective lists | |
true_classes.append(true_class) | |
predicted_labels.append(predicted_class) | |
# Check if the prediction is correct | |
if predicted_class == true_class: | |
correct_predictions += 1 | |
# Calculate accuracy and f1 score | |
accuracy = correct_predictions / total_predictions | |
print("Accuracy:", accuracy) | |
f1 = f1_metric(torch.tensor(predicted_labels), torch.tensor(true_classes)).item() | |
print("Weighted F1 Score:", f1) | |
# Convert the lists to tensors | |
predicted_labels_tensor = torch.tensor(predicted_labels) | |
true_classes_tensor = torch.tensor(true_classes) | |
# Create a confusion matrix | |
conf_matrix = ConfusionMatrix(num_classes=NUM_CLASSES, task="multiclass") | |
conf_matrix(predicted_labels_tensor, true_classes_tensor) | |
# Plot the confusion matrix | |
conf_matrix.compute() | |
conf_matrix.plot() | |
plt.show() | |
# Call predict_image function | |
predict_image(image_path, MODEL, preprocess) | |