prithivMLmods commited on
Commit
ea9ba29
·
verified ·
1 Parent(s): bce38cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -38
app.py CHANGED
@@ -4,6 +4,8 @@ from threading import Thread
4
  import gradio as gr
5
  import spaces
6
  import torch
 
 
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
8
 
9
  DESCRIPTION = """
@@ -40,6 +42,14 @@ model = AutoModelForCausalLM.from_pretrained(
40
  model.eval()
41
 
42
 
 
 
 
 
 
 
 
 
43
  @spaces.GPU
44
  def generate(
45
  message: str,
@@ -49,7 +59,11 @@ def generate(
49
  top_p: float = 0.9,
50
  top_k: int = 50,
51
  repetition_penalty: float = 1.2,
52
- ) -> Iterator[str]:
 
 
 
 
53
  conversation = [*chat_history, {"role": "user", "content": message}]
54
 
55
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
@@ -78,52 +92,31 @@ def generate(
78
  outputs.append(text)
79
  yield "".join(outputs)
80
 
 
 
 
 
 
 
 
 
81
 
82
  demo = gr.ChatInterface(
83
  fn=generate,
84
  additional_inputs=[
85
- gr.Slider(
86
- label="Max new tokens",
87
- minimum=1,
88
- maximum=MAX_MAX_NEW_TOKENS,
89
- step=1,
90
- value=DEFAULT_MAX_NEW_TOKENS,
91
- ),
92
- gr.Slider(
93
- label="Temperature",
94
- minimum=0.1,
95
- maximum=4.0,
96
- step=0.1,
97
- value=0.6,
98
- ),
99
- gr.Slider(
100
- label="Top-p (nucleus sampling)",
101
- minimum=0.05,
102
- maximum=1.0,
103
- step=0.05,
104
- value=0.9,
105
- ),
106
- gr.Slider(
107
- label="Top-k",
108
- minimum=1,
109
- maximum=1000,
110
- step=1,
111
- value=50,
112
- ),
113
- gr.Slider(
114
- label="Repetition penalty",
115
- minimum=1.0,
116
- maximum=2.0,
117
- step=0.05,
118
- value=1.2,
119
- ),
120
  ],
121
  stop_btn=None,
122
  examples=[
123
  ["A train travels 60 kilometers per hour. If it travels for 5 hours, how far will it travel in total?"],
124
- ["Write a Python function to check if a number is prime. "],
125
  ["What causes rainbows to form?"],
126
  ["Rewrite the following sentence in passive voice: 'The dog chased the cat.'"],
 
127
  ],
128
  cache_examples=False,
129
  type="messages",
@@ -132,6 +125,5 @@ demo = gr.ChatInterface(
132
  fill_height=True,
133
  )
134
 
135
-
136
  if __name__ == "__main__":
137
  demo.queue(max_size=20).launch()
 
4
  import gradio as gr
5
  import spaces
6
  import torch
7
+ import edge_tts
8
+ import asyncio
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
10
 
11
  DESCRIPTION = """
 
42
  model.eval()
43
 
44
 
45
+ async def text_to_speech(text: str, output_file="output.mp3"):
46
+ """Convert text to speech using Edge TTS and save as MP3"""
47
+ voice = "en-US-JennyNeural" # Change this to your preferred voice
48
+ communicate = edge_tts.Communicate(text, voice)
49
+ await communicate.save(output_file)
50
+ return output_file
51
+
52
+
53
  @spaces.GPU
54
  def generate(
55
  message: str,
 
59
  top_p: float = 0.9,
60
  top_k: int = 50,
61
  repetition_penalty: float = 1.2,
62
+ ):
63
+ """Generates chatbot response and handles TTS requests"""
64
+ is_tts = message.strip().lower().startswith("@tts")
65
+ message = message.replace("@tts", "").strip()
66
+
67
  conversation = [*chat_history, {"role": "user", "content": message}]
68
 
69
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
 
92
  outputs.append(text)
93
  yield "".join(outputs)
94
 
95
+ final_response = "".join(outputs)
96
+
97
+ if is_tts:
98
+ output_file = asyncio.run(text_to_speech(final_response))
99
+ return output_file # Return MP3 file
100
+
101
+ return final_response # Return text response
102
+
103
 
104
  demo = gr.ChatInterface(
105
  fn=generate,
106
  additional_inputs=[
107
+ gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS),
108
+ gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
109
+ gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
110
+ gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
111
+ gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  ],
113
  stop_btn=None,
114
  examples=[
115
  ["A train travels 60 kilometers per hour. If it travels for 5 hours, how far will it travel in total?"],
116
+ ["Write a Python function to check if a number is prime."],
117
  ["What causes rainbows to form?"],
118
  ["Rewrite the following sentence in passive voice: 'The dog chased the cat.'"],
119
+ ["@tts What is the capital of France?"],
120
  ],
121
  cache_examples=False,
122
  type="messages",
 
125
  fill_height=True,
126
  )
127
 
 
128
  if __name__ == "__main__":
129
  demo.queue(max_size=20).launch()