Spaces:
Sleeping
Sleeping
Create new file
Browse files- lib/core/function.py +508 -0
lib/core/function.py
ADDED
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
import time
|
4 |
+
from lib.core.evaluate import ConfusionMatrix,SegmentationMetric
|
5 |
+
from lib.core.general import non_max_suppression,check_img_size,scale_coords,xyxy2xywh,xywh2xyxy,box_iou,coco80_to_coco91_class,plot_images,ap_per_class,output_to_target
|
6 |
+
from lib.utils.utils import time_synchronized
|
7 |
+
from lib.utils import plot_img_and_mask,plot_one_box,show_seg_result
|
8 |
+
import torch
|
9 |
+
from threading import Thread
|
10 |
+
import numpy as np
|
11 |
+
from PIL import Image
|
12 |
+
from torchvision import transforms
|
13 |
+
from pathlib import Path
|
14 |
+
import json
|
15 |
+
import random
|
16 |
+
import cv2
|
17 |
+
import os
|
18 |
+
import math
|
19 |
+
from torch.cuda import amp
|
20 |
+
from tqdm import tqdm
|
21 |
+
|
22 |
+
|
23 |
+
def train(cfg, train_loader, model, criterion, optimizer, scaler, epoch, num_batch, num_warmup,
|
24 |
+
writer_dict, logger, device, rank=-1):
|
25 |
+
"""
|
26 |
+
train for one epoch
|
27 |
+
Inputs:
|
28 |
+
- config: configurations
|
29 |
+
- train_loader: loder for data
|
30 |
+
- model:
|
31 |
+
- criterion: (function) calculate all the loss, return total_loss, head_losses
|
32 |
+
- writer_dict:
|
33 |
+
outputs(2,)
|
34 |
+
output[0] len:3, [1,3,32,32,85], [1,3,16,16,85], [1,3,8,8,85]
|
35 |
+
output[1] len:1, [2,256,256]
|
36 |
+
output[2] len:1, [2,256,256]
|
37 |
+
target(2,)
|
38 |
+
target[0] [1,n,5]
|
39 |
+
target[1] [2,256,256]
|
40 |
+
target[2] [2,256,256]
|
41 |
+
Returns:
|
42 |
+
None
|
43 |
+
"""
|
44 |
+
batch_time = AverageMeter()
|
45 |
+
data_time = AverageMeter()
|
46 |
+
losses = AverageMeter()
|
47 |
+
|
48 |
+
# switch to train mode
|
49 |
+
model.train()
|
50 |
+
start = time.time()
|
51 |
+
for i, (input, target, paths, shapes) in enumerate(train_loader):
|
52 |
+
intermediate = time.time()
|
53 |
+
#print('tims:{}'.format(intermediate-start))
|
54 |
+
num_iter = i + num_batch * (epoch - 1)
|
55 |
+
|
56 |
+
if num_iter < num_warmup:
|
57 |
+
# warm up
|
58 |
+
lf = lambda x: ((1 + math.cos(x * math.pi / cfg.TRAIN.END_EPOCH)) / 2) * \
|
59 |
+
(1 - cfg.TRAIN.LRF) + cfg.TRAIN.LRF # cosine
|
60 |
+
xi = [0, num_warmup]
|
61 |
+
# model.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
|
62 |
+
for j, x in enumerate(optimizer.param_groups):
|
63 |
+
# bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
|
64 |
+
x['lr'] = np.interp(num_iter, xi, [cfg.TRAIN.WARMUP_BIASE_LR if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
|
65 |
+
if 'momentum' in x:
|
66 |
+
x['momentum'] = np.interp(num_iter, xi, [cfg.TRAIN.WARMUP_MOMENTUM, cfg.TRAIN.MOMENTUM])
|
67 |
+
|
68 |
+
data_time.update(time.time() - start)
|
69 |
+
if not cfg.DEBUG:
|
70 |
+
input = input.to(device, non_blocking=True)
|
71 |
+
assign_target = []
|
72 |
+
for tgt in target:
|
73 |
+
assign_target.append(tgt.to(device))
|
74 |
+
target = assign_target
|
75 |
+
with amp.autocast(enabled=device.type != 'cpu'):
|
76 |
+
outputs = model(input)
|
77 |
+
total_loss, head_losses = criterion(outputs, target, shapes,model)
|
78 |
+
# print(head_losses)
|
79 |
+
|
80 |
+
# compute gradient and do update step
|
81 |
+
optimizer.zero_grad()
|
82 |
+
scaler.scale(total_loss).backward()
|
83 |
+
scaler.step(optimizer)
|
84 |
+
scaler.update()
|
85 |
+
|
86 |
+
if rank in [-1, 0]:
|
87 |
+
# measure accuracy and record loss
|
88 |
+
losses.update(total_loss.item(), input.size(0))
|
89 |
+
|
90 |
+
# _, avg_acc, cnt, pred = accuracy(output.detach().cpu().numpy(),
|
91 |
+
# target.detach().cpu().numpy())
|
92 |
+
# acc.update(avg_acc, cnt)
|
93 |
+
|
94 |
+
# measure elapsed time
|
95 |
+
batch_time.update(time.time() - start)
|
96 |
+
end = time.time()
|
97 |
+
if i % cfg.PRINT_FREQ == 0:
|
98 |
+
msg = 'Epoch: [{0}][{1}/{2}]\t' \
|
99 |
+
'Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \
|
100 |
+
'Speed {speed:.1f} samples/s\t' \
|
101 |
+
'Data {data_time.val:.3f}s ({data_time.avg:.3f}s)\t' \
|
102 |
+
'Loss {loss.val:.5f} ({loss.avg:.5f})'.format(
|
103 |
+
epoch, i, len(train_loader), batch_time=batch_time,
|
104 |
+
speed=input.size(0)/batch_time.val,
|
105 |
+
data_time=data_time, loss=losses)
|
106 |
+
logger.info(msg)
|
107 |
+
|
108 |
+
writer = writer_dict['writer']
|
109 |
+
global_steps = writer_dict['train_global_steps']
|
110 |
+
writer.add_scalar('train_loss', losses.val, global_steps)
|
111 |
+
# writer.add_scalar('train_acc', acc.val, global_steps)
|
112 |
+
writer_dict['train_global_steps'] = global_steps + 1
|
113 |
+
|
114 |
+
|
115 |
+
def validate(epoch,config, val_loader, val_dataset, model, criterion, output_dir,
|
116 |
+
tb_log_dir, writer_dict=None, logger=None, device='cpu', rank=-1):
|
117 |
+
"""
|
118 |
+
validata
|
119 |
+
Inputs:
|
120 |
+
- config: configurations
|
121 |
+
- train_loader: loder for data
|
122 |
+
- model:
|
123 |
+
- criterion: (function) calculate all the loss, return
|
124 |
+
- writer_dict:
|
125 |
+
Return:
|
126 |
+
None
|
127 |
+
"""
|
128 |
+
# setting
|
129 |
+
max_stride = 32
|
130 |
+
weights = None
|
131 |
+
|
132 |
+
save_dir = output_dir + os.path.sep + 'visualization'
|
133 |
+
if not os.path.exists(save_dir):
|
134 |
+
os.mkdir(save_dir)
|
135 |
+
|
136 |
+
# print(save_dir)
|
137 |
+
_, imgsz = [check_img_size(x, s=max_stride) for x in config.MODEL.IMAGE_SIZE] #imgsz is multiple of max_stride
|
138 |
+
batch_size = config.TRAIN.BATCH_SIZE_PER_GPU * len(config.GPUS)
|
139 |
+
test_batch_size = config.TEST.BATCH_SIZE_PER_GPU * len(config.GPUS)
|
140 |
+
training = False
|
141 |
+
is_coco = False #is coco dataset
|
142 |
+
save_conf=False # save auto-label confidences
|
143 |
+
verbose=False
|
144 |
+
save_hybrid=False
|
145 |
+
log_imgs,wandb = min(16,100), None
|
146 |
+
|
147 |
+
nc = 1
|
148 |
+
iouv = torch.linspace(0.5,0.95,10).to(device) #iou vector for [email protected]:0.95
|
149 |
+
niou = iouv.numel()
|
150 |
+
|
151 |
+
try:
|
152 |
+
import wandb
|
153 |
+
except ImportError:
|
154 |
+
wandb = None
|
155 |
+
log_imgs = 0
|
156 |
+
|
157 |
+
seen = 0
|
158 |
+
confusion_matrix = ConfusionMatrix(nc=model.nc) #detector confusion matrix
|
159 |
+
da_metric = SegmentationMetric(config.num_seg_class) #segment confusion matrix
|
160 |
+
ll_metric = SegmentationMetric(2) #segment confusion matrix
|
161 |
+
|
162 |
+
names = {k: v for k, v in enumerate(model.names if hasattr(model, 'names') else model.module.names)}
|
163 |
+
colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]
|
164 |
+
coco91class = coco80_to_coco91_class()
|
165 |
+
|
166 |
+
s = ('%20s' + '%12s' * 6) % ('Class', 'Images', 'Targets', 'P', 'R', '[email protected]', '[email protected]:.95')
|
167 |
+
p, r, f1, mp, mr, map50, map, t_inf, t_nms = 0., 0., 0., 0., 0., 0., 0., 0., 0.
|
168 |
+
|
169 |
+
losses = AverageMeter()
|
170 |
+
|
171 |
+
da_acc_seg = AverageMeter()
|
172 |
+
da_IoU_seg = AverageMeter()
|
173 |
+
da_mIoU_seg = AverageMeter()
|
174 |
+
|
175 |
+
ll_acc_seg = AverageMeter()
|
176 |
+
ll_IoU_seg = AverageMeter()
|
177 |
+
ll_mIoU_seg = AverageMeter()
|
178 |
+
|
179 |
+
T_inf = AverageMeter()
|
180 |
+
T_nms = AverageMeter()
|
181 |
+
|
182 |
+
# switch to train mode
|
183 |
+
model.eval()
|
184 |
+
jdict, stats, ap, ap_class, wandb_images = [], [], [], [], []
|
185 |
+
|
186 |
+
for batch_i, (img, target, paths, shapes) in tqdm(enumerate(val_loader), total=len(val_loader)):
|
187 |
+
if not config.DEBUG:
|
188 |
+
img = img.to(device, non_blocking=True)
|
189 |
+
assign_target = []
|
190 |
+
for tgt in target:
|
191 |
+
assign_target.append(tgt.to(device))
|
192 |
+
target = assign_target
|
193 |
+
nb, _, height, width = img.shape #batch size, channel, height, width
|
194 |
+
|
195 |
+
with torch.no_grad():
|
196 |
+
pad_w, pad_h = shapes[0][1][1]
|
197 |
+
pad_w = int(pad_w)
|
198 |
+
pad_h = int(pad_h)
|
199 |
+
ratio = shapes[0][1][0][0]
|
200 |
+
|
201 |
+
t = time_synchronized()
|
202 |
+
det_out, da_seg_out, ll_seg_out= model(img)
|
203 |
+
t_inf = time_synchronized() - t
|
204 |
+
if batch_i > 0:
|
205 |
+
T_inf.update(t_inf/img.size(0),img.size(0))
|
206 |
+
|
207 |
+
inf_out,train_out = det_out
|
208 |
+
|
209 |
+
#driving area segment evaluation
|
210 |
+
_,da_predict=torch.max(da_seg_out, 1)
|
211 |
+
_,da_gt=torch.max(target[1], 1)
|
212 |
+
da_predict = da_predict[:, pad_h:height-pad_h, pad_w:width-pad_w]
|
213 |
+
da_gt = da_gt[:, pad_h:height-pad_h, pad_w:width-pad_w]
|
214 |
+
|
215 |
+
da_metric.reset()
|
216 |
+
da_metric.addBatch(da_predict.cpu(), da_gt.cpu())
|
217 |
+
da_acc = da_metric.pixelAccuracy()
|
218 |
+
da_IoU = da_metric.IntersectionOverUnion()
|
219 |
+
da_mIoU = da_metric.meanIntersectionOverUnion()
|
220 |
+
|
221 |
+
da_acc_seg.update(da_acc,img.size(0))
|
222 |
+
da_IoU_seg.update(da_IoU,img.size(0))
|
223 |
+
da_mIoU_seg.update(da_mIoU,img.size(0))
|
224 |
+
|
225 |
+
#lane line segment evaluation
|
226 |
+
_,ll_predict=torch.max(ll_seg_out, 1)
|
227 |
+
_,ll_gt=torch.max(target[2], 1)
|
228 |
+
ll_predict = ll_predict[:, pad_h:height-pad_h, pad_w:width-pad_w]
|
229 |
+
ll_gt = ll_gt[:, pad_h:height-pad_h, pad_w:width-pad_w]
|
230 |
+
|
231 |
+
ll_metric.reset()
|
232 |
+
ll_metric.addBatch(ll_predict.cpu(), ll_gt.cpu())
|
233 |
+
ll_acc = ll_metric.lineAccuracy()
|
234 |
+
ll_IoU = ll_metric.IntersectionOverUnion()
|
235 |
+
ll_mIoU = ll_metric.meanIntersectionOverUnion()
|
236 |
+
|
237 |
+
ll_acc_seg.update(ll_acc,img.size(0))
|
238 |
+
ll_IoU_seg.update(ll_IoU,img.size(0))
|
239 |
+
ll_mIoU_seg.update(ll_mIoU,img.size(0))
|
240 |
+
|
241 |
+
total_loss, head_losses = criterion((train_out,da_seg_out, ll_seg_out), target, shapes,model) #Compute loss
|
242 |
+
losses.update(total_loss.item(), img.size(0))
|
243 |
+
|
244 |
+
#NMS
|
245 |
+
t = time_synchronized()
|
246 |
+
target[0][:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels
|
247 |
+
lb = [target[0][target[0][:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
|
248 |
+
output = non_max_suppression(inf_out, conf_thres= config.TEST.NMS_CONF_THRESHOLD, iou_thres=config.TEST.NMS_IOU_THRESHOLD, labels=lb)
|
249 |
+
#output = non_max_suppression(inf_out, conf_thres=0.001, iou_thres=0.6)
|
250 |
+
#output = non_max_suppression(inf_out, conf_thres=config.TEST.NMS_CONF_THRES, iou_thres=config.TEST.NMS_IOU_THRES)
|
251 |
+
t_nms = time_synchronized() - t
|
252 |
+
if batch_i > 0:
|
253 |
+
T_nms.update(t_nms/img.size(0),img.size(0))
|
254 |
+
|
255 |
+
if config.TEST.PLOTS:
|
256 |
+
if batch_i == 0:
|
257 |
+
for i in range(test_batch_size):
|
258 |
+
img_test = cv2.imread(paths[i])
|
259 |
+
da_seg_mask = da_seg_out[i][:, pad_h:height-pad_h, pad_w:width-pad_w].unsqueeze(0)
|
260 |
+
da_seg_mask = torch.nn.functional.interpolate(da_seg_mask, scale_factor=int(1/ratio), mode='bilinear')
|
261 |
+
_, da_seg_mask = torch.max(da_seg_mask, 1)
|
262 |
+
|
263 |
+
da_gt_mask = target[1][i][:, pad_h:height-pad_h, pad_w:width-pad_w].unsqueeze(0)
|
264 |
+
da_gt_mask = torch.nn.functional.interpolate(da_gt_mask, scale_factor=int(1/ratio), mode='bilinear')
|
265 |
+
_, da_gt_mask = torch.max(da_gt_mask, 1)
|
266 |
+
|
267 |
+
da_seg_mask = da_seg_mask.int().squeeze().cpu().numpy()
|
268 |
+
da_gt_mask = da_gt_mask.int().squeeze().cpu().numpy()
|
269 |
+
# seg_mask = seg_mask > 0.5
|
270 |
+
# plot_img_and_mask(img_test, seg_mask, i,epoch,save_dir)
|
271 |
+
img_test1 = img_test.copy()
|
272 |
+
_ = show_seg_result(img_test, da_seg_mask, i,epoch,save_dir)
|
273 |
+
_ = show_seg_result(img_test1, da_gt_mask, i, epoch, save_dir, is_gt=True)
|
274 |
+
|
275 |
+
img_ll = cv2.imread(paths[i])
|
276 |
+
ll_seg_mask = ll_seg_out[i][:, pad_h:height-pad_h, pad_w:width-pad_w].unsqueeze(0)
|
277 |
+
ll_seg_mask = torch.nn.functional.interpolate(ll_seg_mask, scale_factor=int(1/ratio), mode='bilinear')
|
278 |
+
_, ll_seg_mask = torch.max(ll_seg_mask, 1)
|
279 |
+
|
280 |
+
ll_gt_mask = target[2][i][:, pad_h:height-pad_h, pad_w:width-pad_w].unsqueeze(0)
|
281 |
+
ll_gt_mask = torch.nn.functional.interpolate(ll_gt_mask, scale_factor=int(1/ratio), mode='bilinear')
|
282 |
+
_, ll_gt_mask = torch.max(ll_gt_mask, 1)
|
283 |
+
|
284 |
+
ll_seg_mask = ll_seg_mask.int().squeeze().cpu().numpy()
|
285 |
+
ll_gt_mask = ll_gt_mask.int().squeeze().cpu().numpy()
|
286 |
+
# seg_mask = seg_mask > 0.5
|
287 |
+
# plot_img_and_mask(img_test, seg_mask, i,epoch,save_dir)
|
288 |
+
img_ll1 = img_ll.copy()
|
289 |
+
_ = show_seg_result(img_ll, ll_seg_mask, i,epoch,save_dir, is_ll=True)
|
290 |
+
_ = show_seg_result(img_ll1, ll_gt_mask, i, epoch, save_dir, is_ll=True, is_gt=True)
|
291 |
+
|
292 |
+
img_det = cv2.imread(paths[i])
|
293 |
+
img_gt = img_det.copy()
|
294 |
+
det = output[i].clone()
|
295 |
+
if len(det):
|
296 |
+
det[:,:4] = scale_coords(img[i].shape[1:],det[:,:4],img_det.shape).round()
|
297 |
+
for *xyxy,conf,cls in reversed(det):
|
298 |
+
#print(cls)
|
299 |
+
label_det_pred = f'{names[int(cls)]} {conf:.2f}'
|
300 |
+
plot_one_box(xyxy, img_det , label=label_det_pred, color=colors[int(cls)], line_thickness=3)
|
301 |
+
cv2.imwrite(save_dir+"/batch_{}_{}_det_pred.png".format(epoch,i),img_det)
|
302 |
+
|
303 |
+
labels = target[0][target[0][:, 0] == i, 1:]
|
304 |
+
# print(labels)
|
305 |
+
labels[:,1:5]=xywh2xyxy(labels[:,1:5])
|
306 |
+
if len(labels):
|
307 |
+
labels[:,1:5]=scale_coords(img[i].shape[1:],labels[:,1:5],img_gt.shape).round()
|
308 |
+
for cls,x1,y1,x2,y2 in labels:
|
309 |
+
#print(names)
|
310 |
+
#print(cls)
|
311 |
+
label_det_gt = f'{names[int(cls)]}'
|
312 |
+
xyxy = (x1,y1,x2,y2)
|
313 |
+
plot_one_box(xyxy, img_gt , label=label_det_gt, color=colors[int(cls)], line_thickness=3)
|
314 |
+
cv2.imwrite(save_dir+"/batch_{}_{}_det_gt.png".format(epoch,i),img_gt)
|
315 |
+
|
316 |
+
# Statistics per image
|
317 |
+
# output([xyxy,conf,cls])
|
318 |
+
# target[0] ([img_id,cls,xyxy])
|
319 |
+
for si, pred in enumerate(output):
|
320 |
+
labels = target[0][target[0][:, 0] == si, 1:] #all object in one image
|
321 |
+
nl = len(labels) # num of object
|
322 |
+
tcls = labels[:, 0].tolist() if nl else [] # target class
|
323 |
+
path = Path(paths[si])
|
324 |
+
seen += 1
|
325 |
+
|
326 |
+
if len(pred) == 0:
|
327 |
+
if nl:
|
328 |
+
stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls))
|
329 |
+
continue
|
330 |
+
|
331 |
+
# Predictions
|
332 |
+
predn = pred.clone()
|
333 |
+
scale_coords(img[si].shape[1:], predn[:, :4], shapes[si][0], shapes[si][1]) # native-space pred
|
334 |
+
|
335 |
+
# Append to text file
|
336 |
+
if config.TEST.SAVE_TXT:
|
337 |
+
gn = torch.tensor(shapes[si][0])[[1, 0, 1, 0]] # normalization gain whwh
|
338 |
+
for *xyxy, conf, cls in predn.tolist():
|
339 |
+
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
|
340 |
+
line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
|
341 |
+
with open(save_dir / 'labels' / (path.stem + '.txt'), 'a') as f:
|
342 |
+
f.write(('%g ' * len(line)).rstrip() % line + '\n')
|
343 |
+
|
344 |
+
# W&B logging
|
345 |
+
if config.TEST.PLOTS and len(wandb_images) < log_imgs:
|
346 |
+
box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
|
347 |
+
"class_id": int(cls),
|
348 |
+
"box_caption": "%s %.3f" % (names[cls], conf),
|
349 |
+
"scores": {"class_score": conf},
|
350 |
+
"domain": "pixel"} for *xyxy, conf, cls in pred.tolist()]
|
351 |
+
boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space
|
352 |
+
wandb_images.append(wandb.Image(img[si], boxes=boxes, caption=path.name))
|
353 |
+
|
354 |
+
# Append to pycocotools JSON dictionary
|
355 |
+
if config.TEST.SAVE_JSON:
|
356 |
+
# [{"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}, ...
|
357 |
+
image_id = int(path.stem) if path.stem.isnumeric() else path.stem
|
358 |
+
box = xyxy2xywh(predn[:, :4]) # xywh
|
359 |
+
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
|
360 |
+
for p, b in zip(pred.tolist(), box.tolist()):
|
361 |
+
jdict.append({'image_id': image_id,
|
362 |
+
'category_id': coco91class[int(p[5])] if is_coco else int(p[5]),
|
363 |
+
'bbox': [round(x, 3) for x in b],
|
364 |
+
'score': round(p[4], 5)})
|
365 |
+
|
366 |
+
|
367 |
+
# Assign all predictions as incorrect
|
368 |
+
correct = torch.zeros(pred.shape[0], niou, dtype=torch.bool, device=device)
|
369 |
+
if nl:
|
370 |
+
detected = [] # target indices
|
371 |
+
tcls_tensor = labels[:, 0]
|
372 |
+
|
373 |
+
# target boxes
|
374 |
+
tbox = xywh2xyxy(labels[:, 1:5])
|
375 |
+
scale_coords(img[si].shape[1:], tbox, shapes[si][0], shapes[si][1]) # native-space labels
|
376 |
+
if config.TEST.PLOTS:
|
377 |
+
confusion_matrix.process_batch(pred, torch.cat((labels[:, 0:1], tbox), 1))
|
378 |
+
|
379 |
+
# Per target class
|
380 |
+
for cls in torch.unique(tcls_tensor):
|
381 |
+
ti = (cls == tcls_tensor).nonzero(as_tuple=False).view(-1) # prediction indices
|
382 |
+
pi = (cls == pred[:, 5]).nonzero(as_tuple=False).view(-1) # target indices
|
383 |
+
|
384 |
+
# Search for detections
|
385 |
+
if pi.shape[0]:
|
386 |
+
# Prediction to target ious
|
387 |
+
# n*m n:pred m:label
|
388 |
+
ious, i = box_iou(predn[pi, :4], tbox[ti]).max(1) # best ious, indices
|
389 |
+
# Append detections
|
390 |
+
detected_set = set()
|
391 |
+
for j in (ious > iouv[0]).nonzero(as_tuple=False):
|
392 |
+
d = ti[i[j]] # detected target
|
393 |
+
if d.item() not in detected_set:
|
394 |
+
detected_set.add(d.item())
|
395 |
+
detected.append(d)
|
396 |
+
correct[pi[j]] = ious[j] > iouv # iou_thres is 1xn
|
397 |
+
if len(detected) == nl: # all targets already located in image
|
398 |
+
break
|
399 |
+
|
400 |
+
# Append statistics (correct, conf, pcls, tcls)
|
401 |
+
stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))
|
402 |
+
|
403 |
+
if config.TEST.PLOTS and batch_i < 3:
|
404 |
+
f = save_dir +'/'+ f'test_batch{batch_i}_labels.jpg' # labels
|
405 |
+
#Thread(target=plot_images, args=(img, target[0], paths, f, names), daemon=True).start()
|
406 |
+
f = save_dir +'/'+ f'test_batch{batch_i}_pred.jpg' # predictions
|
407 |
+
#Thread(target=plot_images, args=(img, output_to_target(output), paths, f, names), daemon=True).start()
|
408 |
+
|
409 |
+
# Compute statistics
|
410 |
+
# stats : [[all_img_correct]...[all_img_tcls]]
|
411 |
+
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy zip(*) :unzip
|
412 |
+
|
413 |
+
map70 = None
|
414 |
+
map75 = None
|
415 |
+
if len(stats) and stats[0].any():
|
416 |
+
p, r, ap, f1, ap_class = ap_per_class(*stats, plot=False, save_dir=save_dir, names=names)
|
417 |
+
ap50, ap70, ap75,ap = ap[:, 0], ap[:,4], ap[:,5],ap.mean(1) # [P, R, [email protected], [email protected]:0.95]
|
418 |
+
mp, mr, map50, map70, map75, map = p.mean(), r.mean(), ap50.mean(), ap70.mean(),ap75.mean(),ap.mean()
|
419 |
+
nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class
|
420 |
+
else:
|
421 |
+
nt = torch.zeros(1)
|
422 |
+
|
423 |
+
# Print results
|
424 |
+
pf = '%20s' + '%12.3g' * 6 # print format
|
425 |
+
print(pf % ('all', seen, nt.sum(), mp, mr, map50, map))
|
426 |
+
#print(map70)
|
427 |
+
#print(map75)
|
428 |
+
|
429 |
+
# Print results per class
|
430 |
+
if (verbose or (nc <= 20 and not training)) and nc > 1 and len(stats):
|
431 |
+
for i, c in enumerate(ap_class):
|
432 |
+
print(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i]))
|
433 |
+
|
434 |
+
# Print speeds
|
435 |
+
t = tuple(x / seen * 1E3 for x in (t_inf, t_nms, t_inf + t_nms)) + (imgsz, imgsz, batch_size) # tuple
|
436 |
+
if not training:
|
437 |
+
print('Speed: %.1f/%.1f/%.1f ms inference/NMS/total per %gx%g image at batch-size %g' % t)
|
438 |
+
|
439 |
+
# Plots
|
440 |
+
if config.TEST.PLOTS:
|
441 |
+
confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
|
442 |
+
if wandb and wandb.run:
|
443 |
+
wandb.log({"Images": wandb_images})
|
444 |
+
wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))]})
|
445 |
+
|
446 |
+
# Save JSON
|
447 |
+
if config.TEST.SAVE_JSON and len(jdict):
|
448 |
+
w = Path(weights[0] if isinstance(weights, list) else weights).stem if weights is not None else '' # weights
|
449 |
+
anno_json = '../coco/annotations/instances_val2017.json' # annotations json
|
450 |
+
pred_json = str(save_dir / f"{w}_predictions.json") # predictions json
|
451 |
+
print('\nEvaluating pycocotools mAP... saving %s...' % pred_json)
|
452 |
+
with open(pred_json, 'w') as f:
|
453 |
+
json.dump(jdict, f)
|
454 |
+
|
455 |
+
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
456 |
+
from pycocotools.coco import COCO
|
457 |
+
from pycocotools.cocoeval import COCOeval
|
458 |
+
|
459 |
+
anno = COCO(anno_json) # init annotations api
|
460 |
+
pred = anno.loadRes(pred_json) # init predictions api
|
461 |
+
eval = COCOeval(anno, pred, 'bbox')
|
462 |
+
if is_coco:
|
463 |
+
eval.params.imgIds = [int(Path(x).stem) for x in val_loader.dataset.img_files] # image IDs to evaluate
|
464 |
+
eval.evaluate()
|
465 |
+
eval.accumulate()
|
466 |
+
eval.summarize()
|
467 |
+
map, map50 = eval.stats[:2] # update results ([email protected]:0.95, [email protected])
|
468 |
+
except Exception as e:
|
469 |
+
print(f'pycocotools unable to run: {e}')
|
470 |
+
|
471 |
+
# Return results
|
472 |
+
if not training:
|
473 |
+
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if config.TEST.SAVE_TXT else ''
|
474 |
+
print(f"Results saved to {save_dir}{s}")
|
475 |
+
model.float() # for training
|
476 |
+
maps = np.zeros(nc) + map
|
477 |
+
for i, c in enumerate(ap_class):
|
478 |
+
maps[c] = ap[i]
|
479 |
+
|
480 |
+
da_segment_result = (da_acc_seg.avg,da_IoU_seg.avg,da_mIoU_seg.avg)
|
481 |
+
ll_segment_result = (ll_acc_seg.avg,ll_IoU_seg.avg,ll_mIoU_seg.avg)
|
482 |
+
|
483 |
+
# print(da_segment_result)
|
484 |
+
# print(ll_segment_result)
|
485 |
+
detect_result = np.asarray([mp, mr, map50, map])
|
486 |
+
# print('mp:{},mr:{},map50:{},map:{}'.format(mp, mr, map50, map))
|
487 |
+
#print segmet_result
|
488 |
+
t = [T_inf.avg, T_nms.avg]
|
489 |
+
return da_segment_result, ll_segment_result, detect_result, losses.avg, maps, t
|
490 |
+
|
491 |
+
|
492 |
+
|
493 |
+
class AverageMeter(object):
|
494 |
+
"""Computes and stores the average and current value"""
|
495 |
+
def __init__(self):
|
496 |
+
self.reset()
|
497 |
+
|
498 |
+
def reset(self):
|
499 |
+
self.val = 0
|
500 |
+
self.avg = 0
|
501 |
+
self.sum = 0
|
502 |
+
self.count = 0
|
503 |
+
|
504 |
+
def update(self, val, n=1):
|
505 |
+
self.val = val
|
506 |
+
self.sum += val * n
|
507 |
+
self.count += n
|
508 |
+
self.avg = self.sum / self.count if self.count != 0 else 0
|