asdflkjh commited on
Commit
13f4638
·
1 Parent(s): 47a02e2

obj and png outputs work

Browse files
Files changed (2) hide show
  1. gradio_app.py +57 -16
  2. run.py +12 -2
gradio_app.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import random
 
3
  import tempfile
4
  import time
5
  import zipfile
@@ -65,6 +66,10 @@ example_files = [
65
 
66
 
67
  def create_zip_file(glb_file, pc_file, illumination_file):
 
 
 
 
68
  if not all([glb_file, pc_file, illumination_file]):
69
  return None
70
 
@@ -206,22 +211,47 @@ def run_model(
206
 
207
  # Create new tmp file
208
  temp_dir = tempfile.mkdtemp()
209
- tmp_file = os.path.join(temp_dir, "mesh.glb")
210
 
211
- trimesh_mesh.export(tmp_file, file_type="glb", include_normals=True)
212
  generated_files.append(tmp_file)
213
 
214
- tmp_file_pc = os.path.join(temp_dir, "points.ply")
215
- trimesh_pc.export(tmp_file_pc)
216
- generated_files.append(tmp_file_pc)
217
-
218
- tmp_file_illumination = os.path.join(temp_dir, "illumination.hdr")
219
- cv2.imwrite(tmp_file_illumination, illumination_map)
220
- generated_files.append(tmp_file_illumination)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
  print("Generation took:", time.time() - start, "s")
223
 
224
- return tmp_file, tmp_file_pc, tmp_file_illumination, trimesh_pc
225
 
226
 
227
  def create_batch(input_image: Image) -> dict[str, Any]:
@@ -300,7 +330,7 @@ def process_model_run(
300
  f"Final vertex count: {final_vertex_count} with type {vertex_count_type} and vertex count {vertex_count}"
301
  )
302
 
303
- glb_file, pc_file, illumination_file, pc_plot = run_model(
304
  background_state,
305
  guidance_scale,
306
  random_seed,
@@ -323,7 +353,7 @@ def process_model_run(
323
  ]
324
  )
325
 
326
- return glb_file, pc_file, illumination_file, point_list
327
 
328
 
329
  def regenerate_run(
@@ -336,7 +366,7 @@ def regenerate_run(
336
  vertex_count,
337
  texture_resolution,
338
  ):
339
- glb_file, pc_file, illumination_file, point_list = process_model_run(
340
  background_state,
341
  guidance_scale,
342
  random_seed,
@@ -394,7 +424,7 @@ def run_button(
394
  else:
395
  pc_cond = None
396
 
397
- glb_file, pc_file, illumination_file, pc_list = process_model_run(
398
  background_state,
399
  guidance_scale,
400
  random_seed,
@@ -405,7 +435,8 @@ def run_button(
405
  texture_resolution,
406
  )
407
 
408
- zip_file = create_zip_file(glb_file, pc_file, illumination_file)
 
409
 
410
  if torch.cuda.is_available():
411
  print("Peak Memory:", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB")
@@ -426,7 +457,8 @@ def run_button(
426
  gr.update(value=pc_list), # point_cloud_editor
427
  gr.update(value=pc_file), # pc_download
428
  gr.update(visible=False), # regenerate_btn
429
- gr.update(value=zip_file, visible=True), # download_all_btn
 
430
  )
431
 
432
  elif run_btn == "Remove Background":
@@ -725,6 +757,14 @@ with gr.Blocks() as demo:
725
  contrast=1.0,
726
  scale=1.0,
727
  )
 
 
 
 
 
 
 
 
728
  with gr.Column(visible=False, scale=1.0) as hdr_row:
729
  gr.Markdown(
730
  """## HDR Environment Map
@@ -875,6 +915,7 @@ with gr.Blocks() as demo:
875
  pc_download,
876
  regenerate_btn,
877
  download_all_btn,
 
878
  ],
879
  )
880
 
 
1
  import os
2
  import random
3
+ from re import I
4
  import tempfile
5
  import time
6
  import zipfile
 
66
 
67
 
68
  def create_zip_file(glb_file, pc_file, illumination_file):
69
+ print("zip disabled for perf")
70
+
71
+ return "zip disabled for perf"
72
+
73
  if not all([glb_file, pc_file, illumination_file]):
74
  return None
75
 
 
211
 
212
  # Create new tmp file
213
  temp_dir = tempfile.mkdtemp()
214
+ tmp_file = os.path.join(temp_dir, "mesh.obj")
215
 
216
+ trimesh_mesh.export(tmp_file, file_type="obj", include_normals=True)
217
  generated_files.append(tmp_file)
218
 
219
+ # tmp_file_pc = os.path.join(temp_dir, "points.ply")
220
+ # trimesh_pc.export(tmp_file_pc)
221
+ # generated_files.append(tmp_file_pc)
222
+
223
+ # tmp_file_illumination = os.path.join(temp_dir, "illumination.hdr")
224
+ # cv2.imwrite(tmp_file_illumination, illumination_map)
225
+ # generated_files.append(tmp_file_illumination)
226
+
227
+ # Extract textures
228
+ texture_full_path = None
229
+ if trimesh_mesh.visual is not None:
230
+ material = trimesh_mesh.visual.material
231
+ if hasattr(material, 'baseColorTexture'):
232
+ texture_path = 'texture.png'
233
+ texture_full_path = os.path.join(temp_dir, texture_path)
234
+
235
+ # Convert the texture data to an image and save
236
+ texture_data = material.baseColorTexture
237
+ if isinstance(texture_data, np.ndarray):
238
+ texture_img = Image.fromarray(texture_data)
239
+ print("Saving texture to", texture_full_path)
240
+ texture_img.save(texture_full_path)
241
+ generated_files.append(texture_full_path)
242
+ elif isinstance(texture_data, Image.Image):
243
+ print("Saving texture to", texture_full_path)
244
+ texture_data.save(texture_full_path)
245
+ generated_files.append(texture_full_path)
246
+ else:
247
+ print("Texture data is not a numpy array, but instead: " + str(type(texture_data)))
248
+ else:
249
+ print("Material has no baseColorTexture, but only:")
250
+ print(vars(material))
251
 
252
  print("Generation took:", time.time() - start, "s")
253
 
254
+ return tmp_file, None, None, trimesh_pc, texture_full_path
255
 
256
 
257
  def create_batch(input_image: Image) -> dict[str, Any]:
 
330
  f"Final vertex count: {final_vertex_count} with type {vertex_count_type} and vertex count {vertex_count}"
331
  )
332
 
333
+ glb_file, pc_file, illumination_file, pc_plot, texture_file = run_model(
334
  background_state,
335
  guidance_scale,
336
  random_seed,
 
353
  ]
354
  )
355
 
356
+ return glb_file, pc_file, illumination_file, point_list, texture_file
357
 
358
 
359
  def regenerate_run(
 
366
  vertex_count,
367
  texture_resolution,
368
  ):
369
+ glb_file, pc_file, illumination_file, point_list, texture_file = process_model_run(
370
  background_state,
371
  guidance_scale,
372
  random_seed,
 
424
  else:
425
  pc_cond = None
426
 
427
+ glb_file, pc_file, illumination_file, pc_list, texture_file = process_model_run(
428
  background_state,
429
  guidance_scale,
430
  random_seed,
 
435
  texture_resolution,
436
  )
437
 
438
+ # Disabled to improve performance
439
+ # zip_file = create_zip_file(glb_file, pc_file, illumination_file)
440
 
441
  if torch.cuda.is_available():
442
  print("Peak Memory:", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB")
 
457
  gr.update(value=pc_list), # point_cloud_editor
458
  gr.update(value=pc_file), # pc_download
459
  gr.update(visible=False), # regenerate_btn
460
+ gr.update(value=None, visible=True), # download_all_btn
461
+ gr.update(value=texture_file, visible=True), # download_all_btn
462
  )
463
 
464
  elif run_btn == "Remove Background":
 
757
  contrast=1.0,
758
  scale=1.0,
759
  )
760
+
761
+ output_texture = gr.File(
762
+ label="Download Texture",
763
+ file_types=[".png"],
764
+ file_count="single",
765
+ visible=False,
766
+ )
767
+
768
  with gr.Column(visible=False, scale=1.0) as hdr_row:
769
  gr.Markdown(
770
  """## HDR Environment Map
 
915
  pc_download,
916
  regenerate_btn,
917
  download_all_btn,
918
+ output_texture,
919
  ],
920
  )
921
 
run.py CHANGED
@@ -10,6 +10,7 @@ from transparent_background import Remover
10
  from spar3d.models.mesh import QUAD_REMESH_AVAILABLE, TRIANGLE_REMESH_AVAILABLE
11
  from spar3d.system import SPAR3D
12
  from spar3d.utils import foreground_crop, get_device, remove_background
 
13
 
14
 
15
  def check_positive(value):
@@ -178,13 +179,22 @@ if __name__ == "__main__":
178
  )
179
 
180
  if len(image) == 1:
181
- out_mesh_path = os.path.join(output_dir, str(i), "mesh.glb")
 
182
  mesh.export(out_mesh_path, include_normals=True)
183
  out_points_path = os.path.join(output_dir, str(i), "points.ply")
184
  glob_dict["point_clouds"][0].export(out_points_path)
 
 
 
 
 
 
 
 
185
  else:
186
  for j in range(len(mesh)):
187
- out_mesh_path = os.path.join(output_dir, str(i + j), "mesh.glb")
188
  mesh[j].export(out_mesh_path, include_normals=True)
189
  out_points_path = os.path.join(output_dir, str(i + j), "points.ply")
190
  glob_dict["point_clouds"][j].export(out_points_path)
 
10
  from spar3d.models.mesh import QUAD_REMESH_AVAILABLE, TRIANGLE_REMESH_AVAILABLE
11
  from spar3d.system import SPAR3D
12
  from spar3d.utils import foreground_crop, get_device, remove_background
13
+ import trimesh
14
 
15
 
16
  def check_positive(value):
 
179
  )
180
 
181
  if len(image) == 1:
182
+ assert(isinstance(mesh, trimesh.Trimesh))
183
+ out_mesh_path = os.path.join(output_dir, str(i), "mesh.obj")
184
  mesh.export(out_mesh_path, include_normals=True)
185
  out_points_path = os.path.join(output_dir, str(i), "points.ply")
186
  glob_dict["point_clouds"][0].export(out_points_path)
187
+
188
+ # Extract textures
189
+ for j, material in enumerate(mesh.visual.material):
190
+ if hasattr(material, 'image'):
191
+ texture_path = f'texture_{j}.png'
192
+ with open(os.path.join(output_dir, str(i), texture_path), 'wb') as f:
193
+ f.write(material.image.tobytes())
194
+
195
  else:
196
  for j in range(len(mesh)):
197
+ out_mesh_path = os.path.join(output_dir, str(i + j), "mesh.obj")
198
  mesh[j].export(out_mesh_path, include_normals=True)
199
  out_points_path = os.path.join(output_dir, str(i + j), "points.ply")
200
  glob_dict["point_clouds"][j].export(out_points_path)