prithivMLmods commited on
Commit
a01646a
·
verified ·
1 Parent(s): c9c7955

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +160 -17
app.py CHANGED
@@ -1,14 +1,31 @@
1
  import os
 
 
 
 
 
2
  from threading import Thread
 
3
  import gradio as gr
4
  import spaces
5
  import torch
 
 
6
  import edge_tts
7
- import asyncio
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
- from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer
 
 
 
 
 
10
  from transformers.image_utils import load_image
11
- import time
 
 
 
 
12
 
13
  DESCRIPTION = """
14
  # QwQ Edge 💬
@@ -44,6 +61,7 @@ model = AutoModelForCausalLM.from_pretrained(
44
  )
45
  model.eval()
46
 
 
47
  TTS_VOICES = [
48
  "en-US-JennyNeural", # @tts1
49
  "en-US-GuyNeural", # @tts2
@@ -75,6 +93,93 @@ def clean_chat_history(chat_history):
75
  cleaned.append(msg)
76
  return cleaned
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  @spaces.GPU
79
  def generate(
80
  input_dict: dict,
@@ -86,20 +191,41 @@ def generate(
86
  repetition_penalty: float = 1.2,
87
  ):
88
  """
89
- Generates chatbot responses with support for multimodal input and TTS.
90
- If the query starts with an @tts command (e.g. "@tts1"), previous chat history is cleared.
 
 
91
  """
92
  text = input_dict["text"]
93
  files = input_dict.get("files", [])
94
 
95
- # Process image files if provided
96
- if len(files) > 1:
97
- images = [load_image(image) for image in files]
98
- elif len(files) == 1:
99
- images = [load_image(files[0])]
100
- else:
101
- images = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
 
 
 
103
  tts_prefix = "@tts"
104
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
105
  voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
@@ -107,16 +233,25 @@ def generate(
107
  if is_tts and voice_index:
108
  voice = TTS_VOICES[voice_index - 1]
109
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
110
- # Clear any previous chat history to avoid concatenation issues
111
  conversation = [{"role": "user", "content": text}]
112
  else:
113
  voice = None
 
114
  text = text.replace(tts_prefix, "").strip()
115
  conversation = clean_chat_history(chat_history)
116
  conversation.append({"role": "user", "content": text})
117
 
118
- if images:
119
- # Multimodal branch using the OCR model
 
 
 
 
 
 
 
 
120
  messages = [{
121
  "role": "user",
122
  "content": [
@@ -139,7 +274,9 @@ def generate(
139
  time.sleep(0.01)
140
  yield buffer
141
  else:
142
- # Text-only branch using the text model
 
 
143
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
144
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
145
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
@@ -168,10 +305,15 @@ def generate(
168
  final_response = "".join(outputs)
169
  yield final_response
170
 
 
171
  if is_tts and voice:
172
  output_file = asyncio.run(text_to_speech(final_response, voice))
173
  yield gr.Audio(output_file, autoplay=True)
174
 
 
 
 
 
175
  demo = gr.ChatInterface(
176
  fn=generate,
177
  additional_inputs=[
@@ -188,6 +330,7 @@ demo = gr.ChatInterface(
188
  ["A train travels 60 kilometers per hour. If it travels for 5 hours, how far will it travel in total?"],
189
  ["Write a Python function to check if a number is prime."],
190
  ["@tts2 What causes rainbows to form?"],
 
191
  ],
192
  cache_examples=False,
193
  type="messages",
 
1
  import os
2
+ import random
3
+ import uuid
4
+ import json
5
+ import time
6
+ import asyncio
7
  from threading import Thread
8
+
9
  import gradio as gr
10
  import spaces
11
  import torch
12
+ import numpy as np
13
+ from PIL import Image
14
  import edge_tts
15
+
16
+ from transformers import (
17
+ AutoModelForCausalLM,
18
+ AutoTokenizer,
19
+ TextIteratorStreamer,
20
+ Qwen2VLForConditionalGeneration,
21
+ AutoProcessor,
22
+ )
23
  from transformers.image_utils import load_image
24
+ from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
25
+
26
+ # ============================================
27
+ # CHAT & TTS SETUP
28
+ # ============================================
29
 
30
  DESCRIPTION = """
31
  # QwQ Edge 💬
 
61
  )
62
  model.eval()
63
 
64
+ # TTS voices
65
  TTS_VOICES = [
66
  "en-US-JennyNeural", # @tts1
67
  "en-US-GuyNeural", # @tts2
 
93
  cleaned.append(msg)
94
  return cleaned
95
 
96
+ # ============================================
97
+ # IMAGE GENERATION SETUP
98
+ # ============================================
99
+
100
+ # Environment variables and parameters for Stable Diffusion XL
101
+ MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # Use SDXL Model repo path via MODEL_VAL_PATH env var
102
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
103
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
104
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
105
+ BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # For potential batched image generation
106
+
107
+ # Load the SDXL pipeline
108
+ sd_pipe = StableDiffusionXLPipeline.from_pretrained(
109
+ MODEL_ID_SD,
110
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
111
+ use_safetensors=True,
112
+ add_watermarker=False,
113
+ ).to(device)
114
+ sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
115
+
116
+ # Optional: compile the model for speedup
117
+ if USE_TORCH_COMPILE:
118
+ sd_pipe.compile()
119
+
120
+ # Optional: offload parts of the model to CPU if needed
121
+ if ENABLE_CPU_OFFLOAD:
122
+ sd_pipe.enable_model_cpu_offload()
123
+
124
+ MAX_SEED = np.iinfo(np.int32).max
125
+
126
+ def save_image(img: Image.Image) -> str:
127
+ """Save a PIL image with a unique filename and return the path."""
128
+ unique_name = str(uuid.uuid4()) + ".png"
129
+ img.save(unique_name)
130
+ return unique_name
131
+
132
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
133
+ if randomize_seed:
134
+ seed = random.randint(0, MAX_SEED)
135
+ return seed
136
+
137
+ @spaces.GPU(duration=60, enable_queue=True)
138
+ def generate_image_fn(
139
+ prompt: str,
140
+ negative_prompt: str = "",
141
+ use_negative_prompt: bool = False,
142
+ seed: int = 1,
143
+ width: int = 1024,
144
+ height: int = 1024,
145
+ guidance_scale: float = 3,
146
+ num_inference_steps: int = 25,
147
+ randomize_seed: bool = False,
148
+ use_resolution_binning: bool = True,
149
+ num_images: int = 1,
150
+ progress=gr.Progress(track_tqdm=True),
151
+ ):
152
+ """Generate images using the SDXL pipeline."""
153
+ seed = int(randomize_seed_fn(seed, randomize_seed))
154
+ generator = torch.Generator(device=device).manual_seed(seed)
155
+
156
+ options = {
157
+ "prompt": [prompt] * num_images,
158
+ "negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None,
159
+ "width": width,
160
+ "height": height,
161
+ "guidance_scale": guidance_scale,
162
+ "num_inference_steps": num_inference_steps,
163
+ "generator": generator,
164
+ "output_type": "pil",
165
+ }
166
+ if use_resolution_binning:
167
+ options["use_resolution_binning"] = True
168
+
169
+ images = []
170
+ for i in range(0, num_images, BATCH_SIZE):
171
+ batch_options = options.copy()
172
+ batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
173
+ if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
174
+ batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
175
+ images.extend(sd_pipe(**batch_options).images)
176
+ image_paths = [save_image(img) for img in images]
177
+ return image_paths, seed
178
+
179
+ # ============================================
180
+ # MAIN GENERATION FUNCTION (CHAT)
181
+ # ============================================
182
+
183
  @spaces.GPU
184
  def generate(
185
  input_dict: dict,
 
191
  repetition_penalty: float = 1.2,
192
  ):
193
  """
194
+ Generates chatbot responses with support for multimodal input, TTS, and now image generation.
195
+ If the query starts with:
196
+ - "@tts1" or "@tts2", it triggers text-to-speech.
197
+ - "@image", it triggers image generation using the SDXL pipeline.
198
  """
199
  text = input_dict["text"]
200
  files = input_dict.get("files", [])
201
 
202
+ # ----------------------------
203
+ # NEW: IMAGE GENERATION BRANCH
204
+ # ----------------------------
205
+ if text.strip().lower().startswith("@image"):
206
+ # Remove the "@image" tag and use the rest as prompt
207
+ prompt = text[len("@image"):].strip()
208
+ yield "Generating image..."
209
+ image_paths, used_seed = generate_image_fn(
210
+ prompt=prompt,
211
+ negative_prompt="",
212
+ use_negative_prompt=False,
213
+ seed=1,
214
+ width=1024,
215
+ height=1024,
216
+ guidance_scale=3,
217
+ num_inference_steps=25,
218
+ randomize_seed=True,
219
+ use_resolution_binning=True,
220
+ num_images=1,
221
+ )
222
+ # Yield the generated image so that the chat interface displays it.
223
+ yield gr.Image(image_paths[0])
224
+ return # Exit early
225
 
226
+ # ----------------------------
227
+ # TTS Branch (if query starts with @tts)
228
+ # ----------------------------
229
  tts_prefix = "@tts"
230
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
231
  voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
 
233
  if is_tts and voice_index:
234
  voice = TTS_VOICES[voice_index - 1]
235
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
236
+ # Clear previous chat history for a fresh TTS request.
237
  conversation = [{"role": "user", "content": text}]
238
  else:
239
  voice = None
240
+ # Remove any stray @tts tags and build the conversation history.
241
  text = text.replace(tts_prefix, "").strip()
242
  conversation = clean_chat_history(chat_history)
243
  conversation.append({"role": "user", "content": text})
244
 
245
+ # ----------------------------
246
+ # Multimodal (image + text) branch
247
+ # ----------------------------
248
+ if files:
249
+ if len(files) > 1:
250
+ images = [load_image(image) for image in files]
251
+ elif len(files) == 1:
252
+ images = [load_image(files[0])]
253
+ else:
254
+ images = []
255
  messages = [{
256
  "role": "user",
257
  "content": [
 
274
  time.sleep(0.01)
275
  yield buffer
276
  else:
277
+ # ----------------------------
278
+ # Text-only branch
279
+ # ----------------------------
280
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
281
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
282
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
 
305
  final_response = "".join(outputs)
306
  yield final_response
307
 
308
+ # If TTS was requested, convert the final response to speech.
309
  if is_tts and voice:
310
  output_file = asyncio.run(text_to_speech(final_response, voice))
311
  yield gr.Audio(output_file, autoplay=True)
312
 
313
+ # ============================================
314
+ # GRADIO DEMO SETUP
315
+ # ============================================
316
+
317
  demo = gr.ChatInterface(
318
  fn=generate,
319
  additional_inputs=[
 
330
  ["A train travels 60 kilometers per hour. If it travels for 5 hours, how far will it travel in total?"],
331
  ["Write a Python function to check if a number is prime."],
332
  ["@tts2 What causes rainbows to form?"],
333
+ ["@image A futuristic city skyline at dusk"],
334
  ],
335
  cache_examples=False,
336
  type="messages",