Spaces:
Running
Running
Federico Galatolo
commited on
Commit
·
bc679dd
1
Parent(s):
6b4ee08
work in progress
Browse files- .gitignore +4 -0
- app.py +218 -0
- plots/gradcam/detectron2_gradcam.py +109 -0
- plots/gradcam/gradcam.py +168 -0
- plots/make_plots.py +226 -0
- plots/plot_features.py +68 -0
- plots/plot_gradcam.py +69 -0
- plots/plot_histogram_dist.py +58 -0
- plots/plot_matrix_distance.py +56 -0
- plots/plot_pca_point.py +57 -0
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/env
|
2 |
+
__pycache__/
|
3 |
+
|
4 |
+
test.jpg
|
app.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import cv2
|
3 |
+
import sys
|
4 |
+
import argparse
|
5 |
+
import numpy as np
|
6 |
+
import os
|
7 |
+
import json
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import detectron2.data.transforms as T
|
11 |
+
import torchvision
|
12 |
+
from collections import OrderedDict
|
13 |
+
from scipy import spatial
|
14 |
+
import matplotlib.pyplot as plt
|
15 |
+
|
16 |
+
from detectron2.engine import DefaultPredictor
|
17 |
+
from detectron2.utils.visualizer import Visualizer
|
18 |
+
from detectron2.config import get_cfg
|
19 |
+
from detectron2 import model_zoo
|
20 |
+
from detectron2.data import Metadata
|
21 |
+
from detectron2.structures.boxes import Boxes
|
22 |
+
from detectron2.structures import Instances
|
23 |
+
|
24 |
+
from plots.plot_pca_point import plot_pca_point
|
25 |
+
from plots.plot_histogram_dist import plot_histogram_dist
|
26 |
+
from plots.plot_gradcam import plot_gradcam
|
27 |
+
|
28 |
+
def extract_features(model, img, box):
|
29 |
+
height, width = img.shape[1:3]
|
30 |
+
inputs = [{"image": img, "height": height, "width": width}]
|
31 |
+
with torch.no_grad():
|
32 |
+
img = model.preprocess_image(inputs)
|
33 |
+
features = model.backbone(img.tensor)
|
34 |
+
features_ = [features[f] for f in model.roi_heads.box_in_features]
|
35 |
+
|
36 |
+
box_features = model.roi_heads.box_pooler(features_, [box])
|
37 |
+
|
38 |
+
output_features = F.avg_pool2d(box_features, [7, 7])
|
39 |
+
output_features = output_features.view(-1, 256)
|
40 |
+
|
41 |
+
return output_features
|
42 |
+
|
43 |
+
def forward_model_full(model, cfg, cv_img):
|
44 |
+
height, width = cv_img.shape[:2]
|
45 |
+
transform_gen = T.ResizeShortestEdge(
|
46 |
+
[cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
|
47 |
+
)
|
48 |
+
|
49 |
+
image = transform_gen.get_transform(cv_img).apply_image(cv_img)
|
50 |
+
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
|
51 |
+
inputs = [{"image": image, "height": height, "width": width}]
|
52 |
+
|
53 |
+
with torch.no_grad():
|
54 |
+
images = model.preprocess_image(inputs)
|
55 |
+
features = model.backbone(images.tensor)
|
56 |
+
proposals, _ = model.proposal_generator(images, features, None)
|
57 |
+
|
58 |
+
features_ = [features[f] for f in model.roi_heads.box_in_features]
|
59 |
+
|
60 |
+
box_features = model.roi_heads.box_pooler(features_, [x.proposal_boxes for x in proposals])
|
61 |
+
box_head = model.roi_heads.box_head(box_features)
|
62 |
+
predictions = model.roi_heads.box_predictor(box_head)
|
63 |
+
|
64 |
+
output_features = F.avg_pool2d(box_features, [7, 7])
|
65 |
+
output_features = output_features.view(-1, 256)
|
66 |
+
|
67 |
+
probs = model.roi_heads.box_predictor.predict_probs(predictions, proposals)
|
68 |
+
|
69 |
+
pred_instances, pred_inds = model.roi_heads.box_predictor.inference(predictions, proposals)
|
70 |
+
pred_instances = model.roi_heads.forward_with_given_boxes(features, pred_instances)
|
71 |
+
|
72 |
+
pred_instances = model._postprocess(pred_instances, inputs, images.image_sizes)
|
73 |
+
|
74 |
+
instances = pred_instances[0]["instances"]
|
75 |
+
|
76 |
+
instances.set("probs", probs[0][pred_inds])
|
77 |
+
instances.set("features", output_features[pred_inds])
|
78 |
+
|
79 |
+
return instances, cv_img
|
80 |
+
|
81 |
+
|
82 |
+
def load_model():
|
83 |
+
cfg = get_cfg()
|
84 |
+
|
85 |
+
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
|
86 |
+
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 3
|
87 |
+
cfg.MODEL.WEIGHTS = MODEL
|
88 |
+
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = TH
|
89 |
+
|
90 |
+
|
91 |
+
metadata = Metadata()
|
92 |
+
metadata.set(
|
93 |
+
evaluator_type="coco",
|
94 |
+
thing_classes=["neoplastic", "aphthous", "traumatic"],
|
95 |
+
thing_dataset_id_to_contiguous_id={"1": 0, "2": 1, "3": 2}
|
96 |
+
)
|
97 |
+
|
98 |
+
predictor = DefaultPredictor(cfg)
|
99 |
+
model = predictor.model
|
100 |
+
|
101 |
+
return dict(
|
102 |
+
predictor=predictor,
|
103 |
+
model=model,
|
104 |
+
metadata=metadata,
|
105 |
+
cfg=cfg
|
106 |
+
)
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
|
111 |
+
def compute_similarities(features, database):
|
112 |
+
similarities = dict()
|
113 |
+
dist_fn = getattr(spatial.distance, DISTANCE)
|
114 |
+
for file_name, elems in database.items():
|
115 |
+
for elem in elems:
|
116 |
+
similarities[file_name] = dict(
|
117 |
+
dist=dist_fn(elem["features"], features),
|
118 |
+
file_name=file_name,
|
119 |
+
box=elem["roi"],
|
120 |
+
type=elem["type"]
|
121 |
+
)
|
122 |
+
similarities = OrderedDict(sorted(similarities.items(), key=lambda e: e[1]["dist"]))
|
123 |
+
return similarities
|
124 |
+
|
125 |
+
|
126 |
+
def draw_box(file_name, box, type, model, resize_input=False):
|
127 |
+
height, width, channels = img.shape
|
128 |
+
|
129 |
+
pred_v = Visualizer(img[:, :, ::-1], model["metadata"], scale=1)
|
130 |
+
instances = Instances((height, width), pred_boxes=Boxes(torch.tensor(box).unsqueeze(0)), pred_classes=torch.tensor([type]))
|
131 |
+
pred_v = pred_v.draw_instance_predictions(instances)
|
132 |
+
|
133 |
+
pred = pred_v.get_image()[:, :, ::-1]
|
134 |
+
pred = cv2.resize(pred, (800, 800))
|
135 |
+
|
136 |
+
return pred
|
137 |
+
|
138 |
+
|
139 |
+
def explain(img, model):
|
140 |
+
database = json.load(open(FEATURES_DATABASE))
|
141 |
+
instances, input = forward_model_full(model["model"], model["cfg"], img)
|
142 |
+
|
143 |
+
instances.remove("pred_masks")
|
144 |
+
|
145 |
+
pred_v = Visualizer(cv2.cvtColor(input, cv2.COLOR_BGR2RGB), model["metadata"], scale=1)
|
146 |
+
pred_v = pred_v.draw_instance_predictions(instances.to("cpu"))
|
147 |
+
|
148 |
+
pred = pred_v.get_image()[:, :, ::-1]
|
149 |
+
pred = cv2.resize(pred, (800, 800))
|
150 |
+
pred = cv2.cvtColor(pred, cv2.COLOR_BGR2RGB)
|
151 |
+
|
152 |
+
tabs = st.tabs(["Detection"] + [f"Lesion #{i}" for i in range(0, len(instances))])
|
153 |
+
lesion_tabs = tabs[1:]
|
154 |
+
|
155 |
+
with tabs[0]:
|
156 |
+
st.header("Detected lesions")
|
157 |
+
state.text("All done...")
|
158 |
+
tooltip.success("Use the tabs for a detailed explanation of each lesion")
|
159 |
+
st.image(pred)
|
160 |
+
|
161 |
+
|
162 |
+
for i, (tab, box, type, scores, features) in enumerate(zip(lesion_tabs, instances.pred_boxes, instances.pred_classes, instances.probs, instances.features)):
|
163 |
+
healthy_prob = scores[-1].item()
|
164 |
+
scores = scores[:-1]
|
165 |
+
features = features.tolist()
|
166 |
+
|
167 |
+
with tab:
|
168 |
+
st.header(f"Lesion #{i}")
|
169 |
+
lesion_img = draw_box(img, box.cpu(), type, model)
|
170 |
+
lesion_img = cv2.cvtColor(lesion_img, cv2.COLOR_BGR2RGB)
|
171 |
+
|
172 |
+
classes = ["healty", "neoplastic", "aphthous", "traumatic"]
|
173 |
+
y_pos = np.arange(len(classes))
|
174 |
+
probs = [healthy_prob] + scores.cpu().numpy().tolist()
|
175 |
+
|
176 |
+
probs_fig = plt.figure()
|
177 |
+
plt.bar(y_pos, probs, align="center")
|
178 |
+
plt.xticks(y_pos, classes)
|
179 |
+
plt.ylabel("Probability")
|
180 |
+
plt.title("Class")
|
181 |
+
|
182 |
+
|
183 |
+
st.subheader("Classification")
|
184 |
+
col1, col2 = st.columns(2)
|
185 |
+
|
186 |
+
col1.image(lesion_img)
|
187 |
+
col2.pyplot(probs_fig)
|
188 |
+
|
189 |
+
st.subheader("Feature space")
|
190 |
+
col1, col2 = st.columns(2)
|
191 |
+
|
192 |
+
fig = plot_pca_point(point=features, features_database=FEATURES_DATABASE, pca_model=PCA_MODEL, fig_h=800, fig_w=600, fig_dpi=100)
|
193 |
+
col1.pyplot(fig)
|
194 |
+
|
195 |
+
fig = plot_histogram_dist(point=features, features_database=FEATURES_DATABASE, fig_h=800, fig_w=600, fig_dpi=100)
|
196 |
+
col2.pyplot(fig)
|
197 |
+
|
198 |
+
st.subheader("Gradcam++")
|
199 |
+
fig = plot_gradcam(model=MODEL, file=FILE, instance=i, fig_h=1600, fig_w=1200, fig_dpi=200, th=TH, layer="backbone.bottom_up.res5.2.conv3")
|
200 |
+
st.pyplot(fig)
|
201 |
+
|
202 |
+
FILE = "./test.jpg"
|
203 |
+
MODEL = "./models/model.pth"
|
204 |
+
PCA_MODEL = "./models/pca.pkl"
|
205 |
+
FEATURES_DATABASE = "./assets/features/features.json"
|
206 |
+
|
207 |
+
DISTANCE = "cosine"
|
208 |
+
TH = 0.5
|
209 |
+
|
210 |
+
state = st.empty()
|
211 |
+
tooltip = st.empty()
|
212 |
+
|
213 |
+
state.write("Loading model...")
|
214 |
+
model = load_model()
|
215 |
+
|
216 |
+
img = cv2.imread(FILE)
|
217 |
+
img = cv2.resize(img, (800, 800))
|
218 |
+
explain(img, model)
|
plots/gradcam/detectron2_gradcam.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Author: Alexander Riedel
|
2 |
+
# License: Unlicensed
|
3 |
+
# Link: https://github.com/alexriedel1/detectron2-GradCAM
|
4 |
+
|
5 |
+
from plots.gradcam.gradcam import GradCAM, GradCamPlusPlus
|
6 |
+
import detectron2.data.transforms as T
|
7 |
+
import torch
|
8 |
+
from detectron2.checkpoint import DetectionCheckpointer
|
9 |
+
from detectron2.config import get_cfg
|
10 |
+
from detectron2.data import DatasetCatalog, MetadataCatalog
|
11 |
+
from detectron2.data.detection_utils import read_image
|
12 |
+
from detectron2.modeling import build_model
|
13 |
+
from detectron2.data.datasets import register_coco_instances
|
14 |
+
|
15 |
+
class Detectron2GradCAM():
|
16 |
+
"""
|
17 |
+
Attributes
|
18 |
+
----------
|
19 |
+
config_file : str
|
20 |
+
detectron2 model config file path
|
21 |
+
cfg_list : list
|
22 |
+
List of additional model configurations
|
23 |
+
root_dir : str [optional]
|
24 |
+
directory of coco.josn and dataset images for custom dataset registration
|
25 |
+
custom_dataset : str [optional]
|
26 |
+
Name of the custom dataset to register
|
27 |
+
"""
|
28 |
+
def __init__(self, config_file, cfg_list, root_dir=None, custom_dataset=None):
|
29 |
+
# load config from file
|
30 |
+
cfg = get_cfg()
|
31 |
+
cfg.merge_from_file(config_file)
|
32 |
+
|
33 |
+
if custom_dataset:
|
34 |
+
register_coco_instances(custom_dataset, {}, root_dir + "coco.json", root_dir)
|
35 |
+
cfg.DATASETS.TRAIN = (custom_dataset,)
|
36 |
+
MetadataCatalog.get(custom_dataset)
|
37 |
+
DatasetCatalog.get(custom_dataset)
|
38 |
+
|
39 |
+
if torch.cuda.is_available():
|
40 |
+
cfg.MODEL.DEVICE = "cuda"
|
41 |
+
else:
|
42 |
+
cfg.MODEL.DEVICE = "cpu"
|
43 |
+
|
44 |
+
cfg.merge_from_list(cfg_list)
|
45 |
+
cfg.freeze()
|
46 |
+
|
47 |
+
self.cfg = cfg
|
48 |
+
|
49 |
+
def _get_input_dict(self, original_image):
|
50 |
+
height, width = original_image.shape[:2]
|
51 |
+
transform_gen = T.ResizeShortestEdge(
|
52 |
+
[self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MIN_SIZE_TEST], self.cfg.INPUT.MAX_SIZE_TEST
|
53 |
+
)
|
54 |
+
image = transform_gen.get_transform(original_image).apply_image(original_image)
|
55 |
+
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)).requires_grad_(True)
|
56 |
+
inputs = {"image": image, "height": height, "width": width}
|
57 |
+
return inputs
|
58 |
+
|
59 |
+
def get_cam(self, img, target_instance, layer_name, grad_cam_type="GradCAM"):
|
60 |
+
"""
|
61 |
+
Calls the GradCAM++ instance
|
62 |
+
|
63 |
+
Parameters
|
64 |
+
----------
|
65 |
+
img : str
|
66 |
+
Path to inference image
|
67 |
+
target_instance : int
|
68 |
+
The target instance index
|
69 |
+
layer_name : str
|
70 |
+
Convolutional layer to perform GradCAM on
|
71 |
+
grad_cam_type : str
|
72 |
+
GradCAM or GradCAM++ (for multiple instances of the same object, GradCAM++ can be favorable)
|
73 |
+
|
74 |
+
Returns
|
75 |
+
-------
|
76 |
+
image_dict : dict
|
77 |
+
{"image" : <image>, "cam" : <cam>, "output" : <output>, "label" : <label>}
|
78 |
+
<image> original input image
|
79 |
+
<cam> class activation map resized to original image shape
|
80 |
+
<output> instances object generated by the model
|
81 |
+
<label> label of the
|
82 |
+
cam_orig : numpy.ndarray
|
83 |
+
unprocessed raw cam
|
84 |
+
"""
|
85 |
+
model = build_model(self.cfg)
|
86 |
+
checkpointer = DetectionCheckpointer(model)
|
87 |
+
checkpointer.load(self.cfg.MODEL.WEIGHTS)
|
88 |
+
|
89 |
+
image = read_image(img, format="BGR")
|
90 |
+
input_image_dict = self._get_input_dict(image)
|
91 |
+
|
92 |
+
if grad_cam_type == "GradCAM":
|
93 |
+
grad_cam = GradCAM(model, layer_name)
|
94 |
+
|
95 |
+
elif grad_cam_type == "GradCAM++":
|
96 |
+
grad_cam = GradCamPlusPlus(model, layer_name)
|
97 |
+
|
98 |
+
else:
|
99 |
+
raise ValueError('Grad CAM type not specified')
|
100 |
+
|
101 |
+
with grad_cam as cam:
|
102 |
+
cam, cam_orig, output = cam(input_image_dict, target_category=target_instance)
|
103 |
+
|
104 |
+
image_dict = {}
|
105 |
+
image_dict["image"] = image
|
106 |
+
image_dict["cam"] = cam
|
107 |
+
image_dict["output"] = output[0]
|
108 |
+
image_dict["label"] = MetadataCatalog.get(self.cfg.DATASETS.TRAIN[0]).thing_classes[output[0]["instances"].pred_classes[target_instance]]
|
109 |
+
return image_dict, cam_orig
|
plots/gradcam/gradcam.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Author: Alexander Riedel
|
2 |
+
# License: Unlicensed
|
3 |
+
# Link: https://github.com/alexriedel1/detectron2-GradCAM
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
class GradCAM():
|
9 |
+
"""
|
10 |
+
Class to implement the GradCam function with it's necessary Pytorch hooks.
|
11 |
+
|
12 |
+
Attributes
|
13 |
+
----------
|
14 |
+
model : detectron2 GeneralizedRCNN Model
|
15 |
+
A model using the detectron2 API for inferencing
|
16 |
+
layer_name : str
|
17 |
+
name of the convolutional layer to perform GradCAM with
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, model, target_layer_name):
|
21 |
+
self.model = model
|
22 |
+
self.target_layer_name = target_layer_name
|
23 |
+
self.activations = None
|
24 |
+
self.gradient = None
|
25 |
+
self.model.eval()
|
26 |
+
self.activations_grads = []
|
27 |
+
self._register_hook()
|
28 |
+
|
29 |
+
def _get_activations_hook(self, module, input, output):
|
30 |
+
self.activations = output
|
31 |
+
|
32 |
+
def _get_grads_hook(self, module, input_grad, output_grad):
|
33 |
+
self.gradient = output_grad[0]
|
34 |
+
|
35 |
+
def _register_hook(self):
|
36 |
+
for (name, module) in self.model.named_modules():
|
37 |
+
if name == self.target_layer_name:
|
38 |
+
self.activations_grads.append(module.register_forward_hook(self._get_activations_hook))
|
39 |
+
self.activations_grads.append(module.register_backward_hook(self._get_grads_hook))
|
40 |
+
return True
|
41 |
+
print(f"Layer {self.target_layer_name} not found in Model!")
|
42 |
+
|
43 |
+
def _release_activations_grads(self):
|
44 |
+
for handle in self.activations_grads:
|
45 |
+
handle.remove()
|
46 |
+
|
47 |
+
def _postprocess_cam(self, raw_cam, img_width, img_height):
|
48 |
+
cam_orig = np.sum(raw_cam, axis=0) # [H,W]
|
49 |
+
cam_orig = np.maximum(cam_orig, 0) # ReLU
|
50 |
+
cam_orig -= np.min(cam_orig)
|
51 |
+
cam_orig /= np.max(cam_orig)
|
52 |
+
cam = cv2.resize(cam_orig, (img_width, img_height))
|
53 |
+
return cam, cam_orig
|
54 |
+
|
55 |
+
def __enter__(self):
|
56 |
+
return self
|
57 |
+
|
58 |
+
def __exit__(self, exc_type, exc_value, exc_tb):
|
59 |
+
self._release_activations_grads()
|
60 |
+
|
61 |
+
def __call__(self, inputs, target_category):
|
62 |
+
"""
|
63 |
+
Calls the GradCAM++ instance
|
64 |
+
|
65 |
+
Parameters
|
66 |
+
----------
|
67 |
+
inputs : dict
|
68 |
+
The input in the standard detectron2 model input format
|
69 |
+
https://detectron2.readthedocs.io/en/latest/tutorials/models.html#model-input-format
|
70 |
+
|
71 |
+
target_category : int, optional
|
72 |
+
The target category index. If `None` the highest scoring class will be selected
|
73 |
+
|
74 |
+
Returns
|
75 |
+
-------
|
76 |
+
cam : np.array()
|
77 |
+
Gradient weighted class activation map
|
78 |
+
output : list
|
79 |
+
list of Instance objects representing the detectron2 model output
|
80 |
+
"""
|
81 |
+
self.model.zero_grad()
|
82 |
+
output = self.model.forward([inputs])
|
83 |
+
|
84 |
+
if target_category == None:
|
85 |
+
target_category = np.argmax(output[0]['instances'].scores.cpu().data.numpy(), axis=-1)
|
86 |
+
|
87 |
+
score = output[0]['instances'].scores[target_category]
|
88 |
+
#box0 = output[0]['instances'].pred_boxes[0].tensor[0][target_category]
|
89 |
+
#print(box0)
|
90 |
+
#box0.backward()
|
91 |
+
score.backward()
|
92 |
+
|
93 |
+
gradient = self.gradient[0].cpu().data.numpy() # [C,H,W]
|
94 |
+
activations = self.activations[0].cpu().data.numpy() # [C,H,W]
|
95 |
+
weight = np.mean(gradient, axis=(1, 2)) # [C]
|
96 |
+
|
97 |
+
cam = activations * weight[:, np.newaxis, np.newaxis] # [C,H,W]
|
98 |
+
cam, cam_orig = self._postprocess_cam(cam, inputs["width"], inputs["height"])
|
99 |
+
|
100 |
+
return cam, cam_orig, output
|
101 |
+
|
102 |
+
class GradCamPlusPlus(GradCAM):
|
103 |
+
"""
|
104 |
+
Subclass to implement the GradCam++ function with it's necessary PyTorch hooks.
|
105 |
+
...
|
106 |
+
|
107 |
+
Attributes
|
108 |
+
----------
|
109 |
+
model : detectron2 GeneralizedRCNN Model
|
110 |
+
A model using the detectron2 API for inferencing
|
111 |
+
target_layer_name : str
|
112 |
+
name of the convolutional layer to perform GradCAM++ with
|
113 |
+
|
114 |
+
"""
|
115 |
+
def __init__(self, model, target_layer_name):
|
116 |
+
super(GradCamPlusPlus, self).__init__(model, target_layer_name)
|
117 |
+
|
118 |
+
def __call__(self, inputs, target_category):
|
119 |
+
"""
|
120 |
+
Calls the GradCAM++ instance
|
121 |
+
|
122 |
+
Parameters
|
123 |
+
----------
|
124 |
+
inputs : dict
|
125 |
+
The input in the standard detectron2 model input format
|
126 |
+
https://detectron2.readthedocs.io/en/latest/tutorials/models.html#model-input-format
|
127 |
+
|
128 |
+
target_category : int, optional
|
129 |
+
The target category index. If `None` the highest scoring class will be selected
|
130 |
+
|
131 |
+
Returns
|
132 |
+
-------
|
133 |
+
cam : np.array()
|
134 |
+
Gradient weighted class activation map
|
135 |
+
output : list
|
136 |
+
list of Instance objects representing the detectron2 model output
|
137 |
+
"""
|
138 |
+
self.model.zero_grad()
|
139 |
+
output = self.model.forward([inputs])
|
140 |
+
|
141 |
+
if target_category == None:
|
142 |
+
target_category = np.argmax(output[0]['instances'].scores.cpu().data.numpy(), axis=-1)
|
143 |
+
|
144 |
+
score = output[0]['instances'].scores[target_category]
|
145 |
+
score.backward()
|
146 |
+
|
147 |
+
gradient = self.gradient[0].cpu().data.numpy() # [C,H,W]
|
148 |
+
activations = self.activations[0].cpu().data.numpy() # [C,H,W]
|
149 |
+
|
150 |
+
#from https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/grad_cam_plusplus.py
|
151 |
+
grads_power_2 = gradient**2
|
152 |
+
grads_power_3 = grads_power_2 * gradient
|
153 |
+
# Equation 19 in https://arxiv.org/abs/1710.11063
|
154 |
+
sum_activations = np.sum(activations, axis=(1, 2))
|
155 |
+
eps = 0.000001
|
156 |
+
aij = grads_power_2 / (2 * grads_power_2 +
|
157 |
+
sum_activations[:, None, None] * grads_power_3 + eps)
|
158 |
+
# Now bring back the ReLU from eq.7 in the paper,
|
159 |
+
# And zero out aijs where the activations are 0
|
160 |
+
aij = np.where(gradient != 0, aij, 0)
|
161 |
+
|
162 |
+
weights = np.maximum(gradient, 0) * aij
|
163 |
+
weight = np.sum(weights, axis=(1, 2))
|
164 |
+
|
165 |
+
cam = activations * weight[:, np.newaxis, np.newaxis] # [C,H,W]
|
166 |
+
cam, cam_orig = self._postprocess_cam(cam, inputs["width"], inputs["height"])
|
167 |
+
|
168 |
+
return cam, cam_orig, output
|
plots/make_plots.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import plotszoo
|
7 |
+
|
8 |
+
|
9 |
+
def get_hyperparameters(data_augmentation, sampler):
|
10 |
+
hp = ["lr", "rpn_loss_weight", "roi_heads_loss_weight", "rois_per_image"]
|
11 |
+
if data_augmentation == "full":
|
12 |
+
hp.extend(["random_brightness", "random_contrast"])
|
13 |
+
if data_augmentation == "full" or data_augmentation == "crop-flip":
|
14 |
+
hp.extend(["random_crop"])
|
15 |
+
if sampler == "RepeatFactorTrainingSampler":
|
16 |
+
hp.extend(["repeat_factor_th"])
|
17 |
+
|
18 |
+
return ["config/"+i for i in hp]
|
19 |
+
|
20 |
+
|
21 |
+
def plot_study():
|
22 |
+
query = {"$or": [{"config.wandb_tag": {"$eq": tag}} for tag in args.tags_study_replicas]}
|
23 |
+
data = plotszoo.data.WandbData(args.username, args.project, query, verbose=args.verbose)
|
24 |
+
data.pull_scalars(force_update=args.update_scalars)
|
25 |
+
|
26 |
+
|
27 |
+
group_keys = ["config/sampler", "config/data_augmentation"]
|
28 |
+
|
29 |
+
fig, axes = plt.subplots(nrows=2, ncols=2)
|
30 |
+
|
31 |
+
yticks_fn = lambda index: "Sampler: %s Data Augmentation: %s" % (index[0], index[1])
|
32 |
+
|
33 |
+
test_detection_plot = plotszoo.scalars.grouped.GroupedScalarsBarchart(data, group_keys, "summary/test/results/detection_accuracy")
|
34 |
+
test_classification_plot = plotszoo.scalars.grouped.GroupedScalarsBarchart(data, group_keys, "summary/test/results/classification_accuracy")
|
35 |
+
|
36 |
+
test_detection_df = test_detection_plot.plot(axes[0][0], title="Test Detection Accuracy", nbins=20, grid=True, yticks_fn=yticks_fn)
|
37 |
+
test_classification_df = test_classification_plot.plot(axes[0][1], title="Test Classification Accuracy", nbins=20, grid=True, yticks_fn=yticks_fn)
|
38 |
+
|
39 |
+
train_detection_plot = plotszoo.scalars.grouped.GroupedScalarsBarchart(data, group_keys, "summary/train/results/detection_accuracy")
|
40 |
+
train_classification_plot = plotszoo.scalars.grouped.GroupedScalarsBarchart(data, group_keys, "summary/train/results/classification_accuracy")
|
41 |
+
|
42 |
+
train_detection_df = train_detection_plot.plot(axes[1][0], title="Train Detection Accuracy", nbins=20, grid=True, yticks_fn=yticks_fn)
|
43 |
+
train_classification_df = train_classification_plot.plot(axes[1][1], title="Train Classification Accuracy", nbins=20, grid=True, yticks_fn=yticks_fn)
|
44 |
+
|
45 |
+
test_detection_df.to_excel(os.path.join(args.output_directory, "study/test_detection.xlsx"))
|
46 |
+
test_classification_df.to_excel(os.path.join(args.output_directory, "study/test_classification.xlsx"))
|
47 |
+
train_detection_df.to_excel(os.path.join(args.output_directory, "study/train_detection.xlsx"))
|
48 |
+
train_classification_df.to_excel(os.path.join(args.output_directory, "study/train_classification.xlsx"))
|
49 |
+
|
50 |
+
|
51 |
+
for ax in axes.flatten():
|
52 |
+
ax.set_xlim(xmin=0.5)
|
53 |
+
|
54 |
+
fig.set_size_inches(30, 10)
|
55 |
+
fig.tight_layout()
|
56 |
+
|
57 |
+
plotszoo.utils.savefig(fig, os.path.join(args.output_directory, "study.png"))
|
58 |
+
|
59 |
+
|
60 |
+
def plot_optimization_history(ax, data, dataset):
|
61 |
+
running_max = dict(accuracy=float("-inf"), detection_accuracy=float("-inf"), classification_accuracy=float("-inf"))
|
62 |
+
plots=dict(best_accuracy=[], best_detection_accuracy=[], best_classification_accuracy=[], accuracy=[], detection_accuracy=[], classification_accuracy=[])
|
63 |
+
plot_index = []
|
64 |
+
for i, row in data.scalars.iterrows():
|
65 |
+
if row["summary/"+dataset+"/results/accuracy"] > running_max["accuracy"]:
|
66 |
+
running_max = dict(
|
67 |
+
accuracy=row["summary/"+dataset+"/results/accuracy"],
|
68 |
+
detection_accuracy=row["summary/"+dataset+"/results/detection_accuracy"],
|
69 |
+
classification_accuracy=row["summary/"+dataset+"/results/classification_accuracy"]
|
70 |
+
)
|
71 |
+
plots["accuracy"].append(row["summary/"+dataset+"/results/accuracy"])
|
72 |
+
plots["detection_accuracy"].append(row["summary/"+dataset+"/results/detection_accuracy"])
|
73 |
+
plots["classification_accuracy"].append(row["summary/"+dataset+"/results/classification_accuracy"])
|
74 |
+
|
75 |
+
plots["best_accuracy"].append(running_max["accuracy"])
|
76 |
+
plots["best_detection_accuracy"].append(running_max["detection_accuracy"])
|
77 |
+
plots["best_classification_accuracy"].append(running_max["classification_accuracy"])
|
78 |
+
|
79 |
+
plot_index.append(i)
|
80 |
+
|
81 |
+
ax.plot(plot_index, plots["best_accuracy"], "k", label="Best "+dataset+" Accuracy")
|
82 |
+
ax.plot(plot_index, plots["best_detection_accuracy"], "b--", label="Best "+dataset+" Detection Accuracy")
|
83 |
+
ax.plot(plot_index, plots["best_classification_accuracy"], "g--", label="Best "+dataset+" Classification Accuracy")
|
84 |
+
|
85 |
+
ax.scatter(plot_index, plots["accuracy"], c="k", alpha=0.5)
|
86 |
+
ax.scatter(plot_index, plots["detection_accuracy"], c="b", alpha=0.5)
|
87 |
+
ax.scatter(plot_index, plots["classification_accuracy"], c="g", alpha=0.5)
|
88 |
+
|
89 |
+
ax.legend(loc="lower right")
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
def plot_optimization():
|
94 |
+
for tag, params in args.tags_optimization.items():
|
95 |
+
query = {"config.wandb_tag": {"$eq": tag}}
|
96 |
+
|
97 |
+
parameters = get_hyperparameters(**params)
|
98 |
+
parameters.extend(["summary/train/results/detection_accuracy", "summary/train/results/classification_accuracy"])
|
99 |
+
|
100 |
+
data = plotszoo.data.WandbData(args.username, args.project, query, verbose=args.verbose)
|
101 |
+
data.pull_scalars(force_update=args.update_scalars)
|
102 |
+
assert len(data.scalars) > 0, "No data, check the tag name"
|
103 |
+
data.pull_series(force_update=args.update_series)
|
104 |
+
|
105 |
+
data.astype(["summary/train/results/accuracy", "summary/train/results/detection_accuracy", "summary/train/results/classification_accuracy"], float)
|
106 |
+
data.dropna(["summary/train/results/accuracy"])
|
107 |
+
|
108 |
+
data.create_scalar_from_series("start_time", lambda s: s["_timestamp"].min())
|
109 |
+
|
110 |
+
fig, axes = plt.subplots(1, len(parameters), sharey=False)
|
111 |
+
|
112 |
+
parallel_plot = plotszoo.scalars.ScalarsParallelCoordinates(data, parameters, "summary/train/results/accuracy")
|
113 |
+
|
114 |
+
parallel_plot.plot(axes)
|
115 |
+
|
116 |
+
fig.set_size_inches(32, 10)
|
117 |
+
plotszoo.utils.savefig(fig, os.path.join(args.output_directory, tag, "optim_parallel.png"))
|
118 |
+
|
119 |
+
fig, ax = plt.subplots(2, 1)
|
120 |
+
|
121 |
+
plot_optimization_history(ax[0], data, "train")
|
122 |
+
plot_optimization_history(ax[1], data, "test")
|
123 |
+
|
124 |
+
fig.set_size_inches(20, 10)
|
125 |
+
plotszoo.utils.savefig(fig, os.path.join(args.output_directory, tag, "optim_history.png"))
|
126 |
+
|
127 |
+
parameters.remove("summary/train/results/detection_accuracy")
|
128 |
+
parameters.remove("summary/train/results/classification_accuracy")
|
129 |
+
|
130 |
+
args_names = [p.split("/")[1].replace("_","-") for p in parameters]
|
131 |
+
best_run = data.scalars["summary/train/results/accuracy"].idxmax()
|
132 |
+
best_args = "".join(["--%s %s " % (n, data.scalars[k][best_run]) for n, k in zip(args_names, parameters)])
|
133 |
+
best_args += "".join(["--%s %s " % (k.replace("_", "-"), v) for k, v in params.items()])
|
134 |
+
print(best_run)
|
135 |
+
print("Tag: %s" % tag)
|
136 |
+
print(data.scalars.loc[best_run][["summary/train/results/detection_accuracy", "summary/train/results/classification_accuracy"]])
|
137 |
+
print("HP: %s" % best_args)
|
138 |
+
print()
|
139 |
+
|
140 |
+
best_args_f = open(os.path.join(args.output_directory, tag, "best_args.txt"), "w")
|
141 |
+
best_args_f.write(best_args)
|
142 |
+
best_args_f.close()
|
143 |
+
|
144 |
+
def plot_replicas():
|
145 |
+
query = {"$or": [{"config.wandb_tag": {"$eq": tag}} for tag in args.tags_best_replicas]}
|
146 |
+
data = plotszoo.data.WandbData(args.username, args.project, query, verbose=args.verbose)
|
147 |
+
data.pull_scalars(force_update=args.update_scalars)
|
148 |
+
|
149 |
+
group_keys = ["config/sampler"]
|
150 |
+
|
151 |
+
fig, axes = plt.subplots(nrows=2, ncols=1)
|
152 |
+
|
153 |
+
yticks_fn = lambda index: "Sampler: %s" % (index, )
|
154 |
+
|
155 |
+
detection_plot = plotszoo.scalars.grouped.GroupedScalarsBarchart(data, group_keys, "summary/test/results/detection_accuracy")
|
156 |
+
classification_plot = plotszoo.scalars.grouped.GroupedScalarsBarchart(data, group_keys, "summary/test/results/classification_accuracy")
|
157 |
+
|
158 |
+
detection_df = detection_plot.plot(axes[0], title="Test Detection Accuracy", nbins=20, grid=True, yticks_fn=yticks_fn)
|
159 |
+
classification_df = classification_plot.plot(axes[1], title="Test Classification Accuracy", nbins=20, grid=True, yticks_fn=yticks_fn)
|
160 |
+
|
161 |
+
for ax in axes: ax.set_xlim(xmin=0.5)
|
162 |
+
|
163 |
+
fig.set_size_inches(20, 10)
|
164 |
+
fig.tight_layout()
|
165 |
+
|
166 |
+
classification_df.to_excel(os.path.join(args.output_directory, "result/classification.xlsx"))
|
167 |
+
detection_df.to_excel(os.path.join(args.output_directory, "result/detection.xlsx"))
|
168 |
+
|
169 |
+
print(classification_df)
|
170 |
+
print(detection_df)
|
171 |
+
|
172 |
+
plotszoo.utils.savefig(fig, os.path.join(args.output_directory, "results.png"))
|
173 |
+
|
174 |
+
|
175 |
+
def plot_tables():
|
176 |
+
query = {"$or": [{"config.wandb_tag": {"$eq": tag}} for tag in args.tags_best_replicas]}
|
177 |
+
data = plotszoo.data.WandbData(args.username, args.project, query, verbose=args.verbose)
|
178 |
+
data.pull_scalars(force_update=args.update_scalars)
|
179 |
+
|
180 |
+
group_keys = ["config/sampler"]
|
181 |
+
classes = ["neoplastic", "aphthous", "traumatic"]
|
182 |
+
metrics = ["precision", "recall", "f1-score"]
|
183 |
+
|
184 |
+
grouped_df = data.scalars.groupby(group_keys).agg(np.mean)
|
185 |
+
for group in grouped_df.index:
|
186 |
+
data_df = grouped_df.loc[group]
|
187 |
+
table = np.zeros((len(classes), len(metrics)))
|
188 |
+
for i, c in enumerate(classes):
|
189 |
+
for j, m in enumerate(metrics):
|
190 |
+
table[i, j] = data_df["summary/test/report/%s/%s" % (c, m)]*100
|
191 |
+
|
192 |
+
table_df = pd.DataFrame(table, columns=metrics, index=classes)
|
193 |
+
table_df.to_csv(os.path.join(args.output_directory, "%s_table.csv" % (group)))
|
194 |
+
print("Sampler: %s" % (group))
|
195 |
+
print(table_df)
|
196 |
+
print()
|
197 |
+
|
198 |
+
|
199 |
+
parser = argparse.ArgumentParser()
|
200 |
+
|
201 |
+
parser.add_argument("--output-directory", type=str, default="./plots")
|
202 |
+
parser.add_argument("--username", type=str, default="mlpi")
|
203 |
+
parser.add_argument("--project", type=str, default="oral-ai")
|
204 |
+
parser.add_argument("--tags-study-replicas", type=str, default=["study-3"], nargs="+")
|
205 |
+
parser.add_argument("--tags-optimization", type=dict, default={
|
206 |
+
"hp-optimization-none-trainingsampler-5": dict(
|
207 |
+
data_augmentation="none",
|
208 |
+
sampler="TrainingSampler"
|
209 |
+
),
|
210 |
+
"hp-optimization-none-repeatfactortrainingsampler-5": dict(
|
211 |
+
data_augmentation="none",
|
212 |
+
sampler="RepeatFactorTrainingSampler"
|
213 |
+
)
|
214 |
+
}, nargs="+")
|
215 |
+
parser.add_argument("--tags-best-replicas", type=str, default=["best-replicas-7"], nargs="+")
|
216 |
+
parser.add_argument("--update-scalars", action="store_true")
|
217 |
+
parser.add_argument("--update-series", action="store_true")
|
218 |
+
parser.add_argument("--verbose", action="store_true")
|
219 |
+
|
220 |
+
|
221 |
+
args = parser.parse_args()
|
222 |
+
|
223 |
+
plot_study()
|
224 |
+
plot_optimization()
|
225 |
+
plot_replicas()
|
226 |
+
#plot_tables()
|
plots/plot_features.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import numpy as np
|
3 |
+
import json
|
4 |
+
|
5 |
+
from matplotlib import pyplot as plt
|
6 |
+
from matplotlib.colors import ListedColormap
|
7 |
+
|
8 |
+
from sklearn.decomposition import PCA
|
9 |
+
from sklearn.manifold import TSNE
|
10 |
+
from scipy import spatial
|
11 |
+
|
12 |
+
parser = argparse.ArgumentParser()
|
13 |
+
|
14 |
+
parser.add_argument("--features-database", type=str, required=True)
|
15 |
+
parser.add_argument("--decomposition", type=str, default="TSNE", choices=["TSNE", "PCA"])
|
16 |
+
parser.add_argument("--output", type=str, default="")
|
17 |
+
parser.add_argument("--fig-h", type=int, default=1080)
|
18 |
+
parser.add_argument("--fig-w", type=int, default=720)
|
19 |
+
parser.add_argument("--fig-dpi", type=int, default=100)
|
20 |
+
parser.add_argument("--distance", type=str, default="cosine")
|
21 |
+
|
22 |
+
parser.add_argument("--point", type=str, default="")
|
23 |
+
|
24 |
+
args = parser.parse_args()
|
25 |
+
point = None
|
26 |
+
if args.point != "":
|
27 |
+
point = json.loads(args.point)
|
28 |
+
|
29 |
+
|
30 |
+
dist_fn = getattr(spatial.distance, args.distance)
|
31 |
+
features_database = json.load(open(args.features_database, "r"))
|
32 |
+
|
33 |
+
features = []
|
34 |
+
classes = []
|
35 |
+
for name, feature_list in features_database.items():
|
36 |
+
for feature in feature_list:
|
37 |
+
features.append(feature["features"])
|
38 |
+
classes.append(feature["type"])
|
39 |
+
|
40 |
+
if point is not None:
|
41 |
+
features.append(point)
|
42 |
+
|
43 |
+
features = np.array(features)
|
44 |
+
classes = np.array(classes)
|
45 |
+
|
46 |
+
if args.decomposition == "TSNE":
|
47 |
+
decomposition = TSNE(n_components=2, metric=dist_fn)
|
48 |
+
elif args.decomposition == "PCA":
|
49 |
+
decomposition = PCA(n_components=2)
|
50 |
+
transformed = decomposition.fit_transform(features)
|
51 |
+
|
52 |
+
if point is not None:
|
53 |
+
transformed = transformed[:-1,:]
|
54 |
+
transformed_point = transformed[-1,:]
|
55 |
+
|
56 |
+
plt.figure(figsize=(args.fig_h/args.fig_dpi, args.fig_w/args.fig_dpi), dpi=args.fig_dpi)
|
57 |
+
cmap = ListedColormap(["r","b","g"])
|
58 |
+
scatter = plt.scatter(transformed[:, 0], transformed[:, 1], c=classes, cmap=cmap, s=10)
|
59 |
+
|
60 |
+
if point is not None:
|
61 |
+
plt.scatter(transformed_point[0], transformed_point[1], marker="x", s=200, c="k")
|
62 |
+
|
63 |
+
plt.legend(handles=scatter.legend_elements()[0], labels=["neoplastic", "aphthous", "traumatic"])
|
64 |
+
|
65 |
+
if args.output == "":
|
66 |
+
plt.show()
|
67 |
+
else:
|
68 |
+
plt.savefig(args.output, dpi=args.fig_dpi)
|
plots/plot_gradcam.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
import matplotlib
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
from types import SimpleNamespace
|
6 |
+
|
7 |
+
from detectron2.utils.visualizer import Visualizer
|
8 |
+
from detectron2.data import Metadata
|
9 |
+
from detectron2 import model_zoo
|
10 |
+
|
11 |
+
from plots.gradcam.detectron2_gradcam import Detectron2GradCAM
|
12 |
+
|
13 |
+
|
14 |
+
def plot_gradcam(**kwargs):
|
15 |
+
kwargs = SimpleNamespace(**kwargs)
|
16 |
+
|
17 |
+
config_file = model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
|
18 |
+
|
19 |
+
cfg_list = [
|
20 |
+
"MODEL.ROI_HEADS.SCORE_THRESH_TEST", str(kwargs.th),
|
21 |
+
"MODEL.ROI_HEADS.NUM_CLASSES", "3",
|
22 |
+
"MODEL.WEIGHTS", kwargs.model
|
23 |
+
]
|
24 |
+
|
25 |
+
metadata = Metadata()
|
26 |
+
metadata.set(
|
27 |
+
evaluator_type="coco",
|
28 |
+
thing_classes=["neoplastic", "aphthous", "traumatic"],
|
29 |
+
thing_dataset_id_to_contiguous_id={"1": 0, "2": 1, "3": 2}
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
cam_extractor = Detectron2GradCAM(config_file, cfg_list)
|
34 |
+
image_dict, cam_orig = cam_extractor.get_cam(img=kwargs.file, target_instance=kwargs.instance, layer_name=kwargs.layer, grad_cam_type="GradCAM++")
|
35 |
+
|
36 |
+
with torch.no_grad():
|
37 |
+
fig = plt.figure(figsize=(kwargs.fig_h/kwargs.fig_dpi, kwargs.fig_w/kwargs.fig_dpi), dpi=kwargs.fig_dpi)
|
38 |
+
v = Visualizer(image_dict["image"], metadata, scale=1.0)
|
39 |
+
img = image_dict["output"]["instances"][kwargs.instance]
|
40 |
+
img.remove("pred_masks")
|
41 |
+
|
42 |
+
out = v.draw_instance_predictions(img.to("cpu"))
|
43 |
+
|
44 |
+
plt.gca().set_axis_off()
|
45 |
+
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
|
46 |
+
hspace = 0, wspace = 0)
|
47 |
+
plt.margins(0,0)
|
48 |
+
plt.imshow(out.get_image(), interpolation='none')
|
49 |
+
plt.imshow(image_dict["cam"], cmap='jet', alpha=0.5)
|
50 |
+
|
51 |
+
return fig
|
52 |
+
|
53 |
+
|
54 |
+
if __name__ == "__main__":
|
55 |
+
parser = argparse.ArgumentParser()
|
56 |
+
|
57 |
+
parser.add_argument("--model", type=str, required=True)
|
58 |
+
parser.add_argument("--layer", type=str, default="backbone.bottom_up.res5.2.conv3")
|
59 |
+
parser.add_argument("--th", type=float, default=0.5)
|
60 |
+
parser.add_argument("--file", type=str, required=True)
|
61 |
+
parser.add_argument("--instance", type=int, required=True)
|
62 |
+
parser.add_argument("--output", type=str, default="")
|
63 |
+
parser.add_argument("--fig-h", type=int, default=1080)
|
64 |
+
parser.add_argument("--fig-w", type=int, default=720)
|
65 |
+
parser.add_argument("--fig-dpi", type=int, default=100)
|
66 |
+
|
67 |
+
args = parser.parse_args()
|
68 |
+
|
69 |
+
plot_gradcam(**vars(args))
|
plots/plot_histogram_dist.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import numpy as np
|
3 |
+
import json
|
4 |
+
import pickle
|
5 |
+
from scipy import spatial
|
6 |
+
|
7 |
+
from matplotlib import pyplot as plt
|
8 |
+
from matplotlib.colors import ListedColormap
|
9 |
+
|
10 |
+
|
11 |
+
def plot_histogram_dist(features_database, fig_h, fig_w, fig_dpi, point, distance="cosine"):
|
12 |
+
features_database = json.load(open(features_database, "r"))
|
13 |
+
dist_fn = getattr(spatial.distance, distance)
|
14 |
+
class_names = ["neoplastic", "aphthous", "traumatic"]
|
15 |
+
cmap = ListedColormap(["r","b","g"])
|
16 |
+
|
17 |
+
dists = dict()
|
18 |
+
for name, feature_list in features_database.items():
|
19 |
+
for feature in feature_list:
|
20 |
+
if feature["type"] not in dists:
|
21 |
+
dists[feature["type"]] = []
|
22 |
+
|
23 |
+
dists[feature["type"]].append(dist_fn(point, feature["features"]))
|
24 |
+
|
25 |
+
|
26 |
+
fig, axes = plt.subplots(len(dists))
|
27 |
+
|
28 |
+
for k, ax in zip(dists.keys(), axes):
|
29 |
+
dist = dists[k]
|
30 |
+
ax.set_title(class_names[k])
|
31 |
+
ax.set_xlim(0, 1)
|
32 |
+
n, bins, patches = ax.hist(dist, "auto", density=True, color=cmap(k))
|
33 |
+
|
34 |
+
fig.tight_layout(pad=3.0)
|
35 |
+
|
36 |
+
return fig
|
37 |
+
|
38 |
+
|
39 |
+
if __name__ == "__main__":
|
40 |
+
parser = argparse.ArgumentParser()
|
41 |
+
|
42 |
+
parser.add_argument("--features-database", type=str, required=True)
|
43 |
+
parser.add_argument("--output", type=str, default="")
|
44 |
+
parser.add_argument("--fig-h", type=int, default=1080)
|
45 |
+
parser.add_argument("--fig-w", type=int, default=720)
|
46 |
+
parser.add_argument("--fig-dpi", type=int, default=100)
|
47 |
+
parser.add_argument("--distance", type=str, default="cosine")
|
48 |
+
|
49 |
+
|
50 |
+
parser.add_argument("--point", type=str, required=True)
|
51 |
+
|
52 |
+
args = parser.parse_args()
|
53 |
+
|
54 |
+
point = json.loads(args.point)
|
55 |
+
|
56 |
+
dict_args = vars(args)
|
57 |
+
del dict_args["point"]
|
58 |
+
plot_histogram_dist(**dict_args, point=point)
|
plots/plot_matrix_distance.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import argparse
|
3 |
+
from scipy import spatial
|
4 |
+
import numpy as np
|
5 |
+
from tqdm import tqdm
|
6 |
+
from matplotlib import pyplot as plt
|
7 |
+
|
8 |
+
parser = argparse.ArgumentParser()
|
9 |
+
|
10 |
+
parser.add_argument("--rows", type=str, required=True)
|
11 |
+
parser.add_argument("--cols", type=str, required=True)
|
12 |
+
parser.add_argument("--distance", type=str, default="cosine")
|
13 |
+
parser.add_argument("--output", type=str, default="")
|
14 |
+
|
15 |
+
args = parser.parse_args()
|
16 |
+
|
17 |
+
rows_features = json.load(open(args.rows, "r"))
|
18 |
+
cols_features = json.load(open(args.cols, "r"))
|
19 |
+
dist_fn = getattr(spatial.distance, args.distance)
|
20 |
+
|
21 |
+
|
22 |
+
rows_features_rois = []
|
23 |
+
cols_features_rois = []
|
24 |
+
|
25 |
+
for row_feature in rows_features.values():
|
26 |
+
for roi_feature in row_feature:
|
27 |
+
rows_features_rois.append(roi_feature)
|
28 |
+
|
29 |
+
for col_feature in cols_features.values():
|
30 |
+
for roi_feature in col_feature:
|
31 |
+
cols_features_rois.append(roi_feature)
|
32 |
+
|
33 |
+
|
34 |
+
rows_features_rois = sorted(rows_features_rois, key=lambda e: e["type"])
|
35 |
+
cols_features_rois = sorted(cols_features_rois, key=lambda e: e["type"])
|
36 |
+
|
37 |
+
|
38 |
+
matrix = np.zeros((len(rows_features_rois), len(cols_features_rois)))
|
39 |
+
for i, row in tqdm(enumerate(rows_features_rois), total=len(rows_features_rois)):
|
40 |
+
for j, col in enumerate(cols_features_rois):
|
41 |
+
matrix[i, j] = dist_fn(row["features"], col["features"])
|
42 |
+
|
43 |
+
fig, ax = plt.subplots()
|
44 |
+
|
45 |
+
ax.set_xlabel(args.rows)
|
46 |
+
ax.set_ylabel(args.cols)
|
47 |
+
|
48 |
+
pos = ax.imshow(matrix)
|
49 |
+
fig.colorbar(pos, ax=ax)
|
50 |
+
|
51 |
+
if args.output == "":
|
52 |
+
plt.show()
|
53 |
+
else:
|
54 |
+
plt.savefig(args.output)
|
55 |
+
|
56 |
+
|
plots/plot_pca_point.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import numpy as np
|
3 |
+
import json
|
4 |
+
import pickle
|
5 |
+
|
6 |
+
from matplotlib import pyplot as plt
|
7 |
+
from matplotlib.colors import ListedColormap
|
8 |
+
|
9 |
+
from sklearn.decomposition import PCA
|
10 |
+
|
11 |
+
|
12 |
+
def plot_pca_point(features_database, pca_model, fig_h, fig_w, fig_dpi, point):
|
13 |
+
features_database = json.load(open(features_database, "r"))
|
14 |
+
pca = pickle.load(open(pca_model, "rb"))
|
15 |
+
|
16 |
+
features = []
|
17 |
+
classes = []
|
18 |
+
for name, feature_list in features_database.items():
|
19 |
+
for feature in feature_list:
|
20 |
+
features.append(feature["features"])
|
21 |
+
classes.append(feature["type"])
|
22 |
+
|
23 |
+
features = np.array(features)
|
24 |
+
classes = np.array(classes)
|
25 |
+
|
26 |
+
features = pca.transform(features)
|
27 |
+
point = pca.transform(np.atleast_2d(point))
|
28 |
+
|
29 |
+
fig = plt.figure(figsize=(fig_h/fig_dpi, fig_w/fig_dpi), dpi=fig_dpi)
|
30 |
+
cmap = ListedColormap(["r","b","g"])
|
31 |
+
scatter = plt.scatter(features[:, 0], features[:, 1], c=classes, cmap=cmap, s=10)
|
32 |
+
|
33 |
+
plt.scatter(point[:, 0], point[:, 1], marker="x", s=200, c="k")
|
34 |
+
|
35 |
+
plt.legend(handles=scatter.legend_elements()[0], labels=["neoplastic", "aphthous", "traumatic"])
|
36 |
+
|
37 |
+
return fig
|
38 |
+
|
39 |
+
if __name__ == "__main__":
|
40 |
+
parser = argparse.ArgumentParser()
|
41 |
+
|
42 |
+
parser.add_argument("--features-database", type=str, required=True)
|
43 |
+
parser.add_argument("--pca-model", type=str, required=True)
|
44 |
+
parser.add_argument("--output", type=str, default="")
|
45 |
+
parser.add_argument("--fig-h", type=int, default=1080)
|
46 |
+
parser.add_argument("--fig-w", type=int, default=720)
|
47 |
+
parser.add_argument("--fig-dpi", type=int, default=100)
|
48 |
+
|
49 |
+
parser.add_argument("--point", type=str, required=True)
|
50 |
+
|
51 |
+
args = parser.parse_args()
|
52 |
+
|
53 |
+
point = json.loads(args.point)
|
54 |
+
|
55 |
+
dict_args = vars(args)
|
56 |
+
del dict_args["point"]
|
57 |
+
plot_pca_point(**dict_args, point=point)
|