import gc import os import random import numpy as np import json import torch import uuid from PIL import Image, PngImagePlugin from datetime import datetime from dataclasses import dataclass from typing import Callable, Dict, Optional, Tuple, Any, List from diffusers import ( DDIMScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, AutoencoderKL, StableDiffusionXLPipeline, ) import logging def load_pipeline(model_name: str, device: torch.device, hf_token: Optional[str] = None, vae: Optional[AutoencoderKL] = None) -> Any: """Load the Stable Diffusion pipeline.""" try: pipeline = ( StableDiffusionXLPipeline.from_single_file if model_name.endswith(".safetensors") else StableDiffusionXLPipeline.from_pretrained ) pipe = pipeline( model_name, vae=vae, torch_dtype=torch.float16, custom_pipeline="lpw_stable_diffusion_xl", use_safetensors=True, add_watermarker=False ) pipe.to(device) return pipe except Exception as e: logging.error(f"Failed to load pipeline: {str(e)}", exc_info=True) raise def seed_everything(seed: int) -> torch.Generator: torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) generator = torch.Generator() generator.manual_seed(seed) return generator def preprocess_image_dimensions(width, height): if width % 8 != 0: width = width - (width % 8) if height % 8 != 0: height = height - (height % 8) return width, height def get_scheduler(scheduler_config: Dict, name: str) -> Optional[Callable]: scheduler_factory_map = { "DPM++ 2M Karras": lambda: DPMSolverMultistepScheduler.from_config( scheduler_config, use_karras_sigmas=True ), "DPM++ SDE Karras": lambda: DPMSolverSinglestepScheduler.from_config( scheduler_config, use_karras_sigmas=True ), "DPM++ 2M SDE Karras": lambda: DPMSolverMultistepScheduler.from_config( scheduler_config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++" ), "Euler": lambda: EulerDiscreteScheduler.from_config(scheduler_config), "Euler a": lambda: EulerAncestralDiscreteScheduler.from_config( scheduler_config ), "DDIM": lambda: DDIMScheduler.from_config(scheduler_config), } return scheduler_factory_map.get(name, lambda: None)() def common_upscale( samples: torch.Tensor, width: int, height: int, upscale_method: str, ) -> torch.Tensor: return torch.nn.functional.interpolate( samples, size=(height, width), mode=upscale_method ) def upscale( samples: torch.Tensor, upscale_method: str, scale_by: float ) -> torch.Tensor: width = round(samples.shape[3] * scale_by) height = round(samples.shape[2] * scale_by) return common_upscale(samples, width, height, upscale_method) def free_memory() -> None: """Free up GPU and system memory.""" if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() gc.collect() def save_image(image, output_dir): filename = str(uuid.uuid4()) + ".jpg" os.makedirs(output_dir, exist_ok=True) filepath = os.path.join(output_dir, filename) image.save(filepath, "JPEG", quality=80) return filepath