prithivMLmods commited on
Commit
b06a87f
·
verified ·
1 Parent(s): 2c4a4a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -58
app.py CHANGED
@@ -1,19 +1,31 @@
1
  import os
 
 
2
  import gradio as gr
 
3
  import torch
4
- import tempfile
5
- import asyncio
6
  import edge_tts
7
- import spaces
8
- from pydub import AudioSegment
9
- from threading import Thread
10
- from collections.abc import Iterator
11
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
12
 
13
  DESCRIPTION = """
14
- # QwQ Tiny with Edge TTS (MP3 Output)
15
  """
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  MAX_MAX_NEW_TOKENS = 2048
18
  DEFAULT_MAX_NEW_TOKENS = 1024
19
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
@@ -29,24 +41,14 @@ model = AutoModelForCausalLM.from_pretrained(
29
  )
30
  model.eval()
31
 
32
- async def text_to_speech(text: str) -> str:
33
- """Converts text to speech using Edge TTS, converts WAV to MP3, and returns the MP3 file path."""
34
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_wav:
35
- wav_path = tmp_wav.name
36
-
37
- communicate = edge_tts.Communicate(text)
38
- await communicate.save(wav_path)
39
-
40
- # Convert WAV to MP3
41
- with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_mp3:
42
- mp3_path = tmp_mp3.name
43
-
44
- audio = AudioSegment.from_wav(wav_path)
45
- audio.export(mp3_path, format="mp3")
46
-
47
- os.remove(wav_path) # Delete the original WAV file
48
- return mp3_path # Return the MP3 file path
49
-
50
  @spaces.GPU
51
  def generate(
52
  message: str,
@@ -56,55 +58,47 @@ def generate(
56
  top_p: float = 0.9,
57
  top_k: int = 50,
58
  repetition_penalty: float = 1.2,
59
- ) -> Iterator[str] | str:
60
-
61
- is_tts = message.strip().startswith("@tts")
62
- is_text_only = message.strip().startswith("@text")
63
-
64
- # Remove special tags
65
- if is_tts:
66
- message = message.replace("@tts", "").strip()
67
- elif is_text_only:
68
- message = message.replace("@text", "").strip()
69
 
70
  conversation = [*chat_history, {"role": "user", "content": message}]
71
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
72
 
 
73
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
74
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
75
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
76
-
77
  input_ids = input_ids.to(model.device)
78
 
79
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
80
- generate_kwargs = {
81
- "input_ids": input_ids,
82
- "streamer": streamer,
83
- "max_new_tokens": max_new_tokens,
84
- "do_sample": True,
85
- "top_p": top_p,
86
- "top_k": top_k,
87
- "temperature": temperature,
88
- "num_beams": 1,
89
- "repetition_penalty": repetition_penalty,
90
- }
91
  t = Thread(target=model.generate, kwargs=generate_kwargs)
92
  t.start()
93
 
94
  outputs = []
95
  for text in streamer:
96
  outputs.append(text)
 
97
 
98
- final_output = "".join(outputs)
99
 
100
- # If TTS requested, generate speech and return audio file
101
  if is_tts:
102
- loop = asyncio.new_event_loop()
103
- asyncio.set_event_loop(loop)
104
- audio_path = loop.run_until_complete(text_to_speech(final_output))
105
- return audio_path
106
 
107
- return final_output #
108
 
109
  demo = gr.ChatInterface(
110
  fn=generate,
@@ -118,13 +112,15 @@ demo = gr.ChatInterface(
118
  stop_btn=None,
119
  examples=[
120
  ["A train travels 60 kilometers per hour. If it travels for 5 hours, how far will it travel in total?"],
121
- ["@text What is AI?"],
122
- ["@tts Explain Newton's third law of motion."],
123
- ["@text Rewrite the following sentence in passive voice: 'The dog chased the cat.'"],
 
124
  ],
125
  cache_examples=False,
126
  type="messages",
127
  description=DESCRIPTION,
 
128
  fill_height=True,
129
  )
130
 
 
1
  import os
2
+ from collections.abc import Iterator
3
+ from threading import Thread
4
  import gradio as gr
5
+ import spaces
6
  import torch
 
 
7
  import edge_tts
8
+ import asyncio
 
 
 
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
10
 
11
  DESCRIPTION = """
12
+ # QwQ Tiny
13
  """
14
 
15
+ css = '''
16
+ h1 {
17
+ text-align: center;
18
+ display: block;
19
+ }
20
+
21
+ #duplicate-button {
22
+ margin: auto;
23
+ color: #fff;
24
+ background: #1565c0;
25
+ border-radius: 100vh;
26
+ }
27
+ '''
28
+
29
  MAX_MAX_NEW_TOKENS = 2048
30
  DEFAULT_MAX_NEW_TOKENS = 1024
31
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
41
  )
42
  model.eval()
43
 
44
+
45
+ async def text_to_speech(text: str, output_file="output.mp3"):
46
+ voice = "en-US-JennyNeural"
47
+ communicate = edge_tts.Communicate(text, voice)
48
+ await communicate.save(output_file)
49
+ return output_file
50
+
51
+
 
 
 
 
 
 
 
 
 
 
52
  @spaces.GPU
53
  def generate(
54
  message: str,
 
58
  top_p: float = 0.9,
59
  top_k: int = 50,
60
  repetition_penalty: float = 1.2,
61
+ ):
62
+ """Generates chatbot response and handles TTS requests"""
63
+ is_tts = message.strip().lower().startswith("@tts")
64
+ message = message.replace("@tts", "").strip()
 
 
 
 
 
 
65
 
66
  conversation = [*chat_history, {"role": "user", "content": message}]
 
67
 
68
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
69
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
70
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
71
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
72
  input_ids = input_ids.to(model.device)
73
 
74
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
75
+ generate_kwargs = dict(
76
+ {"input_ids": input_ids},
77
+ streamer=streamer,
78
+ max_new_tokens=max_new_tokens,
79
+ do_sample=True,
80
+ top_p=top_p,
81
+ top_k=top_k,
82
+ temperature=temperature,
83
+ num_beams=1,
84
+ repetition_penalty=repetition_penalty,
85
+ )
86
  t = Thread(target=model.generate, kwargs=generate_kwargs)
87
  t.start()
88
 
89
  outputs = []
90
  for text in streamer:
91
  outputs.append(text)
92
+ yield "".join(outputs)
93
 
94
+ final_response = "".join(outputs)
95
 
 
96
  if is_tts:
97
+ output_file = asyncio.run(text_to_speech(final_response))
98
+ yield output_file # Return MP3 file
99
+ else:
100
+ yield final_response # Return text response
101
 
 
102
 
103
  demo = gr.ChatInterface(
104
  fn=generate,
 
112
  stop_btn=None,
113
  examples=[
114
  ["A train travels 60 kilometers per hour. If it travels for 5 hours, how far will it travel in total?"],
115
+ ["Write a Python function to check if a number is prime."],
116
+ ["What causes rainbows to form?"],
117
+ ["Rewrite the following sentence in passive voice: 'The dog chased the cat.'"],
118
+ ["@tts What is the capital of France?"],
119
  ],
120
  cache_examples=False,
121
  type="messages",
122
  description=DESCRIPTION,
123
+ css=css,
124
  fill_height=True,
125
  )
126