import gradio as gr import spaces from gradio_litmodel3d import LitModel3D import os os.environ['SPCONV_ALGO'] = 'native' from typing import * import torch import numpy as np import imageio import uuid from easydict import EasyDict as edict from PIL import Image from trellis.pipelines import TrellisImageTo3DPipeline from trellis.representations import Gaussian, MeshExtractResult from trellis.utils import render_utils, postprocessing_utils # 기본 설정 MAX_SEED = np.iinfo(np.int32).max TMP_DIR = "/tmp/Trellis-demo" os.makedirs(TMP_DIR, exist_ok=True) # CUDA 초기화 함수 def init_cuda(): try: if torch.cuda.is_available(): device = torch.device('cuda') print("CUDA 초기화 성공") else: device = torch.device('cpu') print("CUDA를 사용할 수 없어 CPU를 사용합니다") return device except Exception as e: print(f"CUDA 초기화 중 오류 발생: {e}") return torch.device('cpu') def preprocess_image(image: Image.Image) -> Tuple[str, Image.Image]: """ 입력 이미지 전처리 """ trial_id = str(uuid.uuid4()) processed_image = pipeline.preprocess_image(image) processed_image.save(f"{TMP_DIR}/{trial_id}.png") return trial_id, processed_image def pack_state(gs: Gaussian, mesh: MeshExtractResult, trial_id: str) -> dict: """ 상태 정보 패킹 """ return { 'gaussian': { **gs.init_params, '_xyz': gs._xyz.cpu().numpy(), '_features_dc': gs._features_dc.cpu().numpy(), '_scaling': gs._scaling.cpu().numpy(), '_rotation': gs._rotation.cpu().numpy(), '_opacity': gs._opacity.cpu().numpy(), }, 'mesh': { 'vertices': mesh.vertices.cpu().numpy(), 'faces': mesh.faces.cpu().numpy(), }, 'trial_id': trial_id, } def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]: """ 상태 정보 언패킹 """ device = init_cuda() gs = Gaussian( aabb=state['gaussian']['aabb'], sh_degree=state['gaussian']['sh_degree'], mininum_kernel_size=state['gaussian']['mininum_kernel_size'], scaling_bias=state['gaussian']['scaling_bias'], opacity_bias=state['gaussian']['opacity_bias'], scaling_activation=state['gaussian']['scaling_activation'], ) gs._xyz = torch.tensor(state['gaussian']['_xyz'], device=device) gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device=device) gs._scaling = torch.tensor(state['gaussian']['_scaling'], device=device) gs._rotation = torch.tensor(state['gaussian']['_rotation'], device=device) gs._opacity = torch.tensor(state['gaussian']['_opacity'], device=device) mesh = edict( vertices=torch.tensor(state['mesh']['vertices'], device=device), faces=torch.tensor(state['mesh']['faces'], device=device), ) return gs, mesh, state['trial_id'] @spaces.GPU def image_to_3d(trial_id: str, seed: int, randomize_seed: bool, ss_guidance_strength: float, ss_sampling_steps: int, slat_guidance_strength: float, slat_sampling_steps: int) -> Tuple[dict, str]: """ 이미지를 3D 모델로 변환 """ try: if randomize_seed: seed = np.random.randint(0, MAX_SEED) outputs = pipeline.run( Image.open(f"{TMP_DIR}/{trial_id}.png"), seed=seed, formats=["gaussian", "mesh"], preprocess_image=False, sparse_structure_sampler_params={ "steps": ss_sampling_steps, "cfg_strength": ss_guidance_strength, }, slat_sampler_params={ "steps": slat_sampling_steps, "cfg_strength": slat_guidance_strength, }, ) video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color'] video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal'] video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))] trial_id = uuid.uuid4() video_path = f"{TMP_DIR}/{trial_id}.mp4" os.makedirs(os.path.dirname(video_path), exist_ok=True) imageio.mimsave(video_path, video, fps=15) state = pack_state(outputs['gaussian'][0], outputs['mesh'][0], trial_id) return state, video_path except Exception as e: print(f"3D 변환 중 오류 발생: {e}") return None, None @spaces.GPU def extract_glb(state: dict, mesh_simplify: float, texture_size: int) -> Tuple[str, str]: """ 3D 모델에서 GLB 파일 추출 """ try: gs, mesh, trial_id = unpack_state(state) glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False) glb_path = f"{TMP_DIR}/{trial_id}.glb" glb.export(glb_path) return glb_path, glb_path except Exception as e: print(f"GLB 추출 중 오류 발생: {e}") return None, None def activate_button() -> gr.Button: return gr.Button(interactive=True) def deactivate_button() -> gr.Button: return gr.Button(interactive=False) # Gradio 인터페이스 설정 css = """ footer { visibility: hidden; } """ with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo: gr.Markdown(""" ## Roblox3D""") with gr.Row(): with gr.Column(): image_prompt = gr.Image(label="Image Prompt", image_mode="RGBA", type="pil", height=300) with gr.Accordion(label="Generation Settings", open=False): seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1) randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) gr.Markdown("Stage 1: Sparse Structure Generation") with gr.Row(): ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1) ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) gr.Markdown("Stage 2: Structured Latent Generation") with gr.Row(): slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1) slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) generate_btn = gr.Button("Generate") with gr.Accordion(label="GLB Extraction Settings", open=False): mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01) texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512) extract_glb_btn = gr.Button("Extract GLB", interactive=False) with gr.Column(): video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300) model_output = LitModel3D(label="Extracted GLB", exposure=20.0, height=300) download_glb = gr.DownloadButton(label="Download GLB", interactive=False) trial_id = gr.Textbox(visible=False) output_buf = gr.State() # 예제 이미지 설정 with gr.Row(): examples = gr.Examples( examples=[ f'assets/example_image/{image}' for image in os.listdir("assets/example_image") ], inputs=[image_prompt], fn=preprocess_image, outputs=[trial_id, image_prompt], run_on_click=True, examples_per_page=64, ) # 이벤트 핸들러 설정 image_prompt.upload( preprocess_image, inputs=[image_prompt], outputs=[trial_id, image_prompt], ) image_prompt.clear( lambda: '', outputs=[trial_id], ) generate_btn.click( image_to_3d, inputs=[trial_id, seed, randomize_seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps], outputs=[output_buf, video_output], ).then( activate_button, outputs=[extract_glb_btn], ) video_output.clear( deactivate_button, outputs=[extract_glb_btn], ) extract_glb_btn.click( extract_glb, inputs=[output_buf, mesh_simplify, texture_size], outputs=[model_output, download_glb], ).then( activate_button, outputs=[download_glb], ) model_output.clear( deactivate_button, outputs=[download_glb], ) # 메인 실행부 if __name__ == "__main__": try: device = init_cuda() pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large") pipeline.to(device) # rembg 사전 로드 시도 try: pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8))) except Exception as e: print(f"사전 로드 중 오류 발생: {e}") # 공유 GPU 환경을 위한 설정으로 데모 실행 demo.queue(max_size=10).launch(share=True) except Exception as e: print(f"애플리케이션 시작 중 오류 발생: {e}")