fantaxy commited on
Commit
719d3c9
ยท
verified ยท
1 Parent(s): 28f1c45

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +529 -556
app.py CHANGED
@@ -25,98 +25,303 @@ import gc
25
  import csv
26
  from datetime import datetime
27
  from openai import OpenAI
28
- import spaces
29
- import argparse
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  import time
32
  from os import path
33
  import shutil
34
- from datetime import datetime
35
  from safetensors.torch import load_file
36
- from huggingface_hub import hf_hub_download
37
- import gradio as gr
38
- import torch
39
  from diffusers import FluxPipeline
40
  from diffusers.pipelines.stable_diffusion import safety_checker
41
- from PIL import Image
42
- from transformers import pipeline
43
  import replicate
44
  import logging
45
  import requests
46
  from pathlib import Path
47
- import cv2
48
- import numpy as np
49
  import sys
50
  import io
51
 
52
- # FluxPipeline import ๋ถ€๋ถ„์„ ์ˆ˜์ •
53
- from diffusers import StableDiffusionPipeline, DiffusionPipeline
54
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
55
- from diffusers import AutoPipelineForText2Image
56
-
57
- # Model initialization ๋ถ€๋ถ„ ์ˆ˜์ •
58
- if not path.exists(cache_path):
59
- os.makedirs(cache_path, exist_ok=True)
60
-
61
- # ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™” ์ˆ˜์ •
62
- pipe = AutoPipelineForText2Image.from_pretrained(
63
- "stabilityai/stable-diffusion-xl-base-1.0",
64
- torch_dtype=torch.float16,
65
- use_safetensors=True,
66
- variant="fp16"
67
- )
68
- pipe.to("cuda")
69
-
70
- # ์•ˆ์ „ ๊ฒ€์‚ฌ๊ธฐ ์„ค์ •
71
- pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained(
72
- "CompVis/stable-diffusion-safety-checker"
73
- )
74
-
75
- # process_and_save_image ํ•จ์ˆ˜ ์ˆ˜์ •
76
- @spaces.GPU
77
- def process_and_save_image(height, width, steps, scales, prompt, seed):
78
- is_safe, processed_prompt = process_prompt(prompt)
79
- if not is_safe:
80
- gr.Warning("๋ถ€์ ์ ˆํ•œ ๋‚ด์šฉ์ด ํฌํ•จ๋œ ํ”„๋กฌํ”„ํŠธ์ž…๋‹ˆ๋‹ค.")
81
- return None, load_gallery()
82
-
83
- with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
84
- try:
85
- generated_image = pipe(
86
- prompt=processed_prompt,
87
- negative_prompt="low quality, worst quality, bad anatomy, bad composition, poor, low effort",
88
- num_inference_steps=steps,
89
- guidance_scale=scales,
90
- height=height,
91
- width=width,
92
- generator=torch.Generator("cuda").manual_seed(int(seed))
93
- ).images[0]
94
-
95
- # PIL Image๋กœ ํ™•์‹คํ•˜๊ฒŒ ๋ณ€ํ™˜
96
- if not isinstance(generated_image, Image.Image):
97
- generated_image = Image.fromarray(generated_image)
98
-
99
- # RGB ๋ชจ๋“œ๋กœ ๋ณ€ํ™˜
100
- if generated_image.mode != 'RGB':
101
- generated_image = generated_image.convert('RGB')
102
-
103
- # ๋ฉ”๋ชจ๋ฆฌ์—์„œ PNG๋กœ ๋ณ€ํ™˜
104
- img_byte_arr = io.BytesIO()
105
- generated_image.save(img_byte_arr, format='PNG')
106
- img_byte_arr = img_byte_arr.getvalue()
107
-
108
- # ๋””์Šคํฌ์— ์ €์žฅ
109
- saved_path = save_image(generated_image)
110
- if saved_path is None:
111
- logger.warning("Failed to save generated image")
112
- return None, load_gallery()
113
-
114
- # PNG ํ˜•์‹์œผ๋กœ ๋‹ค์‹œ ๋กœ๋“œ
115
- return Image.open(io.BytesIO(img_byte_arr)), load_gallery()
116
- except Exception as e:
117
- logger.error(f"Error in image generation: {str(e)}")
118
- return None, load_gallery()
119
-
120
  logging.basicConfig(level=logging.INFO)
121
  logger = logging.getLogger(__name__)
122
 
@@ -138,6 +343,9 @@ os.environ["HF_HOME"] = cache_path
138
  # CUDA ์„ค์ •
139
  torch.backends.cuda.matmul.allow_tf32 = True
140
 
 
 
 
141
  # ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ
142
  for dir_path in [gallery_path, video_gallery_path]:
143
  if not path.exists(dir_path):
@@ -155,7 +363,7 @@ def check_api_key():
155
  def translate_if_korean(text):
156
  """ํ•œ๊ธ€์ด ํฌํ•จ๋œ ๊ฒฝ์šฐ ์˜์–ด๋กœ ๋ฒˆ์—ญ"""
157
  if any(ord(char) >= 0xAC00 and ord(char) <= 0xD7A3 for char in text):
158
- translation = translator(text)[0]['translation_text']
159
  return translation
160
  return text
161
 
@@ -173,28 +381,11 @@ def filter_prompt(prompt):
173
  return False, "๋ถ€์ ์ ˆํ•œ ๋‚ด์šฉ์ด ํฌํ•จ๋œ ํ”„๋กฌํ”„ํŠธ์ž…๋‹ˆ๋‹ค."
174
  return True, prompt
175
 
176
- def process_prompt(prompt):
177
  """ํ”„๋กฌํ”„ํŠธ ์ „์ฒ˜๋ฆฌ (๋ฒˆ์—ญ ๋ฐ ํ•„ํ„ฐ๋ง)"""
178
- # ํ•œ๊ธ€์ด ํฌํ•จ๋˜์–ด ์žˆ๋Š”์ง€ ํ™•์ธ
179
- if any(ord('๊ฐ€') <= ord(char) <= ord('ํžฃ') for char in prompt):
180
- # ํ•œ๊ธ€์„ ์˜์–ด๋กœ ๋ฒˆ์—ญ
181
- translated = translator(prompt)[0]['translation_text']
182
- prompt = translated
183
-
184
- # ๋ถ€์ ์ ˆํ•œ ๋‚ด์šฉ ํ•„ํ„ฐ๋ง
185
- inappropriate_keywords = [
186
- "nude", "naked", "nsfw", "porn", "sex", "explicit", "adult", "xxx",
187
- "erotic", "sensual", "seductive", "provocative", "intimate",
188
- "violence", "gore", "blood", "death", "kill", "murder", "torture",
189
- "drug", "suicide", "abuse", "hate", "discrimination"
190
- ]
191
-
192
- prompt_lower = prompt.lower()
193
- for keyword in inappropriate_keywords:
194
- if keyword in prompt_lower:
195
- return False, "๋ถ€์ ์ ˆํ•œ ๋‚ด์šฉ์ด ํฌํ•จ๋œ ํ”„๋กฌํ”„ํŠธ์ž…๋‹ˆ๋‹ค."
196
-
197
- return True, prompt
198
 
199
  class timer:
200
  def __init__(self, method_name="timed process"):
@@ -206,6 +397,15 @@ class timer:
206
  end = time.time()
207
  print(f"{self.method} took {str(round(end - self.start, 2))}s")
208
 
 
 
 
 
 
 
 
 
 
209
 
210
  def upload_to_catbox(image_path):
211
  """catbox.moe API๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ด๋ฏธ์ง€ ์—…๋กœ๋“œ"""
@@ -257,415 +457,142 @@ def add_watermark(video_path):
257
  font_scale = height * 0.05 / 30
258
  thickness = 2
259
  color = (255, 255, 255)
260
-
261
- (text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness)
262
- margin = int(height * 0.02)
263
- x_pos = width - text_width - margin
264
- y_pos = height - margin
265
-
266
- output_path = "watermarked_output.mp4"
267
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
268
- out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
269
-
270
- while cap.isOpened():
271
- ret, frame = cap.read()
272
- if not ret:
273
- break
274
- cv2.putText(frame, text, (x_pos, y_pos), font, font_scale, color, thickness)
275
- out.write(frame)
276
-
277
- cap.release()
278
- out.release()
279
-
280
- return output_path
281
-
282
- except Exception as e:
283
- logger.error(f"Error adding watermark: {str(e)}")
284
- return video_path
285
-
286
- def generate_video(image, prompt):
287
- logger.info("Starting video generation")
288
- try:
289
- if not check_api_key():
290
- return "Replicate API key not properly configured"
291
-
292
- if not image:
293
- logger.error("No image provided")
294
- return "Please upload an image"
295
-
296
- image_url = upload_to_catbox(image)
297
- if not image_url:
298
- return "Failed to upload image"
299
-
300
- input_data = {
301
- "prompt": prompt,
302
- "first_frame_image": image_url
303
- }
304
-
305
- try:
306
- replicate.Client(api_token=REPLICATE_API_TOKEN)
307
- output = replicate.run(
308
- "minimax/video-01-live",
309
- input=input_data
310
- )
311
-
312
- temp_file = "temp_output.mp4"
313
-
314
- if hasattr(output, 'read'):
315
- with open(temp_file, "wb") as file:
316
- file.write(output.read())
317
- elif isinstance(output, str):
318
- response = requests.get(output)
319
- with open(temp_file, "wb") as file:
320
- file.write(response.content)
321
-
322
- final_video = add_watermark(temp_file)
323
- return final_video
324
-
325
- except Exception as api_error:
326
- logger.error(f"API call failed: {str(api_error)}")
327
- return f"API call failed: {str(api_error)}"
328
-
329
- except Exception as e:
330
- logger.error(f"Unexpected error: {str(e)}")
331
- return f"Unexpected error: {str(e)}"
332
-
333
- def save_image(image):
334
- """Save the generated image in PNG format and return the path"""
335
- try:
336
- if not os.path.exists(gallery_path):
337
- os.makedirs(gallery_path, exist_ok=True)
338
-
339
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
340
- random_suffix = os.urandom(4).hex()
341
- filename = f"generated_{timestamp}_{random_suffix}.png"
342
- filepath = os.path.join(gallery_path, filename)
343
-
344
- # PIL Image๋กœ ๋ณ€ํ™˜
345
- if not isinstance(image, Image.Image):
346
- image = Image.fromarray(image)
347
-
348
- # RGB ๋ชจ๋“œ๋กœ ๋ณ€ํ™˜ (RGBA์—์„œ ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ๋Š” ๋ฌธ์ œ ๋ฐฉ์ง€)
349
- if image.mode != 'RGB':
350
- image = image.convert('RGB')
351
-
352
- # PNG ํ˜•์‹์œผ๋กœ ๋ช…์‹œ์  ์ €์žฅ
353
- image.save(
354
- filepath,
355
- format='PNG',
356
- optimize=True,
357
- quality=100 # ์ตœ๊ณ  ํ’ˆ์งˆ
358
- )
359
-
360
- logger.info(f"Image saved successfully as PNG: {filepath}")
361
- return filepath
362
- except Exception as e:
363
- logger.error(f"Error in save_image: {str(e)}")
364
- return None
365
-
366
- def load_gallery():
367
- """Load all images from the gallery directory"""
368
- try:
369
- os.makedirs(gallery_path, exist_ok=True)
370
-
371
- image_files = []
372
- for f in os.listdir(gallery_path):
373
- if f.lower().endswith(('.png', '.jpg', '.jpeg')):
374
- full_path = os.path.join(gallery_path, f)
375
- image_files.append((full_path, os.path.getmtime(full_path)))
376
-
377
- image_files.sort(key=lambda x: x[1], reverse=True)
378
- return [f[0] for f in image_files]
379
- except Exception as e:
380
- print(f"Error loading gallery: {str(e)}")
381
- return []
382
-
383
-
384
- # ํ•œ๊ธ€-์˜์–ด ๋ฒˆ์—ญ๊ธฐ ์ดˆ๊ธฐํ™”
385
- translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
386
-
387
- MAX_SEED = np.iinfo(np.int32).max
388
-
389
- # Load Hugging Face token if needed
390
- hf_token = os.getenv("HF_TOKEN")
391
- openai_api_key = os.getenv("OPENAI_API_KEY")
392
- client = OpenAI(api_key=openai_api_key)
393
-
394
- system_prompt_t2v_path = "assets/system_prompt_t2v.txt"
395
- with open(system_prompt_t2v_path, "r") as f:
396
- system_prompt_t2v = f.read()
397
-
398
- # Set model download directory within Hugging Face Spaces
399
- model_path = "asset"
400
-
401
- commit_hash='c7c8ad4c2ddba847b94e8bfaefbd30bd8669fafc'
402
-
403
- if not os.path.exists(model_path):
404
- snapshot_download("Lightricks/LTX-Video", revision=commit_hash, local_dir=model_path, repo_type="model", token=hf_token)
405
-
406
- # Global variables to load components
407
- vae_dir = Path(model_path) / "vae"
408
- unet_dir = Path(model_path) / "unet"
409
- scheduler_dir = Path(model_path) / "scheduler"
410
-
411
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
412
-
413
- clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", cache_dir=model_path).to(torch.device("cuda:0"))
414
- clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", cache_dir=model_path)
415
-
416
- # ํ•˜๋‚˜์˜ ์ผ๊ด€๋œ CUDA ์„ค์ • ์‚ฌ์šฉ
417
- torch.backends.cuda.matmul.allow_tf32 = False
418
- torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
419
- torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
420
- torch.backends.cudnn.allow_tf32 = False
421
- torch.backends.cudnn.deterministic = False
422
- torch.backends.cuda.preferred_blas_library = "cublas"
423
- torch.set_float32_matmul_precision("highest")
424
-
425
-
426
-
427
- def compute_clip_embedding(text=None):
428
- inputs = clip_processor(text=text, return_tensors="pt", padding=True).to(device)
429
- outputs = clip_model.get_text_features(**inputs)
430
- embedding = outputs.detach().cpu().numpy().flatten().tolist()
431
- return embedding
432
-
433
- def load_vae(vae_dir):
434
- vae_ckpt_path = vae_dir / "vae_diffusion_pytorch_model.safetensors"
435
- vae_config_path = vae_dir / "config.json"
436
- with open(vae_config_path, "r") as f:
437
- vae_config = json.load(f)
438
- vae = CausalVideoAutoencoder.from_config(vae_config)
439
- vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
440
- vae.load_state_dict(vae_state_dict)
441
- return vae.to(device).to(torch.bfloat16)
442
-
443
- def load_unet(unet_dir):
444
- unet_ckpt_path = unet_dir / "unet_diffusion_pytorch_model.safetensors"
445
- unet_config_path = unet_dir / "config.json"
446
- transformer_config = Transformer3DModel.load_config(unet_config_path)
447
- transformer = Transformer3DModel.from_config(transformer_config)
448
- unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
449
- transformer.load_state_dict(unet_state_dict, strict=True)
450
- return transformer.to(device).to(torch.bfloat16)
451
-
452
- def load_scheduler(scheduler_dir):
453
- scheduler_config_path = scheduler_dir / "scheduler_config.json"
454
- scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
455
- return RectifiedFlowScheduler.from_config(scheduler_config)
456
-
457
- # Preset options for resolution and frame configuration
458
- preset_options = [
459
- {"label": "1216x704, 41 frames", "width": 1216, "height": 704, "num_frames": 41},
460
- {"label": "1088x704, 49 frames", "width": 1088, "height": 704, "num_frames": 49},
461
- {"label": "1056x640, 57 frames", "width": 1056, "height": 640, "num_frames": 57},
462
- {"label": "448x448, 100 frames", "width": 448, "height": 448, "num_frames": 100},
463
- {"label": "448x448, 200 frames", "width": 448, "height": 448, "num_frames": 200},
464
- {"label": "448x448, 300 frames", "width": 448, "height": 448, "num_frames": 300},
465
- {"label": "640x640, 80 frames", "width": 640, "height": 640, "num_frames": 80},
466
- {"label": "640x640, 120 frames", "width": 640, "height": 640, "num_frames": 120},
467
- {"label": "768x768, 64 frames", "width": 768, "height": 768, "num_frames": 64},
468
- {"label": "768x768, 90 frames", "width": 768, "height": 768, "num_frames": 90},
469
- {"label": "720x720, 64 frames", "width": 768, "height": 768, "num_frames": 64},
470
- {"label": "720x720, 100 frames", "width": 768, "height": 768, "num_frames": 100},
471
- {"label": "768x512, 97 frames", "width": 768, "height": 512, "num_frames": 97},
472
- {"label": "512x512, 160 frames", "width": 512, "height": 512, "num_frames": 160},
473
- {"label": "512x512, 200 frames", "width": 512, "height": 512, "num_frames": 200},
474
- ]
475
-
476
- def preset_changed(preset):
477
- if preset != "Custom":
478
- selected = next(item for item in preset_options if item["label"] == preset)
479
- return (
480
- selected["height"],
481
- selected["width"],
482
- selected["num_frames"],
483
- gr.update(visible=False),
484
- gr.update(visible=False),
485
- gr.update(visible=False),
486
- )
487
- else:
488
- return (
489
- None,
490
- None,
491
- None,
492
- gr.update(visible=True),
493
- gr.update(visible=True),
494
- gr.update(visible=True),
495
- )
496
-
497
- # Load models
498
- vae = load_vae(vae_dir)
499
- unet = load_unet(unet_dir)
500
- scheduler = load_scheduler(scheduler_dir)
501
- patchifier = SymmetricPatchifier(patch_size=1)
502
- text_encoder = T5EncoderModel.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder").to(torch.device("cuda:0"))
503
- tokenizer = T5Tokenizer.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer")
504
-
505
- pipeline = XoraVideoPipeline(
506
- transformer=unet,
507
- patchifier=patchifier,
508
- text_encoder=text_encoder,
509
- tokenizer=tokenizer,
510
- scheduler=scheduler,
511
- vae=vae,
512
- ).to(torch.device("cuda:0"))
513
-
514
- def enhance_prompt_if_enabled(prompt, enhance_toggle):
515
- if not enhance_toggle:
516
- print("Enhance toggle is off, Prompt: ", prompt)
517
- return prompt
518
-
519
- messages = [
520
- {"role": "system", "content": system_prompt_t2v},
521
- {"role": "user", "content": prompt},
522
- ]
523
 
 
 
524
  try:
525
- response = client.chat.completions.create(
526
- model="gpt-4-mini",
527
- messages=messages,
528
- max_tokens=200,
529
- )
530
- print("Enhanced Prompt: ", response.choices[0].message.content.strip())
531
- return response.choices[0].message.content.strip()
532
- except Exception as e:
533
- print(f"Error: {e}")
534
- return prompt
535
 
536
- @spaces.GPU(duration=90)
537
- def generate_video_from_text_90(
538
- prompt="",
539
- enhance_prompt_toggle=False,
540
- negative_prompt="",
541
- frame_rate=25,
542
- seed=random.randint(0, MAX_SEED),
543
- num_inference_steps=30,
544
- guidance_scale=3.2,
545
- height=768,
546
- width=768,
547
- num_frames=60,
548
- progress=gr.Progress(),
549
- ):
550
- # ํ”„๋กฌํ”„ํŠธ ์ „์ฒ˜๋ฆฌ (ํ•œ๊ธ€ -> ์˜์–ด)
551
- prompt = process_prompt(prompt)
552
- negative_prompt = process_prompt(negative_prompt)
553
 
554
- if len(prompt.strip()) < 50:
555
- raise gr.Error(
556
- "Prompt must be at least 50 characters long. Please provide more details for the best results.",
557
- duration=5,
558
- )
559
 
560
- prompt = enhance_prompt_if_enabled(prompt, enhance_prompt_toggle)
 
 
 
561
 
562
- sample = {
563
- "prompt": prompt,
564
- "prompt_attention_mask": None,
565
- "negative_prompt": negative_prompt,
566
- "negative_prompt_attention_mask": None,
567
- "media_items": None,
568
- }
569
 
570
- generator = torch.Generator(device="cuda").manual_seed(seed)
 
 
 
 
 
 
 
 
 
 
 
571
 
572
- def gradio_progress_callback(self, step, timestep, kwargs):
573
- progress((step + 1) / num_inference_steps)
 
574
 
575
- try:
576
- with torch.no_grad():
577
- images = pipeline(
578
- num_inference_steps=num_inference_steps,
579
- num_images_per_prompt=1,
580
- guidance_scale=guidance_scale,
581
- generator=generator,
582
- output_type="pt",
583
- height=height,
584
- width=width,
585
- num_frames=num_frames,
586
- frame_rate=frame_rate,
587
- **sample,
588
- is_video=True,
589
- vae_per_channel_normalize=True,
590
- conditioning_method=ConditioningMethod.UNCONDITIONAL,
591
- mixed_precision=True,
592
- callback_on_step_end=gradio_progress_callback,
593
- ).images
594
  except Exception as e:
595
- raise gr.Error(
596
- f"An error occurred while generating the video. Please try again. Error: {e}",
597
- duration=5,
598
- )
599
- finally:
600
- torch.cuda.empty_cache()
601
- gc.collect()
602
-
603
- output_path = tempfile.mktemp(suffix=".mp4")
604
- video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
605
- video_np = (video_np * 255).astype(np.uint8)
606
- height, width = video_np.shape[1:3]
607
- out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_rate, (width, height))
608
- for frame in video_np[..., ::-1]:
609
- out.write(frame)
610
- out.release()
611
- del images
612
- del video_np
613
- torch.cuda.empty_cache()
614
- return output_path
615
 
616
- def create_advanced_options():
617
- with gr.Accordion("Step 4: Advanced Options (Optional)", open=False):
618
- seed = gr.Slider(label="4.1 Seed", minimum=0, maximum=1000000, step=1, value=646373)
619
- inference_steps = gr.Slider(label="4.2 Inference Steps", minimum=5, maximum=150, step=5, value=40)
620
- guidance_scale = gr.Slider(label="4.3 Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=4.2)
621
 
622
- height_slider = gr.Slider(
623
- label="4.4 Height",
624
- minimum=256,
625
- maximum=1024,
626
- step=64,
627
- value=768,
628
- visible=False,
629
- )
630
- width_slider = gr.Slider(
631
- label="4.5 Width",
632
- minimum=256,
633
- maximum=1024,
634
- step=64,
635
- value=768,
636
- visible=False,
637
- )
638
- num_frames_slider = gr.Slider(
639
- label="4.5 Number of Frames",
640
- minimum=1,
641
- maximum=500,
642
- step=1,
643
- value=60,
644
- visible=False,
645
  )
 
 
 
 
 
 
646
 
647
- return [
648
- seed,
649
- inference_steps,
650
- guidance_scale,
651
- height_slider,
652
- width_slider,
653
- num_frames_slider,
654
- ]
 
 
 
 
 
 
 
 
655
 
656
  # CSS ์Šคํƒ€์ผ ์ •์˜
657
  css = """
658
  [์ด์ „์˜ CSS ์ฝ”๋“œ๋ฅผ ๊ทธ๋Œ€๋กœ ์œ ์ง€]
659
  """
660
 
 
 
661
 
 
 
 
662
 
663
- # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ
664
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
665
  gr.HTML('<div class="title">AI Image & Video Generator</div>')
666
 
667
  with gr.Tabs():
668
- # ์ฒซ ๋ฒˆ์งธ ํƒญ: Image Generation
669
  with gr.Tab("Image Generation"):
670
  with gr.Row():
671
  with gr.Column(scale=3):
@@ -708,9 +635,6 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
708
  value=3.5
709
  )
710
 
711
- def get_random_seed():
712
- return torch.randint(0, 1000000, (1,)).item()
713
-
714
  seed = gr.Number(
715
  label="Seed",
716
  value=get_random_seed(),
@@ -741,7 +665,6 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
741
  )
742
  img_gallery.value = load_gallery()
743
 
744
- # ๋‘ ๋ฒˆ์งธ ํƒญ: Video Generation
745
  with gr.Tab("Video Generation"):
746
  with gr.Row():
747
  with gr.Column(scale=3):
@@ -770,53 +693,124 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
770
  object_fit="cover"
771
  )
772
 
773
- # ์„ธ ๋ฒˆ์งธ ํƒญ: AI Video Generation
774
- with gr.Tab("AI Video Generation"):
775
- with gr.Column():
776
- txt2vid_prompt = gr.Textbox(
777
- label="Step 1: Enter Your Prompt (ํ•œ๊ธ€ ๋˜๋Š” ์˜์–ด)",
778
- placeholder="์ƒ์„ฑํ•˜๊ณ  ์‹ถ์€ ๋น„๋””์˜ค๋ฅผ ์„ค๋ช…ํ•˜์„ธ์š” (์ตœ์†Œ 50์ž)...",
779
- value="๊ธด ๊ฐˆ์ƒ‰ ๋จธ๋ฆฌ์™€ ๋ฐ์€ ํ”ผ๋ถ€๋ฅผ ๊ฐ€์ง„ ์—ฌ์„ฑ์ด ๊ธด ๊ธˆ๋ฐœ ๋จธ๋ฆฌ๋ฅผ ๊ฐ€์ง„ ๋‹ค๋ฅธ ์—ฌ์„ฑ์„ ํ–ฅํ•ด ๋ฏธ์†Œ ์ง“์Šต๋‹ˆ๋‹ค. ๊ฐˆ์ƒ‰ ๋จธ๋ฆฌ ์—ฌ์„ฑ์€ ๊ฒ€์€ ์žฌํ‚ท์„ ์ž…๊ณ  ์žˆ์œผ๋ฉฐ ์˜ค๋ฅธ์ชฝ ๋บจ์— ์ž‘๊ณ  ๊ฑฐ์˜ ๋ˆˆ์— ๋„์ง€ ์•Š๋Š” ์ ์ด ์žˆ์Šต๋‹ˆ๋‹ค. ์นด๋ฉ”๋ผ ์•ต๊ธ€์€ ๊ฐˆ์ƒ‰ ๋จธ๋ฆฌ ์—ฌ์„ฑ์˜ ์–ผ๊ตด์— ์ดˆ์ ์„ ๋งž์ถ˜ ํด๋กœ์ฆˆ์—…์ž…๋‹ˆ๋‹ค. ์กฐ๋ช…์€ ๋”ฐ๋œปํ•˜๊ณ  ์ž์—ฐ์Šค๋Ÿฌ์šฐ๋ฉฐ, ์•„๋งˆ๋„ ์ง€๋Š” ํ•ด์—์„œ ๋‚˜์˜ค๋Š” ๊ฒƒ ๊ฐ™์•„ ์žฅ๋ฉด์— ๋ถ€๋“œ๋Ÿฌ์šด ๋น›์„ ๋น„์ถฅ๋‹ˆ๋‹ค.",
780
- lines=5,
781
- )
782
-
783
- txt2vid_enhance_toggle = Toggle(
784
- label="Enhance Prompt",
785
- value=False,
786
- interactive=True,
787
- )
788
-
789
- txt2vid_negative_prompt = gr.Textbox(
790
- label="Step 2: Enter Negative Prompt",
791
- placeholder="๋น„๋””์˜ค์—์„œ ์›ํ•˜์ง€ ์•Š๋Š” ์š”์†Œ๋ฅผ ์„ค๋ช…ํ•˜์„ธ์š”...",
792
- value="low quality, worst quality, deformed, distorted, damaged, motion blur, motion artifacts, fused fingers, incorrect anatomy, strange hands, ugly",
793
- lines=2,
794
- )
795
-
796
- txt2vid_preset = gr.Dropdown(
797
- choices=[p["label"] for p in preset_options],
798
- value="512x512, 160 frames",
799
- label="Step 3.1: Choose Resolution Preset",
800
- )
801
-
802
- txt2vid_frame_rate = gr.Slider(
803
- label="Step 3.2: Frame Rate",
804
- minimum=6,
805
- maximum=60,
806
- step=1,
807
- value=20,
808
- )
809
-
810
- txt2vid_advanced = create_advanced_options()
811
- txt2vid_generate = gr.Button(
812
- "Step 5: Generate Video",
813
- variant="primary",
814
- size="lg",
815
- )
816
-
817
- txt2vid_output = gr.Video(label="Generated Output")
818
-
819
- # ์ด๋ฒคํŠธ ํ•ธ๋“ค๋Ÿฌ ์—ฐ๊ฒฐ
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
820
  generate_btn.click(
821
  process_and_save_image,
822
  inputs=[height, width, steps, scales, img_prompt, seed],
@@ -839,26 +833,5 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
839
  outputs=[seed]
840
  )
841
 
842
- txt2vid_preset.change(
843
- fn=preset_changed,
844
- inputs=[txt2vid_preset],
845
- outputs=txt2vid_advanced[3:],
846
- )
847
-
848
- txt2vid_generate.click(
849
- fn=generate_video_from_text_90,
850
- inputs=[
851
- txt2vid_prompt,
852
- txt2vid_enhance_toggle,
853
- txt2vid_negative_prompt,
854
- txt2vid_frame_rate,
855
- *txt2vid_advanced,
856
- ],
857
- outputs=txt2vid_output,
858
- concurrency_limit=1,
859
- concurrency_id="generate_video",
860
- queue=True,
861
- )
862
-
863
  if __name__ == "__main__":
864
- demo.launch(allowed_paths=[PERSISTENT_DIR])
 
25
  import csv
26
  from datetime import datetime
27
  from openai import OpenAI
 
 
28
 
29
+ # ํ•œ๊ธ€-์˜์–ด ๋ฒˆ์—ญ๊ธฐ ์ดˆ๊ธฐํ™”
30
+ translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
31
+
32
+ torch.backends.cuda.matmul.allow_tf32 = False
33
+ torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
34
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
35
+ torch.backends.cudnn.allow_tf32 = False
36
+ torch.backends.cudnn.deterministic = False
37
+ torch.backends.cuda.preferred_blas_library="cublas"
38
+ torch.set_float32_matmul_precision("highest")
39
+
40
+ MAX_SEED = np.iinfo(np.int32).max
41
+
42
+ # Load Hugging Face token if needed
43
+ hf_token = os.getenv("HF_TOKEN")
44
+ openai_api_key = os.getenv("OPENAI_API_KEY")
45
+ client = OpenAI(api_key=openai_api_key)
46
+
47
+ system_prompt_t2v_path = "assets/system_prompt_t2v.txt"
48
+ with open(system_prompt_t2v_path, "r") as f:
49
+ system_prompt_t2v = f.read()
50
+
51
+ # Set model download directory within Hugging Face Spaces
52
+ model_path = "asset"
53
+
54
+ commit_hash='c7c8ad4c2ddba847b94e8bfaefbd30bd8669fafc'
55
+
56
+ if not os.path.exists(model_path):
57
+ snapshot_download("Lightricks/LTX-Video", revision=commit_hash, local_dir=model_path, repo_type="model", token=hf_token)
58
+
59
+ # Global variables to load components
60
+ vae_dir = Path(model_path) / "vae"
61
+ unet_dir = Path(model_path) / "unet"
62
+ scheduler_dir = Path(model_path) / "scheduler"
63
+
64
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
+
66
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", cache_dir=model_path).to(torch.device("cuda:0"))
67
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", cache_dir=model_path)
68
+
69
+ def process_prompt(prompt):
70
+ # ํ•œ๊ธ€์ด ํฌํ•จ๋˜์–ด ์žˆ๋Š”์ง€ ํ™•์ธ
71
+ if any(ord('๊ฐ€') <= ord(char) <= ord('ํžฃ') for char in prompt):
72
+ # ํ•œ๊ธ€์„ ์˜์–ด๋กœ ๋ฒˆ์—ญ
73
+ translated = translator(prompt)[0]['translation_text']
74
+ return translated
75
+ return prompt
76
+
77
+ def compute_clip_embedding(text=None):
78
+ inputs = clip_processor(text=text, return_tensors="pt", padding=True).to(device)
79
+ outputs = clip_model.get_text_features(**inputs)
80
+ embedding = outputs.detach().cpu().numpy().flatten().tolist()
81
+ return embedding
82
+
83
+ def load_vae(vae_dir):
84
+ vae_ckpt_path = vae_dir / "vae_diffusion_pytorch_model.safetensors"
85
+ vae_config_path = vae_dir / "config.json"
86
+ with open(vae_config_path, "r") as f:
87
+ vae_config = json.load(f)
88
+ vae = CausalVideoAutoencoder.from_config(vae_config)
89
+ vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
90
+ vae.load_state_dict(vae_state_dict)
91
+ return vae.to(device).to(torch.bfloat16)
92
+
93
+ def load_unet(unet_dir):
94
+ unet_ckpt_path = unet_dir / "unet_diffusion_pytorch_model.safetensors"
95
+ unet_config_path = unet_dir / "config.json"
96
+ transformer_config = Transformer3DModel.load_config(unet_config_path)
97
+ transformer = Transformer3DModel.from_config(transformer_config)
98
+ unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
99
+ transformer.load_state_dict(unet_state_dict, strict=True)
100
+ return transformer.to(device).to(torch.bfloat16)
101
+
102
+ def load_scheduler(scheduler_dir):
103
+ scheduler_config_path = scheduler_dir / "scheduler_config.json"
104
+ scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
105
+ return RectifiedFlowScheduler.from_config(scheduler_config)
106
+
107
+ # Preset options for resolution and frame configuration
108
+ preset_options = [
109
+ {"label": "1216x704, 41 frames", "width": 1216, "height": 704, "num_frames": 41},
110
+ {"label": "1088x704, 49 frames", "width": 1088, "height": 704, "num_frames": 49},
111
+ {"label": "1056x640, 57 frames", "width": 1056, "height": 640, "num_frames": 57},
112
+ {"label": "448x448, 100 frames", "width": 448, "height": 448, "num_frames": 100},
113
+ {"label": "448x448, 200 frames", "width": 448, "height": 448, "num_frames": 200},
114
+ {"label": "448x448, 300 frames", "width": 448, "height": 448, "num_frames": 300},
115
+ {"label": "640x640, 80 frames", "width": 640, "height": 640, "num_frames": 80},
116
+ {"label": "640x640, 120 frames", "width": 640, "height": 640, "num_frames": 120},
117
+ {"label": "768x768, 64 frames", "width": 768, "height": 768, "num_frames": 64},
118
+ {"label": "768x768, 90 frames", "width": 768, "height": 768, "num_frames": 90},
119
+ {"label": "720x720, 64 frames", "width": 768, "height": 768, "num_frames": 64},
120
+ {"label": "720x720, 100 frames", "width": 768, "height": 768, "num_frames": 100},
121
+ {"label": "768x512, 97 frames", "width": 768, "height": 512, "num_frames": 97},
122
+ {"label": "512x512, 160 frames", "width": 512, "height": 512, "num_frames": 160},
123
+ {"label": "512x512, 200 frames", "width": 512, "height": 512, "num_frames": 200},
124
+ ]
125
+
126
+ def preset_changed(preset):
127
+ if preset != "Custom":
128
+ selected = next(item for item in preset_options if item["label"] == preset)
129
+ return (
130
+ selected["height"],
131
+ selected["width"],
132
+ selected["num_frames"],
133
+ gr.update(visible=False),
134
+ gr.update(visible=False),
135
+ gr.update(visible=False),
136
+ )
137
+ else:
138
+ return (
139
+ None,
140
+ None,
141
+ None,
142
+ gr.update(visible=True),
143
+ gr.update(visible=True),
144
+ gr.update(visible=True),
145
+ )
146
+
147
+ # Load models
148
+ vae = load_vae(vae_dir)
149
+ unet = load_unet(unet_dir)
150
+ scheduler = load_scheduler(scheduler_dir)
151
+ patchifier = SymmetricPatchifier(patch_size=1)
152
+ text_encoder = T5EncoderModel.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder").to(torch.device("cuda:0"))
153
+ tokenizer = T5Tokenizer.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer")
154
+
155
+ pipeline_video = XoraVideoPipeline(
156
+ transformer=unet,
157
+ patchifier=patchifier,
158
+ text_encoder=text_encoder,
159
+ tokenizer=tokenizer,
160
+ scheduler=scheduler,
161
+ vae=vae,
162
+ ).to(torch.device("cuda:0"))
163
+
164
+ def enhance_prompt_if_enabled(prompt, enhance_toggle):
165
+ if not enhance_toggle:
166
+ print("Enhance toggle is off, Prompt: ", prompt)
167
+ return prompt
168
+
169
+ messages = [
170
+ {"role": "system", "content": system_prompt_t2v},
171
+ {"role": "user", "content": prompt},
172
+ ]
173
+
174
+ try:
175
+ response = client.chat.completions.create(
176
+ model="gpt-4-mini",
177
+ messages=messages,
178
+ max_tokens=200,
179
+ )
180
+ print("Enhanced Prompt: ", response.choices[0].message.content.strip())
181
+ return response.choices[0].message.content.strip()
182
+ except Exception as e:
183
+ print(f"Error: {e}")
184
+ return prompt
185
+
186
+ @spaces.GPU(duration=90)
187
+ def generate_video_from_text_90(
188
+ prompt="",
189
+ enhance_prompt_toggle=False,
190
+ negative_prompt="",
191
+ frame_rate=25,
192
+ seed=random.randint(0, MAX_SEED),
193
+ num_inference_steps=30,
194
+ guidance_scale=3.2,
195
+ height=768,
196
+ width=768,
197
+ num_frames=60,
198
+ progress=gr.Progress(),
199
+ ):
200
+ # ํ”„๋กฌํ”„ํŠธ ์ „์ฒ˜๋ฆฌ (ํ•œ๊ธ€ -> ์˜์–ด)
201
+ prompt = process_prompt(prompt)
202
+ negative_prompt = process_prompt(negative_prompt)
203
+
204
+ if len(prompt.strip()) < 50:
205
+ raise gr.Error(
206
+ "Prompt must be at least 50 characters long. Please provide more details for the best results.",
207
+ duration=5,
208
+ )
209
+
210
+ prompt = enhance_prompt_if_enabled(prompt, enhance_prompt_toggle)
211
+
212
+ sample = {
213
+ "prompt": prompt,
214
+ "prompt_attention_mask": None,
215
+ "negative_prompt": negative_prompt,
216
+ "negative_prompt_attention_mask": None,
217
+ "media_items": None,
218
+ }
219
+
220
+ generator = torch.Generator(device="cuda").manual_seed(seed)
221
+
222
+ def gradio_progress_callback(self, step, timestep, kwargs):
223
+ progress((step + 1) / num_inference_steps)
224
+
225
+ try:
226
+ with torch.no_grad():
227
+ images = pipeline_video(
228
+ num_inference_steps=num_inference_steps,
229
+ num_images_per_prompt=1,
230
+ guidance_scale=guidance_scale,
231
+ generator=generator,
232
+ output_type="pt",
233
+ height=height,
234
+ width=width,
235
+ num_frames=num_frames,
236
+ frame_rate=frame_rate,
237
+ **sample,
238
+ is_video=True,
239
+ vae_per_channel_normalize=True,
240
+ conditioning_method=ConditioningMethod.UNCONDITIONAL,
241
+ mixed_precision=True,
242
+ callback_on_step_end=gradio_progress_callback,
243
+ ).images
244
+ except Exception as e:
245
+ raise gr.Error(
246
+ f"An error occurred while generating the video. Please try again. Error: {e}",
247
+ duration=5,
248
+ )
249
+ finally:
250
+ torch.cuda.empty_cache()
251
+ gc.collect()
252
+
253
+ output_path = tempfile.mktemp(suffix=".mp4")
254
+ video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
255
+ video_np = (video_np * 255).astype(np.uint8)
256
+ height, width = video_np.shape[1:3]
257
+ out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_rate, (width, height))
258
+ for frame in video_np[..., ::-1]:
259
+ out.write(frame)
260
+ out.release()
261
+ del images
262
+ del video_np
263
+ torch.cuda.empty_cache()
264
+ return output_path
265
+
266
+ def create_advanced_options():
267
+ with gr.Accordion("Step 4: Advanced Options (Optional)", open=False):
268
+ seed = gr.Slider(label="4.1 Seed", minimum=0, maximum=1000000, step=1, value=646373)
269
+ inference_steps = gr.Slider(label="4.2 Inference Steps", minimum=5, maximum=150, step=5, value=40)
270
+ guidance_scale = gr.Slider(label="4.3 Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=4.2)
271
+
272
+ height_slider = gr.Slider(
273
+ label="4.4 Height",
274
+ minimum=256,
275
+ maximum=1024,
276
+ step=64,
277
+ value=768,
278
+ visible=False,
279
+ )
280
+ width_slider = gr.Slider(
281
+ label="4.5 Width",
282
+ minimum=256,
283
+ maximum=1024,
284
+ step=64,
285
+ value=768,
286
+ visible=False,
287
+ )
288
+ num_frames_slider = gr.Slider(
289
+ label="4.5 Number of Frames",
290
+ minimum=1,
291
+ maximum=500,
292
+ step=1,
293
+ value=60,
294
+ visible=False,
295
+ )
296
+
297
+ return [
298
+ seed,
299
+ inference_steps,
300
+ guidance_scale,
301
+ height_slider,
302
+ width_slider,
303
+ num_frames_slider,
304
+ ]
305
+
306
+ ###############################################
307
+ # ์—ฌ๊ธฐ์„œ๋ถ€ํ„ฐ ๋‘ ๋ฒˆ์งธ ์ฝ”๋“œ ํ†ตํ•ฉ ์ ์šฉ
308
+ ###############################################
309
+
310
+ import argparse
311
  import time
312
  from os import path
313
  import shutil
 
314
  from safetensors.torch import load_file
 
 
 
315
  from diffusers import FluxPipeline
316
  from diffusers.pipelines.stable_diffusion import safety_checker
 
 
317
  import replicate
318
  import logging
319
  import requests
320
  from pathlib import Path
 
 
321
  import sys
322
  import io
323
 
324
+ # ๋กœ๊น… ์„ค์ •
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  logging.basicConfig(level=logging.INFO)
326
  logger = logging.getLogger(__name__)
327
 
 
343
  # CUDA ์„ค์ •
344
  torch.backends.cuda.matmul.allow_tf32 = True
345
 
346
+ # ๋ฒˆ์—ญ๊ธฐ ์ดˆ๊ธฐํ™” (์ด๋ฏธ ์œ„์—์„œ translator ์„ ์–ธ๋จ, ์ค‘๋ณต ์„ ์–ธ)
347
+ translator2 = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en") # ๋‘ ๋ฒˆ์งธ ์ฝ”๋“œ์—์„œ๋„ ์„ ์–ธ. ๋ˆ„๋ฝ์—†์ด ์ถœ๋ ฅํ•˜๊ธฐ ์œ„ํ•ด ์ถ”๊ฐ€.
348
+
349
  # ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ
350
  for dir_path in [gallery_path, video_gallery_path]:
351
  if not path.exists(dir_path):
 
363
  def translate_if_korean(text):
364
  """ํ•œ๊ธ€์ด ํฌํ•จ๋œ ๊ฒฝ์šฐ ์˜์–ด๋กœ ๋ฒˆ์—ญ"""
365
  if any(ord(char) >= 0xAC00 and ord(char) <= 0xD7A3 for char in text):
366
+ translation = translator2(text)[0]['translation_text']
367
  return translation
368
  return text
369
 
 
381
  return False, "๋ถ€์ ์ ˆํ•œ ๋‚ด์šฉ์ด ํฌํ•จ๋œ ํ”„๋กฌํ”„ํŠธ์ž…๋‹ˆ๋‹ค."
382
  return True, prompt
383
 
384
+ def process_prompt_for_sd(prompt):
385
  """ํ”„๋กฌํ”„ํŠธ ์ „์ฒ˜๋ฆฌ (๋ฒˆ์—ญ ๋ฐ ํ•„ํ„ฐ๋ง)"""
386
+ translated_prompt = translate_if_korean(prompt)
387
+ is_safe, filtered_prompt = filter_prompt(translated_prompt)
388
+ return is_safe, filtered_prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
 
390
  class timer:
391
  def __init__(self, method_name="timed process"):
 
397
  end = time.time()
398
  print(f"{self.method} took {str(round(end - self.start, 2))}s")
399
 
400
+ # Model initialization
401
+ if not path.exists(cache_path):
402
+ os.makedirs(cache_path, exist_ok=True)
403
+
404
+ pipe_sd = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
405
+ pipe_sd.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
406
+ pipe_sd.fuse_lora(lora_scale=0.125)
407
+ pipe_sd.to(device="cuda", dtype=torch.bfloat16)
408
+ pipe_sd.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
409
 
410
  def upload_to_catbox(image_path):
411
  """catbox.moe API๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ด๋ฏธ์ง€ ์—…๋กœ๋“œ"""
 
457
  font_scale = height * 0.05 / 30
458
  thickness = 2
459
  color = (255, 255, 255)
460
+
461
+ (text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness)
462
+ margin = int(height * 0.02)
463
+ x_pos = width - text_width - margin
464
+ y_pos = height - margin
465
+
466
+ output_path = "watermarked_output.mp4"
467
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
468
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
469
+
470
+ while cap.isOpened():
471
+ ret, frame = cap.read()
472
+ if not ret:
473
+ break
474
+ cv2.putText(frame, text, (x_pos, y_pos), font, font_scale, color, thickness)
475
+ out.write(frame)
476
+
477
+ cap.release()
478
+ out.release()
479
+
480
+ return output_path
481
+
482
+ except Exception as e:
483
+ logger.error(f"Error adding watermark: {str(e)}")
484
+ return video_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
 
486
+ def generate_video(image, prompt):
487
+ logger.info("Starting video generation")
488
  try:
489
+ if not check_api_key():
490
+ return "Replicate API key not properly configured"
 
 
 
 
 
 
 
 
491
 
492
+ if not image:
493
+ logger.error("No image provided")
494
+ return "Please upload an image"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495
 
496
+ image_url = upload_to_catbox(image)
497
+ if not image_url:
498
+ return "Failed to upload image"
 
 
499
 
500
+ input_data = {
501
+ "prompt": prompt,
502
+ "first_frame_image": image_url
503
+ }
504
 
505
+ try:
506
+ replicate.Client(api_token=REPLICATE_API_TOKEN)
507
+ output = replicate.run(
508
+ "minimax/video-01-live",
509
+ input=input_data
510
+ )
 
511
 
512
+ temp_file = "temp_output.mp4"
513
+
514
+ if hasattr(output, 'read'):
515
+ with open(temp_file, "wb") as file:
516
+ file.write(output.read())
517
+ elif isinstance(output, str):
518
+ response = requests.get(output)
519
+ with open(temp_file, "wb") as file:
520
+ file.write(response.content)
521
+
522
+ final_video = add_watermark(temp_file)
523
+ return final_video
524
 
525
+ except Exception as api_error:
526
+ logger.error(f"API call failed: {str(api_error)}")
527
+ return f"API call failed: {str(api_error)}"
528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
529
  except Exception as e:
530
+ logger.error(f"Unexpected error: {str(e)}")
531
+ return f"Unexpected error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
532
 
533
+ def save_image(image):
534
+ """Save the generated image in PNG format and return the path"""
535
+ try:
536
+ if not os.path.exists(gallery_path):
537
+ os.makedirs(gallery_path, exist_ok=True)
538
 
539
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
540
+ random_suffix = os.urandom(4).hex()
541
+ filename = f"generated_{timestamp}_{random_suffix}.png"
542
+ filepath = os.path.join(gallery_path, filename)
543
+
544
+ if not isinstance(image, Image.Image):
545
+ image = Image.fromarray(image)
546
+
547
+ if image.mode != 'RGB':
548
+ image = image.convert('RGB')
549
+
550
+ image.save(
551
+ filepath,
552
+ format='PNG',
553
+ optimize=True,
554
+ quality=100
 
 
 
 
 
 
 
555
  )
556
+
557
+ logger.info(f"Image saved successfully as PNG: {filepath}")
558
+ return filepath
559
+ except Exception as e:
560
+ logger.error(f"Error in save_image: {str(e)}")
561
+ return None
562
 
563
+ def load_gallery():
564
+ """Load all images from the gallery directory"""
565
+ try:
566
+ os.makedirs(gallery_path, exist_ok=True)
567
+
568
+ image_files = []
569
+ for f in os.listdir(gallery_path):
570
+ if f.lower().endswith(('.png', '.jpg', '.jpeg')):
571
+ full_path = os.path.join(gallery_path, f)
572
+ image_files.append((full_path, os.path.getmtime(full_path)))
573
+
574
+ image_files.sort(key=lambda x: x[1], reverse=True)
575
+ return [f[0] for f in image_files]
576
+ except Exception as e:
577
+ print(f"Error loading gallery: {str(e)}")
578
+ return []
579
 
580
  # CSS ์Šคํƒ€์ผ ์ •์˜
581
  css = """
582
  [์ด์ „์˜ CSS ์ฝ”๋“œ๋ฅผ ๊ทธ๋Œ€๋กœ ์œ ์ง€]
583
  """
584
 
585
+ def get_random_seed():
586
+ return torch.randint(0, 1000000, (1,)).item()
587
 
588
+ ###############################################
589
+ # ์—ฌ๊ธฐ์„œ๋ถ€ํ„ฐ Gradio UI ํ†ตํ•ฉ
590
+ ###############################################
591
 
 
592
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
593
  gr.HTML('<div class="title">AI Image & Video Generator</div>')
594
 
595
  with gr.Tabs():
 
596
  with gr.Tab("Image Generation"):
597
  with gr.Row():
598
  with gr.Column(scale=3):
 
635
  value=3.5
636
  )
637
 
 
 
 
638
  seed = gr.Number(
639
  label="Seed",
640
  value=get_random_seed(),
 
665
  )
666
  img_gallery.value = load_gallery()
667
 
 
668
  with gr.Tab("Video Generation"):
669
  with gr.Row():
670
  with gr.Column(scale=3):
 
693
  object_fit="cover"
694
  )
695
 
696
+ # ์ดํ•˜ ์ฒซ ๋ฒˆ์งธ ์ฝ”๋“œ์˜ txt2vid ๊ด€๋ จ UI๋ฅผ ํ†ตํ•ฉ
697
+ # ์ฒซ ๋ฒˆ์งธ ์ฝ”๋“œ์˜ txt2vid UI๋ฅผ ์ถ”๊ฐ€ ํƒญ์œผ๋กœ ํ†ตํ•ฉ
698
+ with gr.Tab("Text-to-Video Generation"):
699
+ with gr.Column():
700
+ txt2vid_prompt = gr.Textbox(
701
+ label="Step 1: Enter Your Prompt (ํ•œ๊ธ€ ๋˜๋Š” ์˜์–ด)",
702
+ placeholder="์ƒ์„ฑํ•˜๊ณ  ์‹ถ์€ ๋น„๋””์˜ค๋ฅผ ์„ค๋ช…ํ•˜์„ธ์š” (์ตœ์†Œ 50์ž)...",
703
+ value="๊ธด ๊ฐˆ์ƒ‰ ๋จธ๋ฆฌ์™€ ๋ฐ์€ ํ”ผ๋ถ€๋ฅผ ๊ฐ€์ง„ ์—ฌ์„ฑ์ด ๊ธด ๊ธˆ๋ฐœ ๋จธ๋ฆฌ๋ฅผ ๊ฐ€์ง„ ๋‹ค๋ฅธ ์—ฌ์„ฑ์„ ํ–ฅํ•ด ๋ฏธ์†Œ ์ง“์Šต๋‹ˆ๋‹ค. ๊ฐˆ์ƒ‰ ๋จธ๋ฆฌ ์—ฌ์„ฑ์€ ๊ฒ€์€ ์žฌํ‚ท์„ ์ž…๊ณ  ์žˆ์œผ๋ฉฐ ์˜ค๋ฅธ์ชฝ ๋บจ์— ์ž‘๊ณ  ๊ฑฐ์˜ ๋ˆˆ์— ๋„์ง€ ์•Š๋Š” ์ ์ด ์žˆ์Šต๋‹ˆ๋‹ค. ์นด๋ฉ”๋ผ ์•ต๊ธ€์€ ๊ฐˆ์ƒ‰ ๋จธ๋ฆฌ ์—ฌ์„ฑ์˜ ์–ผ๊ตด์— ์ดˆ์ ์„ ๋งž์ถ˜ ํด๋กœ์ฆˆ์—…์ž…๋‹ˆ๋‹ค. ์กฐ๋ช…์€ ๋”ฐ๋œปํ•˜๊ณ  ์ž์—ฐ์Šค๋Ÿฌ์šฐ๋ฉฐ, ์•„๋งˆ๋„ ์ง€๋Š” ํ•ด์—์„œ ๋‚˜์˜ค๋Š” ๊ฒƒ ๊ฐ™์•„ ์žฅ๋ฉด์— ๋ถ€๋“œ๋Ÿฌ์šด ๋น›์„ ๋น„์ถฅ๋‹ˆ๋‹ค.",
704
+ lines=5,
705
+ )
706
+
707
+ txt2vid_enhance_toggle = Toggle(
708
+ label="Enhance Prompt",
709
+ value=False,
710
+ interactive=True,
711
+ )
712
+
713
+ txt2vid_negative_prompt = gr.Textbox(
714
+ label="Step 2: Enter Negative Prompt",
715
+ placeholder="๋น„๋””์˜ค์—์„œ ์›ํ•˜์ง€ ์•Š๋Š” ์š”์†Œ๋ฅผ ๏ฟฝ๏ฟฝ๏ฟฝ๋ช…ํ•˜์„ธ์š”...",
716
+ value="low quality, worst quality, deformed, distorted, damaged, motion blur, motion artifacts, fused fingers, incorrect anatomy, strange hands, ugly",
717
+ lines=2,
718
+ )
719
+
720
+ txt2vid_preset = gr.Dropdown(
721
+ choices=[p["label"] for p in preset_options],
722
+ value="512x512, 160 frames",
723
+ label="Step 3.1: Choose Resolution Preset",
724
+ )
725
+
726
+ txt2vid_frame_rate = gr.Slider(
727
+ label="Step 3.2: Frame Rate",
728
+ minimum=6,
729
+ maximum=60,
730
+ step=1,
731
+ value=20,
732
+ )
733
+
734
+ txt2vid_advanced = create_advanced_options()
735
+ txt2vid_generate = gr.Button(
736
+ "Step 5: Generate Video",
737
+ variant="primary",
738
+ size="lg",
739
+ )
740
+
741
+ txt2vid_output = gr.Video(label="Generated Output")
742
+
743
+ txt2vid_preset.change(
744
+ fn=preset_changed,
745
+ inputs=[txt2vid_preset],
746
+ outputs=txt2vid_advanced[3:],
747
+ )
748
+
749
+ txt2vid_generate.click(
750
+ fn=generate_video_from_text_90,
751
+ inputs=[
752
+ txt2vid_prompt,
753
+ txt2vid_enhance_toggle,
754
+ txt2vid_negative_prompt,
755
+ txt2vid_frame_rate,
756
+ *txt2vid_advanced,
757
+ ],
758
+ outputs=txt2vid_output,
759
+ concurrency_limit=1,
760
+ concurrency_id="generate_video",
761
+ queue=True,
762
+ )
763
+
764
+ @spaces.GPU
765
+ def process_and_save_image(height, width, steps, scales, prompt, seed):
766
+ is_safe, translated_prompt = process_prompt_for_sd(prompt)
767
+ if not is_safe:
768
+ gr.Warning("๋ถ€์ ์ ˆํ•œ ๋‚ด์šฉ์ด ํฌํ•จ๋œ ํ”„๋กฌํ”„ํŠธ์ž…๋‹ˆ๋‹ค.")
769
+ return None, load_gallery()
770
+
771
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
772
+ try:
773
+ generated_image = pipe_sd(
774
+ prompt=[translated_prompt],
775
+ generator=torch.Generator().manual_seed(int(seed)),
776
+ num_inference_steps=int(steps),
777
+ guidance_scale=float(scales),
778
+ height=int(height),
779
+ width=int(width),
780
+ max_sequence_length=256
781
+ ).images[0]
782
+
783
+ if not isinstance(generated_image, Image.Image):
784
+ generated_image = Image.fromarray(generated_image)
785
+
786
+ if generated_image.mode != 'RGB':
787
+ generated_image = generated_image.convert('RGB')
788
+
789
+ img_byte_arr = io.BytesIO()
790
+ generated_image.save(img_byte_arr, format='PNG')
791
+ img_byte_arr = img_byte_arr.getvalue()
792
+
793
+ saved_path = save_image(generated_image)
794
+ if saved_path is None:
795
+ logger.warning("Failed to save generated image")
796
+ return None, load_gallery()
797
+
798
+ return Image.open(io.BytesIO(img_byte_arr)), load_gallery()
799
+ except Exception as e:
800
+ logger.error(f"Error in image generation: {str(e)}")
801
+ return None, load_gallery()
802
+
803
+
804
+ def process_and_generate_video(image, prompt):
805
+ is_safe, translated_prompt = process_prompt_for_sd(prompt)
806
+ if not is_safe:
807
+ gr.Warning("๋ถ€์ ์ ˆํ•œ ๋‚ด์šฉ์ด ํฌํ•จ๋œ ํ”„๋กฌํ”„ํŠธ์ž…๋‹ˆ๋‹ค.")
808
+ return None
809
+ return generate_video(image, translated_prompt)
810
+
811
+ def update_seed():
812
+ return get_random_seed()
813
+
814
  generate_btn.click(
815
  process_and_save_image,
816
  inputs=[height, width, steps, scales, img_prompt, seed],
 
833
  outputs=[seed]
834
  )
835
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
836
  if __name__ == "__main__":
837
+ demo.queue(max_size=64, default_concurrency_limit=1, api_open=False).launch(share=True, show_api=False, allowed_paths=[PERSISTENT_DIR])