imabackstabber commited on
Commit
db8354d
·
1 Parent(s): 4923179

refine layout

Browse files
Files changed (2) hide show
  1. app.py +7 -3
  2. main/inference.py +17 -4
app.py CHANGED
@@ -32,9 +32,9 @@ def infer(image_input, in_threshold=0.5, num_people="Single person", render_mesh
32
  inferer = Inferer(DEFAULT_MODEL, num_gpus, OUT_FOLDER)
33
  os.system(f'rm -rf {OUT_FOLDER}/*')
34
  multi_person = False if (num_people == "Single person") else True
35
- vis_img, num_bbox, mmdet_box = inferer.infer(image_input, in_threshold, multi_person, not(render_mesh))
36
 
37
- return vis_img, "bbox num: {}, bbox meta: {}".format(num_bbox, mmdet_box)
38
 
39
  TITLE = '''<h1 align="center">PostoMETRO: Pose Token Enhanced Mesh Transformer for Robust 3D Human Mesh Recovery</h1>'''
40
  DESCRIPTION = '''
@@ -43,6 +43,9 @@ DESCRIPTION = '''
43
  Note: You can drop a image at the panel (or select one of the examples)
44
  to obtain the 3D parametric reconstructions of the detected humans.
45
  </p>
 
 
 
46
  '''
47
 
48
  with gr.Blocks(title="PostoMETRO", css=".gradio-container") as demo:
@@ -71,10 +74,11 @@ with gr.Blocks(title="PostoMETRO", css=".gradio-container") as demo:
71
  send_button = gr.Button("Infer")
72
  with gr.Column():
73
  processed_frames = gr.Image(label="Rendered Results")
 
74
  debug_textbox = gr.Textbox(label="Debug information")
75
 
76
  # example_images = gr.Examples([])
77
- send_button.click(fn=infer, inputs=[image_input, threshold, num_people, mesh_as_vertices], outputs=[processed_frames, debug_textbox])
78
  # with gr.Row():
79
  example_images = gr.Examples([
80
  ['/home/user/app/assets/01.jpg'],
 
32
  inferer = Inferer(DEFAULT_MODEL, num_gpus, OUT_FOLDER)
33
  os.system(f'rm -rf {OUT_FOLDER}/*')
34
  multi_person = False if (num_people == "Single person") else True
35
+ vis_img, bbox_img, num_bbox, mmdet_box = inferer.infer(image_input, in_threshold, multi_person, not(render_mesh))
36
 
37
+ return vis_img, bbox_img, "bbox num: {}\nbbox meta: {}".format(num_bbox, mmdet_box)
38
 
39
  TITLE = '''<h1 align="center">PostoMETRO: Pose Token Enhanced Mesh Transformer for Robust 3D Human Mesh Recovery</h1>'''
40
  DESCRIPTION = '''
 
43
  Note: You can drop a image at the panel (or select one of the examples)
44
  to obtain the 3D parametric reconstructions of the detected humans.
45
  </p>
46
+ <p>
47
+ Check out <a href="https://arxiv.org/abs/2403.12473"><b>our paper on arxiv page</b>!
48
+ </p>
49
  '''
50
 
51
  with gr.Blocks(title="PostoMETRO", css=".gradio-container") as demo:
 
74
  send_button = gr.Button("Infer")
75
  with gr.Column():
76
  processed_frames = gr.Image(label="Rendered Results")
77
+ bbox_frames = gr.Image(label="Bbox Results")
78
  debug_textbox = gr.Textbox(label="Debug information")
79
 
80
  # example_images = gr.Examples([])
81
+ send_button.click(fn=infer, inputs=[image_input, threshold, num_people, mesh_as_vertices], outputs=[processed_frames, bbox_frames, debug_textbox])
82
  # with gr.Row():
83
  example_images = gr.Examples([
84
  ['/home/user/app/assets/01.jpg'],
main/inference.py CHANGED
@@ -57,6 +57,7 @@ class Inferer:
57
  transform = transforms.Normalize(mean=[0.485, 0.456, 0.406],
58
  std=[0.229, 0.224, 0.225])
59
  vis_img = original_img.copy()
 
60
  original_img_height, original_img_width = original_img.shape[:2]
61
 
62
  # load renderer
@@ -97,13 +98,14 @@ class Inferer:
97
  # align these pre-processing steps
98
  bbox = process_bbox(mmdet_box_xywh, original_img_width, original_img_height)
99
 
100
- ok_bboxes.append(bbox)
101
 
102
  # [DEBUG] test mmdet pipeline
103
  if bbox is not None:
104
  top_left = (int(bbox[0]), int(bbox[1]))
105
  bottom_right = (int(bbox[0] + bbox[2]), int(bbox[1] + bbox[3]))
106
  cv2.rectangle(vis_img, top_left, bottom_right, (0, 0, 255), 2)
 
107
 
108
  # human model inference
109
  img, img2bb_trans, bb2img_trans = generate_patch_image(original_img, bbox, 1.0, 0.0, False, self.cfg.input_img_shape)
@@ -136,18 +138,19 @@ class Inferer:
136
  pred_cam[2] + cy_delta / (pred_cam[0] / (original_img_height / bbox[3]))],
137
  mesh_as_vertices=mesh_as_vertices)
138
  vis_img = vis_img.astype('uint8')
139
- return vis_img, len(ok_bboxes), ok_bboxes
140
 
141
 
142
  if __name__ == '__main__':
143
  from PIL import Image
144
  inferer = Inferer('postometro', 1, './out_folder') # gpu
145
- image_path = f'../assets/07.jpg'
146
  image = Image.open(image_path)
147
  # Convert the PIL image to a NumPy array
148
  image_np = np.array(image)
149
- vis_img, _ , _ = inferer.infer(image_np, 0.2, multi_person=True, mesh_as_vertices=True)
150
  save_path = f'./saved_vis_07.jpg'
 
151
 
152
  # Ensure the image is in the correct format (PIL expects uint8)
153
  if vis_img.dtype != np.uint8:
@@ -157,3 +160,13 @@ if __name__ == '__main__':
157
  image = Image.fromarray(vis_img)
158
  image.save(save_path)
159
 
 
 
 
 
 
 
 
 
 
 
 
57
  transform = transforms.Normalize(mean=[0.485, 0.456, 0.406],
58
  std=[0.229, 0.224, 0.225])
59
  vis_img = original_img.copy()
60
+ bbox_img = original_img.copy()
61
  original_img_height, original_img_width = original_img.shape[:2]
62
 
63
  # load renderer
 
98
  # align these pre-processing steps
99
  bbox = process_bbox(mmdet_box_xywh, original_img_width, original_img_height)
100
 
101
+ ok_bboxes.append(bbox.tolist())
102
 
103
  # [DEBUG] test mmdet pipeline
104
  if bbox is not None:
105
  top_left = (int(bbox[0]), int(bbox[1]))
106
  bottom_right = (int(bbox[0] + bbox[2]), int(bbox[1] + bbox[3]))
107
  cv2.rectangle(vis_img, top_left, bottom_right, (0, 0, 255), 2)
108
+ cv2.rectangle(bbox_img, top_left, bottom_right, (0, 0, 255), 2)
109
 
110
  # human model inference
111
  img, img2bb_trans, bb2img_trans = generate_patch_image(original_img, bbox, 1.0, 0.0, False, self.cfg.input_img_shape)
 
138
  pred_cam[2] + cy_delta / (pred_cam[0] / (original_img_height / bbox[3]))],
139
  mesh_as_vertices=mesh_as_vertices)
140
  vis_img = vis_img.astype('uint8')
141
+ return vis_img, bbox_img, len(ok_bboxes), ok_bboxes
142
 
143
 
144
  if __name__ == '__main__':
145
  from PIL import Image
146
  inferer = Inferer('postometro', 1, './out_folder') # gpu
147
+ image_path = f'../assets/06.jpg'
148
  image = Image.open(image_path)
149
  # Convert the PIL image to a NumPy array
150
  image_np = np.array(image)
151
+ vis_img, bbox_img, num_bbox, mmdet_box = inferer.infer(image_np, 0.2, multi_person=True, mesh_as_vertices=True)
152
  save_path = f'./saved_vis_07.jpg'
153
+ bbox_save_path = f'./bbox_saved_vis_07.jpg'
154
 
155
  # Ensure the image is in the correct format (PIL expects uint8)
156
  if vis_img.dtype != np.uint8:
 
160
  image = Image.fromarray(vis_img)
161
  image.save(save_path)
162
 
163
+ # Ensure the image is in the correct format (PIL expects uint8)
164
+ if bbox_img.dtype != np.uint8:
165
+ bbox_img = bbox_img.astype('uint8')
166
+
167
+ # Convert the Numpy array (if RGB) to a PIL image and save
168
+ image = Image.fromarray(bbox_img)
169
+ image.save(bbox_save_path)
170
+
171
+ print("bbox num: {}\nbbox meta: {}".format(num_bbox, mmdet_box))
172
+