diffusion / generate.py
adamelliotfields's picture
Add AuraSR GAN
cb5daed verified
raw
history blame
12.3 kB
import json
import os
import re
import time
from contextlib import contextmanager
from datetime import datetime
from itertools import product
from typing import Callable
import spaces
import tomesd
import torch
from aura_sr import AuraSR
from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType
from compel.prompt_parser import PromptParser
from DeepCache import DeepCacheSDHelper
from diffusers import (
DEISMultistepScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
HeunDiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
StableDiffusionPipeline,
)
from diffusers.models import AutoencoderKL, AutoencoderTiny
from torch._dynamo import OptimizedModule
__import__("warnings").filterwarnings("ignore", category=FutureWarning, module="diffusers")
__import__("warnings").filterwarnings("ignore", category=FutureWarning, module="transformers")
__import__("transformers").logging.set_verbosity_error()
ZERO_GPU = (
os.environ.get("SPACES_ZERO_GPU", "").lower() == "true"
or os.environ.get("SPACES_ZERO_GPU", "") == "1"
)
EMBEDDINGS = {
"./embeddings/bad_prompt_version2.pt": "<bad_prompt>",
"./embeddings/BadDream.pt": "<bad_dream>",
"./embeddings/FastNegativeV2.pt": "<fast_negative>",
"./embeddings/negative_hand.pt": "<negative_hand>",
"./embeddings/UnrealisticDream.pt": "<unrealistic_dream>",
}
with open("./styles/twri.json") as f:
styles = json.load(f)
# inspired by ComfyUI
# https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/model_management.py
class Loader:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(Loader, cls).__new__(cls)
cls._instance.cpu = torch.device("cpu")
cls._instance.gpu = torch.device("cuda")
cls._instance.gan = None
cls._instance.pipe = None
return cls._instance
def _load_deepcache(self, interval=1):
has_deepcache = hasattr(self.pipe, "deepcache")
if has_deepcache and self.pipe.deepcache.params["cache_interval"] == interval:
return self.pipe.deepcache
if has_deepcache:
self.pipe.deepcache.disable()
else:
self.pipe.deepcache = DeepCacheSDHelper(pipe=self.pipe)
self.pipe.deepcache.set_params(cache_interval=interval)
self.pipe.deepcache.enable()
return self.pipe.deepcache
def _load_vae(self, model_name=None, taesd=False, dtype=None):
vae_type = type(self.pipe.vae)
is_kl = issubclass(vae_type, (AutoencoderKL, OptimizedModule))
is_tiny = issubclass(vae_type, AutoencoderTiny)
# by default all models use KL
if is_kl and taesd:
# can't compile tiny VAE
print("Switching to Tiny VAE...")
self.pipe.vae = AutoencoderTiny.from_pretrained(
pretrained_model_name_or_path="madebyollin/taesd",
use_safetensors=True,
torch_dtype=dtype,
).to(self.gpu)
return self.pipe.vae
if is_tiny and not taesd:
print("Switching to KL VAE...")
self.pipe.vae = torch.compile(
fullgraph=True,
mode="reduce-overhead",
model=AutoencoderKL.from_pretrained(
pretrained_model_name_or_path=model_name,
use_safetensors=True,
torch_dtype=dtype,
subfolder="vae",
).to(self.gpu),
)
return self.pipe.vae
def load(self, model, scheduler, karras, taesd, deepcache_interval, upscale, dtype=None):
model_lower = model.lower()
schedulers = {
"DEIS 2M": DEISMultistepScheduler,
"DPM++ 2M": DPMSolverMultistepScheduler,
"DPM2 a": KDPM2AncestralDiscreteScheduler,
"Euler a": EulerAncestralDiscreteScheduler,
"Heun": HeunDiscreteScheduler,
"LMS": LMSDiscreteScheduler,
"PNDM": PNDMScheduler,
}
scheduler_kwargs = {
"beta_schedule": "scaled_linear",
"timestep_spacing": "leading",
"use_karras_sigmas": karras,
"beta_start": 0.00085,
"beta_end": 0.012,
"steps_offset": 1,
}
if scheduler in ["Euler a", "PNDM"]:
del scheduler_kwargs["use_karras_sigmas"]
pipe_kwargs = {
"scheduler": schedulers[scheduler](**scheduler_kwargs),
"pretrained_model_name_or_path": model_lower,
"requires_safety_checker": False,
"use_safetensors": True,
"safety_checker": None,
"torch_dtype": dtype,
}
# already loaded
if self.pipe is not None:
model_name = self.pipe.config._name_or_path
same_model = model_name.lower() == model_lower
same_scheduler = isinstance(self.pipe.scheduler, schedulers[scheduler])
same_karras = (
not hasattr(self.pipe.scheduler.config, "use_karras_sigmas")
or self.pipe.scheduler.config.use_karras_sigmas == karras
)
if same_model:
if not same_scheduler:
print(f"Switching to {scheduler}...")
if not same_karras:
print(f"{'Enabling' if karras else 'Disabling'} Karras sigmas...")
if not same_scheduler or not same_karras:
self.pipe.scheduler = schedulers[scheduler](**scheduler_kwargs)
self._load_vae(model_lower, taesd, dtype)
self._load_deepcache(interval=deepcache_interval)
return self.pipe, self.gan
else:
print(f"Unloading {model_name.lower()}...")
self.pipe = None
torch.cuda.empty_cache()
# no fp16 available
if not ZERO_GPU and model_lower not in [
"sg161222/realistic_vision_v5.1_novae",
"prompthero/openjourney-v4",
"linaqruf/anything-v3-1",
]:
pipe_kwargs["variant"] = "fp16"
print(f"Loading {model_lower} with {'Tiny' if taesd else 'KL'} VAE...")
self.pipe = StableDiffusionPipeline.from_pretrained(**pipe_kwargs).to(self.gpu)
self.pipe.load_textual_inversion(
pretrained_model_name_or_path=list(EMBEDDINGS.keys()),
tokens=list(EMBEDDINGS.values()),
)
self._load_vae(model_lower, taesd, dtype)
self._load_deepcache(interval=deepcache_interval)
if upscale and self.gan is None:
print("Loading fal/AuraSR-v2...")
self.gan = AuraSR.from_pretrained("fal/AuraSR-v2")
if not upscale and self.gan is not None:
print("Unloading fal/AuraSR-v2...")
self.gan = None
torch.cuda.empty_cache
return self.pipe, self.gan
# applies tome to the pipeline
@contextmanager
def token_merging(pipe, tome_ratio=0):
try:
if tome_ratio > 0:
tomesd.apply_patch(pipe, max_downsample=1, sx=2, sy=2, ratio=tome_ratio)
yield
finally:
tomesd.remove_patch(pipe) # idempotent
# parse prompts with arrays
def parse_prompt(prompt: str) -> list[str]:
arrays = re.findall(r"\[\[(.*?)\]\]", prompt)
if not arrays:
return [prompt]
tokens = [item.split(",") for item in arrays]
combinations = list(product(*tokens))
prompts = []
for combo in combinations:
current_prompt = prompt
for i, token in enumerate(combo):
current_prompt = current_prompt.replace(f"[[{arrays[i]}]]", token.strip(), 1)
prompts.append(current_prompt)
return prompts
def apply_style(prompt, style_name, negative=False):
global styles
if not style_name or style_name == "None":
return prompt
for style in styles:
if style["name"] == style_name:
if negative:
return prompt + " . " + style["negative_prompt"]
else:
return style["prompt"].format(prompt=prompt)
return prompt
@spaces.GPU(duration=40)
def generate(
positive_prompt,
negative_prompt="",
style=None,
seed=None,
model="runwayml/stable-diffusion-v1-5",
scheduler="PNDM",
width=512,
height=512,
guidance_scale=7.5,
inference_steps=50,
num_images=1,
karras=False,
taesd=False,
clip_skip=False,
truncate_prompts=False,
increment_seed=True,
deepcache_interval=1,
tome_ratio=0,
upscale=False,
log: Callable[[str], None] = None,
Error=Exception,
):
if not torch.cuda.is_available():
raise Error("CUDA not available")
# https://pytorch.org/docs/stable/generated/torch.manual_seed.html
if seed is None or seed < 0:
seed = int(datetime.now().timestamp() * 1_000_000) % (2**64)
GPU = torch.device("cuda")
TORCH_DTYPE = (
torch.bfloat16
if torch.cuda.is_available() and torch.cuda.is_bf16_supported(including_emulation=False)
else torch.float16
)
EMBEDDINGS_TYPE = (
ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NORMALIZED
if clip_skip
else ReturnedEmbeddingsType.LAST_HIDDEN_STATES_NORMALIZED
)
with torch.inference_mode():
start = time.perf_counter()
loader = Loader()
pipe, gan = loader.load(
model,
scheduler,
karras,
taesd,
deepcache_interval,
upscale,
TORCH_DTYPE,
)
# prompt embeds
compel = Compel(
textual_inversion_manager=DiffusersTextualInversionManager(pipe),
dtype_for_device_getter=lambda _: TORCH_DTYPE,
returned_embeddings_type=EMBEDDINGS_TYPE,
truncate_long_prompts=truncate_prompts,
text_encoder=pipe.text_encoder,
tokenizer=pipe.tokenizer,
device=GPU,
)
images = []
current_seed = seed
try:
styled_negative_prompt = apply_style(negative_prompt, style, negative=True)
neg_embeds = compel(styled_negative_prompt)
except PromptParser.ParsingException:
raise Error("ParsingException: Invalid negative prompt")
for i in range(num_images):
# seeded generator for each iteration
generator = torch.Generator(device=GPU).manual_seed(current_seed)
try:
all_positive_prompts = parse_prompt(positive_prompt)
prompt_index = i % len(all_positive_prompts)
pos_prompt = all_positive_prompts[prompt_index]
styled_pos_prompt = apply_style(pos_prompt, style)
pos_embeds = compel(styled_pos_prompt)
pos_embeds, neg_embeds = compel.pad_conditioning_tensors_to_same_length(
[pos_embeds, neg_embeds]
)
except PromptParser.ParsingException:
raise Error("ParsingException: Invalid prompt")
with token_merging(pipe, tome_ratio=tome_ratio):
image = pipe(
num_inference_steps=inference_steps,
negative_prompt_embeds=neg_embeds,
guidance_scale=guidance_scale,
prompt_embeds=pos_embeds,
generator=generator,
height=height,
width=width,
).images[0]
if upscale:
print("Upscaling image...")
batch_size = 12 if ZERO_GPU else 4 # smaller batch to fit in 8GB
image = gan.upscale_4x_overlapped(image, max_batch_size=batch_size)
images.append((image, str(current_seed)))
if increment_seed:
current_seed += 1
if ZERO_GPU:
# spaces always start fresh
loader.pipe = None
loader.gan = None
diff = time.perf_counter() - start
if log:
log(f"Generated {len(images)} image{'s' if len(images) > 1 else ''} in {diff:.2f}s")
return images