Update README.md
Browse files
README.md
CHANGED
@@ -1,13 +1,471 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
license: apache-2.0
|
|
|
|
|
3 |
base_model:
|
4 |
-
-
|
5 |
-
- prithivMLmods/Flux-Realism-FineDetailed
|
6 |
-
- prithivMLmods/Canopus-LoRA-Flux-UltraRealism-2.0
|
7 |
-
- kudzueye/boreal-flux-dev-v2
|
8 |
-
- kudzueye/Boreal
|
9 |
-
- adirik/flux-cinestill
|
10 |
-
- Schaffsch/ostris_flux-dev-lora-trainer
|
11 |
-
- Jovie/Midjourney
|
12 |
pipeline_tag: text-to-image
|
13 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# main.py
|
2 |
+
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
|
3 |
+
import torch
|
4 |
+
from PIL import Image, ImageEnhance
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
import json
|
8 |
+
import argparse
|
9 |
+
from pathlib import Path
|
10 |
+
from torch.utils.data import Dataset, DataLoader
|
11 |
+
from torchvision import transforms
|
12 |
+
from accelerate import Accelerator
|
13 |
+
from diffusers import UNet2DConditionModel, AutoencoderKL
|
14 |
+
from diffusers.training_utils import EMAModel
|
15 |
+
from diffusers import LoraLoaderMixin
|
16 |
+
from diffusers.optimization import get_scheduler
|
17 |
+
from torch.optim import AdamW
|
18 |
+
import math
|
19 |
+
import numpy as np
|
20 |
+
from huggingface_hub import create_repo, upload_folder
|
21 |
+
from huggingface_hub import HfFolder
|
22 |
+
from tqdm.auto import tqdm
|
23 |
+
|
24 |
+
# --- Configuration ---
|
25 |
+
MODEL_NAME = "photo-fluxXL"
|
26 |
+
BASE_MODEL = "kudzueye/Boreal" # Choose one of your base models
|
27 |
+
IMAGE_FOLDER = "/content/drive/MyDrive/training_data" # Replace with your image folder path
|
28 |
+
OUTPUT_DIR = "/content/drive/MyDrive/my_lora_models" # Replace with your output folder path
|
29 |
+
TRAIN_BATCH_SIZE = 1
|
30 |
+
GRADIENT_ACCUMULATION_STEPS = 4
|
31 |
+
LEARNING_RATE = 1e-4
|
32 |
+
NUM_EPOCHS = 10
|
33 |
+
SAVE_STEPS = 500
|
34 |
+
SEED = 42
|
35 |
+
PUSH_TO_HUB = False # Set to True if you want to push to Hugging Face Hub
|
36 |
+
HUB_REPO_ID = "your-username/your-repo-name" # Replace with your Hugging Face repo ID
|
37 |
+
GENERATE_AFTER_TRAINING = True # Set to True to generate images after training
|
38 |
+
PROMPTS_FILE = "/content/drive/MyDrive/prompts.json" # Replace with your prompts file path
|
39 |
+
BATCH_SIZE_GENERATE = 4 # Batch size for generation
|
40 |
+
|
41 |
+
# --- Load Base Model ---
|
42 |
+
pipe = StableDiffusionPipeline.from_pretrained(
|
43 |
+
BASE_MODEL,
|
44 |
+
torch_dtype=torch.float16,
|
45 |
+
safety_checker=None,
|
46 |
+
requires_safety_checker=False,
|
47 |
+
variant="fp16",
|
48 |
+
use_safetensors=True
|
49 |
+
).to("cuda")
|
50 |
+
|
51 |
+
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
|
52 |
+
pipe.scheduler.config,
|
53 |
+
algorithm_type="dpmsolver++",
|
54 |
+
solver_order=2
|
55 |
+
)
|
56 |
+
|
57 |
+
pipe.enable_attention_slicing()
|
58 |
+
pipe.enable_xformers_memory_efficient_attention()
|
59 |
+
|
60 |
+
# --- Prepare Dataset ---
|
61 |
+
class FluxDataset(Dataset):
|
62 |
+
def __init__(self, image_folder, transform=None):
|
63 |
+
self.image_paths = {}
|
64 |
+
for category in os.listdir(image_folder):
|
65 |
+
category_path = os.path.join(image_folder, category)
|
66 |
+
if os.path.isdir(category_path):
|
67 |
+
self.image_paths[category] = []
|
68 |
+
for subcategory in os.listdir(category_path):
|
69 |
+
subcategory_path = os.path.join(category_path, subcategory)
|
70 |
+
if os.path.isdir(subcategory_path):
|
71 |
+
self.image_paths[category].extend([os.path.join(subcategory_path, f) for f in os.listdir(subcategory_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
|
72 |
+
elif subcategory.lower().endswith(('.png', '.jpg', '.jpeg')):
|
73 |
+
self.image_paths[category].append(subcategory_path)
|
74 |
+
|
75 |
+
self.transform = transform
|
76 |
+
self.categories = list(self.image_paths.keys())
|
77 |
+
|
78 |
+
def __len__(self):
|
79 |
+
return max(len(paths) for paths in self.image_paths.values())
|
80 |
+
|
81 |
+
def __getitem__(self, idx):
|
82 |
+
item = {}
|
83 |
+
for category in self.categories:
|
84 |
+
if self.image_paths[category]:
|
85 |
+
image_path = self.image_paths[category][idx % len(self.image_paths[category])]
|
86 |
+
image = Image.open(image_path).convert("RGB")
|
87 |
+
if self.transform:
|
88 |
+
image = self.transform(image)
|
89 |
+
item[category] = image
|
90 |
+
return item
|
91 |
+
|
92 |
+
transform = transforms.Compose([
|
93 |
+
transforms.Resize((1024, 1024)),
|
94 |
+
transforms.ToTensor(),
|
95 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
96 |
+
])
|
97 |
+
|
98 |
+
dataset = FluxDataset(IMAGE_FOLDER, transform=transform)
|
99 |
+
|
100 |
+
# --- Data Loader ---
|
101 |
+
dataloader = DataLoader(dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True)
|
102 |
+
|
103 |
+
# --- Prepare LoRA ---
|
104 |
+
unet = pipe.unet
|
105 |
+
vae = pipe.vae
|
106 |
+
text_encoder = pipe.text_encoder
|
107 |
+
text_encoder_2 = pipe.text_encoder_2
|
108 |
+
|
109 |
+
unet_lora_layers = LoraLoaderMixin.get_lora_layers(unet)
|
110 |
+
vae_lora_layers = LoraLoaderMixin.get_lora_layers(vae)
|
111 |
+
text_encoder_lora_layers = LoraLoaderMixin.get_lora_layers(text_encoder)
|
112 |
+
text_encoder_2_lora_layers = LoraLoaderMixin.get_lora_layers(text_encoder_2)
|
113 |
+
|
114 |
+
# --- Optimizer ---
|
115 |
+
optimizer = AdamW(
|
116 |
+
[
|
117 |
+
{"params": unet_lora_layers.parameters(), "lr": LEARNING_RATE},
|
118 |
+
{"params": vae_lora_layers.parameters(), "lr": LEARNING_RATE},
|
119 |
+
{"params": text_encoder_lora_layers.parameters(), "lr": LEARNING_RATE},
|
120 |
+
{"params": text_encoder_2_lora_layers.parameters(), "lr": LEARNING_RATE},
|
121 |
+
]
|
122 |
+
)
|
123 |
+
|
124 |
+
# --- Scheduler ---
|
125 |
+
lr_scheduler = get_scheduler(
|
126 |
+
"cosine",
|
127 |
+
optimizer=optimizer,
|
128 |
+
num_warmup_steps=math.ceil(len(dataloader) * NUM_EPOCHS * 0.1),
|
129 |
+
num_training_steps=len(dataloader) * NUM_EPOCHS,
|
130 |
+
)
|
131 |
+
|
132 |
+
# --- Accelerator ---
|
133 |
+
accelerator = Accelerator(gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, mixed_precision="fp16")
|
134 |
+
|
135 |
+
unet, vae, text_encoder, text_encoder_2, optimizer, dataloader, lr_scheduler = accelerator.prepare(
|
136 |
+
unet, vae, text_encoder, text_encoder_2, optimizer, dataloader, lr_scheduler
|
137 |
+
)
|
138 |
+
|
139 |
+
# --- Training Loop ---
|
140 |
+
progress_bar = tqdm(range(len(dataloader) * NUM_EPOCHS), desc="Training")
|
141 |
+
global_step = 0
|
142 |
+
|
143 |
+
for epoch in range(NUM_EPOCHS):
|
144 |
+
for batch in dataloader:
|
145 |
+
with accelerator.accumulate(unet, vae, text_encoder, text_encoder_2):
|
146 |
+
latents = vae.encode(batch["body"].to(accelerator.device)).latent_dist.sample()
|
147 |
+
noise = torch.randn_like(latents)
|
148 |
+
timesteps = torch.randint(0, pipe.scheduler.num_train_timesteps, (latents.shape[0],), device=accelerator.device)
|
149 |
+
noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)
|
150 |
+
|
151 |
+
# --- Generate Prompt ---
|
152 |
+
prompt_parts = []
|
153 |
+
if "body" in batch:
|
154 |
+
if "belly" in batch:
|
155 |
+
prompt_parts.append("belly visible")
|
156 |
+
if "body_shape_1" in batch:
|
157 |
+
prompt_parts.append("body shape 1")
|
158 |
+
if "body_shape_2" in batch:
|
159 |
+
prompt_parts.append("body shape 2")
|
160 |
+
if "body_type" in batch:
|
161 |
+
prompt_parts.append("body type")
|
162 |
+
if "body measurements-proportion" in batch:
|
163 |
+
prompt_parts.append("body measurements-proportion")
|
164 |
+
if "details" in batch:
|
165 |
+
if "eyebrows" in batch:
|
166 |
+
prompt_parts.append("eyebrows")
|
167 |
+
if "eyelashes_1" in batch:
|
168 |
+
prompt_parts.append("eyelashes 1")
|
169 |
+
if "eyelashes_2" in batch:
|
170 |
+
prompt_parts.append("eyelashes 2")
|
171 |
+
if "hair" in batch:
|
172 |
+
prompt_parts.append("hair")
|
173 |
+
if "lips" in batch:
|
174 |
+
prompt_parts.append("lips")
|
175 |
+
if "face" in batch:
|
176 |
+
for i in range(1, 18):
|
177 |
+
if f"face_{i}" in batch:
|
178 |
+
prompt_parts.append(f"face {i}")
|
179 |
+
if "pose" in batch:
|
180 |
+
for i in range(1, 4):
|
181 |
+
if f"pose_{i}" in batch:
|
182 |
+
prompt_parts.append(f"pose {i}")
|
183 |
+
if "skin" in batch:
|
184 |
+
if "skin_tone" in batch:
|
185 |
+
prompt_parts.append("skin tone")
|
186 |
+
if "textures" in batch:
|
187 |
+
for texture in batch["textures"]:
|
188 |
+
prompt_parts.append(f"texture {texture}")
|
189 |
+
|
190 |
+
prompt = "a photo of a woman, " + ", ".join(prompt_parts)
|
191 |
+
prompt_embeds = pipe.text_encoder(pipe.tokenizer(prompt, return_tensors="pt").input_ids.to(accelerator.device))[0]
|
192 |
+
prompt_embeds_2 = pipe.text_encoder_2(pipe.tokenizer(prompt, return_tensors="pt").input_ids.to(accelerator.device))[0]
|
193 |
+
|
194 |
+
model_pred = unet(noisy_latents, timesteps, prompt_embeds, prompt_embeds_2).sample
|
195 |
+
|
196 |
+
loss = torch.nn.functional.mse_loss(model_pred, noise)
|
197 |
+
accelerator.backward(loss)
|
198 |
+
optimizer.step()
|
199 |
+
lr_scheduler.step()
|
200 |
+
optimizer.zero_grad()
|
201 |
+
|
202 |
+
progress_bar.update(1)
|
203 |
+
global_step += 1
|
204 |
+
|
205 |
+
if global_step % SAVE_STEPS == 0:
|
206 |
+
if accelerator.is_main_process:
|
207 |
+
print(f"Saving checkpoint at step {global_step}")
|
208 |
+
save_path = os.path.join(OUTPUT_DIR, f"{MODEL_NAME}_{global_step}")
|
209 |
+
accelerator.save_state(save_path)
|
210 |
+
# Save LoRA weights
|
211 |
+
unet_lora_layers.save_pretrained(os.path.join(save_path, "unet_lora"))
|
212 |
+
vae_lora_layers.save_pretrained(os.path.join(save_path, "vae_lora"))
|
213 |
+
text_encoder_lora_layers.save_pretrained(os.path.join(save_path, "text_encoder_lora"))
|
214 |
+
text_encoder_2_lora_layers.save_pretrained(os.path.join(save_path, "text_encoder_2_lora"))
|
215 |
+
|
216 |
+
# Save model card
|
217 |
+
model_card = f"""---
|
218 |
+
license: apache-2.0
|
219 |
+
language:
|
220 |
+
- en
|
221 |
+
base_model:
|
222 |
+
- {BASE_MODEL}
|
223 |
+
pipeline_tag: text-to-image
|
224 |
---
|
225 |
+
|
226 |
+
# Model Description: {MODEL_NAME}
|
227 |
+
|
228 |
+
This LoRa model enhances text-to-image generation with a hyperrealistic style focusing on a specific subject.
|
229 |
+
|
230 |
+
## Subject Description:
|
231 |
+
(Add detailed subject description here)
|
232 |
+
|
233 |
+
## Hyperrealistic Style: {True}
|
234 |
+
|
235 |
+
## Base Models:
|
236 |
+
This model was trained using the following base model:
|
237 |
+
- {BASE_MODEL}
|
238 |
+
|
239 |
+
|
240 |
+
## Usage Instructions:
|
241 |
+
(Add detailed instructions on how to use this LoRa model here. Include example prompts)
|
242 |
+
|
243 |
+
|
244 |
+
## Training Data:
|
245 |
+
(Add information on the training data here)
|
246 |
+
|
247 |
+
|
248 |
+
## Limitations:
|
249 |
+
(List any known limitations of the model)
|
250 |
+
|
251 |
+
## Bias and Fairness Considerations:
|
252 |
+
(Address potential bias in the model)
|
253 |
+
|
254 |
+
## Known Issues:
|
255 |
+
(List any known issues)
|
256 |
+
|
257 |
+
"""
|
258 |
+
with open(os.path.join(save_path, "model_card.txt"), "w") as f:
|
259 |
+
f.write(model_card)
|
260 |
+
|
261 |
+
# --- Save Final Model ---
|
262 |
+
if accelerator.is_main_process:
|
263 |
+
print("Saving final model")
|
264 |
+
save_path = os.path.join(OUTPUT_DIR, f"{MODEL_NAME}_final")
|
265 |
+
accelerator.save_state(save_path)
|
266 |
+
# Save LoRA weights
|
267 |
+
unet_lora_layers.save_pretrained(os.path.join(save_path, "unet_lora"))
|
268 |
+
vae_lora_layers.save_pretrained(os.path.join(save_path, "vae_lora"))
|
269 |
+
text_encoder_lora_layers.save_pretrained(os.path.join(save_path, "text_encoder_lora"))
|
270 |
+
text_encoder_2_lora_layers.save_pretrained(os.path.join(save_path, "text_encoder_2_lora"))
|
271 |
+
|
272 |
+
# Save model card
|
273 |
+
model_card = f"""---
|
274 |
license: apache-2.0
|
275 |
+
language:
|
276 |
+
- en
|
277 |
base_model:
|
278 |
+
- {BASE_MODEL}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
279 |
pipeline_tag: text-to-image
|
280 |
+
---
|
281 |
+
|
282 |
+
# Model Description: {MODEL_NAME}
|
283 |
+
|
284 |
+
This LoRa model enhances text-to-image generation with a hyperrealistic style focusing on a specific subject.
|
285 |
+
|
286 |
+
## Subject Description:
|
287 |
+
(Add detailed subject description here)
|
288 |
+
|
289 |
+
## Hyperrealistic Style: {True}
|
290 |
+
|
291 |
+
## Base Models:
|
292 |
+
This model was trained using the following base model:
|
293 |
+
- {BASE_MODEL}
|
294 |
+
|
295 |
+
|
296 |
+
## Usage Instructions:
|
297 |
+
(Add detailed instructions on how to use this LoRa model here. Include example prompts)
|
298 |
+
|
299 |
+
|
300 |
+
## Training Data:
|
301 |
+
(Add information on the training data here)
|
302 |
+
|
303 |
+
|
304 |
+
## Limitations:
|
305 |
+
(List any known limitations of the model)
|
306 |
+
|
307 |
+
## Bias and Fairness Considerations:
|
308 |
+
(Address potential bias in the model)
|
309 |
+
|
310 |
+
## Known Issues:
|
311 |
+
(List any known issues)
|
312 |
+
|
313 |
+
"""
|
314 |
+
with open(os.path.join(save_path, "model_card.txt"), "w") as f:
|
315 |
+
f.write(model_card)
|
316 |
+
|
317 |
+
# --- Push to Hub ---
|
318 |
+
if PUSH_TO_HUB and accelerator.is_main_process:
|
319 |
+
print("Pushing to Hugging Face Hub")
|
320 |
+
repo_id = HUB_REPO_ID
|
321 |
+
repo_url = create_repo(repo_id, exist_ok=True, repo_type="model", token=HfFolder.get_token()).clone_url
|
322 |
+
upload_folder(repo_id=repo_id, folder_path=save_path, token=HfFolder.get_token())
|
323 |
+
print(f"Model pushed to {repo_url}")
|
324 |
+
|
325 |
+
# --- FluxLoraModel Class ---
|
326 |
+
class FluxLoraModel:
|
327 |
+
def __init__(self, model_path="your-model-path", device="cuda"):
|
328 |
+
self.device = device
|
329 |
+
self.model = StableDiffusionPipeline.from_pretrained(
|
330 |
+
model_path,
|
331 |
+
torch_dtype=torch.float16,
|
332 |
+
safety_checker=None,
|
333 |
+
requires_safety_checker=False
|
334 |
+
).to(device)
|
335 |
+
|
336 |
+
self.model.scheduler = DPMSolverMultistepScheduler.from_config(
|
337 |
+
self.model.scheduler.config,
|
338 |
+
algorithm_type="dpmsolver++",
|
339 |
+
solver_order=2
|
340 |
+
)
|
341 |
+
|
342 |
+
self.model.enable_attention_slicing()
|
343 |
+
self.model.enable_xformers_memory_efficient_attention()
|
344 |
+
|
345 |
+
self.quality_modifiers = {
|
346 |
+
'realism': ["hyperrealistic", "photorealistic", "ultra realistic", "octane render", "raw photo", "unedited", "photographic", "35mm film"],
|
347 |
+
'resolution': ["4K UHD", "8K resolution", "ultra high definition", "extremely detailed", "high resolution"],
|
348 |
+
'detail_level': ["ultra detailed", "fine details", "intricate details", "sharp focus", "highly detailed", "maximum detail"]
|
349 |
+
}
|
350 |
+
|
351 |
+
self.texture_modifiers = {
|
352 |
+
'skin_details': ["detailed skin texture", "natural skin pores", "realistic skin subsurface scattering", "fine skin details"],
|
353 |
+
'clothing_details': ["detailed fabric texture", "intricate fabric weave", "realistic cloth folds", "natural fabric wrinkles"],
|
354 |
+
'hair_details': {
|
355 |
+
'general_quality': ["ultra detailed hair strands", "photorealistic hair texture", "volumetric hair rendering"],
|
356 |
+
'hair_types': {
|
357 |
+
'straight': ["silky straight hair", "smooth hair texture"],
|
358 |
+
'wavy': ["natural wave pattern", "defined hair waves"],
|
359 |
+
'curly': ["detailed curl pattern", "natural curl definition"],
|
360 |
+
'coily': ["detailed coil pattern", "natural coil definition"]
|
361 |
+
}
|
362 |
+
},
|
363 |
+
'eye_details': {
|
364 |
+
'general_quality': ["ultra detailed iris", "photorealistic eyes", "8K eye details"],
|
365 |
+
'iris_details': ["detailed iris patterns", "intricate iris fibers"],
|
366 |
+
'eye_properties': {
|
367 |
+
'reflection': ["natural catchlights", "realistic eye reflections"],
|
368 |
+
'moisture': ["natural eye moisture", "subtle tear film"],
|
369 |
+
'depth': ["volumetric eye depth", "realistic eye socket depth"]
|
370 |
+
}
|
371 |
+
}
|
372 |
+
}
|
373 |
+
|
374 |
+
def enhance_prompt(self, base_prompt):
|
375 |
+
realism_mod = ", ".join(random.sample(self.quality_modifiers['realism'], 3))
|
376 |
+
resolution_mod = ", ".join(random.sample(self.quality_modifiers['resolution'], 2))
|
377 |
+
detail_mod = ", ".join(random.sample(self.quality_modifiers['detail_level'], 3))
|
378 |
+
|
379 |
+
enhanced_prompt = f"{base_prompt}, {realism_mod}, {resolution_mod}, {detail_mod}, masterpiece, professional photography"
|
380 |
+
|
381 |
+
if any(word in base_prompt.lower() for word in ["person", "portrait", "face"]):
|
382 |
+
skin_mod = ", ".join(random.sample(self.texture_modifiers['skin_details'], 2))
|
383 |
+
eye_mod = ", ".join(random.sample(self.texture_modifiers['eye_details']['general_quality'], 2))
|
384 |
+
enhanced_prompt = f"{enhanced_prompt}, {skin_mod}, {eye_mod}"
|
385 |
+
|
386 |
+
return enhanced_prompt
|
387 |
+
|
388 |
+
def generate_image(self, prompt, negative_prompt="", num_images=1, steps=50, cfg_scale=8.5, width=2048, height=2048, seed=None, output_dir="outputs"):
|
389 |
+
default_negative = "blur, haze, soft, deformed, low quality, low resolution, noise, grainy, bad details"
|
390 |
+
enhanced_negative_prompt = f"{negative_prompt}, {default_negative}"
|
391 |
+
enhanced_prompt = self.enhance_prompt(prompt)
|
392 |
+
|
393 |
+
if width >= 1024 or height >= 1024:
|
394 |
+
self.model.enable_vae_tiling()
|
395 |
+
|
396 |
+
generator = torch.Generator(device=self.device).manual_seed(seed) if seed else None
|
397 |
+
|
398 |
+
images = self.model(
|
399 |
+
prompt=enhanced_prompt,
|
400 |
+
negative_prompt=enhanced_negative_prompt,
|
401 |
+
num_images_per_prompt=num_images,
|
402 |
+
num_inference_steps=steps,
|
403 |
+
guidance_scale=cfg_scale,
|
404 |
+
width=width,
|
405 |
+
height=height,
|
406 |
+
generator=generator
|
407 |
+
).images
|
408 |
+
|
409 |
+
processed_images = []
|
410 |
+
for img in images:
|
411 |
+
img = img.filter(ImageEnhance.Sharpness(1.2))
|
412 |
+
img = img.filter(ImageEnhance.Contrast(1.1))
|
413 |
+
processed_images.append(img)
|
414 |
+
|
415 |
+
os.makedirs(output_dir, exist_ok=True)
|
416 |
+
saved_paths = []
|
417 |
+
for i, image in enumerate(processed_images):
|
418 |
+
path = os.path.join(output_dir, f"flux_4k_detailed_{i}.png")
|
419 |
+
image.save(path, "PNG", quality=100, optimize=True)
|
420 |
+
saved_paths.append(path)
|
421 |
+
|
422 |
+
return processed_images, saved_paths
|
423 |
+
|
424 |
+
def generate_4k_portrait(self, prompt, **kwargs):
|
425 |
+
return self.generate_image(prompt=prompt, width=3840, height=2160, steps=60, cfg_scale=9.0, **kwargs)
|
426 |
+
|
427 |
+
@staticmethod
|
428 |
+
def image_grid(imgs, rows, cols):
|
429 |
+
w, h = imgs[0].size
|
430 |
+
grid = Image.new('RGB', size=(cols * w, rows * h))
|
431 |
+
for i, img in enumerate(imgs):
|
432 |
+
grid.paste(img, box=(i % cols * w, i // cols * h))
|
433 |
+
return grid
|
434 |
+
|
435 |
+
def batch_generate(prompts_file, output_dir="batch_outputs", batch_size=4, **kwargs):
|
436 |
+
model = FluxLoraModel()
|
437 |
+
with open(prompts_file, 'r') as f:
|
438 |
+
prompts = json.load(f)
|
439 |
+
|
440 |
+
output_dir = Path(output_dir)
|
441 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
442 |
+
|
443 |
+
for i, prompt in enumerate(prompts):
|
444 |
+
try:
|
445 |
+
images, paths = model.generate_image(
|
446 |
+
prompt=prompt,
|
447 |
+
num_images=batch_size,
|
448 |
+
output_dir=str(output_dir / f"prompt_{i}"),
|
449 |
+
**kwargs
|
450 |
+
)
|
451 |
+
grid = model.image_grid(images, rows=batch_size//2, cols=2)
|
452 |
+
grid.save(output_dir / f"prompt_{i}_grid.png")
|
453 |
+
except Exception as e:
|
454 |
+
print(f"Error processing prompt {i}: {str(e)}")
|
455 |
+
|
456 |
+
# --- Generation after training ---
|
457 |
+
if GENERATE_AFTER_TRAINING and accelerator.is_main_process:
|
458 |
+
print("Generating images after training...")
|
459 |
+
# Load the trained LoRA model
|
460 |
+
model = FluxLoraModel(model_path=os.path.join(OUTPUT_DIR, f"{MODEL_NAME}_final"))
|
461 |
+
batch_generate(PROMPTS_FILE, output_dir="generated_images", batch_size=BATCH_SIZE_GENERATE)
|
462 |
+
|
463 |
+
if __name__ == "__main__":
|
464 |
+
parser = argparse.ArgumentParser()
|
465 |
+
parser.add_argument("--prompts", type=str, required=False)
|
466 |
+
parser.add_argument("--output", type=str, default="outputs")
|
467 |
+
parser.add_argument("--batch-size", type=int, default=4)
|
468 |
+
args = parser.parse_args()
|
469 |
+
|
470 |
+
if args.prompts:
|
471 |
+
batch_generate(args.prompts, args.output, args.batch_size)
|