amir22010 commited on
Commit
f0dd428
·
verified ·
1 Parent(s): b96fdc3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -16
app.py CHANGED
@@ -26,6 +26,27 @@ for name in list_repo_files(repo_id="balacoon/tts"):
26
  local_dir=os.getcwd(),
27
  )
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  #client
30
  client = Groq(
31
  api_key=os.getenv("GROQ_API_KEY"),
@@ -98,16 +119,8 @@ def greet(product,description):
98
  response = client.chat.completions.create(model=guard_llm, messages=messages, temperature=0)
99
  if response.choices[0].message.content != "not moderated":
100
  a_list = ["Sorry, I can't proceed for generating marketing email. Your content needs to be moderated first. Thank you!"]
101
- with locker:
102
- tts = TTS(os.path.join(os.getcwd(), tts_model_str))
103
- speakers = tts.get_speakers()
104
- if len(a_list[0]) > 1024:
105
- # truncate the text
106
- text_str = a_list[0][:1024]
107
- else:
108
- text_str = a_list[0]
109
- samples = tts.synthesize(text_str, speakers[-1])
110
- yield gr.Audio(value=(tts.get_sampling_rate(), samples)), text_str
111
  else:
112
  output = llm.create_chat_completion(
113
  messages=[
@@ -122,15 +135,14 @@ def greet(product,description):
122
  stream=True
123
  )
124
  partial_message = ""
 
125
  for chunk in output:
126
  delta = chunk['choices'][0]['delta']
127
  if 'content' in delta:
128
- with locker:
129
- tts = TTS(os.path.join(os.getcwd(), tts_model_str))
130
- speakers = tts.get_speakers()
131
- samples = tts.synthesize(delta.get('content', ''), speakers[-1])
132
- partial_message = partial_message + delta.get('content', '')
133
- yield gr.Audio(value=(tts.get_sampling_rate(), samples)), partial_message
134
 
135
  audio = gr.Audio()
136
  demo = gr.Interface(fn=greet, inputs=["text","text"], concurrency_limit=10, outputs=[audio,"text"])
 
26
  local_dir=os.getcwd(),
27
  )
28
 
29
+ def text_to_speech(text):
30
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_file:
31
+ with locker:
32
+ tts = TTS(os.path.join(os.getcwd(), tts_model_str))
33
+ if len(text) > 1024:
34
+ # truncate the text
35
+ text_str = text[:1024]
36
+ else:
37
+ text_str = a_list[0]
38
+ audio_data = tts.synthesize(text_str, "92")
39
+ temp_file.write(np.ascontiguousarray(audio_data))
40
+ return temp_file.name
41
+
42
+ def combine_audio_files(audio_files):
43
+ combined = AudioSegment.empty()
44
+ for audio_file in audio_files:
45
+ segment = AudioSegment.from_wav(audio_file)
46
+ combined += segment
47
+ os.remove(audio_file) # Remove temporary files
48
+ return combined
49
+
50
  #client
51
  client = Groq(
52
  api_key=os.getenv("GROQ_API_KEY"),
 
119
  response = client.chat.completions.create(model=guard_llm, messages=messages, temperature=0)
120
  if response.choices[0].message.content != "not moderated":
121
  a_list = ["Sorry, I can't proceed for generating marketing email. Your content needs to be moderated first. Thank you!"]
122
+ processed_audio = combine_audio_files([text_to_speech(a_list[0])])
123
+ yield (processed_audio.sample_rate,processed_audio) a_list[0]
 
 
 
 
 
 
 
 
124
  else:
125
  output = llm.create_chat_completion(
126
  messages=[
 
135
  stream=True
136
  )
137
  partial_message = ""
138
+ audio_list = []
139
  for chunk in output:
140
  delta = chunk['choices'][0]['delta']
141
  if 'content' in delta:
142
+ audio_list = audio_list + [text_to_speech(delta.get('content', ''))]
143
+ processed_audio = combine_audio_files(audio_list)
144
+ partial_message = partial_message + delta.get('content', '')
145
+ yield (processed_audio.sample_rate,processed_audio), partial_message
 
 
146
 
147
  audio = gr.Audio()
148
  demo = gr.Interface(fn=greet, inputs=["text","text"], concurrency_limit=10, outputs=[audio,"text"])