File size: 1,393 Bytes
97dcf92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import numpy as np
from lime.lime_image import LimeImageExplainer
from PIL import Image
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from configs import *


model = MODEL.to(DEVICE)
model.load_state_dict(torch.load(MODEL_SAVE_PATH))
model.eval()

# Load the image
image = Image.open(
    r"data\test\Task 1\Healthy\0a7259b2-e650-43aa-93a0-e8b1063476fc.png"
).convert("RGB")
image = preprocess(image)
image = image.unsqueeze(0)  # Add batch dimension
image = image.to(DEVICE)


# Define a function to predict with the model
def predict(input_image):
    input_image = torch.tensor(input_image, dtype=torch.float32)
    if input_image.dim() == 4:
        input_image = input_image.permute(0, 3, 1, 2)  # Permute the dimensions
    input_image = input_image.to(DEVICE)  # Move to the appropriate device
    with torch.no_grad():
        output = model(input_image)
    return output


# Create the LIME explainer
explainer = LimeImageExplainer()

# Explain the model's predictions for the image
explanation = explainer.explain_instance(
    image[0].permute(1, 2, 0).numpy(), predict, top_labels=5, num_samples=2000
)

# Get the image and mask for the explanation
image, mask = explanation.get_image_and_mask(
    explanation.top_labels[0], positive_only=False, num_features=5, hide_rest=False
)

# Display the explanation
plt.imshow(image)
plt.show()