Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -10,43 +10,6 @@ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, TextIte
|
|
10 |
from transformers.image_utils import load_image
|
11 |
import time
|
12 |
|
13 |
-
# =============================================================================
|
14 |
-
# New imports and helper classes for image generation
|
15 |
-
# =============================================================================
|
16 |
-
try:
|
17 |
-
# We use Hugging Face’s InferenceClient as a generic image-generation API client.
|
18 |
-
from huggingface_hub import InferenceClient as HFInferenceClient
|
19 |
-
except ImportError:
|
20 |
-
HFInferenceClient = None
|
21 |
-
|
22 |
-
# A simple wrapper client for our primary image-generation space.
|
23 |
-
class Client:
|
24 |
-
def __init__(self, repo_id):
|
25 |
-
self.repo_id = repo_id
|
26 |
-
if HFInferenceClient is not None:
|
27 |
-
self.client = HFInferenceClient(repo_id)
|
28 |
-
else:
|
29 |
-
self.client = None
|
30 |
-
|
31 |
-
def predict(self, task, arg2, prompt, api_name):
|
32 |
-
if self.client is not None:
|
33 |
-
# Here we assume that calling the client with the prompt returns an image.
|
34 |
-
# (Depending on your API, you might need to adjust parameters.)
|
35 |
-
return self.client(prompt)
|
36 |
-
else:
|
37 |
-
raise Exception("HFInferenceClient not available")
|
38 |
-
|
39 |
-
def image_gen(prompt):
|
40 |
-
"""
|
41 |
-
Uses the STABLE-HAMSTER space to generate an image based on the prompt.
|
42 |
-
"""
|
43 |
-
client = Client("prithivMLmods/STABLE-HAMSTER")
|
44 |
-
return client.predict("Image Generation", None, prompt, api_name="/stable_hamster")
|
45 |
-
|
46 |
-
# =============================================================================
|
47 |
-
# Original Code (with modifications below)
|
48 |
-
# =============================================================================
|
49 |
-
|
50 |
DESCRIPTION = """
|
51 |
# QwQ Edge 💬
|
52 |
"""
|
@@ -123,46 +86,13 @@ def generate(
|
|
123 |
repetition_penalty: float = 1.2,
|
124 |
):
|
125 |
"""
|
126 |
-
Generates chatbot responses with support for multimodal input
|
127 |
If the query starts with an @tts command (e.g. "@tts1"), previous chat history is cleared.
|
128 |
-
If the query starts with an @image command, the image generation branch is used.
|
129 |
"""
|
130 |
text = input_dict["text"]
|
131 |
files = input_dict.get("files", [])
|
132 |
|
133 |
-
#
|
134 |
-
# NEW: Check for image generation command (@image)
|
135 |
-
# -------------------------------------------------------------------------
|
136 |
-
image_prefix = "@image"
|
137 |
-
if text.strip().lower().startswith(image_prefix):
|
138 |
-
# Remove the prefix and any extra whitespace
|
139 |
-
query = text[len(image_prefix):].strip()
|
140 |
-
yield "Generating Image, Please wait 10 sec..."
|
141 |
-
try:
|
142 |
-
image = image_gen(query)
|
143 |
-
# If the API returns a tuple (as in the snippet) use the second element;
|
144 |
-
# otherwise assume it returns an image directly.
|
145 |
-
if isinstance(image, (list, tuple)) and len(image) > 1:
|
146 |
-
yield gr.Image(image[1])
|
147 |
-
else:
|
148 |
-
yield gr.Image(image)
|
149 |
-
except Exception as e:
|
150 |
-
yield "Error in primary image generation, trying fallback..."
|
151 |
-
try:
|
152 |
-
# Use the fallback image generation client.
|
153 |
-
if HFInferenceClient is not None:
|
154 |
-
client_flux = HFInferenceClient("black-forest-labs/FLUX.1-schnell")
|
155 |
-
image = client_flux.text_to_image(query)
|
156 |
-
yield gr.Image(image)
|
157 |
-
else:
|
158 |
-
yield "Fallback client not available."
|
159 |
-
except Exception as fallback_error:
|
160 |
-
yield f"Error in image generation: {str(fallback_error)}"
|
161 |
-
return # End execution after processing the image-generation request.
|
162 |
-
|
163 |
-
# -------------------------------------------------------------------------
|
164 |
-
# Continue with the original processing (image files, TTS, or text conversation)
|
165 |
-
# -------------------------------------------------------------------------
|
166 |
if len(files) > 1:
|
167 |
images = [load_image(image) for image in files]
|
168 |
elif len(files) == 1:
|
@@ -173,7 +103,7 @@ def generate(
|
|
173 |
tts_prefix = "@tts"
|
174 |
is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
|
175 |
voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
|
176 |
-
|
177 |
if is_tts and voice_index:
|
178 |
voice = TTS_VOICES[voice_index - 1]
|
179 |
text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
|
@@ -258,7 +188,6 @@ demo = gr.ChatInterface(
|
|
258 |
["A train travels 60 kilometers per hour. If it travels for 5 hours, how far will it travel in total?"],
|
259 |
["Write a Python function to check if a number is prime."],
|
260 |
["@tts2 What causes rainbows to form?"],
|
261 |
-
["@image A beautiful sunset over a mountain range"],
|
262 |
],
|
263 |
cache_examples=False,
|
264 |
type="messages",
|
|
|
10 |
from transformers.image_utils import load_image
|
11 |
import time
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
DESCRIPTION = """
|
14 |
# QwQ Edge 💬
|
15 |
"""
|
|
|
86 |
repetition_penalty: float = 1.2,
|
87 |
):
|
88 |
"""
|
89 |
+
Generates chatbot responses with support for multimodal input and TTS.
|
90 |
If the query starts with an @tts command (e.g. "@tts1"), previous chat history is cleared.
|
|
|
91 |
"""
|
92 |
text = input_dict["text"]
|
93 |
files = input_dict.get("files", [])
|
94 |
|
95 |
+
# Process image files if provided
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
if len(files) > 1:
|
97 |
images = [load_image(image) for image in files]
|
98 |
elif len(files) == 1:
|
|
|
103 |
tts_prefix = "@tts"
|
104 |
is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
|
105 |
voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
|
106 |
+
|
107 |
if is_tts and voice_index:
|
108 |
voice = TTS_VOICES[voice_index - 1]
|
109 |
text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
|
|
|
188 |
["A train travels 60 kilometers per hour. If it travels for 5 hours, how far will it travel in total?"],
|
189 |
["Write a Python function to check if a number is prime."],
|
190 |
["@tts2 What causes rainbows to form?"],
|
|
|
191 |
],
|
192 |
cache_examples=False,
|
193 |
type="messages",
|