import os import numpy as np from typing import Literal, cast import torch from PIL import Image, ImageOps from diffusers import DiffusionPipeline import gradio as gr from gradio.components.image_editor import EditorValue import spaces DEVICE = "cuda" MAIN_MODEL_REPO_ID = os.getenv("MAIN_MODEL_REPO_ID", None) SUB_MODEL_REPO_ID = os.getenv("SUB_MODEL_REPO_ID", None) SUB_MODEL_SUBFOLDER = os.getenv("SUB_MODEL_SUBFOLDER", None) if MAIN_MODEL_REPO_ID is None: raise ValueError("MAIN_MODEL_REPO_ID is not set") if SUB_MODEL_REPO_ID is None: raise ValueError("SUB_MODEL_REPO_ID is not set") if SUB_MODEL_SUBFOLDER is None: raise ValueError("SUB_MODEL_SUBFOLDER is not set") pipeline = DiffusionPipeline.from_pretrained( MAIN_MODEL_REPO_ID, torch_dtype=torch.bfloat16, custom_pipeline=SUB_MODEL_REPO_ID, ).to(DEVICE) def crop_divisible_by_16(image: Image.Image) -> Image.Image: w, h = image.size w = w - w % 16 h = h - h % 16 return image.crop((0, 0, w, h)) @spaces.GPU(duration=150) def predict( image_and_mask: EditorValue | None, prompt: str = "", seed: int = 0, num_inference_steps: int = 28, max_dimension: int = 704, condition_scale: float = 1.0, checkpoint_step: int = 10_000, checkpoint_version: Literal["v1", "v2"] = "v1", progress: gr.Progress = gr.Progress(track_tqdm=True), # noqa: ARG001, B008 ) -> Image.Image | None: # ) -> tuple[Image.Image, Image.Image] | None: if not image_and_mask: gr.Info("Please upload an image and draw a mask") return None image_np = image_and_mask["background"] image_np = cast(np.ndarray, image_np) # If the image is empty, return None if np.sum(image_np) == 0: gr.Info("Please upload an image") return None alpha_channel = image_and_mask["layers"][0] alpha_channel = cast(np.ndarray, alpha_channel) mask_np = np.where(alpha_channel[:, :, 3] == 0, 0, 255).astype(np.uint8) # if mask_np is empty, return None if np.sum(mask_np) == 0: gr.Info("Please mark the areas you want to remove") return None try: pipeline.load( SUB_MODEL_REPO_ID, subfolder=f"model/{checkpoint_version}/ckpt/{checkpoint_step}", ) except Exception as e: gr.Info( f"Error loading checkpoint (model/{checkpoint_version}/ckpt/{checkpoint_step}): {e}" ) return None image = Image.fromarray(image_np) # Resize to max dimension image.thumbnail((max_dimension, max_dimension)) # Ensure dimensions are multiple of 16 (for VAE) image = crop_divisible_by_16(image) mask = Image.fromarray(mask_np) mask.thumbnail((max_dimension, max_dimension)) mask = crop_divisible_by_16(mask) # Invert the mask mask = ImageOps.invert(mask) # Image masked is the image with the mask applied (black background) image_masked = Image.new("RGB", image.size, (0, 0, 0)) image_masked.paste(image, (0, 0), mask) generator = torch.Generator(device="cpu").manual_seed(seed) final_image = pipeline( condition_image=image_masked, condition_scale=condition_scale, prompt=prompt, num_inference_steps=num_inference_steps, generator=generator, max_sequence_length=512, latent_lora=True if checkpoint_version == "v2" else False, ).images[0] return final_image intro_markdown = r""" # Inpainting Demo """ css = r""" #col-left { margin: 0 auto; max-width: 650px; } #col-right { margin: 0 auto; max-width: 650px; } #col-showcase { margin: 0 auto; max-width: 1100px; } """ with gr.Blocks(css=css) as demo: gr.Markdown(intro_markdown) with gr.Row() as content: with gr.Column(elem_id="col-left"): gr.HTML( """