Spaces:
Runtime error
Runtime error
import os | |
import torch | |
from torchvision.transforms import transforms | |
from sklearn.metrics import f1_score | |
from handetect.models import * | |
import pathlib | |
from PIL import Image | |
from torchmetrics import ConfusionMatrix | |
import matplotlib.pyplot as plt | |
from handetect.configs import * | |
image_path = "data/test/Task 1/" | |
# constants | |
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
NUM_CLASSES = 6 | |
# load the model | |
images = list(pathlib.Path(image_path).rglob("*.png")) | |
classes = os.listdir(image_path) | |
print(images) | |
true_classs = [] | |
predicted_labels = [] | |
MODEL.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE)) | |
MODEL.eval() | |
MODEL = MODEL.to(DEVICE) | |
# Define transformation for preprocessing | |
preprocess = transforms.Compose( | |
[ | |
transforms.Resize((64, 64)), # Resize images to 64x64 | |
transforms.Grayscale(num_output_channels=3), # Convert to grayscale | |
transforms.ToTensor(), # Convert to tensor | |
transforms.Normalize((0.5,), (0.5,)), # Normalize (for grayscale) | |
] | |
) | |
# evaluate the model | |
all_predictions = [] | |
true_labels = [] | |
def predict_image(image_path, model, transform): | |
model.eval() | |
correct_predictions = 0 | |
total_predictions = len(images) | |
with torch.no_grad(): | |
for i in images: | |
print('---------------------------') | |
# Check the true label of the image by checking the sequence of the folder in Task 1 | |
true_class = classes.index(i.parts[-2]) | |
print("Image path:", i) | |
print("True class:", true_class) | |
image = Image.open(i) | |
image = transform(image).unsqueeze(0) | |
image = image.to(DEVICE) | |
output = model(image) | |
predicted_class = torch.argmax(output, dim=1).item() | |
# Print the predicted class | |
print("Predicted class:", predicted_class) | |
# Append true and predicted labels to their respective lists | |
true_classs.append(true_class) | |
predicted_labels.append(predicted_class) | |
# Check if the prediction is correct | |
if predicted_class == true_class: | |
correct_predictions += 1 | |
# Calculate accuracy and f1 socre | |
accuracy = correct_predictions / total_predictions | |
print("Accuracy:", accuracy) | |
f1 = f1_score(true_classs, predicted_labels, average='weighted') | |
print("Weighted F1 Score:", f1) | |
# Call predict_image function | |
predict_image(image_path, MODEL, preprocess) | |
# Convert the lists to tensors | |
predicted_labels_tensor = torch.tensor(predicted_labels) | |
true_classs_tensor = torch.tensor(true_classs) | |
conf_matrix = ConfusionMatrix(num_classes=NUM_CLASSES, task='multiclass') | |
conf_matrix.update(predicted_labels_tensor, true_classs_tensor) | |
# Plot confusion matrix | |
conf_matrix.plot() | |
plt.show() |