prithivMLmods commited on
Commit
d6b5ac6
·
verified ·
1 Parent(s): 37efd95

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -49
app.py CHANGED
@@ -8,22 +8,6 @@ import edge_tts
8
  import asyncio
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
10
 
11
- MAX_MAX_NEW_TOKENS = 2048
12
- DEFAULT_MAX_NEW_TOKENS = 1024
13
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
14
-
15
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16
-
17
- model_id = "prithivMLmods/FastThink-0.5B-Tiny"
18
- tokenizer = AutoTokenizer.from_pretrained(model_id)
19
- model = AutoModelForCausalLM.from_pretrained(
20
- model_id,
21
- device_map="auto",
22
- torch_dtype=torch.bfloat16,
23
- )
24
- model.eval()
25
-
26
-
27
  DESCRIPTION = """
28
  # QwQ Edge 💬
29
  """
@@ -42,25 +26,38 @@ h1 {
42
  }
43
  '''
44
 
45
- # List of voices
46
- voices = [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  "en-US-JennyNeural", # @tts1
48
  "en-US-GuyNeural", # @tts2
49
  "en-US-AriaNeural", # @tts3
50
- "en-US-JaneNeural", # @tts4
51
- "en-US-JasonNeural", # @tts5
52
- "en-US-NancyNeural", # @tts6
53
- "en-US-TonyNeural", # @tts7
 
54
  ]
55
 
56
-
57
  async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
58
  """Convert text to speech using Edge TTS and save as MP3"""
59
  communicate = edge_tts.Communicate(text, voice)
60
  await communicate.save(output_file)
61
  return output_file
62
 
63
-
64
  @spaces.GPU
65
  def generate(
66
  message: str,
@@ -72,25 +69,16 @@ def generate(
72
  repetition_penalty: float = 1.2,
73
  ):
74
  """Generates chatbot response and handles TTS requests"""
75
- is_tts = message.strip().lower().startswith("@tts")
76
- tts_index = None
77
-
78
- if is_tts:
79
- # Extract the number after @tts
80
- tts_part = message.strip().lower().split()[0] # Get the @ttsX part
81
- if len(tts_part) > 8: # Check if it's @ttsX (e.g., @tts1, @tts2, etc.)
82
- try:
83
- tts_index = int(tts_part[8:]) - 1 # Convert to 0-based index
84
- if tts_index < 0 or tts_index >= len(voices):
85
- gr.Warning(f"Invalid TTS voice index. Using default voice.")
86
- tts_index = 0
87
- except ValueError:
88
- gr.Warning(f"Invalid TTS voice index. Using default voice.")
89
- tts_index = 0
90
- else:
91
- tts_index = 0 # Default to the first voice if no number is provided
92
-
93
- message = message.replace(tts_part, "").strip() # Remove @ttsX from the message
94
 
95
  conversation = [*chat_history, {"role": "user", "content": message}]
96
 
@@ -122,14 +110,12 @@ def generate(
122
 
123
  final_response = "".join(outputs)
124
 
125
- if is_tts:
126
- voice = voices[tts_index] # Select the voice based on the index
127
  output_file = asyncio.run(text_to_speech(final_response, voice))
128
  yield gr.Audio(output_file, autoplay=True) # Return playable audio
129
  else:
130
  yield final_response # Return text response
131
 
132
-
133
  demo = gr.ChatInterface(
134
  fn=generate,
135
  additional_inputs=[
@@ -141,12 +127,12 @@ demo = gr.ChatInterface(
141
  ],
142
  stop_btn=None,
143
  examples=[
144
- ["@tts7 Who is Nikola Tesla, and why did he die?"],
145
  ["A train travels 60 kilometers per hour. If it travels for 5 hours, how far will it travel in total?"],
146
  ["Write a Python function to check if a number is prime."],
147
- ["@tts6 What causes rainbows to form?"],
148
  ["Rewrite the following sentence in passive voice: 'The dog chased the cat.'"],
149
- ["@tts4 What is the capital of France?"],
150
  ],
151
  cache_examples=False,
152
  type="messages",
 
8
  import asyncio
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  DESCRIPTION = """
12
  # QwQ Edge 💬
13
  """
 
26
  }
27
  '''
28
 
29
+ MAX_MAX_NEW_TOKENS = 2048
30
+ DEFAULT_MAX_NEW_TOKENS = 1024
31
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
32
+
33
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
34
+
35
+ model_id = "prithivMLmods/FastThink-0.5B-Tiny"
36
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
37
+ model = AutoModelForCausalLM.from_pretrained(
38
+ model_id,
39
+ device_map="auto",
40
+ torch_dtype=torch.bfloat16,
41
+ )
42
+ model.eval()
43
+
44
+ TTS_VOICES = [
45
  "en-US-JennyNeural", # @tts1
46
  "en-US-GuyNeural", # @tts2
47
  "en-US-AriaNeural", # @tts3
48
+ "en-US-DavisNeural", # @tts4
49
+ "en-US-JaneNeural", # @tts5
50
+ "en-US-JasonNeural", # @tts6
51
+ "en-US-NancyNeural", # @tts7
52
+ "en-US-TonyNeural", # @tts8
53
  ]
54
 
 
55
  async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
56
  """Convert text to speech using Edge TTS and save as MP3"""
57
  communicate = edge_tts.Communicate(text, voice)
58
  await communicate.save(output_file)
59
  return output_file
60
 
 
61
  @spaces.GPU
62
  def generate(
63
  message: str,
 
69
  repetition_penalty: float = 1.2,
70
  ):
71
  """Generates chatbot response and handles TTS requests"""
72
+ tts_prefix = "@tts"
73
+ is_tts = any(message.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 9))
74
+ voice_index = next((i for i in range(1, 9) if message.strip().lower().startswith(f"{tts_prefix}{i}")), None)
75
+
76
+ if is_tts and voice_index:
77
+ voice = TTS_VOICES[voice_index - 1]
78
+ message = message.replace(f"{tts_prefix}{voice_index}", "").strip()
79
+ else:
80
+ voice = None
81
+ message = message.replace(tts_prefix, "").strip()
 
 
 
 
 
 
 
 
 
82
 
83
  conversation = [*chat_history, {"role": "user", "content": message}]
84
 
 
110
 
111
  final_response = "".join(outputs)
112
 
113
+ if is_tts and voice:
 
114
  output_file = asyncio.run(text_to_speech(final_response, voice))
115
  yield gr.Audio(output_file, autoplay=True) # Return playable audio
116
  else:
117
  yield final_response # Return text response
118
 
 
119
  demo = gr.ChatInterface(
120
  fn=generate,
121
  additional_inputs=[
 
127
  ],
128
  stop_btn=None,
129
  examples=[
130
+ ["@tts1 Who is Nikola Tesla, and why did he die?"],
131
  ["A train travels 60 kilometers per hour. If it travels for 5 hours, how far will it travel in total?"],
132
  ["Write a Python function to check if a number is prime."],
133
+ ["@tts2 What causes rainbows to form?"],
134
  ["Rewrite the following sentence in passive voice: 'The dog chased the cat.'"],
135
+ ["@tts5 What is the capital of France?"],
136
  ],
137
  cache_examples=False,
138
  type="messages",