prithivMLmods commited on
Commit
a592e13
·
verified ·
1 Parent(s): d6b5ac6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -83
app.py CHANGED
@@ -1,45 +1,23 @@
1
  import os
2
- from collections.abc import Iterator
3
  from threading import Thread
4
  import gradio as gr
5
- import spaces
6
  import torch
 
 
7
  import edge_tts
8
  import asyncio
9
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
10
-
11
- DESCRIPTION = """
12
- # QwQ Edge 💬
13
- """
14
-
15
- css = '''
16
- h1 {
17
- text-align: center;
18
- display: block;
19
- }
20
 
21
- #duplicate-button {
22
- margin: auto;
23
- color: #fff;
24
- background: #1565c0;
25
- border-radius: 100vh;
26
- }
27
- '''
28
 
29
- MAX_MAX_NEW_TOKENS = 2048
30
- DEFAULT_MAX_NEW_TOKENS = 1024
31
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
32
-
33
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
34
-
35
- model_id = "prithivMLmods/FastThink-0.5B-Tiny"
36
- tokenizer = AutoTokenizer.from_pretrained(model_id)
37
- model = AutoModelForCausalLM.from_pretrained(
38
- model_id,
39
- device_map="auto",
40
- torch_dtype=torch.bfloat16,
41
- )
42
- model.eval()
43
 
44
  TTS_VOICES = [
45
  "en-US-JennyNeural", # @tts1
@@ -52,6 +30,7 @@ TTS_VOICES = [
52
  "en-US-TonyNeural", # @tts8
53
  ]
54
 
 
55
  async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
56
  """Convert text to speech using Edge TTS and save as MP3"""
57
  communicate = edge_tts.Communicate(text, voice)
@@ -60,72 +39,109 @@ async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
60
 
61
  @spaces.GPU
62
  def generate(
63
- message: str,
64
- chat_history: list[dict],
65
- max_new_tokens: int = 1024,
66
- temperature: float = 0.6,
67
- top_p: float = 0.9,
68
- top_k: int = 50,
69
- repetition_penalty: float = 1.2,
70
  ):
71
- """Generates chatbot response and handles TTS requests"""
 
 
 
 
 
 
 
 
 
 
72
  tts_prefix = "@tts"
73
- is_tts = any(message.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 9))
74
- voice_index = next((i for i in range(1, 9) if message.strip().lower().startswith(f"{tts_prefix}{i}")), None)
75
 
76
  if is_tts and voice_index:
77
  voice = TTS_VOICES[voice_index - 1]
78
- message = message.replace(f"{tts_prefix}{voice_index}", "").strip()
79
  else:
80
  voice = None
81
- message = message.replace(tts_prefix, "").strip()
82
-
83
- conversation = [*chat_history, {"role": "user", "content": message}]
84
-
85
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
86
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
87
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
88
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
89
- input_ids = input_ids.to(model.device)
90
-
91
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
92
- generate_kwargs = dict(
93
- {"input_ids": input_ids},
94
- streamer=streamer,
95
- max_new_tokens=max_new_tokens,
96
- do_sample=True,
97
- top_p=top_p,
98
- top_k=top_k,
99
- temperature=temperature,
100
- num_beams=1,
101
- repetition_penalty=repetition_penalty,
102
- )
103
- t = Thread(target=model.generate, kwargs=generate_kwargs)
104
- t.start()
105
-
106
- outputs = []
107
- for text in streamer:
108
- outputs.append(text)
109
- yield "".join(outputs)
110
-
111
- final_response = "".join(outputs)
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  if is_tts and voice:
114
  output_file = asyncio.run(text_to_speech(final_response, voice))
115
  yield gr.Audio(output_file, autoplay=True) # Return playable audio
116
  else:
117
  yield final_response # Return text response
118
 
119
- demo = gr.ChatInterface(
 
120
  fn=generate,
121
- additional_inputs=[
 
 
122
  gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
123
  gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
124
  gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
125
  gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
126
  gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
127
  ],
128
- stop_btn=None,
129
  examples=[
130
  ["@tts1 Who is Nikola Tesla, and why did he die?"],
131
  ["A train travels 60 kilometers per hour. If it travels for 5 hours, how far will it travel in total?"],
@@ -134,12 +150,11 @@ demo = gr.ChatInterface(
134
  ["Rewrite the following sentence in passive voice: 'The dog chased the cat.'"],
135
  ["@tts5 What is the capital of France?"],
136
  ],
137
- cache_examples=False,
138
- type="messages",
139
- description=DESCRIPTION,
140
  css=css,
141
  fill_height=True,
142
  )
143
 
144
  if __name__ == "__main__":
145
- demo.queue(max_size=20).launch()
 
1
  import os
2
+ import time
3
  from threading import Thread
4
  import gradio as gr
 
5
  import torch
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
7
+ from transformers.image_utils import load_image
8
  import edge_tts
9
  import asyncio
10
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # Load models
13
+ MODEL_ID = "prithivMLmods/FastThink-0.5B-Tiny"
14
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
15
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto", torch_dtype=torch.bfloat16).eval()
 
 
 
16
 
17
+ # For multimodal OCR processing
18
+ OCR_MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
19
+ ocr_processor = AutoProcessor.from_pretrained(OCR_MODEL_ID, trust_remote_code=True)
20
+ ocr_model = Qwen2VLForConditionalGeneration.from_pretrained(OCR_MODEL_ID, trust_remote_code=True, torch_dtype=torch.float16).to("cuda").eval()
 
 
 
 
 
 
 
 
 
 
21
 
22
  TTS_VOICES = [
23
  "en-US-JennyNeural", # @tts1
 
30
  "en-US-TonyNeural", # @tts8
31
  ]
32
 
33
+ # Handle text-to-speech conversion
34
  async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
35
  """Convert text to speech using Edge TTS and save as MP3"""
36
  communicate = edge_tts.Communicate(text, voice)
 
39
 
40
  @spaces.GPU
41
  def generate(
42
+ input_dict,
43
+ history,
44
+ max_new_tokens: int = 1024,
45
+ temperature: float = 0.6,
46
+ top_p: float = 0.9,
47
+ top_k: int = 50,
48
+ repetition_penalty: float = 1.2
49
  ):
50
+ """Generates chatbot response and handles TTS requests with multimodal support"""
51
+ text = input_dict.get("text", "")
52
+ files = input_dict.get("files", [])
53
+
54
+ # Handle multimodal OCR processing
55
+ if files:
56
+ images = [load_image(image) for image in files]
57
+ else:
58
+ images = []
59
+
60
+ # Check if the message is TTS request
61
  tts_prefix = "@tts"
62
+ is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 9))
63
+ voice_index = next((i for i in range(1, 9) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
64
 
65
  if is_tts and voice_index:
66
  voice = TTS_VOICES[voice_index - 1]
67
+ text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
68
  else:
69
  voice = None
70
+ text = text.replace(tts_prefix, "").strip()
71
+
72
+ # If images are provided, combine image and text for the prompt
73
+ if images:
74
+ # Prepare images as part of the conversation
75
+ messages = [
76
+ {
77
+ "role": "user",
78
+ "content": [
79
+ *[{"type": "image", "image": image} for image in images],
80
+ {"type": "text", "text": text},
81
+ ],
82
+ }
83
+ ]
84
+ prompt = ocr_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
85
+ inputs = ocr_processor(
86
+ text=[prompt],
87
+ images=images,
88
+ return_tensors="pt",
89
+ padding=True,
90
+ ).to("cuda")
 
 
 
 
 
 
 
 
 
 
91
 
92
+ else:
93
+ # Normal text-only input
94
+ conversation = [*history, {"role": "user", "content": text}]
95
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
96
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
97
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
98
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
99
+ input_ids = input_ids.to(model.device)
100
+
101
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
102
+ generate_kwargs = dict(
103
+ {"input_ids": input_ids},
104
+ streamer=streamer,
105
+ max_new_tokens=max_new_tokens,
106
+ do_sample=True,
107
+ top_p=top_p,
108
+ top_k=top_k,
109
+ temperature=temperature,
110
+ num_beams=1,
111
+ repetition_penalty=repetition_penalty,
112
+ )
113
+
114
+ # Start generation in a separate thread
115
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
116
+ t.start()
117
+
118
+ # Collect generated text
119
+ outputs = []
120
+ for text in streamer:
121
+ outputs.append(text)
122
+ yield "".join(outputs)
123
+ final_response = "".join(outputs)
124
+
125
+ # Handle text-to-speech
126
  if is_tts and voice:
127
  output_file = asyncio.run(text_to_speech(final_response, voice))
128
  yield gr.Audio(output_file, autoplay=True) # Return playable audio
129
  else:
130
  yield final_response # Return text response
131
 
132
+ # Gradio Interface
133
+ demo = gr.Interface(
134
  fn=generate,
135
+ inputs=[
136
+ gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"), # Multimodal input
137
+ gr.Textbox(label="Chat History", value="", placeholder="Previous conversation history"),
138
  gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
139
  gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
140
  gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
141
  gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
142
  gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
143
  ],
144
+ outputs=["text", "audio"],
145
  examples=[
146
  ["@tts1 Who is Nikola Tesla, and why did he die?"],
147
  ["A train travels 60 kilometers per hour. If it travels for 5 hours, how far will it travel in total?"],
 
150
  ["Rewrite the following sentence in passive voice: 'The dog chased the cat.'"],
151
  ["@tts5 What is the capital of France?"],
152
  ],
153
+ stop_btn="Stop Generation",
154
+ description="QwQ Edge: A Chatbot with Text-to-Speech and Multimodal Support",
 
155
  css=css,
156
  fill_height=True,
157
  )
158
 
159
  if __name__ == "__main__":
160
+ demo.launch()