Redux / app.py
nftnik's picture
Update app.py
250758a verified
raw
history blame
6.26 kB
import os
import sys
import random
import torch
from pathlib import Path
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download
import spaces
from typing import Union, Sequence, Mapping, Any
# Configuração inicial e diagnóstico CUDA
print("Python version:", sys.version)
print("Torch version:", torch.__version__)
print("CUDA disponível:", torch.cuda.is_available())
print("Quantidade de GPUs:", torch.cuda.device_count())
if torch.cuda.is_available():
print("GPU atual:", torch.cuda.get_device_name(0))
# Adicionar o caminho da pasta ComfyUI ao sys.path
current_dir = os.path.dirname(os.path.abspath(__file__))
comfyui_path = os.path.join(current_dir, "ComfyUI")
sys.path.append(comfyui_path)
# Importar ComfyUI components
from nodes import NODE_CLASS_MAPPINGS, init_extra_nodes
from comfy import model_management
import folder_paths
# Configuração de diretórios
BASE_DIR = os.path.dirname(os.path.realpath(__file__))
output_dir = os.path.join(BASE_DIR, "output")
os.makedirs(output_dir, exist_ok=True)
folder_paths.set_output_directory(output_dir)
# Inicializar nós extras
print("Inicializando nós extras...")
init_extra_nodes()
# Helper function
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
try:
return obj[index]
except KeyError:
return obj["result"][index]
# Baixar modelos necessários
def download_models():
print("Baixando modelos...")
models = [
("black-forest-labs/FLUX.1-Redux-dev", "flux1-redux-dev.safetensors", "models/style_models"),
("comfyanonymous/flux_text_encoders", "t5xxl_fp16.safetensors", "models/text_encoders"),
("zer0int/CLIP-GmP-ViT-L-14", "ViT-L-14-TEXT-detail-improved-hiT-GmP-HF.safetensors", "models/text_encoders"),
("black-forest-labs/FLUX.1-dev", "ae.safetensors", "models/vae"),
("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors", "models/diffusion_models"), # Corrigido aqui
("google/siglip-so400m-patch14-384", "model.safetensors", "models/clip_vision"),
("nftnik/NFTNIK-FLUX.1-dev-LoRA", "NFTNIK_FLUX.1[dev]_LoRA.safetensors", "models/lora")
]
for repo_id, filename, local_dir in models:
try:
os.makedirs(local_dir, exist_ok=True)
print(f"Baixando {filename} de {repo_id}...")
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir)
except Exception as e:
print(f"Erro ao baixar {filename} de {repo_id}: {str(e)}")
# Continue mesmo se um download falhar
continue
# Download models antes de inicializar
download_models()
# Inicializar modelos
print("Inicializando modelos...")
with torch.inference_mode():
# Initialize nodes
intconstant = NODE_CLASS_MAPPINGS["INTConstant"]()
dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
dualcliploader_357 = dualcliploader.load_clip(
clip_name1="models/text_encoders/t5xxl_fp16.safetensors",
clip_name2="models/text_encoders/ViT-L-14-TEXT-detail-improved-hiT-GmP-HF.safetensors",
type="flux",
)
stylemodelloader = NODE_CLASS_MAPPINGS["StyleModelLoader"]()
stylemodelloader_441 = stylemodelloader.load_style_model(
style_model_name="models/style_models/flux1-redux-dev.safetensors"
)
vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
vaeloader_359 = vaeloader.load_vae(vae_name="models/vae/ae.safetensors")
# Carregar modelos na GPU
model_loaders = [dualcliploader_357, vaeloader_359, stylemodelloader_441]
valid_models = [
getattr(loader[0], 'patcher', loader[0])
for loader in model_loaders
if not isinstance(loader[0], dict) and not isinstance(getattr(loader[0], 'patcher', None), dict)
]
model_management.load_models_gpu(valid_models)
@spaces.GPU
def generate_image(prompt, input_image, lora_weight, progress=gr.Progress(track_tqdm=True)):
"""Função principal de geração com monitoramento de progresso"""
try:
with torch.inference_mode():
# Codificar texto
cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
encoded_text = cliptextencode.encode(
text=prompt,
clip=get_value_at_index(dualcliploader_357, 0)
)
# Carregar LoRA
loraloadermodelonly = NODE_CLASS_MAPPINGS["LoraLoaderModelOnly"]()
lora_model = loraloadermodelonly.load_lora_model_only(
lora_name="models/lora/NFTNIK_FLUX.1[dev]_LoRA.safetensors",
strength_model=lora_weight,
model=get_value_at_index(stylemodelloader_441, 0)
)
# Processar imagem
loadimage = NODE_CLASS_MAPPINGS["LoadImage"]()
loaded_image = loadimage.load_image(image=input_image)
# Decodificar
vaedecode = NODE_CLASS_MAPPINGS["VAEDecode"]()
decoded = vaedecode.decode(
samples=get_value_at_index(lora_model, 0),
vae=get_value_at_index(vaeloader_359, 0)
)
# Salvar imagem
temp_filename = f"Flux_{random.randint(0, 99999)}.png"
temp_path = os.path.join(output_dir, temp_filename)
Image.fromarray((get_value_at_index(decoded, 0) * 255).astype("uint8")).save(temp_path)
return temp_path
except Exception as e:
print(f"Erro ao gerar imagem: {str(e)}")
return None
# Interface Gradio
with gr.Blocks() as app:
gr.Markdown("# Gerador de Imagens FLUX Redux")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(label="Prompt", placeholder="Digite seu prompt aqui...", lines=5)
input_image = gr.Image(label="Imagem de Entrada", type="filepath")
lora_weight = gr.Slider(minimum=0, maximum=2, step=0.1, value=0.6, label="Peso LoRA")
generate_btn = gr.Button("Gerar Imagem")
with gr.Column():
output_image = gr.Image(label="Imagem Gerada", type="filepath")
generate_btn.click(
fn=generate_image,
inputs=[prompt_input, input_image, lora_weight],
outputs=[output_image]
)
if __name__ == "__main__":
app.launch()