Spaces:
Runtime error
Runtime error
Commit
Β·
b5d5c28
1
Parent(s):
85b8db7
Add advanced option
Browse files
app.py
CHANGED
@@ -1,5 +1,7 @@
|
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
|
|
3 |
from diffusers import DiffusionPipeline
|
4 |
import streamlit as st
|
5 |
from transformers import (
|
@@ -10,6 +12,8 @@ from transformers import (
|
|
10 |
|
11 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
12 |
device_dict = {"cuda": 0, "cpu": -1}
|
|
|
|
|
13 |
|
14 |
# Add language detection pipeline
|
15 |
language_detection_model_ckpt = "papluca/xlm-roberta-base-language-detection"
|
@@ -30,19 +34,21 @@ pipe = DiffusionPipeline.from_pretrained(
|
|
30 |
detection_pipeline=language_detection_pipeline,
|
31 |
translation_model=trans_model,
|
32 |
translation_tokenizer=trans_tokenizer,
|
33 |
-
|
34 |
-
|
35 |
)
|
36 |
|
37 |
-
pipe.enable_attention_slicing()
|
38 |
pipe = pipe.to(device)
|
39 |
|
40 |
#torch.backends.cudnn.benchmark = True
|
41 |
num_samples = 2
|
42 |
|
43 |
-
def infer(prompt):
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
46 |
|
47 |
css = """
|
48 |
.gradio-container {
|
@@ -100,7 +106,6 @@ css = """
|
|
100 |
border-radius: 14px !important;
|
101 |
}
|
102 |
#advanced-options {
|
103 |
-
display: none;
|
104 |
margin-bottom: 20px;
|
105 |
}
|
106 |
.footer {
|
@@ -167,13 +172,19 @@ block = gr.Blocks(css=css)
|
|
167 |
|
168 |
examples = [
|
169 |
[
|
170 |
-
'Una casa en la playa en un atardecer lluvioso'
|
|
|
|
|
171 |
],
|
172 |
[
|
173 |
-
'Ein Hund, der Orange isst'
|
|
|
|
|
174 |
],
|
175 |
[
|
176 |
-
"Photo d'un restaurant parisien"
|
|
|
|
|
177 |
],
|
178 |
]
|
179 |
|
@@ -216,14 +227,20 @@ with block as demo:
|
|
216 |
)
|
217 |
|
218 |
gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="generated_id").style(
|
219 |
-
grid=[
|
220 |
)
|
221 |
|
222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
ex.dataset.headers = [""]
|
224 |
|
225 |
-
text.submit(infer, inputs=[text], outputs=gallery)
|
226 |
-
btn.click(infer, inputs=[text], outputs=gallery)
|
227 |
|
228 |
gr.HTML(
|
229 |
"""
|
|
|
1 |
+
from contextlib import nullcontext
|
2 |
import gradio as gr
|
3 |
import torch
|
4 |
+
from torch import autocast
|
5 |
from diffusers import DiffusionPipeline
|
6 |
import streamlit as st
|
7 |
from transformers import (
|
|
|
12 |
|
13 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
14 |
device_dict = {"cuda": 0, "cpu": -1}
|
15 |
+
context = autocast if device == "cuda" else nullcontext
|
16 |
+
dtype = torch.float16 if device == "cuda" else torch.float32
|
17 |
|
18 |
# Add language detection pipeline
|
19 |
language_detection_model_ckpt = "papluca/xlm-roberta-base-language-detection"
|
|
|
34 |
detection_pipeline=language_detection_pipeline,
|
35 |
translation_model=trans_model,
|
36 |
translation_tokenizer=trans_tokenizer,
|
37 |
+
revision="fp16",
|
38 |
+
torch_dtype=dtype,
|
39 |
)
|
40 |
|
|
|
41 |
pipe = pipe.to(device)
|
42 |
|
43 |
#torch.backends.cudnn.benchmark = True
|
44 |
num_samples = 2
|
45 |
|
46 |
+
def infer(prompt, scale, steps):
|
47 |
+
|
48 |
+
with context("cuda"):
|
49 |
+
images = pipe(num_samples*[prompt], guidance_scale=scale, num_inference_steps=steps).images
|
50 |
+
|
51 |
+
return images
|
52 |
|
53 |
css = """
|
54 |
.gradio-container {
|
|
|
106 |
border-radius: 14px !important;
|
107 |
}
|
108 |
#advanced-options {
|
|
|
109 |
margin-bottom: 20px;
|
110 |
}
|
111 |
.footer {
|
|
|
172 |
|
173 |
examples = [
|
174 |
[
|
175 |
+
'Una casa en la playa en un atardecer lluvioso',
|
176 |
+
45,
|
177 |
+
7.5,
|
178 |
],
|
179 |
[
|
180 |
+
'Ein Hund, der Orange isst',
|
181 |
+
45,
|
182 |
+
7.5,
|
183 |
],
|
184 |
[
|
185 |
+
"Photo d'un restaurant parisien",
|
186 |
+
45,
|
187 |
+
7.5,
|
188 |
],
|
189 |
]
|
190 |
|
|
|
227 |
)
|
228 |
|
229 |
gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="generated_id").style(
|
230 |
+
grid=[1], height="auto"
|
231 |
)
|
232 |
|
233 |
+
with gr.Row(elem_id="advanced-options"):
|
234 |
+
steps = gr.Slider(label="Steps", minimum=5, maximum=50, value=45, step=5)
|
235 |
+
scale = gr.Slider(
|
236 |
+
label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1
|
237 |
+
)
|
238 |
+
|
239 |
+
ex = gr.Examples(examples=examples, fn=infer, inputs=[text, steps, scale], outputs=gallery, cache_examples=False)
|
240 |
ex.dataset.headers = [""]
|
241 |
|
242 |
+
text.submit(infer, inputs=[text, steps, scale], outputs=gallery)
|
243 |
+
btn.click(infer, inputs=[text, steps, scale], outputs=gallery)
|
244 |
|
245 |
gr.HTML(
|
246 |
"""
|