Spaces:
Runtime error
Runtime error
File size: 4,531 Bytes
73666ad c45c4e1 73666ad c45c4e1 90a1747 c45c4e1 73666ad 90a1747 c45c4e1 90a1747 c45c4e1 73666ad c45c4e1 90a1747 73666ad 90a1747 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import os
import numpy as np
from lime.lime_image import LimeImageExplainer
from PIL import Image
import torch
import matplotlib.pyplot as plt
from configs import *
import time
model = MODEL.to(DEVICE)
model.load_state_dict(torch.load(MODEL_SAVE_PATH))
model.eval()
# Define a function to predict with the model
def predict(input_image):
input_image = torch.tensor(input_image, dtype=torch.float32)
if input_image.dim() == 4:
input_image = input_image.permute(0, 3, 1, 2) # Permute the dimensions
input_image = input_image.to(DEVICE) # Move to the appropriate device
with torch.no_grad():
output = model(input_image)
return output
def generate_lime(image_path=None, save_path=None):
if image_path is None:
for disease in CLASSES:
print("Processing", disease)
for image_path in os.listdir(r"data\test\Task 1\{}".format(disease)):
image = None
print("Processing", image_path)
image_path = r"data\test\Task 1\{}\{}".format(disease, image_path)
image_name = image_path.split(".")[0].split("\\")[-1]
image = Image.open(image_path).convert("RGB")
width, height = image.size
image = preprocess(image)
image = image.unsqueeze(0) # Add batch dimension
image = image.to(DEVICE)
# Create the LIME explainer
explainer = LimeImageExplainer()
# Explain the model's predictions for the image
explanation = explainer.explain_instance(
image[0].permute(1, 2, 0).numpy(),
predict,
top_labels=5,
num_samples=1000,
)
# Get the image and mask for the explanation
image, mask = explanation.get_image_and_mask(
explanation.top_labels[0],
positive_only=False,
num_features=10,
hide_rest=False,
)
# Save the image (dun use plt.imsave)
# Normalize the image to the [0, 1] range
# norm = Normalize(vmin=0, vmax=1)
# image = norm(image)
image = (image - np.min(image)) / (np.max(image) - np.min(image))
# image = Image.fromarray(image)
os.makedirs(f"docs/evaluation/lime/{disease}", exist_ok=True)
# image.save(f'docs/evaluation/lime/{disease}/{image_name}.jpg')
plt.imsave(f"docs/evaluation/lime/{disease}/{image_name}.jpg", image)
# Resize the image to the original size
image = Image.open(f"docs/evaluation/lime/{disease}/{image_name}.jpg")
image = image.resize((width, height))
image.save(f"docs/evaluation/lime/{disease}/{image_name}.jpg")
else:
image = None
print("Processing", image_path)
image = Image.open(image_path).convert("RGB")
width, height = image.size
image = preprocess(image)
image = image.unsqueeze(0) # Add batch dimension
image = image.to(DEVICE)
# Create the LIME explainer
explainer = LimeImageExplainer()
# Explain the model's predictions for the image
explanation = explainer.explain_instance(
image[0].permute(1, 2, 0).numpy(), predict, top_labels=5, num_samples=1000
)
# Get the image and mask for the explanation
image, mask = explanation.get_image_and_mask(
explanation.top_labels[0],
positive_only=False,
num_features=10,
hide_rest=False,
)
# Save the image (dun use plt.imsave)
# Normalize the image to the [0, 1] range
# norm = Normalize(vmin=0, vmax=1)
# image = norm(image)
image = (image - np.min(image)) / (np.max(image) - np.min(image))
# image = Image.fromarray(image)
# os.makedirs(f"docs/evaluation/lime/{disease}", exist_ok=True)
# image.save(f'docs/evaluation/lime/{disease}/{image_name}.jpg')
plt.imsave(save_path, image)
# Resize the image to the original size
image = Image.open(save_path)
image = image.resize((width, height))
image.save(save_path)
# start = time.time()
# generate_lime()
# end = time.time()
# print("Time taken:", end - start)
|