import os import numpy as np from typing import Literal, cast import torch from PIL import Image, ImageOps from diffusers import DiffusionPipeline import gradio as gr from gradio.components.image_editor import EditorValue import spaces DEVICE = "cuda" MAIN_MODEL_REPO_ID = os.getenv("MAIN_MODEL_REPO_ID", None) SUB_MODEL_REPO_ID = os.getenv("SUB_MODEL_REPO_ID", None) SUB_MODEL_SUBFOLDER = os.getenv("SUB_MODEL_SUBFOLDER", None) if MAIN_MODEL_REPO_ID is None: raise ValueError("MAIN_MODEL_REPO_ID is not set") if SUB_MODEL_REPO_ID is None: raise ValueError("SUB_MODEL_REPO_ID is not set") if SUB_MODEL_SUBFOLDER is None: raise ValueError("SUB_MODEL_SUBFOLDER is not set") pipeline = DiffusionPipeline.from_pretrained( MAIN_MODEL_REPO_ID, torch_dtype=torch.bfloat16, custom_pipeline=SUB_MODEL_REPO_ID, ).to(DEVICE) def crop_divisible_by_16(image: Image.Image) -> Image.Image: w, h = image.size w = w - w % 16 h = h - h % 16 return image.crop((0, 0, w, h)) @spaces.GPU(duration=150) def predict( image_and_mask: EditorValue | None, prompt: str = "", seed: int = 0, num_inference_steps: int = 28, max_dimension: int = 704, condition_scale: float = 1.0, checkpoint_step: int = 10_000, checkpoint_version: Literal["v1", "v2"] = "v1", progress: gr.Progress = gr.Progress(track_tqdm=True), # noqa: ARG001, B008 ) -> Image.Image | None: # ) -> tuple[Image.Image, Image.Image] | None: if not image_and_mask: gr.Info("Please upload an image and draw a mask") return None image_np = image_and_mask["background"] image_np = cast(np.ndarray, image_np) # If the image is empty, return None if np.sum(image_np) == 0: gr.Info("Please upload an image") return None alpha_channel = image_and_mask["layers"][0] alpha_channel = cast(np.ndarray, alpha_channel) mask_np = np.where(alpha_channel[:, :, 3] == 0, 0, 255).astype(np.uint8) # if mask_np is empty, return None if np.sum(mask_np) == 0: gr.Info("Please mark the areas you want to remove") return None try: pipeline.load( SUB_MODEL_REPO_ID, subfolder=f"model/{checkpoint_version}/ckpt/{checkpoint_step}", ) except Exception as e: gr.Info( f"Error loading checkpoint (model/{checkpoint_version}/ckpt/{checkpoint_step}): {e}" ) return None image = Image.fromarray(image_np) # Resize to max dimension image.thumbnail((max_dimension, max_dimension)) # Ensure dimensions are multiple of 16 (for VAE) image = crop_divisible_by_16(image) mask = Image.fromarray(mask_np) mask.thumbnail((max_dimension, max_dimension)) mask = crop_divisible_by_16(mask) # Invert the mask mask = ImageOps.invert(mask) # Image masked is the image with the mask applied (black background) image_masked = Image.new("RGB", image.size, (0, 0, 0)) image_masked.paste(image, (0, 0), mask) generator = torch.Generator(device="cpu").manual_seed(seed) final_image = pipeline( condition_image=image_masked, condition_scale=condition_scale, prompt=prompt, num_inference_steps=num_inference_steps, generator=generator, max_sequence_length=512, latent_lora=True if checkpoint_version == "v2" else False, ).images[0] return final_image intro_markdown = r""" # Inpainting Demo """ css = r""" #col-left { margin: 0 auto; max-width: 650px; } #col-right { margin: 0 auto; max-width: 650px; } #col-showcase { margin: 0 auto; max-width: 1100px; } """ with gr.Blocks(css=css) as demo: gr.Markdown(intro_markdown) with gr.Row() as content: with gr.Column(elem_id="col-left"): gr.HTML( """
Step 1. Upload a room image ⬇️
""", max_height=50, ) image_and_mask = gr.ImageMask( label="Image and Mask", layers=False, height="full", width="full", show_fullscreen_button=False, sources=["upload"], show_download_button=False, interactive=True, brush=gr.Brush(default_size=75, colors=["#000000"], color_mode="fixed"), transforms=[], ) prompt = gr.Textbox(label="Prompt", value="An empty room") checkpoint_step = gr.Slider( label="Checkpoint Step", minimum=1_000, maximum=20_000, step=1_000, value=10_000, ) checkpoint_version = gr.Radio( ["v1", "v2"], label="Checkpoint Version", interactive=True, ) with gr.Column(elem_id="col-right"): gr.HTML( """
Step 2. Press Run to launch
""", max_height=50, ) # image_slider = ImageSlider( # label="Result", # interactive=False, # ) result = gr.Image(label="Result") run_button = gr.Button("Run") with gr.Accordion("Advanced Settings", open=False): seed = gr.Slider( label="Seed", minimum=0, maximum=100_000, step=1, value=0, ) condition_scale = gr.Slider( label="Condition Scale", minimum=-10.0, maximum=10.0, step=0.10, value=1.0, ) with gr.Column(): max_dimension = gr.Slider( label="Max Dimension", minimum=512, maximum=2048, step=128, value=704, ) num_inference_steps = gr.Slider( label="Number of inference steps", minimum=1, maximum=50, step=1, value=28, ) run_button.click( fn=predict, inputs=[ image_and_mask, prompt, seed, num_inference_steps, max_dimension, condition_scale, checkpoint_step, checkpoint_version, ], # outputs=[image_slider], outputs=[result], ) demo.launch()