blanchon's picture
Update pipeline
5565519
raw
history blame
5.24 kB
import os
import torch
from PIL import Image
from diffusers import FluxPipeline
import gradio as gr
import spaces
DEVICE = "cuda"
MAIN_MODEL_REPO_ID = os.getenv("MAIN_MODEL_REPO_ID", None)
SUB_MODEL_REPO_ID = os.getenv("SUB_MODEL_REPO_ID", None)
SUB_MODEL_SUBFOLDER = os.getenv("SUB_MODEL_SUBFOLDER", None)
if MAIN_MODEL_REPO_ID is None:
raise ValueError("MAIN_MODEL_REPO_ID is not set")
if SUB_MODEL_REPO_ID is None:
raise ValueError("SUB_MODEL_REPO_ID is not set")
if SUB_MODEL_SUBFOLDER is None:
raise ValueError("SUB_MODEL_SUBFOLDER is not set")
pipeline = FluxPipeline.from_pretrained(
MAIN_MODEL_REPO_ID,
custom_pipeline=SUB_MODEL_REPO_ID,
).to(DEVICE)
pipeline.load(
SUB_MODEL_REPO_ID,
subfolder=SUB_MODEL_SUBFOLDER,
)
def crop_divisible_by_16(image: Image.Image) -> Image.Image:
w, h = image.size
w = w - w % 16
h = h - h % 16
return image.crop((0, 0, w, h))
@spaces.GPU(duration=150)
def predict(
room_image_input: Image.Image,
room_image_category: str,
custom_prompt: str | None = None,
seed: int = 0,
num_inference_steps: int = 28,
max_dimension: int = 1024,
condition_scale: float = 1.0,
progress: gr.Progress = gr.Progress(track_tqdm=True), # noqa: ARG001, B008
) -> Image.Image:
# Resize to max dimension
room_image_input.thumbnail((max_dimension, max_dimension))
# Ensure dimensions are multiple of 16 (for VAE)
room_image_input = crop_divisible_by_16(room_image_input)
prompt = f"[VIRTUAL STAGING] {room_image_category}\n"
if custom_prompt:
prompt += f" {custom_prompt}"
generator = torch.Generator(device=DEVICE).manual_seed(seed)
final_images = pipeline(
condition_image=room_image_input,
condition_scale=condition_scale,
prompt=prompt,
num_inference_steps=num_inference_steps,
generator=generator,
)
return final_images
intro_markdown = r"""
# Virtual Staging Demo
"""
css = r"""
#col-left {
margin: 0 auto;
max-width: 650px;
}
#col-right {
margin: 0 auto;
max-width: 650px;
}
#col-showcase {
margin: 0 auto;
max-width: 1100px;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(intro_markdown)
with gr.Row() as content:
with gr.Column(elem_id="col-left"):
gr.HTML(
"""
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
<div>
Step 1. Upload a room image ⬇️
</div>
</div>
""",
max_height=50,
)
room_image_input = gr.Image(
label="room",
type="pil",
sources=["upload"],
image_mode="RGB",
)
room_image_category = gr.Dropdown(
label="Room category",
choices=[
"bedroom",
"living room",
"bathroom",
],
info="Select the room category",
multiselect=False,
)
with gr.Column(elem_id="col-right"):
gr.HTML(
"""
<div style="display: flex; justify-content: center; align-items: center; text-align: center; font-size: 20px;">
<div>
Step 2. Press Run to launch
</div>
</div>
""",
max_height=50,
)
result = gr.Image(label="result")
run_button = gr.Button("Run")
with gr.Accordion("Advanced Settings", open=False):
custom_prompt = gr.Text(
label="Prompt",
max_lines=3,
placeholder="Enter a custom prompt",
container=False,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=100_000,
step=1,
value=0,
)
condition_scale = gr.Slider(
label="Condition Scale",
minimum=-10.0,
maximum=10.0,
step=0.10,
value=1.0,
)
with gr.Column():
max_dimension = gr.Slider(
label="Max Dimension",
minimum=512,
maximum=2048,
step=128,
value=1024,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=28,
)
run_button.click(
fn=predict,
inputs=[
room_image_input,
room_image_category,
custom_prompt,
seed,
num_inference_steps,
max_dimension,
condition_scale,
],
outputs=[result],
)
demo.launch()