import gradio as gr from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler from transformers import AutoProcessor, AutoModelForVision2Seq, AutoModelForCausalLM, AutoTokenizer import torch from PIL import Image, ImageDraw, ImageFont import numpy as np import textwrap import os import gc import re import psutil from datetime import datetime import spaces from kokoro import KPipeline import soundfile as sf def clear_memory(): """Helper function to clear both CUDA and system memory""" gc.collect() torch.cuda.empty_cache() if torch.cuda.is_available(): torch.cuda.synchronize() process = psutil.Process(os.getpid()) if hasattr(process, 'memory_info'): process.memory_info().rss gc.collect(generation=0) gc.collect(generation=1) gc.collect(generation=2) if torch.cuda.is_available(): print(f"GPU Memory allocated: {torch.cuda.memory_allocated()/1024**2:.2f} MB") print(f"GPU Memory cached: {torch.cuda.memory_reserved()/1024**2:.2f} MB") print(f"CPU RAM used: {process.memory_info().rss/1024**2:.2f} MB") # Initialize models at startup - only the lightweight ones print("Loading models...") # Load SmolVLM for image analysis processor_vlm = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-500M-Instruct") model_vlm = AutoModelForVision2Seq.from_pretrained( "HuggingFaceTB/SmolVLM-500M-Instruct", torch_dtype=torch.bfloat16 ).to("cuda") # Load SmolLM2 for story and prompt generation checkpoint = "HuggingFaceTB/SmolLM2-1.7B-Instruct" tokenizer_lm = AutoTokenizer.from_pretrained(checkpoint) model_lm = AutoModelForCausalLM.from_pretrained(checkpoint).to("cuda") # Initialize Kokoro TTS pipeline pipeline = KPipeline(lang_code='a') # 'a' for American English def load_sd_model(): """Load Stable Diffusion model only when needed""" pipe = StableDiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, ) pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) pipe.to("cuda") pipe.enable_attention_slicing() return pipe @torch.inference_mode() @spaces.GPU(duration=30) def generate_image(): """Generate a random landscape image.""" clear_memory() pipe = load_sd_model() default_prompt = "a beautiful, professional landscape photograph" default_negative_prompt = "blurry, bad quality, distorted, deformed" default_steps = 30 default_guidance = 7.5 default_seed = torch.randint(0, 2**32 - 1, (1,)).item() generator = torch.Generator("cuda").manual_seed(default_seed) try: image = pipe( prompt=default_prompt, negative_prompt=default_negative_prompt, num_inference_steps=default_steps, guidance_scale=default_guidance, generator=generator, ).images[0] del pipe clear_memory() return image except Exception as e: print(f"Error generating image: {e}") if 'pipe' in locals(): del pipe clear_memory() return None @torch.inference_mode() @spaces.GPU(duration=30) def analyze_image(image): if image is None: return "Please generate an image first." clear_memory() if isinstance(image, np.ndarray): image = Image.fromarray(image) messages = [ { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": "Describe this image and Be brief but descriptive."} ] } ] try: prompt = processor_vlm.apply_chat_template(messages, add_generation_prompt=True) inputs = processor_vlm( text=prompt, images=[image], return_tensors="pt" ).to('cuda') outputs = model_vlm.generate( input_ids=inputs.input_ids, pixel_values=inputs.pixel_values, attention_mask=inputs.attention_mask, num_return_sequences=1, no_repeat_ngram_size=2, max_new_tokens=500, min_new_tokens=10 ) description = processor_vlm.decode(outputs[0], skip_special_tokens=True) description = re.sub(r".*?Assistant:\s*", "", description, flags=re.DOTALL).strip() # Split into sentences and take only the first three sentences = re.split(r'(?<=[.!?])\s+', description) description = ' '.join(sentences[:3]) clear_memory() return description except Exception as e: print(f"Error analyzing image: {e}") clear_memory() return "Error analyzing image. Please try again." @torch.inference_mode() @spaces.GPU(duration=30) def generate_story(image_description): clear_memory() story_prompt = f"""Write a short children's story (one chapter, about 500 words) based on this scene: {image_description} Requirements: 1. Main character: An English bulldog named Champ 2. Include these values: confidence, teamwork, caring, and hope 3. Theme: "We are stronger together than as individuals" 4. Keep it simple and engaging for young children 5. End with a simple moral lesson""" try: messages = [{"role": "user", "content": story_prompt}] input_text = tokenizer_lm.apply_chat_template(messages, tokenize=False) inputs = tokenizer_lm.encode(input_text, return_tensors="pt").to("cuda") outputs = model_lm.generate( inputs, max_new_tokens=750, temperature=0.7, top_p=0.9, do_sample=True, repetition_penalty=1.2 ) story = tokenizer_lm.decode(outputs[0]) story = clean_story_output(story) clear_memory() return story except Exception as e: print(f"Error generating story: {e}") clear_memory() return "Error generating story. Please try again." @torch.inference_mode() @spaces.GPU(duration=30) def generate_image_prompts(story_text): clear_memory() paragraphs = split_into_paragraphs(story_text) all_prompts = [] prompt_instruction = '''Here is a story paragraph: {paragraph} Start your response with "Watercolor bulldog" and describe what Champ is doing in this scene. Add where it takes place and one mood detail. Keep it short.''' try: for i, paragraph in enumerate(paragraphs, 1): messages = [{"role": "user", "content": prompt_instruction.format(paragraph=paragraph)}] input_text = tokenizer_lm.apply_chat_template(messages, tokenize=False) inputs = tokenizer_lm.encode(input_text, return_tensors="pt").to("cuda") outputs = model_lm.generate( inputs, max_new_tokens=30, temperature=0.5, top_p=0.9, do_sample=True, repetition_penalty=1.2 ) prompt = process_generated_prompt(tokenizer_lm.decode(outputs[0]), paragraph) section = f"Paragraph {i}:\n{paragraph}\n\nScenery Prompt {i}:\n{prompt}\n\n{'='*50}" all_prompts.append(section) clear_memory() return '\n'.join(all_prompts) except Exception as e: print(f"Error generating prompts: {e}") clear_memory() return "Error generating prompts. Please try again." @torch.inference_mode() @spaces.GPU(duration=60) def generate_story_image(prompt, seed=-1): clear_memory() pipe = load_sd_model() try: pipe.load_lora_weights("Prof-Hunt/lora-bulldog") generator = torch.Generator("cuda") if seed != -1: generator.manual_seed(seed) else: generator.manual_seed(torch.randint(0, 2**32 - 1, (1,)).item()) enhanced_prompt = f"{prompt}, watercolor style, children's book illustration, soft colors" image = pipe( prompt=enhanced_prompt, negative_prompt="deformed, ugly, blurry, bad art, poor quality, distorted", num_inference_steps=50, guidance_scale=15, generator=generator ).images[0] pipe.unload_lora_weights() del pipe clear_memory() return image except Exception as e: print(f"Error generating image: {e}") if 'pipe' in locals(): pipe.unload_lora_weights() del pipe clear_memory() return None @torch.inference_mode() @spaces.GPU(duration=180) def generate_all_scenes(prompts_text): clear_memory() generated_images = [] formatted_prompts = [] sections = prompts_text.split('='*50) for section in sections: if not section.strip(): continue scene_prompt = None for line in section.split('\n'): if 'Scenery Prompt' in line: scene_num = line.split('Scenery Prompt')[1].split(':')[0].strip() next_line_index = section.split('\n').index(line) + 1 if next_line_index < len(section.split('\n')): scene_prompt = section.split('\n')[next_line_index].strip() formatted_prompts.append(f"Scene {scene_num}: {scene_prompt}") break if scene_prompt: try: clear_memory() print(f"Generating image for scene: {scene_prompt}") image = generate_story_image(scene_prompt) if image is not None: # Convert PIL Image to numpy array with explicit mode conversion pil_image = image if isinstance(image, Image.Image) else Image.fromarray(image) pil_image = pil_image.convert('RGB') # Ensure RGB mode img_array = np.array(pil_image) # Verify array shape and type if len(img_array.shape) == 3 and img_array.shape[2] == 3: generated_images.append(img_array) print(f"Successfully added image for scene {scene_num}") else: print(f"Invalid image array shape: {img_array.shape}") else: print(f"Failed to generate image for scene {scene_num}") clear_memory() except Exception as e: print(f"Error generating image for scene {scene_num}: {str(e)}") clear_memory() continue # Verify we have images before returning if not generated_images: print("No images were successfully generated") return [], "\n\n".join(formatted_prompts) print(f"Successfully generated {len(generated_images)} images") return generated_images, "\n\n".join(formatted_prompts) def overlay_text_on_image(image, text): if image is None: return None try: img = image.convert('RGB') draw = ImageDraw.Draw(img) font_size = int(img.width * 0.025) try: font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", font_size) except: font = ImageFont.load_default() y_position = int(img.height * 0.005) x_margin = int(img.width * 0.005) available_width = img.width - (2 * x_margin) wrapped_text = textwrap.fill(text, width=int(available_width / (font_size * 0.6))) outline_color = (255, 255, 255) text_color = (0, 0, 0) offsets = [-2, -1, 1, 2] for dx in offsets: for dy in offsets: draw.multiline_text( (x_margin + dx, y_position + dy), wrapped_text, font=font, fill=outline_color ) draw.multiline_text( (x_margin, y_position), wrapped_text, font=font, fill=text_color ) return img except Exception as e: print(f"Error overlaying text: {e}") return None def add_text_to_scenes(gallery_images, prompts_text): if not isinstance(gallery_images, list): return [], [] clear_memory() sections = prompts_text.split('='*50) overlaid_images = [] output_files = [] temp_dir = "temp_book_pages" os.makedirs(temp_dir, exist_ok=True) for i, (image_data, section) in enumerate(zip(gallery_images, sections)): if not section.strip(): continue lines = [line.strip() for line in section.split('\n') if line.strip()] paragraph = None for j, line in enumerate(lines): if line.startswith('Paragraph'): if j + 1 < len(lines): paragraph = lines[j + 1] break if paragraph and image_data is not None: try: if isinstance(image_data, np.ndarray): image = Image.fromarray(image_data) else: image = image_data overlaid_img = overlay_text_on_image(image, paragraph) if overlaid_img is not None: overlaid_array = np.array(overlaid_img) overlaid_images.append(overlaid_array) output_path = os.path.join(temp_dir, f"panel_{i+1}.png") overlaid_img.save(output_path) output_files.append(output_path) except Exception as e: print(f"Error processing image: {str(e)}") continue clear_memory() return overlaid_images, output_files def generate_combined_audio_from_story(story_text, voice='af_heart', speed=1): clear_memory() if not story_text: return None paragraphs = split_into_paragraphs(story_text) combined_audio = [] try: for paragraph in paragraphs: if not paragraph.strip(): continue generator = pipeline( paragraph, voice=voice, speed=speed, split_pattern=r'\n+' ) for _, _, audio in generator: combined_audio.extend(audio) # Convert combined audio to NumPy array and save combined_audio = np.array(combined_audio) filename = "combined_story.wav" sf.write(filename, combined_audio, 24000) # Save audio as .wav clear_memory() return filename except Exception as e: print(f"Error generating audio: {e}") clear_memory() return None # Helper functions def clean_story_output(story): """Clean up the generated story text.""" story = story.replace("<|im_end|>", "") story_start = story.find("Once upon") if story_start == -1: possible_starts = ["One day", "In a", "There was", "Champ"] for marker in possible_starts: story_start = story.find(marker) if story_start != -1: break if story_start != -1: story = story[story_start:] lines = story.split('\n') cleaned_lines = [] for line in lines: line = line.strip() if line and not any(skip in line.lower() for skip in ['requirement', 'include these values', 'theme:', 'keep it simple', 'end with', 'write a']): if not line.startswith(('1.', '2.', '3.', '4.', '5.')): cleaned_lines.append(line) return '\n\n'.join(cleaned_lines).strip() def split_into_paragraphs(text): """Split text into paragraphs.""" paragraphs = [] current_paragraph = [] for line in text.split('\n'): line = line.strip() if not line: if current_paragraph: paragraphs.append(' '.join(current_paragraph)) current_paragraph = [] else: current_paragraph.append(line) if current_paragraph: paragraphs.append(' '.join(current_paragraph)) return [p for p in paragraphs if not any(skip in p.lower() for skip in ['requirement', 'include these values', 'theme:', 'keep it simple', 'end with', 'write a'])] def process_generated_prompt(prompt, paragraph): """Process and clean up generated image prompts.""" prompt = prompt.replace("<|im_start|>", "").replace("<|im_end|>", "") prompt = prompt.replace("assistant", "").replace("system", "").replace("user", "") cleaned_lines = [line.strip() for line in prompt.split('\n') if line.strip().lower().startswith("watercolor bulldog")] if cleaned_lines: prompt = cleaned_lines[0] else: setting = "quiet town" if "quiet town" in paragraph.lower() else "park" mood = "hopeful" if "wished" in paragraph.lower() else "peaceful" prompt = f"Watercolor bulldog watching friends play in {setting}, {mood} atmosphere." if not prompt.endswith('.'): prompt = prompt + '.' return prompt # Create the interface def create_interface(): with gr.Blocks() as demo: gr.Markdown("# Tech Tales: Story Creation") with gr.Row(): generate_btn = gr.Button("1. Generate Random Landscape") image_output = gr.Image(label="Generated Image", type="pil", interactive=False) with gr.Row(): analyze_btn = gr.Button("2. Get Brief Description") analysis_output = gr.Textbox(label="Image Description", lines=3) with gr.Row(): story_btn = gr.Button("3. Create Children's Story") story_output = gr.Textbox(label="Generated Story", lines=10) with gr.Row(): prompts_btn = gr.Button("4. Generate Scene Prompts") prompts_output = gr.Textbox(label="Generated Scene Prompts", lines=20) with gr.Row(): generate_scenes_btn = gr.Button("5. Generate Story Scenes", variant="primary") with gr.Row(): scene_prompts_display = gr.Textbox( label="Scenes Being Generated", lines=8, interactive=False ) with gr.Row(): gallery = gr.Gallery( label="Story Scenes", show_label=True, columns=2, height="auto", interactive=False ) with gr.Row(): add_text_btn = gr.Button("6. Add Text to Scenes", variant="primary") with gr.Row(): final_gallery = gr.Gallery( label="Story Book Pages", show_label=True, columns=2, height="auto", interactive=False ) with gr.Row(): download_btn = gr.File( label="Download Story Book", file_count="multiple", interactive=False ) with gr.Row(): tts_btn = gr.Button("7. Read Story Aloud") audio_output = gr.Audio(label="Story Audio") # Event handlers generate_btn.click( fn=generate_image, outputs=image_output ) analyze_btn.click( fn=analyze_image, inputs=[image_output], outputs=analysis_output ) story_btn.click( fn=generate_story, inputs=[analysis_output], outputs=story_output ) prompts_btn.click( fn=generate_image_prompts, inputs=[story_output], outputs=prompts_output ) generate_scenes_btn.click( fn=generate_all_scenes, inputs=[prompts_output], outputs=[gallery, scene_prompts_display] ) add_text_btn.click( fn=add_text_to_scenes, inputs=[gallery, prompts_output], outputs=[final_gallery, download_btn] ) tts_btn.click( fn=generate_combined_audio_from_story, inputs=[story_output], outputs=audio_output ) return demo if __name__ == "__main__": demo = create_interface() demo.launch()