hank1996 commited on
Commit
7e10b33
·
1 Parent(s): 79c80cd

Create new file

Browse files
Files changed (1) hide show
  1. lib/core/general.py +466 -0
lib/core/general.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import glob
3
+ import logging
4
+ import os
5
+ import platform
6
+ import random
7
+ import re
8
+ import shutil
9
+ import subprocess
10
+ import time
11
+ import torchvision
12
+ from contextlib import contextmanager
13
+ from copy import copy
14
+ from pathlib import Path
15
+
16
+ import cv2
17
+ import math
18
+ import matplotlib
19
+ import matplotlib.pyplot as plt
20
+ import numpy as np
21
+ import torch
22
+ import torch.nn as nn
23
+ import yaml
24
+ from PIL import Image
25
+ from scipy.cluster.vq import kmeans
26
+ from scipy.signal import butter, filtfilt
27
+ from tqdm import tqdm
28
+
29
+
30
+ def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-9):
31
+ # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
32
+ box2 = box2.T
33
+
34
+ # Get the coordinates of bounding boxes
35
+ if x1y1x2y2: # x1, y1, x2, y2 = box1
36
+ b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
37
+ b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
38
+ else: # transform from xywh to xyxy
39
+ b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
40
+ b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
41
+ b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
42
+ b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
43
+
44
+ # Intersection area
45
+ inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
46
+ (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
47
+
48
+ # Union Area
49
+ w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
50
+ w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
51
+ union = w1 * h1 + w2 * h2 - inter + eps
52
+
53
+ iou = inter / union
54
+ if GIoU or DIoU or CIoU:
55
+ cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
56
+ ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
57
+ if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
58
+ c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
59
+ rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
60
+ (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared
61
+ if DIoU:
62
+ return iou - rho2 / c2 # DIoU
63
+ elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
64
+ v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
65
+ with torch.no_grad():
66
+ alpha = v / ((1 + eps) - iou + v)
67
+ return iou - (rho2 / c2 + v * alpha) # CIoU
68
+ else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
69
+ c_area = cw * ch + eps # convex area
70
+ return iou - (c_area - union) / c_area # GIoU
71
+ else:
72
+ return iou # IoU
73
+
74
+
75
+ def box_iou(box1, box2):
76
+ # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
77
+ """
78
+ Return intersection-over-union (Jaccard index) of boxes.
79
+ Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
80
+ Arguments:
81
+ box1 (Tensor[N, 4])
82
+ box2 (Tensor[M, 4])
83
+ Returns:
84
+ iou (Tensor[N, M]): the NxM matrix containing the pairwise
85
+ IoU values for every element in boxes1 and boxes2
86
+ """
87
+
88
+ def box_area(box):
89
+ # box = 4xn
90
+ return (box[2] - box[0]) * (box[3] - box[1]) #(x2-x1)*(y2-y1)
91
+
92
+ area1 = box_area(box1.T)
93
+ area2 = box_area(box2.T)
94
+
95
+ # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
96
+ inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
97
+ return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
98
+
99
+ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()):
100
+ """Performs Non-Maximum Suppression (NMS) on inference results
101
+ Returns:
102
+ detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
103
+ """
104
+
105
+ nc = prediction.shape[2] - 5 # number of classes
106
+ xc = prediction[..., 4] > conf_thres # candidates
107
+
108
+ # Settings
109
+ min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
110
+ max_det = 300 # maximum number of detections per image
111
+ max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
112
+ time_limit = 10.0 # seconds to quit after
113
+ redundant = True # require redundant detections
114
+ multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
115
+ merge = False # use merge-NMS
116
+
117
+ t = time.time()
118
+ output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
119
+ for xi, x in enumerate(prediction): # image index, image inference
120
+ # Apply constraints
121
+ # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
122
+ x = x[xc[xi]] # confidence
123
+
124
+ # Cat apriori labels if autolabelling
125
+ if labels and len(labels[xi]):
126
+ l = labels[xi]
127
+ v = torch.zeros((len(l), nc + 5), device=x.device)
128
+ v[:, :4] = l[:, 1:5] # box
129
+ v[:, 4] = 1.0 # conf
130
+ v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
131
+ x = torch.cat((x, v), 0)
132
+
133
+ # If none remain process next image
134
+ if not x.shape[0]:
135
+ continue
136
+
137
+ # Compute conf
138
+ x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
139
+
140
+ # Box (center x, center y, width, height) to (x1, y1, x2, y2)
141
+ box = xywh2xyxy(x[:, :4])
142
+
143
+ # Detections matrix nx6 (xyxy, conf, cls)
144
+ if multi_label:
145
+ i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
146
+ x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
147
+ else: # best class only
148
+ conf, j = x[:, 5:].max(1, keepdim=True)
149
+ x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
150
+
151
+ # Filter by class
152
+ if classes is not None:
153
+ x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
154
+
155
+ # Apply finite constraint
156
+ # if not torch.isfinite(x).all():
157
+ # x = x[torch.isfinite(x).all(1)]
158
+
159
+ # Check shape
160
+ n = x.shape[0] # number of boxes
161
+ if not n: # no boxes
162
+ continue
163
+ elif n > max_nms: # excess boxes
164
+ x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
165
+
166
+ # Batched NMS
167
+ c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
168
+ boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
169
+ i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
170
+ if i.shape[0] > max_det: # limit detections
171
+ i = i[:max_det]
172
+ if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
173
+ # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
174
+ iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
175
+ weights = iou * scores[None] # box weights
176
+ x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
177
+ if redundant:
178
+ i = i[iou.sum(1) > 1] # require redundancy
179
+
180
+ output[xi] = x[i]
181
+ if (time.time() - t) > time_limit:
182
+ print(f'WARNING: NMS time limit {time_limit}s exceeded')
183
+ break # time limit exceeded
184
+
185
+ return output
186
+
187
+
188
+ def xywh2xyxy(x):
189
+ # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
190
+ y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x)
191
+ y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
192
+ y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
193
+ y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
194
+ y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
195
+ return y
196
+
197
+ def fitness(x):
198
+ # Returns fitness (for use with results.txt or evolve.txt)
199
+ w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, [email protected], [email protected]:0.95]
200
+ return (x[:, :4] * w).sum(1)
201
+
202
+ def check_img_size(img_size, s=32):
203
+ # Verify img_size is a multiple of stride s
204
+ new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
205
+ if new_size != img_size:
206
+ print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size))
207
+ return new_size
208
+
209
+ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
210
+ # Rescale coords (xyxy) from img1_shape to img0_shape
211
+ if ratio_pad is None: # calculate from img0_shape
212
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
213
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
214
+ else:
215
+ gain = ratio_pad[0][0]
216
+ pad = ratio_pad[1]
217
+
218
+ coords[:, [0, 2]] -= pad[0] # x padding
219
+ coords[:, [1, 3]] -= pad[1] # y padding
220
+ coords[:, :4] /= gain
221
+ clip_coords(coords, img0_shape)
222
+ return coords
223
+
224
+ def clip_coords(boxes, img_shape):
225
+ # Clip bounding xyxy bounding boxes to image shape (height, width)
226
+ boxes[:, 0].clamp_(0, img_shape[1]) # x1
227
+ boxes[:, 1].clamp_(0, img_shape[0]) # y1
228
+ boxes[:, 2].clamp_(0, img_shape[1]) # x2
229
+ boxes[:, 3].clamp_(0, img_shape[0]) # y2
230
+
231
+ def make_divisible(x, divisor):
232
+ # Returns x evenly divisible by divisor
233
+ return math.ceil(x / divisor) * divisor
234
+
235
+ def xyxy2xywh(x):
236
+ # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
237
+ y = torch.zeros_like(x) if isinstance(x, torch.Tensor) else np.zeros_like(x)
238
+ y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
239
+ y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
240
+ y[:, 2] = x[:, 2] - x[:, 0] # width
241
+ y[:, 3] = x[:, 3] - x[:, 1] # height
242
+ return y
243
+
244
+ def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16):
245
+ # Plot image grid with labels
246
+
247
+ if isinstance(images, torch.Tensor):
248
+ images = images.cpu().float().numpy()
249
+ if isinstance(targets, torch.Tensor):
250
+ targets = targets.cpu().numpy()
251
+
252
+ # un-normalise
253
+ if np.max(images[0]) <= 1:
254
+ images *= 255
255
+
256
+ tl = 3 # line thickness
257
+ tf = max(tl - 1, 1) # font thickness
258
+ bs, _, h, w = images.shape # batch size, _, height, width
259
+ bs = min(bs, max_subplots) # limit plot images
260
+ ns = np.ceil(bs ** 0.5) # number of subplots (square)
261
+
262
+ # Check if we should resize
263
+ scale_factor = max_size / max(h, w)
264
+ if scale_factor < 1:
265
+ h = math.ceil(scale_factor * h)
266
+ w = math.ceil(scale_factor * w)
267
+
268
+ colors = color_list() # list of colors
269
+ mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
270
+ for i, img in enumerate(images):
271
+ if i == max_subplots: # if last batch has fewer images than we expect
272
+ break
273
+
274
+ block_x = int(w * (i // ns))
275
+ block_y = int(h * (i % ns))
276
+
277
+ img = img.transpose(1, 2, 0)
278
+ if scale_factor < 1:
279
+ img = cv2.resize(img, (w, h))
280
+
281
+ mosaic[block_y:block_y + h, block_x:block_x + w, :] = img
282
+ if len(targets) > 0:
283
+ image_targets = targets[targets[:, 0] == i]
284
+ boxes = xywh2xyxy(image_targets[:, 2:6]).T
285
+ classes = image_targets[:, 1].astype('int')
286
+ labels = image_targets.shape[1] == 6 # labels if no conf column
287
+ conf = None if labels else image_targets[:, 6] # check for confidence presence (label vs pred)
288
+
289
+ if boxes.shape[1]:
290
+ if boxes.max() <= 1.01: # if normalized with tolerance 0.01
291
+ boxes[[0, 2]] *= w # scale to pixels
292
+ boxes[[1, 3]] *= h
293
+ elif scale_factor < 1: # absolute coords need scale if image scales
294
+ boxes *= scale_factor
295
+ boxes[[0, 2]] += block_x
296
+ boxes[[1, 3]] += block_y
297
+ for j, box in enumerate(boxes.T):
298
+ cls = int(classes[j])
299
+ color = colors[cls % len(colors)]
300
+ cls = names[cls] if names else cls
301
+ if labels or conf[j] > 0.25: # 0.25 conf thresh
302
+ label = '%s' % cls if labels else '%s %.1f' % (cls, conf[j])
303
+ plot_one_box(box, mosaic, label=label, color=color, line_thickness=tl)
304
+
305
+ # Draw image filename labels
306
+ if paths:
307
+ label = Path(paths[i]).name[:40] # trim to 40 char
308
+ t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
309
+ cv2.putText(mosaic, label, (block_x + 5, block_y + t_size[1] + 5), 0, tl / 3, [220, 220, 220], thickness=tf,
310
+ lineType=cv2.LINE_AA)
311
+
312
+ # Image border
313
+ cv2.rectangle(mosaic, (block_x, block_y), (block_x + w, block_y + h), (255, 255, 255), thickness=3)
314
+
315
+ if fname:
316
+ r = min(1280. / max(h, w) / ns, 1.0) # ratio to limit image size
317
+ mosaic = cv2.resize(mosaic, (int(ns * w * r), int(ns * h * r)), interpolation=cv2.INTER_AREA)
318
+ # cv2.imwrite(fname, cv2.cvtColor(mosaic, cv2.COLOR_BGR2RGB)) # cv2 save
319
+ Image.fromarray(mosaic).save(fname) # PIL save
320
+ return mosaic
321
+
322
+ def plot_one_box(x, img, color=None, label=None, line_thickness=None):
323
+ # Plots one bounding box on image img
324
+ tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
325
+ color = color or [random.randint(0, 255) for _ in range(3)]
326
+ c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
327
+ cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
328
+ if label:
329
+ tf = max(tl - 1, 1) # font thickness
330
+ t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
331
+ c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
332
+ cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled
333
+ cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
334
+
335
+ def color_list():
336
+ # Return first 10 plt colors as (r,g,b) https://stackoverflow.com/questions/51350872/python-from-color-name-to-rgb
337
+ def hex2rgb(h):
338
+ return tuple(int(str(h[1 + i:1 + i + 2]), 16) for i in (0, 2, 4))
339
+
340
+ return [hex2rgb(h) for h in plt.rcParams['axes.prop_cycle'].by_key()['color']]
341
+
342
+ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='precision-recall_curve.png', names=[]):
343
+ """ Compute the average precision, given the recall and precision curves.
344
+ Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
345
+ # Arguments
346
+ tp: True positives (nparray, nx1 or nx10).
347
+ conf: Objectness value from 0-1 (nparray).
348
+ pred_cls: Predicted object classes (nparray).
349
+ target_cls: True object classes (nparray).
350
+ plot: Plot precision-recall curve at [email protected]
351
+ save_dir: Plot save directory
352
+ # Returns
353
+ The average precision as computed in py-faster-rcnn.
354
+ """
355
+
356
+ # Sort by objectness
357
+ i = np.argsort(-conf)
358
+ tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
359
+
360
+ # Find unique classes
361
+ unique_classes = np.unique(target_cls)
362
+
363
+ # Create Precision-Recall curve and compute AP for each class
364
+ px, py = np.linspace(0, 1, 1000), [] # for plotting
365
+ pr_score = 0.1 # score to evaluate P and R https://github.com/ultralytics/yolov3/issues/898
366
+ s = [unique_classes.shape[0], tp.shape[1]] # number class, number iou thresholds (i.e. 10 for mAP0.5...0.95)
367
+ ap, p, r = np.zeros(s), np.zeros((unique_classes.shape[0], 1000)), np.zeros((unique_classes.shape[0], 1000))
368
+ for ci, c in enumerate(unique_classes):
369
+ i = pred_cls == c
370
+ n_l = (target_cls == c).sum() # number of labels
371
+ n_p = i.sum() # number of predictions
372
+
373
+ if n_p == 0 or n_l == 0:
374
+ continue
375
+ else:
376
+ # Accumulate FPs and TPs
377
+ fpc = (1 - tp[i]).cumsum(0)
378
+ tpc = tp[i].cumsum(0)
379
+
380
+ # Recall
381
+ recall = tpc / (n_l + 1e-16) # recall curve
382
+ r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases
383
+
384
+ # Precision
385
+ precision = tpc / (tpc + fpc) # precision curve
386
+ p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score
387
+ # AP from recall-precision curve
388
+ for j in range(tp.shape[1]):
389
+ ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
390
+ if plot and (j == 0):
391
+ py.append(np.interp(px, mrec, mpre)) # precision at [email protected]
392
+
393
+ # Compute F1 score (harmonic mean of precision and recall)
394
+ f1 = 2 * p * r / (p + r + 1e-16)
395
+ i=r.mean(0).argmax()
396
+
397
+ if plot:
398
+ plot_pr_curve(px, py, ap, save_dir, names)
399
+
400
+ return p[:, i], r[:, i], ap, f1[:, i], unique_classes.astype('int32')
401
+
402
+ def compute_ap(recall, precision):
403
+ """ Compute the average precision, given the recall and precision curves.
404
+ Source: https://github.com/rbgirshick/py-faster-rcnn.
405
+ # Arguments
406
+ recall: The recall curve (list).
407
+ precision: The precision curve (list).
408
+ # Returns
409
+ The average precision as computed in py-faster-rcnn.
410
+ """
411
+
412
+ # Append sentinel values to beginning and end
413
+ mrec = np.concatenate(([0.], recall, [recall[-1] + 1E-3]))
414
+ mpre = np.concatenate(([1.], precision, [0.]))
415
+
416
+ # Compute the precision envelope
417
+ mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
418
+
419
+ # Integrate area under curve
420
+ method = 'interp' # methods: 'continuous', 'interp'
421
+ if method == 'interp':
422
+ x = np.linspace(0, 1, 101) # 101-point interp (COCO)
423
+ ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
424
+
425
+ else: # 'continuous'
426
+ i = np.where(mrec[1:] != mrec[:-1])[0] # points where x axis (recall) changes
427
+ ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
428
+
429
+ return ap, mpre, mrec
430
+
431
+ def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
432
+ # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
433
+ # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
434
+ # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
435
+ # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
436
+ # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
437
+ x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
438
+ 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
439
+ 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
440
+ return x
441
+
442
+ def output_to_target(output):
443
+ # Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
444
+ targets = []
445
+ for i, o in enumerate(output):
446
+ for *box, conf, cls in o.cpu().numpy():
447
+ targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf])
448
+ return np.array(targets)
449
+
450
+ def plot_pr_curve(px, py, ap, save_dir='.', names=()):
451
+ fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
452
+ py = np.stack(py, axis=1)
453
+
454
+ if 0 < len(names) < 21: # show mAP in legend if < 10 classes
455
+ for i, y in enumerate(py.T):
456
+ ax.plot(px, y, linewidth=1, label=f'{names[i]} %.3f' % ap[i, 0]) # plot(recall, precision)
457
+ else:
458
+ ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision)
459
+
460
+ ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f [email protected]' % ap[:, 0].mean())
461
+ ax.set_xlabel('Recall')
462
+ ax.set_ylabel('Precision')
463
+ ax.set_xlim(0, 1)
464
+ ax.set_ylim(0, 1)
465
+ plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
466
+ fig.savefig(Path(save_dir) / 'precision_recall_curve.png', dpi=250)