File size: 4,136 Bytes
73666ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from pytorch_grad_cam import GradCAMPlusPlus
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
import cv2
import numpy as np
import torch
import time
import torch.nn as nn  # Replace with your model
from configs import *
import os, random

# Load your model (replace with your model class)
model = MODEL.to(DEVICE)
model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
model.eval()

# Find the target layer (modify this based on your model architecture)
target_layer = None
for child in model.features[-1]:
    if isinstance(child, nn.Conv2d):
        target_layer = child

if target_layer is None:
    raise ValueError("Invalid layer name: {}".format(target_layer))

print(target_layer)

def extract_gradcam(image_path=None, save_path=None):
    if image_path is None:
        for disease in CLASSES:
            print("Processing", disease)
            for image_path in os.listdir(r"data\test\Task 1\{}".format(disease)):
                print("Processing", image_path)
                image_path = r"data\test\Task 1\{}\{}".format(disease, image_path)
                image_name = image_path.split(".")[0].split("\\")[-1]
                rgb_img = cv2.imread(image_path, 1)
                rgb_img = np.float32(rgb_img) / 255
                input_tensor = preprocess_image(
                    rgb_img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]
                )
                input_tensor = input_tensor.to(DEVICE)
                input_tensor.requires_grad = True

                # Create a GradCAMPlusPlus object
                cam = GradCAMPlusPlus(
                    model=model, target_layers=[target_layer], use_cuda=True
                )

                # Generate the GradCAM heatmap
                grayscale_cam = cam(input_tensor=input_tensor)[0]

                # Apply a colormap to the grayscale heatmap
                heatmap_colored = cv2.applyColorMap(
                    np.uint8(255 * grayscale_cam), cv2.COLORMAP_JET
                )

                # Ensure heatmap_colored has the same dtype as rgb_img
                heatmap_colored = heatmap_colored.astype(np.float32) / 255

                # Adjust the alpha value to control transparency
                alpha = 0.3  # You can change this value to make the original image more or less transparent

                # Overlay the colored heatmap on the original image
                final_output = cv2.addWeighted(rgb_img, 0.3, heatmap_colored, 0.7, 0)

                # Save the final output
                os.makedirs(f"docs/evaluation/gradcam/{disease}", exist_ok=True)
                cv2.imwrite(
                    f"docs/evaluation/gradcam/{disease}/{image_name}.jpg",
                    (final_output * 255).astype(np.uint8),
                )
    else:
        rgb_img = cv2.imread(image_path, 1)
        rgb_img = np.float32(rgb_img) / 255
        input_tensor = preprocess_image(
            rgb_img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]
        )
        input_tensor = input_tensor.to(DEVICE)
        input_tensor.requires_grad = True

        # Create a GradCAMPlusPlus object
        cam = GradCAMPlusPlus(model=model, target_layers=[target_layer])

        # Generate the GradCAM heatmap
        grayscale_cam = cam(input_tensor=input_tensor)[0]

        # Apply a colormap to the grayscale heatmap
        heatmap_colored = cv2.applyColorMap(
            np.uint8(255 * grayscale_cam), cv2.COLORMAP_JET
        )

        # Ensure heatmap_colored has the same dtype as rgb_img
        heatmap_colored = heatmap_colored.astype(np.float32) / 255

        # Adjust the alpha value to control transparency
        alpha = 0.3  # You can change this value to make the original image more or less transparent

        # Overlay the colored heatmap on the original image
        final_output = cv2.addWeighted(rgb_img, 0.3, heatmap_colored, 0.7, 0)

        # Save the final output
        cv2.imwrite(save_path, (final_output * 255).astype(np.uint8))

        return save_path

# start = time.time()
# extract_gradcam()
# end = time.time()
# print("Time taken:", end - start)