Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import torch.nn as nn | |
from torchvision import transforms | |
from PIL import Image | |
from models import * # Make sure you import your model correctly from the 'models' module | |
from torchmetrics import ConfusionMatrix | |
import matplotlib.pyplot as plt | |
import pathlib | |
# Define the path to your model checkpoint | |
model_checkpoint_path = "model.pth" | |
# Define the path to the image you want to classify | |
image_path = "data/test/Task 1/" # Use forward slashes for file paths | |
# Define images variable to recursively list all the data file in the image_path | |
images = list(pathlib.Path(image_path).rglob("*.png")) | |
classes = os.listdir(image_path) | |
print(images) | |
true_classs = [] | |
predicted_labels = [] | |
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
NUM_CLASSES = 5 # Update with the correct number of classes | |
# Load your model (change this according to your model definition) | |
model = resnet18(pretrained=False, num_classes=NUM_CLASSES) | |
model.load_state_dict(torch.load(model_checkpoint_path, map_location=DEVICE)) # Load the model on the same device | |
model.eval() | |
model = model.to(DEVICE) | |
# Define transformation for preprocessing the input image | |
preprocess = transforms.Compose( | |
[ | |
transforms.Resize((64, 64)), # Resize the image to match training input size | |
transforms.Grayscale(num_output_channels=3), # Convert the image to grayscale | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # Normalize the image | |
] | |
) | |
def predict_image(image_path, model, transform): | |
model.eval() | |
correct_predictions = 0 | |
total_predictions = len(images) | |
with torch.no_grad(): | |
for i in images: | |
print('---------------------------') | |
# Check the true label of the image by checking the sequence of the folder in Task 1 | |
true_class = classes.index(i.parts[-2]) | |
print("Image path:", i) | |
print("True class:", true_class) | |
image = Image.open(i) | |
image = transform(image).unsqueeze(0) | |
image = image.to(DEVICE) | |
output = model(image) | |
# softmax algorithm | |
probabilities = torch.softmax(output, dim=1)[0] * 100 | |
predicted_class = torch.argmax(output, dim=1).item() | |
# Append true and predicted labels to their respective lists | |
true_classs.append(true_class) | |
predicted_labels.append(predicted_class) | |
# Check if the prediction is correct | |
if predicted_class == true_class: | |
correct_predictions += 1 | |
# Report the prediction | |
print("Predicted class:", predicted_class) | |
print("Probability:", probabilities[predicted_class].item()) | |
print("Predicted label:", classes[predicted_class]) | |
print("Correct predictions:", correct_predictions) | |
print("Correct?", "Yes" if predicted_class == true_class else "No") | |
print("---------------------------") | |
# Calculate accuracy | |
accuracy = correct_predictions / total_predictions | |
print("Accuracy:", accuracy) | |
# Call the predict_image function | |
predict_image(image_path, model, preprocess) | |
# Convert the lists to tensors | |
predicted_labels_tensor = torch.tensor(predicted_labels) | |
true_classs_tensor = torch.tensor(true_classs) | |
# Create confusion matrix | |
conf_matrix = ConfusionMatrix(num_classes=NUM_CLASSES, task='multiclass') | |
conf_matrix.update(predicted_labels_tensor, true_classs_tensor) | |
# Plot confusion matrix | |
conf_matrix.plot() | |
plt.show() | |