Harumiiii commited on
Commit
0b06de3
·
verified ·
1 Parent(s): ebb1079

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -137
app.py CHANGED
@@ -1,143 +1,43 @@
1
- import gradio as gr
2
- import numpy as np
3
- import random
4
- from diffusers import DiffusionPipeline
5
  import torch
 
 
 
 
 
6
 
7
- # Set the device based on availability
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
-
10
- # Use the ByteDance/AnimateDiff-Lightning model
11
- model_repo_id = "ByteDance/AnimateDiff-Lightning"
12
-
13
- # Set the torch dtype based on available hardware
14
- if torch.cuda.is_available():
15
- torch_dtype = torch.float16
16
- else:
17
- torch_dtype = torch.float32
18
-
19
- # Load the pipeline from the pretrained model repository
20
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
21
- pipe = pipe.to(device)
22
-
23
- # Maximum values for seed and image size
24
- MAX_SEED = np.iinfo(np.int32).max
25
- MAX_IMAGE_SIZE = 1024
26
-
27
- # Define the inference function
28
- def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress(track_tqdm=True)):
29
-
30
- # Randomize seed if the checkbox is selected
31
- if randomize_seed:
32
- seed = random.randint(0, MAX_SEED)
33
-
34
- generator = torch.Generator(device=device).manual_seed(seed)
35
-
36
- # Generate the animation using the pipeline
37
- animation = pipe(
38
- prompt=prompt,
39
- negative_prompt=negative_prompt,
40
- guidance_scale=guidance_scale,
41
- num_inference_steps=num_inference_steps,
42
- width=width,
43
- height=height,
44
- generator=generator
45
- ).images[0] # Assuming the model generates images in the `.images` property
46
 
47
- return animation, seed
48
-
49
- # Sample prompts for the UI
50
- examples = [
51
- "A cat playing with a ball in a garden",
52
- "A dancing astronaut in space",
53
- "A flying dragon in the sky at sunset",
54
- ]
55
-
56
- # Define CSS for styling
57
- css = """
58
- #col-container {
59
- margin: 0 auto;
60
- max-width: 640px;
61
- }
62
- """
63
-
64
- # Build the Gradio UI
65
- with gr.Blocks(css=css) as demo:
66
 
67
- with gr.Column(elem_id="col-container"):
68
- gr.Markdown(f"""
69
- # AnimateDiff Lightning Model Text-to-Animation
70
- """)
71
-
72
- with gr.Row():
73
- prompt = gr.Text(
74
- label="Prompt",
75
- show_label=False,
76
- max_lines=1,
77
- placeholder="Enter your prompt",
78
- container=False,
79
- )
80
- run_button = gr.Button("Run", scale=0)
81
-
82
- result = gr.Image(label="Generated Animation", show_label=False)
83
-
84
- with gr.Accordion("Advanced Settings", open=False):
85
-
86
- negative_prompt = gr.Text(
87
- label="Negative prompt",
88
- max_lines=1,
89
- placeholder="Enter a negative prompt",
90
- visible=True,
91
- )
92
-
93
- seed = gr.Slider(
94
- label="Seed",
95
- minimum=0,
96
- maximum=MAX_SEED,
97
- step=1,
98
- value=0,
99
- )
100
-
101
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
102
-
103
- with gr.Row():
104
- width = gr.Slider(
105
- label="Width",
106
- minimum=256,
107
- maximum=MAX_IMAGE_SIZE,
108
- step=32,
109
- value=1024,
110
- )
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024,
117
- )
118
-
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=7.5,
126
- )
127
- num_inference_steps = gr.Slider(
128
- label="Number of inference steps",
129
- minimum=1,
130
- maximum=50,
131
- step=1,
132
- value=30,
133
- )
134
-
135
- # Example prompts for user selection
136
- gr.Examples(
137
- examples=examples,
138
- inputs=[prompt]
139
- )
140
 
141
- # Create an API endpoint for the model
142
- demo.api(fn=infer, inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps], outputs=[result, seed])
143
- demo.launch()
 
 
 
 
 
1
  import torch
2
+ from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
3
+ from diffusers.utils import export_to_gif
4
+ from huggingface_hub import hf_hub_download
5
+ from safetensors.torch import load_file
6
+ import gradio as gr
7
 
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
10
+
11
+ step = 4 # Options: [1,2,4,8]
12
+ repo = "ByteDance/AnimateDiff-Lightning"
13
+ ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
14
+ base = "emilianJR/epiCRealism"
15
+
16
+ adapter = MotionAdapter().to(device, dtype)
17
+ adapter.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
18
+ pipe = AnimateDiffPipeline.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
19
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
20
+
21
+ def animate_image(prompt, guidance_scale, num_inference_steps):
22
+ output = pipe(prompt=prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps)
23
+ gif_path = "animation.gif"
24
+ export_to_gif(output.frames[0], gif_path)
25
+ return gif_path
26
+
27
+ # Define the Gradio Interface
28
+ with gr.Blocks() as demo:
29
+ gr.Markdown("# AnimateDiff API")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ with gr.Row():
32
+ prompt = gr.Textbox(label="Prompt", placeholder="A girl smiling", value="A girl smiling")
33
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=10.0, value=1.0, step=0.1)
34
+ num_inference_steps = gr.Slider(label="Steps", minimum=1, maximum=8, value=step, step=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ gif_output = gr.Image(label="Generated Animation")
37
+
38
+ # Button to run the pipeline
39
+ run_button = gr.Button("Generate Animation")
40
+ run_button.click(animate_image, inputs=[prompt, guidance_scale, num_inference_steps], outputs=gif_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ # Launch the interface
43
+ demo.launch()