Federico Galatolo commited on
Commit
bc679dd
·
1 Parent(s): 6b4ee08

work in progress

Browse files
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ /env
2
+ __pycache__/
3
+
4
+ test.jpg
app.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import cv2
3
+ import sys
4
+ import argparse
5
+ import numpy as np
6
+ import os
7
+ import json
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import detectron2.data.transforms as T
11
+ import torchvision
12
+ from collections import OrderedDict
13
+ from scipy import spatial
14
+ import matplotlib.pyplot as plt
15
+
16
+ from detectron2.engine import DefaultPredictor
17
+ from detectron2.utils.visualizer import Visualizer
18
+ from detectron2.config import get_cfg
19
+ from detectron2 import model_zoo
20
+ from detectron2.data import Metadata
21
+ from detectron2.structures.boxes import Boxes
22
+ from detectron2.structures import Instances
23
+
24
+ from plots.plot_pca_point import plot_pca_point
25
+ from plots.plot_histogram_dist import plot_histogram_dist
26
+ from plots.plot_gradcam import plot_gradcam
27
+
28
+ def extract_features(model, img, box):
29
+ height, width = img.shape[1:3]
30
+ inputs = [{"image": img, "height": height, "width": width}]
31
+ with torch.no_grad():
32
+ img = model.preprocess_image(inputs)
33
+ features = model.backbone(img.tensor)
34
+ features_ = [features[f] for f in model.roi_heads.box_in_features]
35
+
36
+ box_features = model.roi_heads.box_pooler(features_, [box])
37
+
38
+ output_features = F.avg_pool2d(box_features, [7, 7])
39
+ output_features = output_features.view(-1, 256)
40
+
41
+ return output_features
42
+
43
+ def forward_model_full(model, cfg, cv_img):
44
+ height, width = cv_img.shape[:2]
45
+ transform_gen = T.ResizeShortestEdge(
46
+ [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
47
+ )
48
+
49
+ image = transform_gen.get_transform(cv_img).apply_image(cv_img)
50
+ image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
51
+ inputs = [{"image": image, "height": height, "width": width}]
52
+
53
+ with torch.no_grad():
54
+ images = model.preprocess_image(inputs)
55
+ features = model.backbone(images.tensor)
56
+ proposals, _ = model.proposal_generator(images, features, None)
57
+
58
+ features_ = [features[f] for f in model.roi_heads.box_in_features]
59
+
60
+ box_features = model.roi_heads.box_pooler(features_, [x.proposal_boxes for x in proposals])
61
+ box_head = model.roi_heads.box_head(box_features)
62
+ predictions = model.roi_heads.box_predictor(box_head)
63
+
64
+ output_features = F.avg_pool2d(box_features, [7, 7])
65
+ output_features = output_features.view(-1, 256)
66
+
67
+ probs = model.roi_heads.box_predictor.predict_probs(predictions, proposals)
68
+
69
+ pred_instances, pred_inds = model.roi_heads.box_predictor.inference(predictions, proposals)
70
+ pred_instances = model.roi_heads.forward_with_given_boxes(features, pred_instances)
71
+
72
+ pred_instances = model._postprocess(pred_instances, inputs, images.image_sizes)
73
+
74
+ instances = pred_instances[0]["instances"]
75
+
76
+ instances.set("probs", probs[0][pred_inds])
77
+ instances.set("features", output_features[pred_inds])
78
+
79
+ return instances, cv_img
80
+
81
+
82
+ def load_model():
83
+ cfg = get_cfg()
84
+
85
+ cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
86
+ cfg.MODEL.ROI_HEADS.NUM_CLASSES = 3
87
+ cfg.MODEL.WEIGHTS = MODEL
88
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = TH
89
+
90
+
91
+ metadata = Metadata()
92
+ metadata.set(
93
+ evaluator_type="coco",
94
+ thing_classes=["neoplastic", "aphthous", "traumatic"],
95
+ thing_dataset_id_to_contiguous_id={"1": 0, "2": 1, "3": 2}
96
+ )
97
+
98
+ predictor = DefaultPredictor(cfg)
99
+ model = predictor.model
100
+
101
+ return dict(
102
+ predictor=predictor,
103
+ model=model,
104
+ metadata=metadata,
105
+ cfg=cfg
106
+ )
107
+
108
+
109
+
110
+
111
+ def compute_similarities(features, database):
112
+ similarities = dict()
113
+ dist_fn = getattr(spatial.distance, DISTANCE)
114
+ for file_name, elems in database.items():
115
+ for elem in elems:
116
+ similarities[file_name] = dict(
117
+ dist=dist_fn(elem["features"], features),
118
+ file_name=file_name,
119
+ box=elem["roi"],
120
+ type=elem["type"]
121
+ )
122
+ similarities = OrderedDict(sorted(similarities.items(), key=lambda e: e[1]["dist"]))
123
+ return similarities
124
+
125
+
126
+ def draw_box(file_name, box, type, model, resize_input=False):
127
+ height, width, channels = img.shape
128
+
129
+ pred_v = Visualizer(img[:, :, ::-1], model["metadata"], scale=1)
130
+ instances = Instances((height, width), pred_boxes=Boxes(torch.tensor(box).unsqueeze(0)), pred_classes=torch.tensor([type]))
131
+ pred_v = pred_v.draw_instance_predictions(instances)
132
+
133
+ pred = pred_v.get_image()[:, :, ::-1]
134
+ pred = cv2.resize(pred, (800, 800))
135
+
136
+ return pred
137
+
138
+
139
+ def explain(img, model):
140
+ database = json.load(open(FEATURES_DATABASE))
141
+ instances, input = forward_model_full(model["model"], model["cfg"], img)
142
+
143
+ instances.remove("pred_masks")
144
+
145
+ pred_v = Visualizer(cv2.cvtColor(input, cv2.COLOR_BGR2RGB), model["metadata"], scale=1)
146
+ pred_v = pred_v.draw_instance_predictions(instances.to("cpu"))
147
+
148
+ pred = pred_v.get_image()[:, :, ::-1]
149
+ pred = cv2.resize(pred, (800, 800))
150
+ pred = cv2.cvtColor(pred, cv2.COLOR_BGR2RGB)
151
+
152
+ tabs = st.tabs(["Detection"] + [f"Lesion #{i}" for i in range(0, len(instances))])
153
+ lesion_tabs = tabs[1:]
154
+
155
+ with tabs[0]:
156
+ st.header("Detected lesions")
157
+ state.text("All done...")
158
+ tooltip.success("Use the tabs for a detailed explanation of each lesion")
159
+ st.image(pred)
160
+
161
+
162
+ for i, (tab, box, type, scores, features) in enumerate(zip(lesion_tabs, instances.pred_boxes, instances.pred_classes, instances.probs, instances.features)):
163
+ healthy_prob = scores[-1].item()
164
+ scores = scores[:-1]
165
+ features = features.tolist()
166
+
167
+ with tab:
168
+ st.header(f"Lesion #{i}")
169
+ lesion_img = draw_box(img, box.cpu(), type, model)
170
+ lesion_img = cv2.cvtColor(lesion_img, cv2.COLOR_BGR2RGB)
171
+
172
+ classes = ["healty", "neoplastic", "aphthous", "traumatic"]
173
+ y_pos = np.arange(len(classes))
174
+ probs = [healthy_prob] + scores.cpu().numpy().tolist()
175
+
176
+ probs_fig = plt.figure()
177
+ plt.bar(y_pos, probs, align="center")
178
+ plt.xticks(y_pos, classes)
179
+ plt.ylabel("Probability")
180
+ plt.title("Class")
181
+
182
+
183
+ st.subheader("Classification")
184
+ col1, col2 = st.columns(2)
185
+
186
+ col1.image(lesion_img)
187
+ col2.pyplot(probs_fig)
188
+
189
+ st.subheader("Feature space")
190
+ col1, col2 = st.columns(2)
191
+
192
+ fig = plot_pca_point(point=features, features_database=FEATURES_DATABASE, pca_model=PCA_MODEL, fig_h=800, fig_w=600, fig_dpi=100)
193
+ col1.pyplot(fig)
194
+
195
+ fig = plot_histogram_dist(point=features, features_database=FEATURES_DATABASE, fig_h=800, fig_w=600, fig_dpi=100)
196
+ col2.pyplot(fig)
197
+
198
+ st.subheader("Gradcam++")
199
+ fig = plot_gradcam(model=MODEL, file=FILE, instance=i, fig_h=1600, fig_w=1200, fig_dpi=200, th=TH, layer="backbone.bottom_up.res5.2.conv3")
200
+ st.pyplot(fig)
201
+
202
+ FILE = "./test.jpg"
203
+ MODEL = "./models/model.pth"
204
+ PCA_MODEL = "./models/pca.pkl"
205
+ FEATURES_DATABASE = "./assets/features/features.json"
206
+
207
+ DISTANCE = "cosine"
208
+ TH = 0.5
209
+
210
+ state = st.empty()
211
+ tooltip = st.empty()
212
+
213
+ state.write("Loading model...")
214
+ model = load_model()
215
+
216
+ img = cv2.imread(FILE)
217
+ img = cv2.resize(img, (800, 800))
218
+ explain(img, model)
plots/gradcam/detectron2_gradcam.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Author: Alexander Riedel
2
+ # License: Unlicensed
3
+ # Link: https://github.com/alexriedel1/detectron2-GradCAM
4
+
5
+ from plots.gradcam.gradcam import GradCAM, GradCamPlusPlus
6
+ import detectron2.data.transforms as T
7
+ import torch
8
+ from detectron2.checkpoint import DetectionCheckpointer
9
+ from detectron2.config import get_cfg
10
+ from detectron2.data import DatasetCatalog, MetadataCatalog
11
+ from detectron2.data.detection_utils import read_image
12
+ from detectron2.modeling import build_model
13
+ from detectron2.data.datasets import register_coco_instances
14
+
15
+ class Detectron2GradCAM():
16
+ """
17
+ Attributes
18
+ ----------
19
+ config_file : str
20
+ detectron2 model config file path
21
+ cfg_list : list
22
+ List of additional model configurations
23
+ root_dir : str [optional]
24
+ directory of coco.josn and dataset images for custom dataset registration
25
+ custom_dataset : str [optional]
26
+ Name of the custom dataset to register
27
+ """
28
+ def __init__(self, config_file, cfg_list, root_dir=None, custom_dataset=None):
29
+ # load config from file
30
+ cfg = get_cfg()
31
+ cfg.merge_from_file(config_file)
32
+
33
+ if custom_dataset:
34
+ register_coco_instances(custom_dataset, {}, root_dir + "coco.json", root_dir)
35
+ cfg.DATASETS.TRAIN = (custom_dataset,)
36
+ MetadataCatalog.get(custom_dataset)
37
+ DatasetCatalog.get(custom_dataset)
38
+
39
+ if torch.cuda.is_available():
40
+ cfg.MODEL.DEVICE = "cuda"
41
+ else:
42
+ cfg.MODEL.DEVICE = "cpu"
43
+
44
+ cfg.merge_from_list(cfg_list)
45
+ cfg.freeze()
46
+
47
+ self.cfg = cfg
48
+
49
+ def _get_input_dict(self, original_image):
50
+ height, width = original_image.shape[:2]
51
+ transform_gen = T.ResizeShortestEdge(
52
+ [self.cfg.INPUT.MIN_SIZE_TEST, self.cfg.INPUT.MIN_SIZE_TEST], self.cfg.INPUT.MAX_SIZE_TEST
53
+ )
54
+ image = transform_gen.get_transform(original_image).apply_image(original_image)
55
+ image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)).requires_grad_(True)
56
+ inputs = {"image": image, "height": height, "width": width}
57
+ return inputs
58
+
59
+ def get_cam(self, img, target_instance, layer_name, grad_cam_type="GradCAM"):
60
+ """
61
+ Calls the GradCAM++ instance
62
+
63
+ Parameters
64
+ ----------
65
+ img : str
66
+ Path to inference image
67
+ target_instance : int
68
+ The target instance index
69
+ layer_name : str
70
+ Convolutional layer to perform GradCAM on
71
+ grad_cam_type : str
72
+ GradCAM or GradCAM++ (for multiple instances of the same object, GradCAM++ can be favorable)
73
+
74
+ Returns
75
+ -------
76
+ image_dict : dict
77
+ {"image" : <image>, "cam" : <cam>, "output" : <output>, "label" : <label>}
78
+ <image> original input image
79
+ <cam> class activation map resized to original image shape
80
+ <output> instances object generated by the model
81
+ <label> label of the
82
+ cam_orig : numpy.ndarray
83
+ unprocessed raw cam
84
+ """
85
+ model = build_model(self.cfg)
86
+ checkpointer = DetectionCheckpointer(model)
87
+ checkpointer.load(self.cfg.MODEL.WEIGHTS)
88
+
89
+ image = read_image(img, format="BGR")
90
+ input_image_dict = self._get_input_dict(image)
91
+
92
+ if grad_cam_type == "GradCAM":
93
+ grad_cam = GradCAM(model, layer_name)
94
+
95
+ elif grad_cam_type == "GradCAM++":
96
+ grad_cam = GradCamPlusPlus(model, layer_name)
97
+
98
+ else:
99
+ raise ValueError('Grad CAM type not specified')
100
+
101
+ with grad_cam as cam:
102
+ cam, cam_orig, output = cam(input_image_dict, target_category=target_instance)
103
+
104
+ image_dict = {}
105
+ image_dict["image"] = image
106
+ image_dict["cam"] = cam
107
+ image_dict["output"] = output[0]
108
+ image_dict["label"] = MetadataCatalog.get(self.cfg.DATASETS.TRAIN[0]).thing_classes[output[0]["instances"].pred_classes[target_instance]]
109
+ return image_dict, cam_orig
plots/gradcam/gradcam.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Author: Alexander Riedel
2
+ # License: Unlicensed
3
+ # Link: https://github.com/alexriedel1/detectron2-GradCAM
4
+
5
+ import cv2
6
+ import numpy as np
7
+
8
+ class GradCAM():
9
+ """
10
+ Class to implement the GradCam function with it's necessary Pytorch hooks.
11
+
12
+ Attributes
13
+ ----------
14
+ model : detectron2 GeneralizedRCNN Model
15
+ A model using the detectron2 API for inferencing
16
+ layer_name : str
17
+ name of the convolutional layer to perform GradCAM with
18
+ """
19
+
20
+ def __init__(self, model, target_layer_name):
21
+ self.model = model
22
+ self.target_layer_name = target_layer_name
23
+ self.activations = None
24
+ self.gradient = None
25
+ self.model.eval()
26
+ self.activations_grads = []
27
+ self._register_hook()
28
+
29
+ def _get_activations_hook(self, module, input, output):
30
+ self.activations = output
31
+
32
+ def _get_grads_hook(self, module, input_grad, output_grad):
33
+ self.gradient = output_grad[0]
34
+
35
+ def _register_hook(self):
36
+ for (name, module) in self.model.named_modules():
37
+ if name == self.target_layer_name:
38
+ self.activations_grads.append(module.register_forward_hook(self._get_activations_hook))
39
+ self.activations_grads.append(module.register_backward_hook(self._get_grads_hook))
40
+ return True
41
+ print(f"Layer {self.target_layer_name} not found in Model!")
42
+
43
+ def _release_activations_grads(self):
44
+ for handle in self.activations_grads:
45
+ handle.remove()
46
+
47
+ def _postprocess_cam(self, raw_cam, img_width, img_height):
48
+ cam_orig = np.sum(raw_cam, axis=0) # [H,W]
49
+ cam_orig = np.maximum(cam_orig, 0) # ReLU
50
+ cam_orig -= np.min(cam_orig)
51
+ cam_orig /= np.max(cam_orig)
52
+ cam = cv2.resize(cam_orig, (img_width, img_height))
53
+ return cam, cam_orig
54
+
55
+ def __enter__(self):
56
+ return self
57
+
58
+ def __exit__(self, exc_type, exc_value, exc_tb):
59
+ self._release_activations_grads()
60
+
61
+ def __call__(self, inputs, target_category):
62
+ """
63
+ Calls the GradCAM++ instance
64
+
65
+ Parameters
66
+ ----------
67
+ inputs : dict
68
+ The input in the standard detectron2 model input format
69
+ https://detectron2.readthedocs.io/en/latest/tutorials/models.html#model-input-format
70
+
71
+ target_category : int, optional
72
+ The target category index. If `None` the highest scoring class will be selected
73
+
74
+ Returns
75
+ -------
76
+ cam : np.array()
77
+ Gradient weighted class activation map
78
+ output : list
79
+ list of Instance objects representing the detectron2 model output
80
+ """
81
+ self.model.zero_grad()
82
+ output = self.model.forward([inputs])
83
+
84
+ if target_category == None:
85
+ target_category = np.argmax(output[0]['instances'].scores.cpu().data.numpy(), axis=-1)
86
+
87
+ score = output[0]['instances'].scores[target_category]
88
+ #box0 = output[0]['instances'].pred_boxes[0].tensor[0][target_category]
89
+ #print(box0)
90
+ #box0.backward()
91
+ score.backward()
92
+
93
+ gradient = self.gradient[0].cpu().data.numpy() # [C,H,W]
94
+ activations = self.activations[0].cpu().data.numpy() # [C,H,W]
95
+ weight = np.mean(gradient, axis=(1, 2)) # [C]
96
+
97
+ cam = activations * weight[:, np.newaxis, np.newaxis] # [C,H,W]
98
+ cam, cam_orig = self._postprocess_cam(cam, inputs["width"], inputs["height"])
99
+
100
+ return cam, cam_orig, output
101
+
102
+ class GradCamPlusPlus(GradCAM):
103
+ """
104
+ Subclass to implement the GradCam++ function with it's necessary PyTorch hooks.
105
+ ...
106
+
107
+ Attributes
108
+ ----------
109
+ model : detectron2 GeneralizedRCNN Model
110
+ A model using the detectron2 API for inferencing
111
+ target_layer_name : str
112
+ name of the convolutional layer to perform GradCAM++ with
113
+
114
+ """
115
+ def __init__(self, model, target_layer_name):
116
+ super(GradCamPlusPlus, self).__init__(model, target_layer_name)
117
+
118
+ def __call__(self, inputs, target_category):
119
+ """
120
+ Calls the GradCAM++ instance
121
+
122
+ Parameters
123
+ ----------
124
+ inputs : dict
125
+ The input in the standard detectron2 model input format
126
+ https://detectron2.readthedocs.io/en/latest/tutorials/models.html#model-input-format
127
+
128
+ target_category : int, optional
129
+ The target category index. If `None` the highest scoring class will be selected
130
+
131
+ Returns
132
+ -------
133
+ cam : np.array()
134
+ Gradient weighted class activation map
135
+ output : list
136
+ list of Instance objects representing the detectron2 model output
137
+ """
138
+ self.model.zero_grad()
139
+ output = self.model.forward([inputs])
140
+
141
+ if target_category == None:
142
+ target_category = np.argmax(output[0]['instances'].scores.cpu().data.numpy(), axis=-1)
143
+
144
+ score = output[0]['instances'].scores[target_category]
145
+ score.backward()
146
+
147
+ gradient = self.gradient[0].cpu().data.numpy() # [C,H,W]
148
+ activations = self.activations[0].cpu().data.numpy() # [C,H,W]
149
+
150
+ #from https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/grad_cam_plusplus.py
151
+ grads_power_2 = gradient**2
152
+ grads_power_3 = grads_power_2 * gradient
153
+ # Equation 19 in https://arxiv.org/abs/1710.11063
154
+ sum_activations = np.sum(activations, axis=(1, 2))
155
+ eps = 0.000001
156
+ aij = grads_power_2 / (2 * grads_power_2 +
157
+ sum_activations[:, None, None] * grads_power_3 + eps)
158
+ # Now bring back the ReLU from eq.7 in the paper,
159
+ # And zero out aijs where the activations are 0
160
+ aij = np.where(gradient != 0, aij, 0)
161
+
162
+ weights = np.maximum(gradient, 0) * aij
163
+ weight = np.sum(weights, axis=(1, 2))
164
+
165
+ cam = activations * weight[:, np.newaxis, np.newaxis] # [C,H,W]
166
+ cam, cam_orig = self._postprocess_cam(cam, inputs["width"], inputs["height"])
167
+
168
+ return cam, cam_orig, output
plots/make_plots.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import numpy as np
4
+ import pandas as pd
5
+ import matplotlib.pyplot as plt
6
+ import plotszoo
7
+
8
+
9
+ def get_hyperparameters(data_augmentation, sampler):
10
+ hp = ["lr", "rpn_loss_weight", "roi_heads_loss_weight", "rois_per_image"]
11
+ if data_augmentation == "full":
12
+ hp.extend(["random_brightness", "random_contrast"])
13
+ if data_augmentation == "full" or data_augmentation == "crop-flip":
14
+ hp.extend(["random_crop"])
15
+ if sampler == "RepeatFactorTrainingSampler":
16
+ hp.extend(["repeat_factor_th"])
17
+
18
+ return ["config/"+i for i in hp]
19
+
20
+
21
+ def plot_study():
22
+ query = {"$or": [{"config.wandb_tag": {"$eq": tag}} for tag in args.tags_study_replicas]}
23
+ data = plotszoo.data.WandbData(args.username, args.project, query, verbose=args.verbose)
24
+ data.pull_scalars(force_update=args.update_scalars)
25
+
26
+
27
+ group_keys = ["config/sampler", "config/data_augmentation"]
28
+
29
+ fig, axes = plt.subplots(nrows=2, ncols=2)
30
+
31
+ yticks_fn = lambda index: "Sampler: %s Data Augmentation: %s" % (index[0], index[1])
32
+
33
+ test_detection_plot = plotszoo.scalars.grouped.GroupedScalarsBarchart(data, group_keys, "summary/test/results/detection_accuracy")
34
+ test_classification_plot = plotszoo.scalars.grouped.GroupedScalarsBarchart(data, group_keys, "summary/test/results/classification_accuracy")
35
+
36
+ test_detection_df = test_detection_plot.plot(axes[0][0], title="Test Detection Accuracy", nbins=20, grid=True, yticks_fn=yticks_fn)
37
+ test_classification_df = test_classification_plot.plot(axes[0][1], title="Test Classification Accuracy", nbins=20, grid=True, yticks_fn=yticks_fn)
38
+
39
+ train_detection_plot = plotszoo.scalars.grouped.GroupedScalarsBarchart(data, group_keys, "summary/train/results/detection_accuracy")
40
+ train_classification_plot = plotszoo.scalars.grouped.GroupedScalarsBarchart(data, group_keys, "summary/train/results/classification_accuracy")
41
+
42
+ train_detection_df = train_detection_plot.plot(axes[1][0], title="Train Detection Accuracy", nbins=20, grid=True, yticks_fn=yticks_fn)
43
+ train_classification_df = train_classification_plot.plot(axes[1][1], title="Train Classification Accuracy", nbins=20, grid=True, yticks_fn=yticks_fn)
44
+
45
+ test_detection_df.to_excel(os.path.join(args.output_directory, "study/test_detection.xlsx"))
46
+ test_classification_df.to_excel(os.path.join(args.output_directory, "study/test_classification.xlsx"))
47
+ train_detection_df.to_excel(os.path.join(args.output_directory, "study/train_detection.xlsx"))
48
+ train_classification_df.to_excel(os.path.join(args.output_directory, "study/train_classification.xlsx"))
49
+
50
+
51
+ for ax in axes.flatten():
52
+ ax.set_xlim(xmin=0.5)
53
+
54
+ fig.set_size_inches(30, 10)
55
+ fig.tight_layout()
56
+
57
+ plotszoo.utils.savefig(fig, os.path.join(args.output_directory, "study.png"))
58
+
59
+
60
+ def plot_optimization_history(ax, data, dataset):
61
+ running_max = dict(accuracy=float("-inf"), detection_accuracy=float("-inf"), classification_accuracy=float("-inf"))
62
+ plots=dict(best_accuracy=[], best_detection_accuracy=[], best_classification_accuracy=[], accuracy=[], detection_accuracy=[], classification_accuracy=[])
63
+ plot_index = []
64
+ for i, row in data.scalars.iterrows():
65
+ if row["summary/"+dataset+"/results/accuracy"] > running_max["accuracy"]:
66
+ running_max = dict(
67
+ accuracy=row["summary/"+dataset+"/results/accuracy"],
68
+ detection_accuracy=row["summary/"+dataset+"/results/detection_accuracy"],
69
+ classification_accuracy=row["summary/"+dataset+"/results/classification_accuracy"]
70
+ )
71
+ plots["accuracy"].append(row["summary/"+dataset+"/results/accuracy"])
72
+ plots["detection_accuracy"].append(row["summary/"+dataset+"/results/detection_accuracy"])
73
+ plots["classification_accuracy"].append(row["summary/"+dataset+"/results/classification_accuracy"])
74
+
75
+ plots["best_accuracy"].append(running_max["accuracy"])
76
+ plots["best_detection_accuracy"].append(running_max["detection_accuracy"])
77
+ plots["best_classification_accuracy"].append(running_max["classification_accuracy"])
78
+
79
+ plot_index.append(i)
80
+
81
+ ax.plot(plot_index, plots["best_accuracy"], "k", label="Best "+dataset+" Accuracy")
82
+ ax.plot(plot_index, plots["best_detection_accuracy"], "b--", label="Best "+dataset+" Detection Accuracy")
83
+ ax.plot(plot_index, plots["best_classification_accuracy"], "g--", label="Best "+dataset+" Classification Accuracy")
84
+
85
+ ax.scatter(plot_index, plots["accuracy"], c="k", alpha=0.5)
86
+ ax.scatter(plot_index, plots["detection_accuracy"], c="b", alpha=0.5)
87
+ ax.scatter(plot_index, plots["classification_accuracy"], c="g", alpha=0.5)
88
+
89
+ ax.legend(loc="lower right")
90
+
91
+
92
+
93
+ def plot_optimization():
94
+ for tag, params in args.tags_optimization.items():
95
+ query = {"config.wandb_tag": {"$eq": tag}}
96
+
97
+ parameters = get_hyperparameters(**params)
98
+ parameters.extend(["summary/train/results/detection_accuracy", "summary/train/results/classification_accuracy"])
99
+
100
+ data = plotszoo.data.WandbData(args.username, args.project, query, verbose=args.verbose)
101
+ data.pull_scalars(force_update=args.update_scalars)
102
+ assert len(data.scalars) > 0, "No data, check the tag name"
103
+ data.pull_series(force_update=args.update_series)
104
+
105
+ data.astype(["summary/train/results/accuracy", "summary/train/results/detection_accuracy", "summary/train/results/classification_accuracy"], float)
106
+ data.dropna(["summary/train/results/accuracy"])
107
+
108
+ data.create_scalar_from_series("start_time", lambda s: s["_timestamp"].min())
109
+
110
+ fig, axes = plt.subplots(1, len(parameters), sharey=False)
111
+
112
+ parallel_plot = plotszoo.scalars.ScalarsParallelCoordinates(data, parameters, "summary/train/results/accuracy")
113
+
114
+ parallel_plot.plot(axes)
115
+
116
+ fig.set_size_inches(32, 10)
117
+ plotszoo.utils.savefig(fig, os.path.join(args.output_directory, tag, "optim_parallel.png"))
118
+
119
+ fig, ax = plt.subplots(2, 1)
120
+
121
+ plot_optimization_history(ax[0], data, "train")
122
+ plot_optimization_history(ax[1], data, "test")
123
+
124
+ fig.set_size_inches(20, 10)
125
+ plotszoo.utils.savefig(fig, os.path.join(args.output_directory, tag, "optim_history.png"))
126
+
127
+ parameters.remove("summary/train/results/detection_accuracy")
128
+ parameters.remove("summary/train/results/classification_accuracy")
129
+
130
+ args_names = [p.split("/")[1].replace("_","-") for p in parameters]
131
+ best_run = data.scalars["summary/train/results/accuracy"].idxmax()
132
+ best_args = "".join(["--%s %s " % (n, data.scalars[k][best_run]) for n, k in zip(args_names, parameters)])
133
+ best_args += "".join(["--%s %s " % (k.replace("_", "-"), v) for k, v in params.items()])
134
+ print(best_run)
135
+ print("Tag: %s" % tag)
136
+ print(data.scalars.loc[best_run][["summary/train/results/detection_accuracy", "summary/train/results/classification_accuracy"]])
137
+ print("HP: %s" % best_args)
138
+ print()
139
+
140
+ best_args_f = open(os.path.join(args.output_directory, tag, "best_args.txt"), "w")
141
+ best_args_f.write(best_args)
142
+ best_args_f.close()
143
+
144
+ def plot_replicas():
145
+ query = {"$or": [{"config.wandb_tag": {"$eq": tag}} for tag in args.tags_best_replicas]}
146
+ data = plotszoo.data.WandbData(args.username, args.project, query, verbose=args.verbose)
147
+ data.pull_scalars(force_update=args.update_scalars)
148
+
149
+ group_keys = ["config/sampler"]
150
+
151
+ fig, axes = plt.subplots(nrows=2, ncols=1)
152
+
153
+ yticks_fn = lambda index: "Sampler: %s" % (index, )
154
+
155
+ detection_plot = plotszoo.scalars.grouped.GroupedScalarsBarchart(data, group_keys, "summary/test/results/detection_accuracy")
156
+ classification_plot = plotszoo.scalars.grouped.GroupedScalarsBarchart(data, group_keys, "summary/test/results/classification_accuracy")
157
+
158
+ detection_df = detection_plot.plot(axes[0], title="Test Detection Accuracy", nbins=20, grid=True, yticks_fn=yticks_fn)
159
+ classification_df = classification_plot.plot(axes[1], title="Test Classification Accuracy", nbins=20, grid=True, yticks_fn=yticks_fn)
160
+
161
+ for ax in axes: ax.set_xlim(xmin=0.5)
162
+
163
+ fig.set_size_inches(20, 10)
164
+ fig.tight_layout()
165
+
166
+ classification_df.to_excel(os.path.join(args.output_directory, "result/classification.xlsx"))
167
+ detection_df.to_excel(os.path.join(args.output_directory, "result/detection.xlsx"))
168
+
169
+ print(classification_df)
170
+ print(detection_df)
171
+
172
+ plotszoo.utils.savefig(fig, os.path.join(args.output_directory, "results.png"))
173
+
174
+
175
+ def plot_tables():
176
+ query = {"$or": [{"config.wandb_tag": {"$eq": tag}} for tag in args.tags_best_replicas]}
177
+ data = plotszoo.data.WandbData(args.username, args.project, query, verbose=args.verbose)
178
+ data.pull_scalars(force_update=args.update_scalars)
179
+
180
+ group_keys = ["config/sampler"]
181
+ classes = ["neoplastic", "aphthous", "traumatic"]
182
+ metrics = ["precision", "recall", "f1-score"]
183
+
184
+ grouped_df = data.scalars.groupby(group_keys).agg(np.mean)
185
+ for group in grouped_df.index:
186
+ data_df = grouped_df.loc[group]
187
+ table = np.zeros((len(classes), len(metrics)))
188
+ for i, c in enumerate(classes):
189
+ for j, m in enumerate(metrics):
190
+ table[i, j] = data_df["summary/test/report/%s/%s" % (c, m)]*100
191
+
192
+ table_df = pd.DataFrame(table, columns=metrics, index=classes)
193
+ table_df.to_csv(os.path.join(args.output_directory, "%s_table.csv" % (group)))
194
+ print("Sampler: %s" % (group))
195
+ print(table_df)
196
+ print()
197
+
198
+
199
+ parser = argparse.ArgumentParser()
200
+
201
+ parser.add_argument("--output-directory", type=str, default="./plots")
202
+ parser.add_argument("--username", type=str, default="mlpi")
203
+ parser.add_argument("--project", type=str, default="oral-ai")
204
+ parser.add_argument("--tags-study-replicas", type=str, default=["study-3"], nargs="+")
205
+ parser.add_argument("--tags-optimization", type=dict, default={
206
+ "hp-optimization-none-trainingsampler-5": dict(
207
+ data_augmentation="none",
208
+ sampler="TrainingSampler"
209
+ ),
210
+ "hp-optimization-none-repeatfactortrainingsampler-5": dict(
211
+ data_augmentation="none",
212
+ sampler="RepeatFactorTrainingSampler"
213
+ )
214
+ }, nargs="+")
215
+ parser.add_argument("--tags-best-replicas", type=str, default=["best-replicas-7"], nargs="+")
216
+ parser.add_argument("--update-scalars", action="store_true")
217
+ parser.add_argument("--update-series", action="store_true")
218
+ parser.add_argument("--verbose", action="store_true")
219
+
220
+
221
+ args = parser.parse_args()
222
+
223
+ plot_study()
224
+ plot_optimization()
225
+ plot_replicas()
226
+ #plot_tables()
plots/plot_features.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import numpy as np
3
+ import json
4
+
5
+ from matplotlib import pyplot as plt
6
+ from matplotlib.colors import ListedColormap
7
+
8
+ from sklearn.decomposition import PCA
9
+ from sklearn.manifold import TSNE
10
+ from scipy import spatial
11
+
12
+ parser = argparse.ArgumentParser()
13
+
14
+ parser.add_argument("--features-database", type=str, required=True)
15
+ parser.add_argument("--decomposition", type=str, default="TSNE", choices=["TSNE", "PCA"])
16
+ parser.add_argument("--output", type=str, default="")
17
+ parser.add_argument("--fig-h", type=int, default=1080)
18
+ parser.add_argument("--fig-w", type=int, default=720)
19
+ parser.add_argument("--fig-dpi", type=int, default=100)
20
+ parser.add_argument("--distance", type=str, default="cosine")
21
+
22
+ parser.add_argument("--point", type=str, default="")
23
+
24
+ args = parser.parse_args()
25
+ point = None
26
+ if args.point != "":
27
+ point = json.loads(args.point)
28
+
29
+
30
+ dist_fn = getattr(spatial.distance, args.distance)
31
+ features_database = json.load(open(args.features_database, "r"))
32
+
33
+ features = []
34
+ classes = []
35
+ for name, feature_list in features_database.items():
36
+ for feature in feature_list:
37
+ features.append(feature["features"])
38
+ classes.append(feature["type"])
39
+
40
+ if point is not None:
41
+ features.append(point)
42
+
43
+ features = np.array(features)
44
+ classes = np.array(classes)
45
+
46
+ if args.decomposition == "TSNE":
47
+ decomposition = TSNE(n_components=2, metric=dist_fn)
48
+ elif args.decomposition == "PCA":
49
+ decomposition = PCA(n_components=2)
50
+ transformed = decomposition.fit_transform(features)
51
+
52
+ if point is not None:
53
+ transformed = transformed[:-1,:]
54
+ transformed_point = transformed[-1,:]
55
+
56
+ plt.figure(figsize=(args.fig_h/args.fig_dpi, args.fig_w/args.fig_dpi), dpi=args.fig_dpi)
57
+ cmap = ListedColormap(["r","b","g"])
58
+ scatter = plt.scatter(transformed[:, 0], transformed[:, 1], c=classes, cmap=cmap, s=10)
59
+
60
+ if point is not None:
61
+ plt.scatter(transformed_point[0], transformed_point[1], marker="x", s=200, c="k")
62
+
63
+ plt.legend(handles=scatter.legend_elements()[0], labels=["neoplastic", "aphthous", "traumatic"])
64
+
65
+ if args.output == "":
66
+ plt.show()
67
+ else:
68
+ plt.savefig(args.output, dpi=args.fig_dpi)
plots/plot_gradcam.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import matplotlib
4
+ import matplotlib.pyplot as plt
5
+ from types import SimpleNamespace
6
+
7
+ from detectron2.utils.visualizer import Visualizer
8
+ from detectron2.data import Metadata
9
+ from detectron2 import model_zoo
10
+
11
+ from plots.gradcam.detectron2_gradcam import Detectron2GradCAM
12
+
13
+
14
+ def plot_gradcam(**kwargs):
15
+ kwargs = SimpleNamespace(**kwargs)
16
+
17
+ config_file = model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
18
+
19
+ cfg_list = [
20
+ "MODEL.ROI_HEADS.SCORE_THRESH_TEST", str(kwargs.th),
21
+ "MODEL.ROI_HEADS.NUM_CLASSES", "3",
22
+ "MODEL.WEIGHTS", kwargs.model
23
+ ]
24
+
25
+ metadata = Metadata()
26
+ metadata.set(
27
+ evaluator_type="coco",
28
+ thing_classes=["neoplastic", "aphthous", "traumatic"],
29
+ thing_dataset_id_to_contiguous_id={"1": 0, "2": 1, "3": 2}
30
+ )
31
+
32
+
33
+ cam_extractor = Detectron2GradCAM(config_file, cfg_list)
34
+ image_dict, cam_orig = cam_extractor.get_cam(img=kwargs.file, target_instance=kwargs.instance, layer_name=kwargs.layer, grad_cam_type="GradCAM++")
35
+
36
+ with torch.no_grad():
37
+ fig = plt.figure(figsize=(kwargs.fig_h/kwargs.fig_dpi, kwargs.fig_w/kwargs.fig_dpi), dpi=kwargs.fig_dpi)
38
+ v = Visualizer(image_dict["image"], metadata, scale=1.0)
39
+ img = image_dict["output"]["instances"][kwargs.instance]
40
+ img.remove("pred_masks")
41
+
42
+ out = v.draw_instance_predictions(img.to("cpu"))
43
+
44
+ plt.gca().set_axis_off()
45
+ plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0,
46
+ hspace = 0, wspace = 0)
47
+ plt.margins(0,0)
48
+ plt.imshow(out.get_image(), interpolation='none')
49
+ plt.imshow(image_dict["cam"], cmap='jet', alpha=0.5)
50
+
51
+ return fig
52
+
53
+
54
+ if __name__ == "__main__":
55
+ parser = argparse.ArgumentParser()
56
+
57
+ parser.add_argument("--model", type=str, required=True)
58
+ parser.add_argument("--layer", type=str, default="backbone.bottom_up.res5.2.conv3")
59
+ parser.add_argument("--th", type=float, default=0.5)
60
+ parser.add_argument("--file", type=str, required=True)
61
+ parser.add_argument("--instance", type=int, required=True)
62
+ parser.add_argument("--output", type=str, default="")
63
+ parser.add_argument("--fig-h", type=int, default=1080)
64
+ parser.add_argument("--fig-w", type=int, default=720)
65
+ parser.add_argument("--fig-dpi", type=int, default=100)
66
+
67
+ args = parser.parse_args()
68
+
69
+ plot_gradcam(**vars(args))
plots/plot_histogram_dist.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import numpy as np
3
+ import json
4
+ import pickle
5
+ from scipy import spatial
6
+
7
+ from matplotlib import pyplot as plt
8
+ from matplotlib.colors import ListedColormap
9
+
10
+
11
+ def plot_histogram_dist(features_database, fig_h, fig_w, fig_dpi, point, distance="cosine"):
12
+ features_database = json.load(open(features_database, "r"))
13
+ dist_fn = getattr(spatial.distance, distance)
14
+ class_names = ["neoplastic", "aphthous", "traumatic"]
15
+ cmap = ListedColormap(["r","b","g"])
16
+
17
+ dists = dict()
18
+ for name, feature_list in features_database.items():
19
+ for feature in feature_list:
20
+ if feature["type"] not in dists:
21
+ dists[feature["type"]] = []
22
+
23
+ dists[feature["type"]].append(dist_fn(point, feature["features"]))
24
+
25
+
26
+ fig, axes = plt.subplots(len(dists))
27
+
28
+ for k, ax in zip(dists.keys(), axes):
29
+ dist = dists[k]
30
+ ax.set_title(class_names[k])
31
+ ax.set_xlim(0, 1)
32
+ n, bins, patches = ax.hist(dist, "auto", density=True, color=cmap(k))
33
+
34
+ fig.tight_layout(pad=3.0)
35
+
36
+ return fig
37
+
38
+
39
+ if __name__ == "__main__":
40
+ parser = argparse.ArgumentParser()
41
+
42
+ parser.add_argument("--features-database", type=str, required=True)
43
+ parser.add_argument("--output", type=str, default="")
44
+ parser.add_argument("--fig-h", type=int, default=1080)
45
+ parser.add_argument("--fig-w", type=int, default=720)
46
+ parser.add_argument("--fig-dpi", type=int, default=100)
47
+ parser.add_argument("--distance", type=str, default="cosine")
48
+
49
+
50
+ parser.add_argument("--point", type=str, required=True)
51
+
52
+ args = parser.parse_args()
53
+
54
+ point = json.loads(args.point)
55
+
56
+ dict_args = vars(args)
57
+ del dict_args["point"]
58
+ plot_histogram_dist(**dict_args, point=point)
plots/plot_matrix_distance.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import argparse
3
+ from scipy import spatial
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from matplotlib import pyplot as plt
7
+
8
+ parser = argparse.ArgumentParser()
9
+
10
+ parser.add_argument("--rows", type=str, required=True)
11
+ parser.add_argument("--cols", type=str, required=True)
12
+ parser.add_argument("--distance", type=str, default="cosine")
13
+ parser.add_argument("--output", type=str, default="")
14
+
15
+ args = parser.parse_args()
16
+
17
+ rows_features = json.load(open(args.rows, "r"))
18
+ cols_features = json.load(open(args.cols, "r"))
19
+ dist_fn = getattr(spatial.distance, args.distance)
20
+
21
+
22
+ rows_features_rois = []
23
+ cols_features_rois = []
24
+
25
+ for row_feature in rows_features.values():
26
+ for roi_feature in row_feature:
27
+ rows_features_rois.append(roi_feature)
28
+
29
+ for col_feature in cols_features.values():
30
+ for roi_feature in col_feature:
31
+ cols_features_rois.append(roi_feature)
32
+
33
+
34
+ rows_features_rois = sorted(rows_features_rois, key=lambda e: e["type"])
35
+ cols_features_rois = sorted(cols_features_rois, key=lambda e: e["type"])
36
+
37
+
38
+ matrix = np.zeros((len(rows_features_rois), len(cols_features_rois)))
39
+ for i, row in tqdm(enumerate(rows_features_rois), total=len(rows_features_rois)):
40
+ for j, col in enumerate(cols_features_rois):
41
+ matrix[i, j] = dist_fn(row["features"], col["features"])
42
+
43
+ fig, ax = plt.subplots()
44
+
45
+ ax.set_xlabel(args.rows)
46
+ ax.set_ylabel(args.cols)
47
+
48
+ pos = ax.imshow(matrix)
49
+ fig.colorbar(pos, ax=ax)
50
+
51
+ if args.output == "":
52
+ plt.show()
53
+ else:
54
+ plt.savefig(args.output)
55
+
56
+
plots/plot_pca_point.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import numpy as np
3
+ import json
4
+ import pickle
5
+
6
+ from matplotlib import pyplot as plt
7
+ from matplotlib.colors import ListedColormap
8
+
9
+ from sklearn.decomposition import PCA
10
+
11
+
12
+ def plot_pca_point(features_database, pca_model, fig_h, fig_w, fig_dpi, point):
13
+ features_database = json.load(open(features_database, "r"))
14
+ pca = pickle.load(open(pca_model, "rb"))
15
+
16
+ features = []
17
+ classes = []
18
+ for name, feature_list in features_database.items():
19
+ for feature in feature_list:
20
+ features.append(feature["features"])
21
+ classes.append(feature["type"])
22
+
23
+ features = np.array(features)
24
+ classes = np.array(classes)
25
+
26
+ features = pca.transform(features)
27
+ point = pca.transform(np.atleast_2d(point))
28
+
29
+ fig = plt.figure(figsize=(fig_h/fig_dpi, fig_w/fig_dpi), dpi=fig_dpi)
30
+ cmap = ListedColormap(["r","b","g"])
31
+ scatter = plt.scatter(features[:, 0], features[:, 1], c=classes, cmap=cmap, s=10)
32
+
33
+ plt.scatter(point[:, 0], point[:, 1], marker="x", s=200, c="k")
34
+
35
+ plt.legend(handles=scatter.legend_elements()[0], labels=["neoplastic", "aphthous", "traumatic"])
36
+
37
+ return fig
38
+
39
+ if __name__ == "__main__":
40
+ parser = argparse.ArgumentParser()
41
+
42
+ parser.add_argument("--features-database", type=str, required=True)
43
+ parser.add_argument("--pca-model", type=str, required=True)
44
+ parser.add_argument("--output", type=str, default="")
45
+ parser.add_argument("--fig-h", type=int, default=1080)
46
+ parser.add_argument("--fig-w", type=int, default=720)
47
+ parser.add_argument("--fig-dpi", type=int, default=100)
48
+
49
+ parser.add_argument("--point", type=str, required=True)
50
+
51
+ args = parser.parse_args()
52
+
53
+ point = json.loads(args.point)
54
+
55
+ dict_args = vars(args)
56
+ del dict_args["point"]
57
+ plot_pca_point(**dict_args, point=point)