from pytorch_grad_cam import GradCAMPlusPlus from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image import cv2 import numpy as np import torch import time import torch.nn as nn # Replace with your model from configs import * import os, random # Load your model (replace with your model class) model = MODEL.to(DEVICE) model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE)) model.eval() # Find the target layer (modify this based on your model architecture) target_layer = None for child in model.features[-1]: if isinstance(child, nn.Conv2d): target_layer = child if target_layer is None: raise ValueError("Invalid layer name: {}".format(target_layer)) print(target_layer) def extract_gradcam(image_path=None, save_path=None): if image_path is None: for disease in CLASSES: print("Processing", disease) for image_path in os.listdir(r"data\test\Task 1\{}".format(disease)): print("Processing", image_path) image_path = r"data\test\Task 1\{}\{}".format(disease, image_path) image_name = image_path.split(".")[0].split("\\")[-1] rgb_img = cv2.imread(image_path, 1) rgb_img = np.float32(rgb_img) / 255 input_tensor = preprocess_image( rgb_img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5] ) input_tensor = input_tensor.to(DEVICE) input_tensor.requires_grad = True # Create a GradCAMPlusPlus object cam = GradCAMPlusPlus( model=model, target_layers=[target_layer], use_cuda=True ) # Generate the GradCAM heatmap grayscale_cam = cam(input_tensor=input_tensor)[0] # Apply a colormap to the grayscale heatmap heatmap_colored = cv2.applyColorMap( np.uint8(255 * grayscale_cam), cv2.COLORMAP_JET ) # Ensure heatmap_colored has the same dtype as rgb_img heatmap_colored = heatmap_colored.astype(np.float32) / 255 # Adjust the alpha value to control transparency alpha = 0.3 # You can change this value to make the original image more or less transparent # Overlay the colored heatmap on the original image final_output = cv2.addWeighted(rgb_img, 0.3, heatmap_colored, 0.7, 0) # Save the final output os.makedirs(f"docs/evaluation/gradcam/{disease}", exist_ok=True) cv2.imwrite( f"docs/evaluation/gradcam/{disease}/{image_name}.jpg", (final_output * 255).astype(np.uint8), ) else: rgb_img = cv2.imread(image_path, 1) rgb_img = np.float32(rgb_img) / 255 input_tensor = preprocess_image( rgb_img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5] ) input_tensor = input_tensor.to(DEVICE) input_tensor.requires_grad = True # Create a GradCAMPlusPlus object cam = GradCAMPlusPlus(model=model, target_layers=[target_layer]) # Generate the GradCAM heatmap grayscale_cam = cam(input_tensor=input_tensor)[0] # Apply a colormap to the grayscale heatmap heatmap_colored = cv2.applyColorMap( np.uint8(255 * grayscale_cam), cv2.COLORMAP_JET ) # Ensure heatmap_colored has the same dtype as rgb_img heatmap_colored = heatmap_colored.astype(np.float32) / 255 # Adjust the alpha value to control transparency alpha = 0.3 # You can change this value to make the original image more or less transparent # Overlay the colored heatmap on the original image final_output = cv2.addWeighted(rgb_img, 0.3, heatmap_colored, 0.7, 0) # Save the final output cv2.imwrite(save_path, (final_output * 255).astype(np.uint8)) return save_path # start = time.time() # extract_gradcam() # end = time.time() # print("Time taken:", end - start)