Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,776 Bytes
0c1b8f7 fca22b9 b06a87f 0c1b8f7 fca22b9 0c1b8f7 32d8e74 b06a87f fca22b9 b06a87f fca22b9 b06a87f fca22b9 d6b5ac6 d7f29b6 d6b5ac6 d7f29b6 47473ae fca22b9 d7f29b6 7f471f2 b06a87f 48a6837 ab6b5e5 fca22b9 b06a87f fca22b9 a592e13 fca22b9 a592e13 fca22b9 a592e13 fca22b9 d6b5ac6 a592e13 d6b5ac6 a592e13 d6b5ac6 a592e13 fca22b9 a592e13 fca22b9 a592e13 fca22b9 a592e13 fca22b9 ea9ba29 a592e13 fca22b9 a592e13 fca22b9 a592e13 fca22b9 ea9ba29 fca22b9 0ba4242 fca22b9 ea9ba29 0ba4242 d6b5ac6 c1f7ac1 b06a87f d6b5ac6 b06a87f d6b5ac6 0ba4242 fca22b9 b06a87f 0ba4242 fca22b9 0ba4242 47473ae 0c1b8f7 fca22b9 |
|
import os
from collections.abc import Iterator
from threading import Thread
import gradio as gr
import spaces
import torch
import edge_tts
import asyncio
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer
from transformers.image_utils import load_image
import time
DESCRIPTION = """
# QwQ Edge 💬
"""
css = '''
h1 {
text-align: center;
display: block;
}
#duplicate-button {
margin: auto;
color: #fff;
background: #1565c0;
border-radius: 100vh;
}
'''
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_id = "prithivMLmods/FastThink-0.5B-Tiny"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
)
model.eval()
TTS_VOICES = [
"en-US-JennyNeural", # @tts1
"en-US-GuyNeural", # @tts2
"en-US-AriaNeural", # @tts3
"en-US-DavisNeural", # @tts4
"en-US-JaneNeural", # @tts5
"en-US-JasonNeural", # @tts6
"en-US-NancyNeural", # @tts7
"en-US-TonyNeural", # @tts8
]
MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
model_m = Qwen2VLForConditionalGeneration.from_pretrained(
MODEL_ID,
trust_remote_code=True,
torch_dtype=torch.float16
).to("auto").eval()
async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
"""Convert text to speech using Edge TTS and save as MP3"""
communicate = edge_tts.Communicate(text, voice)
await communicate.save(output_file)
return output_file
@spaces.GPU
def generate(
input_dict: dict,
chat_history: list[dict],
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
):
"""Generates chatbot response and handles TTS requests with multimodal input support"""
text = input_dict["text"]
files = input_dict.get("files", [])
# Check if input includes image(s)
if len(files) > 1:
images = [load_image(image) for image in files]
elif len(files) == 1:
images = [load_image(files[0])]
else:
images = []
# Check if message is for TTS
tts_prefix = "@tts"
is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 9))
voice_index = next((i for i in range(1, 9) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
if is_tts and voice_index:
voice = TTS_VOICES[voice_index - 1]
text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
else:
voice = None
text = text.replace(tts_prefix, "").strip()
conversation = [*chat_history, {"role": "user", "content": text}]
if images:
# Process multimodal input
messages = [
{"role": "user", "content": [
*[{"type": "image", "image": image} for image in images],
{"type": "text", "text": text},
]}
]
prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[prompt], images=images, return_tensors="pt", padding=True).to("cuda")
# Handle generation for multimodal input
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
thread.start()
buffer = ""
yield "Thinking..."
for new_text in streamer:
buffer += new_text
buffer = buffer.replace("<|im_end|>", "")
time.sleep(0.01)
yield buffer
else:
# Process text-only input
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
final_response = "".join(outputs)
if is_tts and voice:
output_file = asyncio.run(text_to_speech(final_response, voice))
yield gr.Audio(output_file, autoplay=True) # Return playable audio
else:
yield final_response # Return text response
demo = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
],
examples=[
["@tts1 Who is Nikola Tesla, and why did he die?"],
["A train travels 60 kilometers per hour. If it travels for 5 hours, how far will it travel in total?"],
["Write a Python function to check if a number is prime."],
["@tts2 What causes rainbows to form?"],
["Rewrite the following sentence in passive voice: 'The dog chased the cat.'"],
["@tts5 What is the capital of France?"],
],
cache_examples=False,
type="messages",
description=DESCRIPTION,
css=css,
fill_height=True,
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"),
stop_btn="Stop Generation",
multimodal=True,
)
if __name__ == "__main__":
demo.queue(max_size=20).launch() |