hank1996 commited on
Commit
0d391f9
·
1 Parent(s): 1f1a29d

Create new file

Browse files
Files changed (1) hide show
  1. lib/utils/autoanchor.py +133 -0
lib/utils/autoanchor.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ # Auto-anchor utils
4
+
5
+ import numpy as np
6
+ import torch
7
+ import yaml
8
+ from scipy.cluster.vq import kmeans
9
+ from tqdm import tqdm
10
+ from lib.utils import is_parallel
11
+
12
+
13
+ def check_anchor_order(m):
14
+ # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
15
+ a = m.anchor_grid.prod(-1).view(-1) # anchor area
16
+ da = a[-1] - a[0] # delta a
17
+ ds = m.stride[-1] - m.stride[0] # delta s
18
+ if da.sign() != ds.sign(): # same order
19
+ print('Reversing anchor order')
20
+ m.anchors[:] = m.anchors.flip(0)
21
+ m.anchor_grid[:] = m.anchor_grid.flip(0)
22
+
23
+
24
+ def run_anchor(logger,dataset, model, thr=4.0, imgsz=640):
25
+ det = model.module.model[model.module.detector_index] if is_parallel(model) \
26
+ else model.model[model.detector_index]
27
+ anchor_num = det.na * det.nl
28
+ new_anchors = kmean_anchors(dataset, n=anchor_num, img_size=imgsz, thr=thr, gen=1000, verbose=False)
29
+ new_anchors = torch.tensor(new_anchors, device=det.anchors.device).type_as(det.anchors)
30
+ det.anchor_grid[:] = new_anchors.clone().view_as(det.anchor_grid) # for inference
31
+ det.anchors[:] = new_anchors.clone().view_as(det.anchors) / det.stride.to(det.anchors.device).view(-1, 1, 1) # loss
32
+ check_anchor_order(det)
33
+ logger.info(str(det.anchors))
34
+ print('New anchors saved to model. Update model config to use these anchors in the future.')
35
+
36
+
37
+ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=1000, verbose=True):
38
+ """ Creates kmeans-evolved anchors from training dataset
39
+ Arguments:
40
+ path: path to dataset *.yaml, or a loaded dataset
41
+ n: number of anchors
42
+ img_size: image size used for training
43
+ thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0
44
+ gen: generations to evolve anchors using genetic algorithm
45
+ verbose: print all results
46
+ Return:
47
+ k: kmeans evolved anchors
48
+ Usage:
49
+ from utils.autoanchor import *; _ = kmean_anchors()
50
+ """
51
+ thr = 1. / thr
52
+
53
+ def metric(k, wh): # compute metrics
54
+ r = wh[:, None] / k[None]
55
+ x = torch.min(r, 1. / r).min(2)[0] # ratio metric
56
+ # x = wh_iou(wh, torch.tensor(k)) # iou metric
57
+ return x, x.max(1)[0] # x, best_x
58
+
59
+ def anchor_fitness(k): # mutation fitness
60
+ _, best = metric(torch.tensor(k, dtype=torch.float32), wh)
61
+ return (best * (best > thr).float()).mean() # fitness
62
+
63
+ def print_results(k):
64
+ k = k[np.argsort(k.prod(1))] # sort small to large
65
+ x, best = metric(k, wh0)
66
+ bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr
67
+ print('thr=%.2f: %.4f best possible recall, %.2f anchors past thr' % (thr, bpr, aat))
68
+ print('n=%g, img_size=%s, metric_all=%.3f/%.3f-mean/best, past_thr=%.3f-mean: ' %
69
+ (n, img_size, x.mean(), best.mean(), x[x > thr].mean()), end='')
70
+ for i, x in enumerate(k):
71
+ print('%i,%i' % (round(x[0]), round(x[1])), end=', ' if i < len(k) - 1 else '\n') # use in *.cfg
72
+ return k
73
+
74
+ if isinstance(path, str): # not class
75
+ raise TypeError('Dataset must be class, but found str')
76
+ else:
77
+ dataset = path # dataset
78
+
79
+ labels = [db['label'] for db in dataset.db]
80
+ labels = np.vstack(labels)
81
+ if not (labels[:, 1:] <= 1).all():
82
+ # normalize label
83
+ labels[:, [2, 4]] /= dataset.shapes[0]
84
+ labels[:, [1, 3]] /= dataset.shapes[1]
85
+ # Get label wh
86
+ shapes = img_size * dataset.shapes / dataset.shapes.max()
87
+ # wh0 = np.concatenate([l[:, 3:5] * shapes for l in labels]) # wh
88
+ wh0 = labels[:, 3:5] * shapes
89
+ # Filter
90
+ i = (wh0 < 3.0).any(1).sum()
91
+ if i:
92
+ print('WARNING: Extremely small objects found. '
93
+ '%g of %g labels are < 3 pixels in width or height.' % (i, len(wh0)))
94
+ wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels
95
+
96
+ # Kmeans calculation
97
+ print('Running kmeans for %g anchors on %g points...' % (n, len(wh)))
98
+ s = wh.std(0) # sigmas for whitening
99
+ k, dist = kmeans(wh / s, n, iter=30) # points, mean distance
100
+ k *= s
101
+ wh = torch.tensor(wh, dtype=torch.float32) # filtered
102
+ wh0 = torch.tensor(wh0, dtype=torch.float32) # unfiltered
103
+ k = print_results(k)
104
+
105
+ # Plot
106
+ # k, d = [None] * 20, [None] * 20
107
+ # for i in tqdm(range(1, 21)):
108
+ # k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance
109
+ # fig, ax = plt.subplots(1, 2, figsize=(14, 7), tight_layout=True)
110
+ # ax = ax.ravel()
111
+ # ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.')
112
+ # fig, ax = plt.subplots(1, 2, figsize=(14, 7)) # plot wh
113
+ # ax[0].hist(wh[wh[:, 0]<100, 0],400)
114
+ # ax[1].hist(wh[wh[:, 1]<100, 1],400)
115
+ # fig.savefig('wh.png', dpi=200)
116
+
117
+ # Evolve
118
+ npr = np.random
119
+ f, sh, mp, s = anchor_fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
120
+ pbar = tqdm(range(gen), desc='Evolving anchors with Genetic Algorithm') # progress bar
121
+ for _ in pbar:
122
+ v = np.ones(sh)
123
+ while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
124
+ v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)
125
+ kg = (k.copy() * v).clip(min=2.0)
126
+ fg = anchor_fitness(kg)
127
+ if fg > f:
128
+ f, k = fg, kg.copy()
129
+ pbar.desc = 'Evolving anchors with Genetic Algorithm: fitness = %.4f' % f
130
+ if verbose:
131
+ print_results(k)
132
+
133
+ return print_results(k)