Update README.md
Browse files
README.md
CHANGED
@@ -1,471 +1,15 @@
|
|
1 |
-
# main.py
|
2 |
-
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
|
3 |
-
import torch
|
4 |
-
from PIL import Image, ImageEnhance
|
5 |
-
import os
|
6 |
-
import random
|
7 |
-
import json
|
8 |
-
import argparse
|
9 |
-
from pathlib import Path
|
10 |
-
from torch.utils.data import Dataset, DataLoader
|
11 |
-
from torchvision import transforms
|
12 |
-
from accelerate import Accelerator
|
13 |
-
from diffusers import UNet2DConditionModel, AutoencoderKL
|
14 |
-
from diffusers.training_utils import EMAModel
|
15 |
-
from diffusers import LoraLoaderMixin
|
16 |
-
from diffusers.optimization import get_scheduler
|
17 |
-
from torch.optim import AdamW
|
18 |
-
import math
|
19 |
-
import numpy as np
|
20 |
-
from huggingface_hub import create_repo, upload_folder
|
21 |
-
from huggingface_hub import HfFolder
|
22 |
-
from tqdm.auto import tqdm
|
23 |
-
|
24 |
-
# --- Configuration ---
|
25 |
-
MODEL_NAME = "photo-fluxXL"
|
26 |
-
BASE_MODEL = "kudzueye/Boreal" # Choose one of your base models
|
27 |
-
IMAGE_FOLDER = "/content/drive/MyDrive/training_data" # Replace with your image folder path
|
28 |
-
OUTPUT_DIR = "/content/drive/MyDrive/my_lora_models" # Replace with your output folder path
|
29 |
-
TRAIN_BATCH_SIZE = 1
|
30 |
-
GRADIENT_ACCUMULATION_STEPS = 4
|
31 |
-
LEARNING_RATE = 1e-4
|
32 |
-
NUM_EPOCHS = 10
|
33 |
-
SAVE_STEPS = 500
|
34 |
-
SEED = 42
|
35 |
-
PUSH_TO_HUB = False # Set to True if you want to push to Hugging Face Hub
|
36 |
-
HUB_REPO_ID = "your-username/your-repo-name" # Replace with your Hugging Face repo ID
|
37 |
-
GENERATE_AFTER_TRAINING = True # Set to True to generate images after training
|
38 |
-
PROMPTS_FILE = "/content/drive/MyDrive/prompts.json" # Replace with your prompts file path
|
39 |
-
BATCH_SIZE_GENERATE = 4 # Batch size for generation
|
40 |
-
|
41 |
-
# --- Load Base Model ---
|
42 |
-
pipe = StableDiffusionPipeline.from_pretrained(
|
43 |
-
BASE_MODEL,
|
44 |
-
torch_dtype=torch.float16,
|
45 |
-
safety_checker=None,
|
46 |
-
requires_safety_checker=False,
|
47 |
-
variant="fp16",
|
48 |
-
use_safetensors=True
|
49 |
-
).to("cuda")
|
50 |
-
|
51 |
-
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
|
52 |
-
pipe.scheduler.config,
|
53 |
-
algorithm_type="dpmsolver++",
|
54 |
-
solver_order=2
|
55 |
-
)
|
56 |
-
|
57 |
-
pipe.enable_attention_slicing()
|
58 |
-
pipe.enable_xformers_memory_efficient_attention()
|
59 |
-
|
60 |
-
# --- Prepare Dataset ---
|
61 |
-
class FluxDataset(Dataset):
|
62 |
-
def __init__(self, image_folder, transform=None):
|
63 |
-
self.image_paths = {}
|
64 |
-
for category in os.listdir(image_folder):
|
65 |
-
category_path = os.path.join(image_folder, category)
|
66 |
-
if os.path.isdir(category_path):
|
67 |
-
self.image_paths[category] = []
|
68 |
-
for subcategory in os.listdir(category_path):
|
69 |
-
subcategory_path = os.path.join(category_path, subcategory)
|
70 |
-
if os.path.isdir(subcategory_path):
|
71 |
-
self.image_paths[category].extend([os.path.join(subcategory_path, f) for f in os.listdir(subcategory_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
|
72 |
-
elif subcategory.lower().endswith(('.png', '.jpg', '.jpeg')):
|
73 |
-
self.image_paths[category].append(subcategory_path)
|
74 |
-
|
75 |
-
self.transform = transform
|
76 |
-
self.categories = list(self.image_paths.keys())
|
77 |
-
|
78 |
-
def __len__(self):
|
79 |
-
return max(len(paths) for paths in self.image_paths.values())
|
80 |
-
|
81 |
-
def __getitem__(self, idx):
|
82 |
-
item = {}
|
83 |
-
for category in self.categories:
|
84 |
-
if self.image_paths[category]:
|
85 |
-
image_path = self.image_paths[category][idx % len(self.image_paths[category])]
|
86 |
-
image = Image.open(image_path).convert("RGB")
|
87 |
-
if self.transform:
|
88 |
-
image = self.transform(image)
|
89 |
-
item[category] = image
|
90 |
-
return item
|
91 |
-
|
92 |
-
transform = transforms.Compose([
|
93 |
-
transforms.Resize((1024, 1024)),
|
94 |
-
transforms.ToTensor(),
|
95 |
-
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
96 |
-
])
|
97 |
-
|
98 |
-
dataset = FluxDataset(IMAGE_FOLDER, transform=transform)
|
99 |
-
|
100 |
-
# --- Data Loader ---
|
101 |
-
dataloader = DataLoader(dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True)
|
102 |
-
|
103 |
-
# --- Prepare LoRA ---
|
104 |
-
unet = pipe.unet
|
105 |
-
vae = pipe.vae
|
106 |
-
text_encoder = pipe.text_encoder
|
107 |
-
text_encoder_2 = pipe.text_encoder_2
|
108 |
-
|
109 |
-
unet_lora_layers = LoraLoaderMixin.get_lora_layers(unet)
|
110 |
-
vae_lora_layers = LoraLoaderMixin.get_lora_layers(vae)
|
111 |
-
text_encoder_lora_layers = LoraLoaderMixin.get_lora_layers(text_encoder)
|
112 |
-
text_encoder_2_lora_layers = LoraLoaderMixin.get_lora_layers(text_encoder_2)
|
113 |
-
|
114 |
-
# --- Optimizer ---
|
115 |
-
optimizer = AdamW(
|
116 |
-
[
|
117 |
-
{"params": unet_lora_layers.parameters(), "lr": LEARNING_RATE},
|
118 |
-
{"params": vae_lora_layers.parameters(), "lr": LEARNING_RATE},
|
119 |
-
{"params": text_encoder_lora_layers.parameters(), "lr": LEARNING_RATE},
|
120 |
-
{"params": text_encoder_2_lora_layers.parameters(), "lr": LEARNING_RATE},
|
121 |
-
]
|
122 |
-
)
|
123 |
-
|
124 |
-
# --- Scheduler ---
|
125 |
-
lr_scheduler = get_scheduler(
|
126 |
-
"cosine",
|
127 |
-
optimizer=optimizer,
|
128 |
-
num_warmup_steps=math.ceil(len(dataloader) * NUM_EPOCHS * 0.1),
|
129 |
-
num_training_steps=len(dataloader) * NUM_EPOCHS,
|
130 |
-
)
|
131 |
-
|
132 |
-
# --- Accelerator ---
|
133 |
-
accelerator = Accelerator(gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, mixed_precision="fp16")
|
134 |
-
|
135 |
-
unet, vae, text_encoder, text_encoder_2, optimizer, dataloader, lr_scheduler = accelerator.prepare(
|
136 |
-
unet, vae, text_encoder, text_encoder_2, optimizer, dataloader, lr_scheduler
|
137 |
-
)
|
138 |
-
|
139 |
-
# --- Training Loop ---
|
140 |
-
progress_bar = tqdm(range(len(dataloader) * NUM_EPOCHS), desc="Training")
|
141 |
-
global_step = 0
|
142 |
-
|
143 |
-
for epoch in range(NUM_EPOCHS):
|
144 |
-
for batch in dataloader:
|
145 |
-
with accelerator.accumulate(unet, vae, text_encoder, text_encoder_2):
|
146 |
-
latents = vae.encode(batch["body"].to(accelerator.device)).latent_dist.sample()
|
147 |
-
noise = torch.randn_like(latents)
|
148 |
-
timesteps = torch.randint(0, pipe.scheduler.num_train_timesteps, (latents.shape[0],), device=accelerator.device)
|
149 |
-
noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)
|
150 |
-
|
151 |
-
# --- Generate Prompt ---
|
152 |
-
prompt_parts = []
|
153 |
-
if "body" in batch:
|
154 |
-
if "belly" in batch:
|
155 |
-
prompt_parts.append("belly visible")
|
156 |
-
if "body_shape_1" in batch:
|
157 |
-
prompt_parts.append("body shape 1")
|
158 |
-
if "body_shape_2" in batch:
|
159 |
-
prompt_parts.append("body shape 2")
|
160 |
-
if "body_type" in batch:
|
161 |
-
prompt_parts.append("body type")
|
162 |
-
if "body measurements-proportion" in batch:
|
163 |
-
prompt_parts.append("body measurements-proportion")
|
164 |
-
if "details" in batch:
|
165 |
-
if "eyebrows" in batch:
|
166 |
-
prompt_parts.append("eyebrows")
|
167 |
-
if "eyelashes_1" in batch:
|
168 |
-
prompt_parts.append("eyelashes 1")
|
169 |
-
if "eyelashes_2" in batch:
|
170 |
-
prompt_parts.append("eyelashes 2")
|
171 |
-
if "hair" in batch:
|
172 |
-
prompt_parts.append("hair")
|
173 |
-
if "lips" in batch:
|
174 |
-
prompt_parts.append("lips")
|
175 |
-
if "face" in batch:
|
176 |
-
for i in range(1, 18):
|
177 |
-
if f"face_{i}" in batch:
|
178 |
-
prompt_parts.append(f"face {i}")
|
179 |
-
if "pose" in batch:
|
180 |
-
for i in range(1, 4):
|
181 |
-
if f"pose_{i}" in batch:
|
182 |
-
prompt_parts.append(f"pose {i}")
|
183 |
-
if "skin" in batch:
|
184 |
-
if "skin_tone" in batch:
|
185 |
-
prompt_parts.append("skin tone")
|
186 |
-
if "textures" in batch:
|
187 |
-
for texture in batch["textures"]:
|
188 |
-
prompt_parts.append(f"texture {texture}")
|
189 |
-
|
190 |
-
prompt = "a photo of a woman, " + ", ".join(prompt_parts)
|
191 |
-
prompt_embeds = pipe.text_encoder(pipe.tokenizer(prompt, return_tensors="pt").input_ids.to(accelerator.device))[0]
|
192 |
-
prompt_embeds_2 = pipe.text_encoder_2(pipe.tokenizer(prompt, return_tensors="pt").input_ids.to(accelerator.device))[0]
|
193 |
-
|
194 |
-
model_pred = unet(noisy_latents, timesteps, prompt_embeds, prompt_embeds_2).sample
|
195 |
-
|
196 |
-
loss = torch.nn.functional.mse_loss(model_pred, noise)
|
197 |
-
accelerator.backward(loss)
|
198 |
-
optimizer.step()
|
199 |
-
lr_scheduler.step()
|
200 |
-
optimizer.zero_grad()
|
201 |
-
|
202 |
-
progress_bar.update(1)
|
203 |
-
global_step += 1
|
204 |
-
|
205 |
-
if global_step % SAVE_STEPS == 0:
|
206 |
-
if accelerator.is_main_process:
|
207 |
-
print(f"Saving checkpoint at step {global_step}")
|
208 |
-
save_path = os.path.join(OUTPUT_DIR, f"{MODEL_NAME}_{global_step}")
|
209 |
-
accelerator.save_state(save_path)
|
210 |
-
# Save LoRA weights
|
211 |
-
unet_lora_layers.save_pretrained(os.path.join(save_path, "unet_lora"))
|
212 |
-
vae_lora_layers.save_pretrained(os.path.join(save_path, "vae_lora"))
|
213 |
-
text_encoder_lora_layers.save_pretrained(os.path.join(save_path, "text_encoder_lora"))
|
214 |
-
text_encoder_2_lora_layers.save_pretrained(os.path.join(save_path, "text_encoder_2_lora"))
|
215 |
-
|
216 |
-
# Save model card
|
217 |
-
model_card = f"""---
|
218 |
-
license: apache-2.0
|
219 |
-
language:
|
220 |
-
- en
|
221 |
-
base_model:
|
222 |
-
- {BASE_MODEL}
|
223 |
-
pipeline_tag: text-to-image
|
224 |
---
|
225 |
-
|
226 |
-
# Model Description: {MODEL_NAME}
|
227 |
-
|
228 |
-
This LoRa model enhances text-to-image generation with a hyperrealistic style focusing on a specific subject.
|
229 |
-
|
230 |
-
## Subject Description:
|
231 |
-
(Add detailed subject description here)
|
232 |
-
|
233 |
-
## Hyperrealistic Style: {True}
|
234 |
-
|
235 |
-
## Base Models:
|
236 |
-
This model was trained using the following base model:
|
237 |
-
- {BASE_MODEL}
|
238 |
-
|
239 |
-
|
240 |
-
## Usage Instructions:
|
241 |
-
(Add detailed instructions on how to use this LoRa model here. Include example prompts)
|
242 |
-
|
243 |
-
|
244 |
-
## Training Data:
|
245 |
-
(Add information on the training data here)
|
246 |
-
|
247 |
-
|
248 |
-
## Limitations:
|
249 |
-
(List any known limitations of the model)
|
250 |
-
|
251 |
-
## Bias and Fairness Considerations:
|
252 |
-
(Address potential bias in the model)
|
253 |
-
|
254 |
-
## Known Issues:
|
255 |
-
(List any known issues)
|
256 |
-
|
257 |
-
"""
|
258 |
-
with open(os.path.join(save_path, "model_card.txt"), "w") as f:
|
259 |
-
f.write(model_card)
|
260 |
-
|
261 |
-
# --- Save Final Model ---
|
262 |
-
if accelerator.is_main_process:
|
263 |
-
print("Saving final model")
|
264 |
-
save_path = os.path.join(OUTPUT_DIR, f"{MODEL_NAME}_final")
|
265 |
-
accelerator.save_state(save_path)
|
266 |
-
# Save LoRA weights
|
267 |
-
unet_lora_layers.save_pretrained(os.path.join(save_path, "unet_lora"))
|
268 |
-
vae_lora_layers.save_pretrained(os.path.join(save_path, "vae_lora"))
|
269 |
-
text_encoder_lora_layers.save_pretrained(os.path.join(save_path, "text_encoder_lora"))
|
270 |
-
text_encoder_2_lora_layers.save_pretrained(os.path.join(save_path, "text_encoder_2_lora"))
|
271 |
-
|
272 |
-
# Save model card
|
273 |
-
model_card = f"""---
|
274 |
license: apache-2.0
|
275 |
language:
|
276 |
- en
|
|
|
|
|
277 |
base_model:
|
278 |
-
-
|
|
|
|
|
|
|
|
|
|
|
279 |
pipeline_tag: text-to-image
|
280 |
-
---
|
281 |
-
|
282 |
-
# Model Description: {MODEL_NAME}
|
283 |
-
|
284 |
-
This LoRa model enhances text-to-image generation with a hyperrealistic style focusing on a specific subject.
|
285 |
-
|
286 |
-
## Subject Description:
|
287 |
-
(Add detailed subject description here)
|
288 |
-
|
289 |
-
## Hyperrealistic Style: {True}
|
290 |
-
|
291 |
-
## Base Models:
|
292 |
-
This model was trained using the following base model:
|
293 |
-
- {BASE_MODEL}
|
294 |
-
|
295 |
-
|
296 |
-
## Usage Instructions:
|
297 |
-
(Add detailed instructions on how to use this LoRa model here. Include example prompts)
|
298 |
-
|
299 |
-
|
300 |
-
## Training Data:
|
301 |
-
(Add information on the training data here)
|
302 |
-
|
303 |
-
|
304 |
-
## Limitations:
|
305 |
-
(List any known limitations of the model)
|
306 |
-
|
307 |
-
## Bias and Fairness Considerations:
|
308 |
-
(Address potential bias in the model)
|
309 |
-
|
310 |
-
## Known Issues:
|
311 |
-
(List any known issues)
|
312 |
-
|
313 |
-
"""
|
314 |
-
with open(os.path.join(save_path, "model_card.txt"), "w") as f:
|
315 |
-
f.write(model_card)
|
316 |
-
|
317 |
-
# --- Push to Hub ---
|
318 |
-
if PUSH_TO_HUB and accelerator.is_main_process:
|
319 |
-
print("Pushing to Hugging Face Hub")
|
320 |
-
repo_id = HUB_REPO_ID
|
321 |
-
repo_url = create_repo(repo_id, exist_ok=True, repo_type="model", token=HfFolder.get_token()).clone_url
|
322 |
-
upload_folder(repo_id=repo_id, folder_path=save_path, token=HfFolder.get_token())
|
323 |
-
print(f"Model pushed to {repo_url}")
|
324 |
-
|
325 |
-
# --- FluxLoraModel Class ---
|
326 |
-
class FluxLoraModel:
|
327 |
-
def __init__(self, model_path="your-model-path", device="cuda"):
|
328 |
-
self.device = device
|
329 |
-
self.model = StableDiffusionPipeline.from_pretrained(
|
330 |
-
model_path,
|
331 |
-
torch_dtype=torch.float16,
|
332 |
-
safety_checker=None,
|
333 |
-
requires_safety_checker=False
|
334 |
-
).to(device)
|
335 |
-
|
336 |
-
self.model.scheduler = DPMSolverMultistepScheduler.from_config(
|
337 |
-
self.model.scheduler.config,
|
338 |
-
algorithm_type="dpmsolver++",
|
339 |
-
solver_order=2
|
340 |
-
)
|
341 |
-
|
342 |
-
self.model.enable_attention_slicing()
|
343 |
-
self.model.enable_xformers_memory_efficient_attention()
|
344 |
-
|
345 |
-
self.quality_modifiers = {
|
346 |
-
'realism': ["hyperrealistic", "photorealistic", "ultra realistic", "octane render", "raw photo", "unedited", "photographic", "35mm film"],
|
347 |
-
'resolution': ["4K UHD", "8K resolution", "ultra high definition", "extremely detailed", "high resolution"],
|
348 |
-
'detail_level': ["ultra detailed", "fine details", "intricate details", "sharp focus", "highly detailed", "maximum detail"]
|
349 |
-
}
|
350 |
-
|
351 |
-
self.texture_modifiers = {
|
352 |
-
'skin_details': ["detailed skin texture", "natural skin pores", "realistic skin subsurface scattering", "fine skin details"],
|
353 |
-
'clothing_details': ["detailed fabric texture", "intricate fabric weave", "realistic cloth folds", "natural fabric wrinkles"],
|
354 |
-
'hair_details': {
|
355 |
-
'general_quality': ["ultra detailed hair strands", "photorealistic hair texture", "volumetric hair rendering"],
|
356 |
-
'hair_types': {
|
357 |
-
'straight': ["silky straight hair", "smooth hair texture"],
|
358 |
-
'wavy': ["natural wave pattern", "defined hair waves"],
|
359 |
-
'curly': ["detailed curl pattern", "natural curl definition"],
|
360 |
-
'coily': ["detailed coil pattern", "natural coil definition"]
|
361 |
-
}
|
362 |
-
},
|
363 |
-
'eye_details': {
|
364 |
-
'general_quality': ["ultra detailed iris", "photorealistic eyes", "8K eye details"],
|
365 |
-
'iris_details': ["detailed iris patterns", "intricate iris fibers"],
|
366 |
-
'eye_properties': {
|
367 |
-
'reflection': ["natural catchlights", "realistic eye reflections"],
|
368 |
-
'moisture': ["natural eye moisture", "subtle tear film"],
|
369 |
-
'depth': ["volumetric eye depth", "realistic eye socket depth"]
|
370 |
-
}
|
371 |
-
}
|
372 |
-
}
|
373 |
-
|
374 |
-
def enhance_prompt(self, base_prompt):
|
375 |
-
realism_mod = ", ".join(random.sample(self.quality_modifiers['realism'], 3))
|
376 |
-
resolution_mod = ", ".join(random.sample(self.quality_modifiers['resolution'], 2))
|
377 |
-
detail_mod = ", ".join(random.sample(self.quality_modifiers['detail_level'], 3))
|
378 |
-
|
379 |
-
enhanced_prompt = f"{base_prompt}, {realism_mod}, {resolution_mod}, {detail_mod}, masterpiece, professional photography"
|
380 |
-
|
381 |
-
if any(word in base_prompt.lower() for word in ["person", "portrait", "face"]):
|
382 |
-
skin_mod = ", ".join(random.sample(self.texture_modifiers['skin_details'], 2))
|
383 |
-
eye_mod = ", ".join(random.sample(self.texture_modifiers['eye_details']['general_quality'], 2))
|
384 |
-
enhanced_prompt = f"{enhanced_prompt}, {skin_mod}, {eye_mod}"
|
385 |
-
|
386 |
-
return enhanced_prompt
|
387 |
-
|
388 |
-
def generate_image(self, prompt, negative_prompt="", num_images=1, steps=50, cfg_scale=8.5, width=2048, height=2048, seed=None, output_dir="outputs"):
|
389 |
-
default_negative = "blur, haze, soft, deformed, low quality, low resolution, noise, grainy, bad details"
|
390 |
-
enhanced_negative_prompt = f"{negative_prompt}, {default_negative}"
|
391 |
-
enhanced_prompt = self.enhance_prompt(prompt)
|
392 |
-
|
393 |
-
if width >= 1024 or height >= 1024:
|
394 |
-
self.model.enable_vae_tiling()
|
395 |
-
|
396 |
-
generator = torch.Generator(device=self.device).manual_seed(seed) if seed else None
|
397 |
-
|
398 |
-
images = self.model(
|
399 |
-
prompt=enhanced_prompt,
|
400 |
-
negative_prompt=enhanced_negative_prompt,
|
401 |
-
num_images_per_prompt=num_images,
|
402 |
-
num_inference_steps=steps,
|
403 |
-
guidance_scale=cfg_scale,
|
404 |
-
width=width,
|
405 |
-
height=height,
|
406 |
-
generator=generator
|
407 |
-
).images
|
408 |
-
|
409 |
-
processed_images = []
|
410 |
-
for img in images:
|
411 |
-
img = img.filter(ImageEnhance.Sharpness(1.2))
|
412 |
-
img = img.filter(ImageEnhance.Contrast(1.1))
|
413 |
-
processed_images.append(img)
|
414 |
-
|
415 |
-
os.makedirs(output_dir, exist_ok=True)
|
416 |
-
saved_paths = []
|
417 |
-
for i, image in enumerate(processed_images):
|
418 |
-
path = os.path.join(output_dir, f"flux_4k_detailed_{i}.png")
|
419 |
-
image.save(path, "PNG", quality=100, optimize=True)
|
420 |
-
saved_paths.append(path)
|
421 |
-
|
422 |
-
return processed_images, saved_paths
|
423 |
-
|
424 |
-
def generate_4k_portrait(self, prompt, **kwargs):
|
425 |
-
return self.generate_image(prompt=prompt, width=3840, height=2160, steps=60, cfg_scale=9.0, **kwargs)
|
426 |
-
|
427 |
-
@staticmethod
|
428 |
-
def image_grid(imgs, rows, cols):
|
429 |
-
w, h = imgs[0].size
|
430 |
-
grid = Image.new('RGB', size=(cols * w, rows * h))
|
431 |
-
for i, img in enumerate(imgs):
|
432 |
-
grid.paste(img, box=(i % cols * w, i // cols * h))
|
433 |
-
return grid
|
434 |
-
|
435 |
-
def batch_generate(prompts_file, output_dir="batch_outputs", batch_size=4, **kwargs):
|
436 |
-
model = FluxLoraModel()
|
437 |
-
with open(prompts_file, 'r') as f:
|
438 |
-
prompts = json.load(f)
|
439 |
-
|
440 |
-
output_dir = Path(output_dir)
|
441 |
-
output_dir.mkdir(parents=True, exist_ok=True)
|
442 |
-
|
443 |
-
for i, prompt in enumerate(prompts):
|
444 |
-
try:
|
445 |
-
images, paths = model.generate_image(
|
446 |
-
prompt=prompt,
|
447 |
-
num_images=batch_size,
|
448 |
-
output_dir=str(output_dir / f"prompt_{i}"),
|
449 |
-
**kwargs
|
450 |
-
)
|
451 |
-
grid = model.image_grid(images, rows=batch_size//2, cols=2)
|
452 |
-
grid.save(output_dir / f"prompt_{i}_grid.png")
|
453 |
-
except Exception as e:
|
454 |
-
print(f"Error processing prompt {i}: {str(e)}")
|
455 |
-
|
456 |
-
# --- Generation after training ---
|
457 |
-
if GENERATE_AFTER_TRAINING and accelerator.is_main_process:
|
458 |
-
print("Generating images after training...")
|
459 |
-
# Load the trained LoRA model
|
460 |
-
model = FluxLoraModel(model_path=os.path.join(OUTPUT_DIR, f"{MODEL_NAME}_final"))
|
461 |
-
batch_generate(PROMPTS_FILE, output_dir="generated_images", batch_size=BATCH_SIZE_GENERATE)
|
462 |
-
|
463 |
-
if __name__ == "__main__":
|
464 |
-
parser = argparse.ArgumentParser()
|
465 |
-
parser.add_argument("--prompts", type=str, required=False)
|
466 |
-
parser.add_argument("--output", type=str, default="outputs")
|
467 |
-
parser.add_argument("--batch-size", type=int, default=4)
|
468 |
-
args = parser.parse_args()
|
469 |
-
|
470 |
-
if args.prompts:
|
471 |
-
batch_generate(args.prompts, args.output, args.batch_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
license: apache-2.0
|
3 |
language:
|
4 |
- en
|
5 |
+
- es
|
6 |
+
- fr
|
7 |
base_model:
|
8 |
+
- kudzueye/Boreal
|
9 |
+
- adirik/flux-cinestill
|
10 |
+
- Shakker-Labs/FLUX.1-dev-LoRA-add-details
|
11 |
+
- prithivMLmods/Flux-Realism-FineDetailed
|
12 |
+
- prithivMLmods/Canopus-LoRA-Flux-UltraRealism-2.0
|
13 |
+
- Jovie/Midjourney
|
14 |
pipeline_tag: text-to-image
|
15 |
+
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|