Prof-Hunt commited on
Commit
4eeb369
·
verified ·
1 Parent(s): a33f1ad

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +540 -0
app.py ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
3
+ from transformers import AutoProcessor, AutoModelForVision2Seq, AutoModelForCausalLM, AutoTokenizer
4
+ import torch
5
+ from PIL import Image, ImageDraw, ImageFont
6
+ import numpy as np
7
+ import textwrap
8
+ import os
9
+ import gc
10
+ import re
11
+ from datetime import datetime
12
+ import spaces
13
+ from kokoro import KPipeline
14
+ import soundfile as sf
15
+
16
+ # Initialize models at startup - outside of functions
17
+ print("Loading models...")
18
+
19
+ # Load SmolVLM for image analysis
20
+ processor_vlm = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-500M-Instruct")
21
+ model_vlm = AutoModelForVision2Seq.from_pretrained(
22
+ "HuggingFaceTB/SmolVLM-500M-Instruct",
23
+ torch_dtype=torch.bfloat16,
24
+ use_safetensors=True
25
+ )
26
+
27
+ # Load SmolLM2 for story and prompt generation
28
+ checkpoint = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
29
+ tokenizer_lm = AutoTokenizer.from_pretrained(checkpoint)
30
+ model_lm = AutoModelForCausalLM.from_pretrained(
31
+ checkpoint,
32
+ use_safetensors=True
33
+ )
34
+
35
+ # Load Stable Diffusion pipeline
36
+ pipe = StableDiffusionPipeline.from_pretrained(
37
+ "runwayml/stable-diffusion-v1-5",
38
+ torch_dtype=torch.float16,
39
+ use_safetensors=True
40
+ )
41
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
42
+
43
+ # Move models to GPU if available
44
+ if torch.cuda.is_available():
45
+ model_vlm = model_vlm.to("cuda")
46
+ model_lm = model_lm.to("cuda")
47
+ pipe = pipe.to("cuda")
48
+
49
+ @torch.inference_mode()
50
+ @spaces.GPU(duration=30)
51
+ def generate_image():
52
+ """Generate a random landscape image."""
53
+ torch.cuda.empty_cache()
54
+
55
+ default_prompt = "a beautiful, professional landscape photograph"
56
+ default_negative_prompt = "blurry, bad quality, distorted, deformed"
57
+ default_steps = 30
58
+ default_guidance = 7.5
59
+ default_seed = torch.randint(0, 2**32 - 1, (1,)).item()
60
+
61
+ generator = torch.Generator("cuda").manual_seed(default_seed)
62
+
63
+ image = pipe(
64
+ prompt=default_prompt,
65
+ negative_prompt=default_negative_prompt,
66
+ num_inference_steps=default_steps,
67
+ guidance_scale=default_guidance,
68
+ generator=generator,
69
+ ).images[0]
70
+
71
+ return image
72
+
73
+ @torch.inference_mode()
74
+ @spaces.GPU(duration=30)
75
+ def analyze_image(image):
76
+ if image is None:
77
+ return "Please generate an image first."
78
+
79
+ torch.cuda.empty_cache()
80
+
81
+ if isinstance(image, np.ndarray):
82
+ image = Image.fromarray(image)
83
+
84
+ messages = [
85
+ {
86
+ "role": "user",
87
+ "content": [
88
+ {"type": "image"},
89
+ {"type": "text", "text": "Describe this image very briefly in five sentences or less."}
90
+ ]
91
+ }
92
+ ]
93
+
94
+ prompt = processor_vlm.apply_chat_template(messages, add_generation_prompt=True)
95
+
96
+ inputs = processor_vlm(
97
+ text=prompt,
98
+ images=[image],
99
+ return_tensors="pt"
100
+ ).to('cuda')
101
+
102
+ outputs = model_vlm.generate(
103
+ input_ids=inputs.input_ids,
104
+ pixel_values=inputs.pixel_values,
105
+ attention_mask=inputs.attention_mask,
106
+ num_return_sequences=1,
107
+ no_repeat_ngram_size=2,
108
+ max_new_tokens=500,
109
+ min_new_tokens=10
110
+ )
111
+
112
+ description = processor_vlm.decode(outputs[0], skip_special_tokens=True)
113
+ description = re.sub(r".*?Assistant:\s*", "", description, flags=re.DOTALL).strip()
114
+
115
+ return description
116
+
117
+ @torch.inference_mode()
118
+ @spaces.GPU(duration=30)
119
+ def generate_story(image_description):
120
+ torch.cuda.empty_cache()
121
+
122
+ story_prompt = f"""Write a short children's story (one chapter, about 500 words) based on this scene: {image_description}
123
+
124
+ Requirements:
125
+ 1. Main character: An English bulldog named Champ
126
+ 2. Include these values: confidence, teamwork, caring, and hope
127
+ 3. Theme: "We are stronger together than as individuals"
128
+ 4. Keep it simple and engaging for young children
129
+ 5. End with a simple moral lesson"""
130
+
131
+ messages = [{"role": "user", "content": story_prompt}]
132
+ input_text = tokenizer_lm.apply_chat_template(messages, tokenize=False)
133
+
134
+ inputs = tokenizer_lm.encode(input_text, return_tensors="pt").to("cuda")
135
+
136
+ outputs = model_lm.generate(
137
+ inputs,
138
+ max_new_tokens=750,
139
+ temperature=0.7,
140
+ top_p=0.9,
141
+ do_sample=True,
142
+ repetition_penalty=1.2
143
+ )
144
+
145
+ story = tokenizer_lm.decode(outputs[0])
146
+ story = clean_story_output(story)
147
+
148
+ return story
149
+
150
+ @torch.inference_mode()
151
+ @spaces.GPU(duration=30)
152
+ def generate_image_prompts(story_text):
153
+ torch.cuda.empty_cache()
154
+ paragraphs = split_into_paragraphs(story_text)
155
+
156
+ all_prompts = []
157
+ prompt_instruction = '''Here is a story paragraph: {paragraph}
158
+
159
+ Start your response with "Watercolor bulldog" and describe what Champ is doing in this scene. Add where it takes place and one mood detail. Keep it short.'''
160
+
161
+ for i, paragraph in enumerate(paragraphs, 1):
162
+ messages = [{"role": "user", "content": prompt_instruction.format(paragraph=paragraph)}]
163
+ input_text = tokenizer_lm.apply_chat_template(messages, tokenize=False)
164
+
165
+ inputs = tokenizer_lm.encode(input_text, return_tensors="pt").to("cuda")
166
+
167
+ outputs = model_lm.generate(
168
+ inputs,
169
+ max_new_tokens=30,
170
+ temperature=0.5,
171
+ top_p=0.9,
172
+ do_sample=True,
173
+ repetition_penalty=1.2
174
+ )
175
+
176
+ prompt = process_generated_prompt(tokenizer_lm.decode(outputs[0]), paragraph)
177
+ section = f"Paragraph {i}:\n{paragraph}\n\nScenery Prompt {i}:\n{prompt}\n\n{'='*50}"
178
+ all_prompts.append(section)
179
+
180
+ return '\n'.join(all_prompts)
181
+
182
+ @torch.inference_mode()
183
+ @spaces.GPU(duration=60)
184
+ def generate_story_image(prompt):
185
+ torch.cuda.empty_cache()
186
+
187
+ pipe.load_lora_weights("Prof-Hunt/lora-bulldog")
188
+ enhanced_prompt = f"{prompt}, watercolor style, children's book illustration, soft colors"
189
+
190
+ image = pipe(
191
+ prompt=enhanced_prompt,
192
+ negative_prompt="deformed, ugly, blurry, bad art, poor quality, distorted",
193
+ num_inference_steps=50,
194
+ guidance_scale=15,
195
+ ).images[0]
196
+
197
+ return image
198
+
199
+ @torch.inference_mode()
200
+ @spaces.GPU(duration=180) # Longer duration for multiple image generation
201
+ def generate_all_scenes(prompts_text):
202
+ generated_images = []
203
+ formatted_prompts = []
204
+
205
+ sections = prompts_text.split('='*50)
206
+
207
+ for section in sections:
208
+ if not section.strip():
209
+ continue
210
+
211
+ lines = [line.strip() for line in section.split('\n') if line.strip()]
212
+
213
+ scene_prompt = None
214
+ for i, line in enumerate(lines):
215
+ if 'Scenery Prompt' in line:
216
+ scene_num = line.split('Scenery Prompt')[1].split(':')[0].strip()
217
+ if i + 1 < len(lines):
218
+ scene_prompt = lines[i + 1]
219
+ formatted_prompts.append(f"Scene {scene_num}: {scene_prompt}")
220
+ break
221
+
222
+ if scene_prompt:
223
+ try:
224
+ torch.cuda.empty_cache()
225
+ image = generate_story_image(scene_prompt)
226
+ if image is not None:
227
+ generated_images.append(np.array(image))
228
+ except Exception as e:
229
+ print(f"Error generating image: {str(e)}")
230
+ continue
231
+
232
+ return generated_images, "\n\n".join(formatted_prompts)
233
+
234
+ # Helper functions without GPU usage
235
+ def clean_story_output(story):
236
+ story = story.replace("<|im_end|>", "")
237
+
238
+ story_start = story.find("Once upon")
239
+ if story_start == -1:
240
+ possible_starts = ["One day", "In a", "There was", "Champ"]
241
+ for marker in possible_starts:
242
+ story_start = story.find(marker)
243
+ if story_start != -1:
244
+ break
245
+
246
+ if story_start != -1:
247
+ story = story[story_start:]
248
+
249
+ lines = story.split('\n')
250
+ cleaned_lines = []
251
+ for line in lines:
252
+ line = line.strip()
253
+ if line and not any(skip in line.lower() for skip in ['requirement', 'include these values', 'theme:', 'keep it simple', 'end with', 'write a']):
254
+ if not line.startswith(('1.', '2.', '3.', '4.', '5.')):
255
+ cleaned_lines.append(line)
256
+
257
+ return '\n\n'.join(cleaned_lines).strip()
258
+
259
+ def split_into_paragraphs(text):
260
+ paragraphs = []
261
+ current_paragraph = []
262
+
263
+ for line in text.split('\n'):
264
+ line = line.strip()
265
+ if not line:
266
+ if current_paragraph:
267
+ paragraphs.append(' '.join(current_paragraph))
268
+ current_paragraph = []
269
+ else:
270
+ current_paragraph.append(line)
271
+
272
+ if current_paragraph:
273
+ paragraphs.append(' '.join(current_paragraph))
274
+
275
+ return [p for p in paragraphs if not any(skip in p.lower()
276
+ for skip in ['requirement', 'include these values', 'theme:',
277
+ 'keep it simple', 'end with', 'write a'])]
278
+
279
+ def process_generated_prompt(prompt, paragraph):
280
+ prompt = prompt.replace("<|im_start|>", "").replace("<|im_end|>", "")
281
+ prompt = prompt.replace("assistant", "").replace("system", "").replace("user", "")
282
+
283
+ cleaned_lines = [line.strip() for line in prompt.split('\n')
284
+ if line.strip().lower().startswith("watercolor bulldog")]
285
+
286
+ if cleaned_lines:
287
+ prompt = cleaned_lines[0]
288
+ else:
289
+ setting = "quiet town" if "quiet town" in paragraph.lower() else "park"
290
+ mood = "hopeful" if "wished" in paragraph.lower() else "peaceful"
291
+ prompt = f"Watercolor bulldog watching friends play in {setting}, {mood} atmosphere."
292
+
293
+ if not prompt.endswith('.'):
294
+ prompt = prompt + '.'
295
+
296
+ return prompt
297
+
298
+ def overlay_text_on_image(image, text):
299
+ if isinstance(image, np.ndarray):
300
+ image = Image.fromarray(image)
301
+
302
+ img = image.convert('RGB')
303
+ draw = ImageDraw.Draw(img)
304
+
305
+ try:
306
+ font_size = int(img.width * 0.025)
307
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", font_size)
308
+ except:
309
+ font = ImageFont.load_default()
310
+
311
+ y_position = int(img.height * 0.005)
312
+ x_margin = int(img.width * 0.005)
313
+ available_width = img.width - (2 * x_margin)
314
+
315
+ wrapped_text = textwrap.fill(text, width=int(available_width / (font_size * 0.6)))
316
+
317
+ outline_color = (255, 255, 255)
318
+ text_color = (0, 0, 0)
319
+ offsets = [-2, -1, 1, 2]
320
+
321
+ for dx in offsets:
322
+ for dy in offsets:
323
+ draw.multiline_text(
324
+ (x_margin + dx, y_position + dy),
325
+ wrapped_text,
326
+ font=font,
327
+ fill=outline_color
328
+ )
329
+
330
+ draw.multiline_text(
331
+ (x_margin, y_position),
332
+ wrapped_text,
333
+ font=font,
334
+ fill=text_color
335
+ )
336
+
337
+ return img
338
+
339
+ # Initialize Kokoro TTS pipeline
340
+ pipeline = KPipeline(lang_code='a') # 'a' for American English
341
+
342
+ def generate_combined_audio_from_story(story_text, voice='af_heart', speed=1):
343
+ """Generate a single audio file for all paragraphs in the story."""
344
+ if not story_text:
345
+ return None
346
+
347
+ # Split story into paragraphs
348
+ paragraphs = []
349
+ current_paragraph = []
350
+
351
+ for line in story_text.split('\n'):
352
+ line = line.strip()
353
+ if not line: # Empty line indicates paragraph break
354
+ if current_paragraph:
355
+ paragraphs.append(' '.join(current_paragraph))
356
+ current_paragraph = []
357
+ else:
358
+ current_paragraph.append(line)
359
+
360
+ if current_paragraph:
361
+ paragraphs.append(' '.join(current_paragraph))
362
+
363
+ # Combine audio for all paragraphs
364
+ combined_audio = []
365
+ for paragraph in paragraphs:
366
+ if not paragraph.strip():
367
+ continue # Skip empty paragraphs
368
+
369
+ generator = pipeline(
370
+ paragraph,
371
+ voice=voice,
372
+ speed=speed,
373
+ split_pattern=r'\n+' # Split on newlines
374
+ )
375
+ for _, _, audio in generator:
376
+ combined_audio.extend(audio) # Append audio data
377
+
378
+ # Convert combined audio to NumPy array and save
379
+ combined_audio = np.array(combined_audio)
380
+ filename = "combined_story.wav"
381
+ sf.write(filename, combined_audio, 24000) # Save audio as .wav
382
+ return filename
383
+
384
+ def add_text_to_scenes(gallery_images, prompts_text):
385
+ if not isinstance(gallery_images, list):
386
+ return [], []
387
+
388
+ sections = prompts_text.split('='*50)
389
+ overlaid_images = []
390
+ output_files = []
391
+
392
+ temp_dir = "temp_book_pages"
393
+ os.makedirs(temp_dir, exist_ok=True)
394
+
395
+ for i, (image_data, section) in enumerate(zip(gallery_images, sections)):
396
+ if not section.strip():
397
+ continue
398
+
399
+ lines = [line.strip() for line in section.split('\n') if line.strip()]
400
+ paragraph = None
401
+ for j, line in enumerate(lines):
402
+ if line.startswith('Paragraph'):
403
+ if j + 1 < len(lines):
404
+ paragraph = lines[j + 1]
405
+ break
406
+
407
+ if paragraph and image_data is not None:
408
+ try:
409
+ overlaid_img = overlay_text_on_image(image_data, paragraph)
410
+ if overlaid_img is not None:
411
+ overlaid_array = np.array(overlaid_img)
412
+ overlaid_images.append(overlaid_array)
413
+
414
+ output_path = os.path.join(temp_dir, f"panel_{i+1}.png")
415
+ overlaid_img.save(output_path)
416
+ output_files.append(output_path)
417
+ except Exception as e:
418
+ print(f"Error processing image: {str(e)}")
419
+ continue
420
+
421
+ return overlaid_images, output_files
422
+
423
+ def create_interface():
424
+ theme = gr.themes.Soft(
425
+ primary_hue="lightblue",
426
+ secondary_hue="red",
427
+ neutral_hue="gray"
428
+ ).set(
429
+ button_primary_background_fill="rgb(173, 216, 230)", # light blue
430
+ button_secondary_background_fill="rgb(255, 182, 193)", # light red
431
+ button_primary_background_fill_dark="rgb(135, 206, 235)", # slightly darker blue for hover
432
+ button_secondary_background_fill_dark="rgb(255, 160, 180)", # slightly darker red for hover
433
+ )
434
+
435
+ with gr.Blocks(theme=theme) as demo:
436
+ gr.Markdown("# Tech Tales: Story Creation")
437
+
438
+ with gr.Row():
439
+ generate_btn = gr.Button("1. Generate Random Landscape")
440
+ image_output = gr.Image(label="Generated Image", type="pil")
441
+
442
+ with gr.Row():
443
+ analyze_btn = gr.Button("2. Get Brief Description")
444
+ analysis_output = gr.Textbox(label="Image Description", lines=3)
445
+
446
+ with gr.Row():
447
+ story_btn = gr.Button("3. Create Children's Story")
448
+ story_output = gr.Textbox(label="Generated Story", lines=10)
449
+
450
+ with gr.Row():
451
+ prompts_btn = gr.Button("4. Generate Scene Prompts")
452
+ prompts_output = gr.Textbox(label="Generated Scene Prompts", lines=20)
453
+
454
+ with gr.Row():
455
+ generate_scenes_btn = gr.Button("5. Generate Story Scenes", variant="primary")
456
+
457
+ with gr.Row():
458
+ scene_prompts_display = gr.Textbox(
459
+ label="Scenes Being Generated",
460
+ lines=8,
461
+ interactive=False
462
+ )
463
+
464
+ with gr.Row():
465
+ gallery = gr.Gallery(
466
+ label="Story Scenes",
467
+ show_label=True,
468
+ columns=2,
469
+ height="auto"
470
+ )
471
+
472
+ with gr.Row():
473
+ add_text_btn = gr.Button("6. Add Text to Scenes", variant="primary")
474
+
475
+ with gr.Row():
476
+ final_gallery = gr.Gallery(
477
+ label="Story Book Pages",
478
+ show_label=True,
479
+ columns=2,
480
+ height="auto"
481
+ )
482
+
483
+ with gr.Row():
484
+ download_btn = gr.File(
485
+ label="Download Story Book",
486
+ file_count="multiple",
487
+ interactive=False
488
+ )
489
+
490
+ with gr.Row():
491
+ tts_btn = gr.Button("7. Read Story Aloud")
492
+ audio_output = gr.Audio(label="Story Audio")
493
+
494
+ # Event handlers
495
+ generate_btn.click(
496
+ fn=generate_image,
497
+ outputs=image_output
498
+ )
499
+
500
+ analyze_btn.click(
501
+ fn=analyze_image,
502
+ inputs=[image_output],
503
+ outputs=analysis_output
504
+ )
505
+
506
+ story_btn.click(
507
+ fn=generate_story,
508
+ inputs=[analysis_output],
509
+ outputs=story_output
510
+ )
511
+
512
+ prompts_btn.click(
513
+ fn=generate_image_prompts,
514
+ inputs=[story_output],
515
+ outputs=prompts_output
516
+ )
517
+
518
+ generate_scenes_btn.click(
519
+ fn=generate_all_scenes,
520
+ inputs=[prompts_output],
521
+ outputs=[gallery, scene_prompts_display]
522
+ )
523
+
524
+ add_text_btn.click(
525
+ fn=add_text_to_scenes,
526
+ inputs=[gallery, prompts_output],
527
+ outputs=[final_gallery, download_btn]
528
+ )
529
+
530
+ tts_btn.click(
531
+ fn=generate_combined_audio_from_story,
532
+ inputs=[story_output],
533
+ outputs=audio_output
534
+ )
535
+
536
+ return demo
537
+
538
+ if __name__ == "__main__":
539
+ demo = create_interface()
540
+ demo.launch()