prithivMLmods commited on
Commit
26f7b76
·
verified ·
1 Parent(s): 3a6718d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -163
app.py CHANGED
@@ -6,34 +6,10 @@ import torch
6
  import edge_tts
7
  import asyncio
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
- from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer
10
  from transformers.image_utils import load_image
 
11
  import time
12
- from gradio_client import Client # For image generation API
13
-
14
- DESCRIPTION = """
15
- # QwQ Edge 💬
16
- """
17
-
18
- css = '''
19
- h1 {
20
- text-align: center;
21
- display: block;
22
- }
23
-
24
- #duplicate-button {
25
- margin: auto;
26
- color: #fff;
27
- background: #1565c0;
28
- border-radius: 100vh;
29
- }
30
- '''
31
-
32
- MAX_MAX_NEW_TOKENS = 2048
33
- DEFAULT_MAX_NEW_TOKENS = 1024
34
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
35
-
36
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
37
 
38
  # Load text-only model and tokenizer
39
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
@@ -45,11 +21,6 @@ model = AutoModelForCausalLM.from_pretrained(
45
  )
46
  model.eval()
47
 
48
- TTS_VOICES = [
49
- "en-US-JennyNeural", # @tts1
50
- "en-US-GuyNeural", # @tts2
51
- ]
52
-
53
  # Load multimodal (OCR) model and processor
54
  MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
55
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
@@ -59,8 +30,19 @@ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
59
  torch_dtype=torch.float16
60
  ).to("cuda").eval()
61
 
62
- # Image generation client
63
- image_gen_client = Client("prithivMLmods/STABLE-HAMSTER")
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
66
  """Convert text to speech using Edge TTS and save as MP3"""
@@ -68,155 +50,86 @@ async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
68
  await communicate.save(output_file)
69
  return output_file
70
 
71
- def image_gen(prompt: str):
72
- """Generate an image using the Stable Hamster API"""
73
- result = image_gen_client.predict("Image Generation", None, prompt, api_name="/stable_hamster")
74
- return result[1] # Return the generated image
75
-
76
  def clean_chat_history(chat_history):
77
- """
78
- Filter out any chat entries whose "content" is not a string.
79
- This helps prevent errors when concatenating previous messages.
80
- """
81
- cleaned = []
82
- for msg in chat_history:
83
- if isinstance(msg, dict) and isinstance(msg.get("content"), str):
84
- cleaned.append(msg)
85
- return cleaned
86
 
87
  @spaces.GPU
88
- def generate(
89
- input_dict: dict,
90
- chat_history: list[dict],
91
- max_new_tokens: int = 1024,
92
- temperature: float = 0.6,
93
- top_p: float = 0.9,
94
- top_k: int = 50,
95
- repetition_penalty: float = 1.2,
96
- ):
97
- """
98
- Generates chatbot responses with support for multimodal input, TTS, and image generation.
99
- If the query starts with an @tts or @image command, previous chat history is cleared.
100
- """
101
  text = input_dict["text"]
102
  files = input_dict.get("files", [])
103
-
104
- # Process image files if provided
105
- if len(files) > 1:
106
- images = [load_image(image) for image in files]
107
- elif len(files) == 1:
108
- images = [load_image(files[0])]
109
- else:
110
- images = []
111
-
112
- # Check for TTS or Image Generation commands
113
- tts_prefix = "@tts"
114
- image_prefix = "@image"
115
- is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
116
- is_image = text.strip().lower().startswith(image_prefix)
117
 
118
- if is_tts:
119
- voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
120
- voice = TTS_VOICES[voice_index - 1] if voice_index else None
121
- text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
122
- # Clear any previous chat history to avoid concatenation issues
123
- conversation = [{"role": "user", "content": text}]
124
- elif is_image:
125
- text = text.replace(image_prefix, "").strip()
126
- conversation = [{"role": "user", "content": text}]
 
 
 
 
127
  else:
128
- voice = None
129
- text = text.replace(tts_prefix, "").strip()
130
- conversation = clean_chat_history(chat_history)
131
- conversation.append({"role": "user", "content": text})
132
-
133
- if is_image:
134
- # Image generation branch
135
- yield "Generating image, please wait..."
136
- try:
137
- image = image_gen(text)
138
- yield gr.Image(image)
139
- except Exception as e:
140
- yield f"Failed to generate image: {str(e)}"
141
- elif images:
142
- # Multimodal branch using the OCR model
143
- messages = [{
144
- "role": "user",
145
- "content": [
146
- *[{"type": "image", "image": image} for image in images],
147
- {"type": "text", "text": text},
148
- ]
149
- }]
150
- prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
151
- inputs = processor(text=[prompt], images=images, return_tensors="pt", padding=True).to("cuda")
152
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
153
- generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
154
- thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
155
- thread.start()
156
-
157
- buffer = ""
158
- yield "Thinking..."
159
- for new_text in streamer:
160
- buffer += new_text
161
- buffer = buffer.replace("<|im_end|>", "")
162
- time.sleep(0.01)
163
- yield buffer
164
- else:
165
- # Text-only branch using the text model
166
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
167
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
168
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
169
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
170
- input_ids = input_ids.to(model.device)
171
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
172
- generation_kwargs = {
173
- "input_ids": input_ids,
174
- "streamer": streamer,
175
- "max_new_tokens": max_new_tokens,
176
- "do_sample": True,
177
- "top_p": top_p,
178
- "top_k": top_k,
179
- "temperature": temperature,
180
- "num_beams": 1,
181
- "repetition_penalty": repetition_penalty,
182
- }
183
- t = Thread(target=model.generate, kwargs=generation_kwargs)
184
- t.start()
185
-
186
- outputs = []
187
- for new_text in streamer:
188
- outputs.append(new_text)
189
- yield "".join(outputs)
190
-
191
- final_response = "".join(outputs)
192
- yield final_response
193
-
194
- if is_tts and voice:
195
- output_file = asyncio.run(text_to_speech(final_response, voice))
196
- yield gr.Audio(output_file, autoplay=True)
197
 
198
  demo = gr.ChatInterface(
199
  fn=generate,
200
  additional_inputs=[
201
- gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
202
  gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
203
- gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
204
  gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
205
  gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
206
  ],
207
  examples=[
208
- ["@tts1 Who is Nikola Tesla, and why did he die?"],
209
- ["@image A futuristic cityscape at sunset"],
210
  [{"text": "Extract JSON from the image", "files": ["examples/document.jpg"]}],
211
- [{"text": "summarize the letter", "files": ["examples/1.png"]}],
212
- ["A train travels 60 kilometers per hour. If it travels for 5 hours, how far will it travel in total?"],
213
- ["Write a Python function to check if a number is prime."],
214
- ["@tts2 What causes rainbows to form?"],
215
  ],
216
  cache_examples=False,
217
- type="messages",
218
- description=DESCRIPTION,
219
- css=css,
220
  fill_height=True,
221
  textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"),
222
  stop_btn="Stop Generation",
 
6
  import edge_tts
7
  import asyncio
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
10
  from transformers.image_utils import load_image
11
+ from huggingface_hub import InferenceClient
12
  import time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  # Load text-only model and tokenizer
15
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
 
21
  )
22
  model.eval()
23
 
 
 
 
 
 
24
  # Load multimodal (OCR) model and processor
25
  MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
26
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
 
30
  torch_dtype=torch.float16
31
  ).to("cuda").eval()
32
 
33
+ TTS_VOICES = [
34
+ "en-US-JennyNeural", # @tts1
35
+ "en-US-GuyNeural", # @tts2
36
+ ]
37
+
38
+ def image_gen(prompt):
39
+ """Generate image using API"""
40
+ try:
41
+ client = InferenceClient("prithivMLmods/STABLE-HAMSTER")
42
+ return client.text_to_image(prompt)
43
+ except:
44
+ client_flux = InferenceClient("black-forest-labs/FLUX.1-schnell")
45
+ return client_flux.text_to_image(prompt)
46
 
47
  async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
48
  """Convert text to speech using Edge TTS and save as MP3"""
 
50
  await communicate.save(output_file)
51
  return output_file
52
 
 
 
 
 
 
53
  def clean_chat_history(chat_history):
54
+ return [msg for msg in chat_history if isinstance(msg, dict) and isinstance(msg.get("content"), str)]
 
 
 
 
 
 
 
 
55
 
56
  @spaces.GPU
57
+ def generate(input_dict: dict, chat_history: list[dict], max_new_tokens=1024, temperature=0.6, top_p=0.9, top_k=50, repetition_penalty=1.2):
58
+ """Generates chatbot responses with multimodal input, TTS, and image generation."""
 
 
 
 
 
 
 
 
 
 
 
59
  text = input_dict["text"]
60
  files = input_dict.get("files", [])
61
+ images = [load_image(file) for file in files] if files else []
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ if text.startswith("@tts"):
64
+ voice_index = next((i for i in range(1, 3) if text.startswith(f"@tts{i}")), None)
65
+ if voice_index:
66
+ voice = TTS_VOICES[voice_index - 1]
67
+ text = text.replace(f"@tts{voice_index}", "").strip()
68
+ conversation = [{"role": "user", "content": text}]
69
+ else:
70
+ voice = None
71
+ elif text.startswith("@image"):
72
+ query = text.replace("@image", "").strip()
73
+ yield "Generating Image, Please wait..."
74
+ image = image_gen(query)
75
+ yield gr.Image(image)
76
  else:
77
+ conversation = clean_chat_history(chat_history) + [{"role": "user", "content": text}]
78
+ if images:
79
+ messages = [{
80
+ "role": "user",
81
+ "content": [
82
+ *[{"type": "image", "image": img} for img in images],
83
+ {"type": "text", "text": text},
84
+ ]
85
+ }]
86
+ prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
87
+ inputs = processor(text=[prompt], images=images, return_tensors="pt", padding=True).to("cuda")
88
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
89
+ thread = Thread(target=model_m.generate, kwargs={**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens})
90
+ thread.start()
91
+ buffer = ""
92
+ for new_text in streamer:
93
+ buffer += new_text.replace("<|im_end|>", "")
94
+ yield buffer
95
+ else:
96
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)
97
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
98
+ thread = Thread(target=model.generate, kwargs={
99
+ "input_ids": input_ids,
100
+ "streamer": streamer,
101
+ "max_new_tokens": max_new_tokens,
102
+ "do_sample": True,
103
+ "top_p": top_p,
104
+ "top_k": top_k,
105
+ "temperature": temperature,
106
+ "num_beams": 1,
107
+ "repetition_penalty": repetition_penalty,
108
+ })
109
+ thread.start()
110
+ response = "".join([new_text for new_text in streamer])
111
+ yield response
112
+ if voice:
113
+ output_file = asyncio.run(text_to_speech(response, voice))
114
+ yield gr.Audio(output_file, autoplay=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  demo = gr.ChatInterface(
117
  fn=generate,
118
  additional_inputs=[
119
+ gr.Slider(label="Max new tokens", minimum=1, maximum=2048, step=1, value=1024),
120
  gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
121
+ gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
122
  gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
123
  gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
124
  ],
125
  examples=[
126
+ ["@tts1 Who is Nikola Tesla?"],
 
127
  [{"text": "Extract JSON from the image", "files": ["examples/document.jpg"]}],
128
+ ["@image futuristic city at sunset"],
129
+ ["A train travels 60 kilometers per hour. How far will it travel in 5 hours?"],
 
 
130
  ],
131
  cache_examples=False,
132
+ description="# QwQ Edge 💬",
 
 
133
  fill_height=True,
134
  textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"),
135
  stop_btn="Stop Generation",