gjP798uy commited on
Commit
6126d4f
·
verified ·
1 Parent(s): 6bfbcd0

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +9 -465
README.md CHANGED
@@ -1,471 +1,15 @@
1
- # main.py
2
- from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
3
- import torch
4
- from PIL import Image, ImageEnhance
5
- import os
6
- import random
7
- import json
8
- import argparse
9
- from pathlib import Path
10
- from torch.utils.data import Dataset, DataLoader
11
- from torchvision import transforms
12
- from accelerate import Accelerator
13
- from diffusers import UNet2DConditionModel, AutoencoderKL
14
- from diffusers.training_utils import EMAModel
15
- from diffusers import LoraLoaderMixin
16
- from diffusers.optimization import get_scheduler
17
- from torch.optim import AdamW
18
- import math
19
- import numpy as np
20
- from huggingface_hub import create_repo, upload_folder
21
- from huggingface_hub import HfFolder
22
- from tqdm.auto import tqdm
23
-
24
- # --- Configuration ---
25
- MODEL_NAME = "photo-fluxXL"
26
- BASE_MODEL = "kudzueye/Boreal" # Choose one of your base models
27
- IMAGE_FOLDER = "/content/drive/MyDrive/training_data" # Replace with your image folder path
28
- OUTPUT_DIR = "/content/drive/MyDrive/my_lora_models" # Replace with your output folder path
29
- TRAIN_BATCH_SIZE = 1
30
- GRADIENT_ACCUMULATION_STEPS = 4
31
- LEARNING_RATE = 1e-4
32
- NUM_EPOCHS = 10
33
- SAVE_STEPS = 500
34
- SEED = 42
35
- PUSH_TO_HUB = False # Set to True if you want to push to Hugging Face Hub
36
- HUB_REPO_ID = "your-username/your-repo-name" # Replace with your Hugging Face repo ID
37
- GENERATE_AFTER_TRAINING = True # Set to True to generate images after training
38
- PROMPTS_FILE = "/content/drive/MyDrive/prompts.json" # Replace with your prompts file path
39
- BATCH_SIZE_GENERATE = 4 # Batch size for generation
40
-
41
- # --- Load Base Model ---
42
- pipe = StableDiffusionPipeline.from_pretrained(
43
- BASE_MODEL,
44
- torch_dtype=torch.float16,
45
- safety_checker=None,
46
- requires_safety_checker=False,
47
- variant="fp16",
48
- use_safetensors=True
49
- ).to("cuda")
50
-
51
- pipe.scheduler = DPMSolverMultistepScheduler.from_config(
52
- pipe.scheduler.config,
53
- algorithm_type="dpmsolver++",
54
- solver_order=2
55
- )
56
-
57
- pipe.enable_attention_slicing()
58
- pipe.enable_xformers_memory_efficient_attention()
59
-
60
- # --- Prepare Dataset ---
61
- class FluxDataset(Dataset):
62
- def __init__(self, image_folder, transform=None):
63
- self.image_paths = {}
64
- for category in os.listdir(image_folder):
65
- category_path = os.path.join(image_folder, category)
66
- if os.path.isdir(category_path):
67
- self.image_paths[category] = []
68
- for subcategory in os.listdir(category_path):
69
- subcategory_path = os.path.join(category_path, subcategory)
70
- if os.path.isdir(subcategory_path):
71
- self.image_paths[category].extend([os.path.join(subcategory_path, f) for f in os.listdir(subcategory_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
72
- elif subcategory.lower().endswith(('.png', '.jpg', '.jpeg')):
73
- self.image_paths[category].append(subcategory_path)
74
-
75
- self.transform = transform
76
- self.categories = list(self.image_paths.keys())
77
-
78
- def __len__(self):
79
- return max(len(paths) for paths in self.image_paths.values())
80
-
81
- def __getitem__(self, idx):
82
- item = {}
83
- for category in self.categories:
84
- if self.image_paths[category]:
85
- image_path = self.image_paths[category][idx % len(self.image_paths[category])]
86
- image = Image.open(image_path).convert("RGB")
87
- if self.transform:
88
- image = self.transform(image)
89
- item[category] = image
90
- return item
91
-
92
- transform = transforms.Compose([
93
- transforms.Resize((1024, 1024)),
94
- transforms.ToTensor(),
95
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
96
- ])
97
-
98
- dataset = FluxDataset(IMAGE_FOLDER, transform=transform)
99
-
100
- # --- Data Loader ---
101
- dataloader = DataLoader(dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True)
102
-
103
- # --- Prepare LoRA ---
104
- unet = pipe.unet
105
- vae = pipe.vae
106
- text_encoder = pipe.text_encoder
107
- text_encoder_2 = pipe.text_encoder_2
108
-
109
- unet_lora_layers = LoraLoaderMixin.get_lora_layers(unet)
110
- vae_lora_layers = LoraLoaderMixin.get_lora_layers(vae)
111
- text_encoder_lora_layers = LoraLoaderMixin.get_lora_layers(text_encoder)
112
- text_encoder_2_lora_layers = LoraLoaderMixin.get_lora_layers(text_encoder_2)
113
-
114
- # --- Optimizer ---
115
- optimizer = AdamW(
116
- [
117
- {"params": unet_lora_layers.parameters(), "lr": LEARNING_RATE},
118
- {"params": vae_lora_layers.parameters(), "lr": LEARNING_RATE},
119
- {"params": text_encoder_lora_layers.parameters(), "lr": LEARNING_RATE},
120
- {"params": text_encoder_2_lora_layers.parameters(), "lr": LEARNING_RATE},
121
- ]
122
- )
123
-
124
- # --- Scheduler ---
125
- lr_scheduler = get_scheduler(
126
- "cosine",
127
- optimizer=optimizer,
128
- num_warmup_steps=math.ceil(len(dataloader) * NUM_EPOCHS * 0.1),
129
- num_training_steps=len(dataloader) * NUM_EPOCHS,
130
- )
131
-
132
- # --- Accelerator ---
133
- accelerator = Accelerator(gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, mixed_precision="fp16")
134
-
135
- unet, vae, text_encoder, text_encoder_2, optimizer, dataloader, lr_scheduler = accelerator.prepare(
136
- unet, vae, text_encoder, text_encoder_2, optimizer, dataloader, lr_scheduler
137
- )
138
-
139
- # --- Training Loop ---
140
- progress_bar = tqdm(range(len(dataloader) * NUM_EPOCHS), desc="Training")
141
- global_step = 0
142
-
143
- for epoch in range(NUM_EPOCHS):
144
- for batch in dataloader:
145
- with accelerator.accumulate(unet, vae, text_encoder, text_encoder_2):
146
- latents = vae.encode(batch["body"].to(accelerator.device)).latent_dist.sample()
147
- noise = torch.randn_like(latents)
148
- timesteps = torch.randint(0, pipe.scheduler.num_train_timesteps, (latents.shape[0],), device=accelerator.device)
149
- noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)
150
-
151
- # --- Generate Prompt ---
152
- prompt_parts = []
153
- if "body" in batch:
154
- if "belly" in batch:
155
- prompt_parts.append("belly visible")
156
- if "body_shape_1" in batch:
157
- prompt_parts.append("body shape 1")
158
- if "body_shape_2" in batch:
159
- prompt_parts.append("body shape 2")
160
- if "body_type" in batch:
161
- prompt_parts.append("body type")
162
- if "body measurements-proportion" in batch:
163
- prompt_parts.append("body measurements-proportion")
164
- if "details" in batch:
165
- if "eyebrows" in batch:
166
- prompt_parts.append("eyebrows")
167
- if "eyelashes_1" in batch:
168
- prompt_parts.append("eyelashes 1")
169
- if "eyelashes_2" in batch:
170
- prompt_parts.append("eyelashes 2")
171
- if "hair" in batch:
172
- prompt_parts.append("hair")
173
- if "lips" in batch:
174
- prompt_parts.append("lips")
175
- if "face" in batch:
176
- for i in range(1, 18):
177
- if f"face_{i}" in batch:
178
- prompt_parts.append(f"face {i}")
179
- if "pose" in batch:
180
- for i in range(1, 4):
181
- if f"pose_{i}" in batch:
182
- prompt_parts.append(f"pose {i}")
183
- if "skin" in batch:
184
- if "skin_tone" in batch:
185
- prompt_parts.append("skin tone")
186
- if "textures" in batch:
187
- for texture in batch["textures"]:
188
- prompt_parts.append(f"texture {texture}")
189
-
190
- prompt = "a photo of a woman, " + ", ".join(prompt_parts)
191
- prompt_embeds = pipe.text_encoder(pipe.tokenizer(prompt, return_tensors="pt").input_ids.to(accelerator.device))[0]
192
- prompt_embeds_2 = pipe.text_encoder_2(pipe.tokenizer(prompt, return_tensors="pt").input_ids.to(accelerator.device))[0]
193
-
194
- model_pred = unet(noisy_latents, timesteps, prompt_embeds, prompt_embeds_2).sample
195
-
196
- loss = torch.nn.functional.mse_loss(model_pred, noise)
197
- accelerator.backward(loss)
198
- optimizer.step()
199
- lr_scheduler.step()
200
- optimizer.zero_grad()
201
-
202
- progress_bar.update(1)
203
- global_step += 1
204
-
205
- if global_step % SAVE_STEPS == 0:
206
- if accelerator.is_main_process:
207
- print(f"Saving checkpoint at step {global_step}")
208
- save_path = os.path.join(OUTPUT_DIR, f"{MODEL_NAME}_{global_step}")
209
- accelerator.save_state(save_path)
210
- # Save LoRA weights
211
- unet_lora_layers.save_pretrained(os.path.join(save_path, "unet_lora"))
212
- vae_lora_layers.save_pretrained(os.path.join(save_path, "vae_lora"))
213
- text_encoder_lora_layers.save_pretrained(os.path.join(save_path, "text_encoder_lora"))
214
- text_encoder_2_lora_layers.save_pretrained(os.path.join(save_path, "text_encoder_2_lora"))
215
-
216
- # Save model card
217
- model_card = f"""---
218
- license: apache-2.0
219
- language:
220
- - en
221
- base_model:
222
- - {BASE_MODEL}
223
- pipeline_tag: text-to-image
224
  ---
225
-
226
- # Model Description: {MODEL_NAME}
227
-
228
- This LoRa model enhances text-to-image generation with a hyperrealistic style focusing on a specific subject.
229
-
230
- ## Subject Description:
231
- (Add detailed subject description here)
232
-
233
- ## Hyperrealistic Style: {True}
234
-
235
- ## Base Models:
236
- This model was trained using the following base model:
237
- - {BASE_MODEL}
238
-
239
-
240
- ## Usage Instructions:
241
- (Add detailed instructions on how to use this LoRa model here. Include example prompts)
242
-
243
-
244
- ## Training Data:
245
- (Add information on the training data here)
246
-
247
-
248
- ## Limitations:
249
- (List any known limitations of the model)
250
-
251
- ## Bias and Fairness Considerations:
252
- (Address potential bias in the model)
253
-
254
- ## Known Issues:
255
- (List any known issues)
256
-
257
- """
258
- with open(os.path.join(save_path, "model_card.txt"), "w") as f:
259
- f.write(model_card)
260
-
261
- # --- Save Final Model ---
262
- if accelerator.is_main_process:
263
- print("Saving final model")
264
- save_path = os.path.join(OUTPUT_DIR, f"{MODEL_NAME}_final")
265
- accelerator.save_state(save_path)
266
- # Save LoRA weights
267
- unet_lora_layers.save_pretrained(os.path.join(save_path, "unet_lora"))
268
- vae_lora_layers.save_pretrained(os.path.join(save_path, "vae_lora"))
269
- text_encoder_lora_layers.save_pretrained(os.path.join(save_path, "text_encoder_lora"))
270
- text_encoder_2_lora_layers.save_pretrained(os.path.join(save_path, "text_encoder_2_lora"))
271
-
272
- # Save model card
273
- model_card = f"""---
274
  license: apache-2.0
275
  language:
276
  - en
 
 
277
  base_model:
278
- - {BASE_MODEL}
 
 
 
 
 
279
  pipeline_tag: text-to-image
280
- ---
281
-
282
- # Model Description: {MODEL_NAME}
283
-
284
- This LoRa model enhances text-to-image generation with a hyperrealistic style focusing on a specific subject.
285
-
286
- ## Subject Description:
287
- (Add detailed subject description here)
288
-
289
- ## Hyperrealistic Style: {True}
290
-
291
- ## Base Models:
292
- This model was trained using the following base model:
293
- - {BASE_MODEL}
294
-
295
-
296
- ## Usage Instructions:
297
- (Add detailed instructions on how to use this LoRa model here. Include example prompts)
298
-
299
-
300
- ## Training Data:
301
- (Add information on the training data here)
302
-
303
-
304
- ## Limitations:
305
- (List any known limitations of the model)
306
-
307
- ## Bias and Fairness Considerations:
308
- (Address potential bias in the model)
309
-
310
- ## Known Issues:
311
- (List any known issues)
312
-
313
- """
314
- with open(os.path.join(save_path, "model_card.txt"), "w") as f:
315
- f.write(model_card)
316
-
317
- # --- Push to Hub ---
318
- if PUSH_TO_HUB and accelerator.is_main_process:
319
- print("Pushing to Hugging Face Hub")
320
- repo_id = HUB_REPO_ID
321
- repo_url = create_repo(repo_id, exist_ok=True, repo_type="model", token=HfFolder.get_token()).clone_url
322
- upload_folder(repo_id=repo_id, folder_path=save_path, token=HfFolder.get_token())
323
- print(f"Model pushed to {repo_url}")
324
-
325
- # --- FluxLoraModel Class ---
326
- class FluxLoraModel:
327
- def __init__(self, model_path="your-model-path", device="cuda"):
328
- self.device = device
329
- self.model = StableDiffusionPipeline.from_pretrained(
330
- model_path,
331
- torch_dtype=torch.float16,
332
- safety_checker=None,
333
- requires_safety_checker=False
334
- ).to(device)
335
-
336
- self.model.scheduler = DPMSolverMultistepScheduler.from_config(
337
- self.model.scheduler.config,
338
- algorithm_type="dpmsolver++",
339
- solver_order=2
340
- )
341
-
342
- self.model.enable_attention_slicing()
343
- self.model.enable_xformers_memory_efficient_attention()
344
-
345
- self.quality_modifiers = {
346
- 'realism': ["hyperrealistic", "photorealistic", "ultra realistic", "octane render", "raw photo", "unedited", "photographic", "35mm film"],
347
- 'resolution': ["4K UHD", "8K resolution", "ultra high definition", "extremely detailed", "high resolution"],
348
- 'detail_level': ["ultra detailed", "fine details", "intricate details", "sharp focus", "highly detailed", "maximum detail"]
349
- }
350
-
351
- self.texture_modifiers = {
352
- 'skin_details': ["detailed skin texture", "natural skin pores", "realistic skin subsurface scattering", "fine skin details"],
353
- 'clothing_details': ["detailed fabric texture", "intricate fabric weave", "realistic cloth folds", "natural fabric wrinkles"],
354
- 'hair_details': {
355
- 'general_quality': ["ultra detailed hair strands", "photorealistic hair texture", "volumetric hair rendering"],
356
- 'hair_types': {
357
- 'straight': ["silky straight hair", "smooth hair texture"],
358
- 'wavy': ["natural wave pattern", "defined hair waves"],
359
- 'curly': ["detailed curl pattern", "natural curl definition"],
360
- 'coily': ["detailed coil pattern", "natural coil definition"]
361
- }
362
- },
363
- 'eye_details': {
364
- 'general_quality': ["ultra detailed iris", "photorealistic eyes", "8K eye details"],
365
- 'iris_details': ["detailed iris patterns", "intricate iris fibers"],
366
- 'eye_properties': {
367
- 'reflection': ["natural catchlights", "realistic eye reflections"],
368
- 'moisture': ["natural eye moisture", "subtle tear film"],
369
- 'depth': ["volumetric eye depth", "realistic eye socket depth"]
370
- }
371
- }
372
- }
373
-
374
- def enhance_prompt(self, base_prompt):
375
- realism_mod = ", ".join(random.sample(self.quality_modifiers['realism'], 3))
376
- resolution_mod = ", ".join(random.sample(self.quality_modifiers['resolution'], 2))
377
- detail_mod = ", ".join(random.sample(self.quality_modifiers['detail_level'], 3))
378
-
379
- enhanced_prompt = f"{base_prompt}, {realism_mod}, {resolution_mod}, {detail_mod}, masterpiece, professional photography"
380
-
381
- if any(word in base_prompt.lower() for word in ["person", "portrait", "face"]):
382
- skin_mod = ", ".join(random.sample(self.texture_modifiers['skin_details'], 2))
383
- eye_mod = ", ".join(random.sample(self.texture_modifiers['eye_details']['general_quality'], 2))
384
- enhanced_prompt = f"{enhanced_prompt}, {skin_mod}, {eye_mod}"
385
-
386
- return enhanced_prompt
387
-
388
- def generate_image(self, prompt, negative_prompt="", num_images=1, steps=50, cfg_scale=8.5, width=2048, height=2048, seed=None, output_dir="outputs"):
389
- default_negative = "blur, haze, soft, deformed, low quality, low resolution, noise, grainy, bad details"
390
- enhanced_negative_prompt = f"{negative_prompt}, {default_negative}"
391
- enhanced_prompt = self.enhance_prompt(prompt)
392
-
393
- if width >= 1024 or height >= 1024:
394
- self.model.enable_vae_tiling()
395
-
396
- generator = torch.Generator(device=self.device).manual_seed(seed) if seed else None
397
-
398
- images = self.model(
399
- prompt=enhanced_prompt,
400
- negative_prompt=enhanced_negative_prompt,
401
- num_images_per_prompt=num_images,
402
- num_inference_steps=steps,
403
- guidance_scale=cfg_scale,
404
- width=width,
405
- height=height,
406
- generator=generator
407
- ).images
408
-
409
- processed_images = []
410
- for img in images:
411
- img = img.filter(ImageEnhance.Sharpness(1.2))
412
- img = img.filter(ImageEnhance.Contrast(1.1))
413
- processed_images.append(img)
414
-
415
- os.makedirs(output_dir, exist_ok=True)
416
- saved_paths = []
417
- for i, image in enumerate(processed_images):
418
- path = os.path.join(output_dir, f"flux_4k_detailed_{i}.png")
419
- image.save(path, "PNG", quality=100, optimize=True)
420
- saved_paths.append(path)
421
-
422
- return processed_images, saved_paths
423
-
424
- def generate_4k_portrait(self, prompt, **kwargs):
425
- return self.generate_image(prompt=prompt, width=3840, height=2160, steps=60, cfg_scale=9.0, **kwargs)
426
-
427
- @staticmethod
428
- def image_grid(imgs, rows, cols):
429
- w, h = imgs[0].size
430
- grid = Image.new('RGB', size=(cols * w, rows * h))
431
- for i, img in enumerate(imgs):
432
- grid.paste(img, box=(i % cols * w, i // cols * h))
433
- return grid
434
-
435
- def batch_generate(prompts_file, output_dir="batch_outputs", batch_size=4, **kwargs):
436
- model = FluxLoraModel()
437
- with open(prompts_file, 'r') as f:
438
- prompts = json.load(f)
439
-
440
- output_dir = Path(output_dir)
441
- output_dir.mkdir(parents=True, exist_ok=True)
442
-
443
- for i, prompt in enumerate(prompts):
444
- try:
445
- images, paths = model.generate_image(
446
- prompt=prompt,
447
- num_images=batch_size,
448
- output_dir=str(output_dir / f"prompt_{i}"),
449
- **kwargs
450
- )
451
- grid = model.image_grid(images, rows=batch_size//2, cols=2)
452
- grid.save(output_dir / f"prompt_{i}_grid.png")
453
- except Exception as e:
454
- print(f"Error processing prompt {i}: {str(e)}")
455
-
456
- # --- Generation after training ---
457
- if GENERATE_AFTER_TRAINING and accelerator.is_main_process:
458
- print("Generating images after training...")
459
- # Load the trained LoRA model
460
- model = FluxLoraModel(model_path=os.path.join(OUTPUT_DIR, f"{MODEL_NAME}_final"))
461
- batch_generate(PROMPTS_FILE, output_dir="generated_images", batch_size=BATCH_SIZE_GENERATE)
462
-
463
- if __name__ == "__main__":
464
- parser = argparse.ArgumentParser()
465
- parser.add_argument("--prompts", type=str, required=False)
466
- parser.add_argument("--output", type=str, default="outputs")
467
- parser.add_argument("--batch-size", type=int, default=4)
468
- args = parser.parse_args()
469
-
470
- if args.prompts:
471
- batch_generate(args.prompts, args.output, args.batch_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  license: apache-2.0
3
  language:
4
  - en
5
+ - es
6
+ - fr
7
  base_model:
8
+ - kudzueye/Boreal
9
+ - adirik/flux-cinestill
10
+ - Shakker-Labs/FLUX.1-dev-LoRA-add-details
11
+ - prithivMLmods/Flux-Realism-FineDetailed
12
+ - prithivMLmods/Canopus-LoRA-Flux-UltraRealism-2.0
13
+ - Jovie/Midjourney
14
  pipeline_tag: text-to-image
15
+ ---