hank1996 commited on
Commit
3e44564
·
1 Parent(s): 35c8a1c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -3
app.py CHANGED
@@ -34,8 +34,6 @@ from utils.functions import \
34
  from PIL import Image
35
 
36
 
37
-
38
-
39
  def detect(img,model):
40
  #with torch.no_grad():
41
  parser = argparse.ArgumentParser()
@@ -86,6 +84,158 @@ def detect(img,model):
86
  print(weights)
87
  if weights == 'yolop.pt':
88
  weights = 'End-to-end.pth'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  if weights == 'yolopv2.pt':
91
  stride =32
@@ -194,4 +344,4 @@ def detect(img,model):
194
  return Image.fromarray(im0[:,:,::-1])
195
 
196
 
197
- gr.Interface(detect,[gr.Image(type="pil"),gr.Dropdown(choices=["yolopv2","yolop"])], gr.Image(type="pil"),title="Yolopv2",examples=[["example.jpeg", "yolopv2"]],description="demo for <a href='https://github.com/CAIC-AD/YOLOPv2' style='text-decoration: underline' target='_blank'>YOLOPv2</a> 🚀: Better, Faster, Stronger for Panoptic driving Perception").launch()
 
34
  from PIL import Image
35
 
36
 
 
 
37
  def detect(img,model):
38
  #with torch.no_grad():
39
  parser = argparse.ArgumentParser()
 
84
  print(weights)
85
  if weights == 'yolop.pt':
86
  weights = 'End-to-end.pth'
87
+ import argparse
88
+ import os, sys
89
+ import shutil
90
+ import time
91
+ from pathlib import Path
92
+ import imageio
93
+
94
+ BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
95
+ sys.path.append(BASE_DIR)
96
+
97
+ print(sys.path)
98
+ import cv2
99
+ import torch
100
+ import torch.backends.cudnn as cudnn
101
+ from numpy import random
102
+ import scipy.special
103
+ import numpy as np
104
+ import torchvision.transforms as transforms
105
+ import PIL.Image as image
106
+
107
+ from lib.config import cfg
108
+ from lib.config import update_config
109
+ from lib.utils.utils import create_logger, select_device, time_synchronized
110
+ from lib.models import get_net
111
+ from lib.dataset import LoadImages, LoadStreams
112
+ from lib.core.general import non_max_suppression, scale_coords
113
+ from lib.utils import plot_one_box,show_seg_result
114
+ from lib.core.function import AverageMeter
115
+ from lib.core.postprocess import morphological_process, connect_lane
116
+ from tqdm import tqdm
117
+ normalize = transforms.Normalize(
118
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
119
+ )
120
+
121
+ transform=transforms.Compose([
122
+ transforms.ToTensor(),
123
+ normalize,
124
+ ])
125
+
126
+ logger, _, _ = create_logger(
127
+ cfg, cfg.LOG_DIR, 'demo')
128
+
129
+ if os.path.exists(opt.save_dir): # output dir
130
+ shutil.rmtree(opt.save_dir) # delete dir
131
+ os.makedirs(opt.save_dir) # make new dir
132
+
133
+ # Load model
134
+ model = get_net(cfg)
135
+ checkpoint = torch.load(weights, map_location= device)
136
+ model.load_state_dict(checkpoint['state_dict'])
137
+ model = model.to(device)
138
+ #if half:
139
+ #model.half() # to FP16
140
+
141
+ # Set Dataloader
142
+
143
+ dataset = LoadImages(opt.source, img_size=opt.img_size)
144
+ bs = 1 # batch_size
145
+
146
+
147
+ # Get names and colors
148
+ names = model.module.names if hasattr(model, 'module') else model.names
149
+ colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(names))]
150
+
151
+ # Run inference
152
+ t0 = time.time()
153
+
154
+ vid_path, vid_writer = None, None
155
+ img = torch.zeros((1, 3, opt.img_size, opt.img_size), device=device) # init img
156
+ _ = model(img.half() if half else img) if device.type != 'cpu' else None # run once
157
+ model.eval()
158
+
159
+ inf_time = AverageMeter()
160
+ nms_time = AverageMeter()
161
+
162
+ for i, (path, img, img_det, vid_cap,shapes) in tqdm(enumerate(dataset),total = len(dataset)):
163
+ img = transform(img).to(device)
164
+ img = img.half() if half else img.float() # uint8 to fp16/32
165
+ if img.ndimension() == 3:
166
+ img = img.unsqueeze(0)
167
+ # Inference
168
+ t1 = time_synchronized()
169
+ det_out, da_seg_out,ll_seg_out= model(img)
170
+ t2 = time_synchronized()
171
+ # if i == 0:
172
+ # print(det_out)
173
+ inf_out, _ = det_out
174
+ inf_time.update(t2-t1,img.size(0))
175
+
176
+ # Apply NMS
177
+ t3 = time_synchronized()
178
+ det_pred = non_max_suppression(inf_out, conf_thres=opt.conf_thres, iou_thres=opt.iou_thres, classes=None, agnostic=False)
179
+ t4 = time_synchronized()
180
+
181
+ nms_time.update(t4-t3,img.size(0))
182
+ det=det_pred[0]
183
+
184
+ save_path = str(opt.save_dir +'/'+ Path(path).name) if dataset.mode != 'stream' else str(opt.save_dir + '/' + "web.mp4")
185
+
186
+ _, _, height, width = img.shape
187
+ h,w,_=img_det.shape
188
+ pad_w, pad_h = shapes[1][1]
189
+ pad_w = int(pad_w)
190
+ pad_h = int(pad_h)
191
+ ratio = shapes[1][0][1]
192
+
193
+ da_predict = da_seg_out[:, :, pad_h:(height-pad_h),pad_w:(width-pad_w)]
194
+ da_seg_mask = torch.nn.functional.interpolate(da_predict, scale_factor=int(1/ratio), mode='bilinear')
195
+ _, da_seg_mask = torch.max(da_seg_mask, 1)
196
+ da_seg_mask = da_seg_mask.int().squeeze().cpu().numpy()
197
+ # da_seg_mask = morphological_process(da_seg_mask, kernel_size=7)
198
+
199
+
200
+ ll_predict = ll_seg_out[:, :,pad_h:(height-pad_h),pad_w:(width-pad_w)]
201
+ ll_seg_mask = torch.nn.functional.interpolate(ll_predict, scale_factor=int(1/ratio), mode='bilinear')
202
+ _, ll_seg_mask = torch.max(ll_seg_mask, 1)
203
+ ll_seg_mask = ll_seg_mask.int().squeeze().cpu().numpy()
204
+ # Lane line post-processing
205
+ #ll_seg_mask = morphological_process(ll_seg_mask, kernel_size=7, func_type=cv2.MORPH_OPEN)
206
+ #ll_seg_mask = connect_lane(ll_seg_mask)
207
+
208
+ img_det = show_seg_result(img_det, (da_seg_mask, ll_seg_mask), _, _, is_demo=True)
209
+
210
+ if len(det):
211
+ det[:,:4] = scale_coords(img.shape[2:],det[:,:4],img_det.shape).round()
212
+ for *xyxy,conf,cls in reversed(det):
213
+ label_det_pred = f'{names[int(cls)]} {conf:.2f}'
214
+ plot_one_box(xyxy, img_det , label=label_det_pred, color=colors[int(cls)], line_thickness=2)
215
+
216
+ if dataset.mode == 'images':
217
+ cv2.imwrite(save_path,img_det)
218
+
219
+ elif dataset.mode == 'video':
220
+ if vid_path != save_path: # new video
221
+ vid_path = save_path
222
+ if isinstance(vid_writer, cv2.VideoWriter):
223
+ vid_writer.release() # release previous video writer
224
+
225
+ fourcc = 'mp4v' # output video codec
226
+ fps = vid_cap.get(cv2.CAP_PROP_FPS)
227
+ h,w,_=img_det.shape
228
+ vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*fourcc), fps, (w, h))
229
+ vid_writer.write(img_det)
230
+
231
+ else:
232
+ cv2.imshow('image', img_det)
233
+ cv2.waitKey(1) # 1 millisecond
234
+
235
+ print('Results saved to %s' % Path(opt.save_dir))
236
+ print('Done. (%.3fs)' % (time.time() - t0))
237
+ print('inf : (%.4fs/frame) nms : (%.4fs/frame)' % (inf_time.avg,nms_time.avg))
238
+
239
 
240
  if weights == 'yolopv2.pt':
241
  stride =32
 
344
  return Image.fromarray(im0[:,:,::-1])
345
 
346
 
347
+ gr.Interface(detect,[gr.Image(type="pil"),gr.Dropdown(choices=["yolopv2","yolop"])], gr.Image(type="pil"),title="Yolopv2",examples=[["example.jpeg", "yolop"]],description="demo for <a href='https://github.com/CAIC-AD/YOLOPv2' style='text-decoration: underline' target='_blank'>YOLOPv2</a> 🚀: Better, Faster, Stronger for Panoptic driving Perception").launch()