jobsm commited on
Commit
0c0dfca
·
verified ·
1 Parent(s): 42b3635
Files changed (1) hide show
  1. app.py +47 -136
app.py CHANGED
@@ -13,61 +13,54 @@ import tempfile
13
  # Load models
14
  whisper_model = whisper.load_model("base")
15
  sentiment_analysis = pipeline(
16
- "sentiment-analysis", framework="pt", model="SamLowe/roberta-base-go_emotions")
17
-
18
 
19
  def load_sign_language_model():
20
- return tf.keras.models.load_model('best_model.h5')
21
-
22
 
23
  sign_language_model = load_sign_language_model()
24
 
25
- # Get all available voices
26
-
27
-
28
  async def get_voices():
29
  voices = await edge_tts.list_voices()
30
- return {f"{v['ShortName']} - {v['Locale']} ({v['Gender']})": v['ShortName'] for v in voices}
 
 
 
31
 
32
  # Audio-based functions
33
-
34
-
35
  def analyze_sentiment(text):
36
  results = sentiment_analysis(text)
37
- sentiment_results = {result['label']: result['score']
38
- for result in results}
39
  return sentiment_results
40
 
41
-
42
  def display_sentiment_results(sentiment_results, option):
43
  sentiment_text = ""
44
  for sentiment, score in sentiment_results.items():
45
  if option == "Sentiment Only":
46
  sentiment_text += f"{sentiment}\n"
47
  elif option == "Sentiment + Score":
48
- sentiment_text += f"{sentiment}: {score}\n"
49
  return sentiment_text
50
 
51
-
52
  def search_text(text, api_key):
53
  api_endpoint = "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-latest:generateContent"
54
  headers = {"Content-Type": "application/json"}
55
  payload = {"contents": [{"parts": [{"text": text}]}]}
56
 
57
  try:
58
- response = requests.post(
59
- api_endpoint, headers=headers, json=payload, params={"key": api_key})
60
  response.raise_for_status()
61
  response_json = response.json()
62
- if 'candidates' in response_json and len(response_json['candidates']) > 0:
63
- content_parts = response_json['candidates'][0]['content']['parts']
64
- if len(content_parts) > 0:
65
- return content_parts[0]['text'].strip()
66
  return "No relevant content found."
67
  except requests.exceptions.RequestException as e:
68
  return {"error": str(e)}
69
 
70
-
71
  async def text_to_speech(text, voice, rate, pitch):
72
  if not text.strip():
73
  return None, gr.Warning("Please enter text to convert.")
@@ -77,20 +70,18 @@ async def text_to_speech(text, voice, rate, pitch):
77
  voice_short_name = voice.split(" - ")[0]
78
  rate_str = f"{rate:+d}%"
79
  pitch_str = f"{pitch:+d}Hz"
80
- communicate = edge_tts.Communicate(
81
- text, voice_short_name, rate=rate_str, pitch=pitch_str)
82
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
83
  tmp_path = tmp_file.name
84
  await communicate.save(tmp_path)
 
85
  return tmp_path, None
86
 
87
-
88
  async def tts_interface(text, voice, rate, pitch):
89
- audio, warning = await text_to_speech(text, voice, rate, pitch)
90
- return audio, warning
91
-
92
 
93
- def inference_audio(audio, sentiment_option, api_key, tts_voice, tts_rate, tts_pitch):
94
  if audio is None:
95
  return "No audio file provided.", "", "", "", None
96
 
@@ -105,49 +96,15 @@ def inference_audio(audio, sentiment_option, api_key, tts_voice, tts_rate, tts_p
105
  result = whisper.decode(whisper_model, mel, options)
106
 
107
  sentiment_results = analyze_sentiment(result.text)
108
- sentiment_output = display_sentiment_results(
109
- sentiment_results, sentiment_option)
110
 
111
  search_results = search_text(result.text, api_key)
112
 
113
- # Generate audio for explanation
114
- explanation_audio, _ = asyncio.run(tts_interface(
115
- search_results, tts_voice, tts_rate, tts_pitch))
116
 
117
  return lang.upper(), result.text, sentiment_output, search_results, explanation_audio
118
 
119
- # Image-based functions
120
-
121
-
122
- def get_explanation(letter, api_key):
123
- url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-latest:generateContent"
124
- headers = {"Content-Type": "application/json"}
125
- data = {
126
- "contents": [
127
- {"parts": [{"text": f"Explain how the American Sign Language letter '{letter}' is shown, its significance, and why it is represented this way."}]}
128
- ]
129
- }
130
- params = {"key": api_key}
131
-
132
- try:
133
- response = requests.post(url, headers=headers,
134
- json=data, params=params)
135
- response.raise_for_status()
136
- response_data = response.json()
137
- explanation = response_data.get("contents", [{}])[0].get("parts", [{}])[
138
- 0].get("text", "No explanation available.")
139
- # Remove unnecessary symbols and formatting
140
- explanation = explanation.replace(
141
- "*", "").replace("#", "").replace("$", "").replace("\n", " ").strip()
142
- # Remove additional special characters, if needed
143
- explanation = explanation.translate(
144
- str.maketrans('', '', string.punctuation))
145
- return explanation
146
- except requests.RequestException as e:
147
- return f"Error fetching explanation: {e}"
148
-
149
-
150
- def classify_sign_language(image, api_key):
151
  img = np.array(image)
152
  gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
153
  gray_img = cv2.resize(gray_img, (28, 28))
@@ -160,22 +117,16 @@ def classify_sign_language(image, api_key):
160
  output = output + 1 if output > 7 else output
161
  pred = uppercase_alphabet[output]
162
 
163
- explanation = get_explanation(pred, api_key)
164
-
165
- return pred, explanation
166
-
167
- # Gradio interface
168
 
 
169
 
170
- def process_input(input_type, audio=None, image=None, sentiment_option=None, api_key=None, tts_voice=None, tts_rate=0, tts_pitch=0):
171
  if input_type == "Audio":
172
- return inference_audio(audio, sentiment_option, api_key, tts_voice, tts_rate, tts_pitch)
173
  elif input_type == "Image":
174
- pred, explanation = classify_sign_language(image, api_key)
175
- explanation_audio, _ = asyncio.run(tts_interface(
176
- explanation, tts_voice, tts_rate, tts_pitch))
177
- return "N/A", pred, "N/A", explanation, explanation_audio
178
-
179
 
180
  async def main():
181
  voices = await get_voices()
@@ -183,74 +134,34 @@ async def main():
183
  with gr.Blocks() as demo:
184
  gr.Markdown("# Speak & Sign AI Assistant")
185
 
186
- # Layout: Split user input and bot response sides
187
  with gr.Row():
188
- # User Input Side
189
  with gr.Column():
190
  gr.Markdown("### User Input")
191
- # Input selection
192
- input_type = gr.Radio(label="Choose Input Type", choices=[
193
- "Audio", "Image"], value="Audio")
194
-
195
- # API key input
196
- api_key_input = gr.Textbox(
197
- label="API Key", placeholder="Your API key here", type="password")
198
-
199
- # Audio input
200
- audio_input = gr.Audio(
201
- label="Upload or Record Audio", type="filepath", visible=True)
202
- sentiment_option = gr.Radio(choices=[
203
- "Sentiment Only", "Sentiment + Score"], label="Sentiment Output", value="Sentiment Only", visible=True)
204
-
205
- # Image input
206
- image_input = gr.Image(
207
- label="Upload Image", type="pil", visible=False)
208
-
209
- # TTS settings for explanation
210
- tts_voice = gr.Dropdown(label="Select Voice", choices=[
211
- ] + list(voices.keys()), value="")
212
- tts_rate = gr.Slider(
213
- minimum=-50, maximum=50, value=0, label="Speech Rate Adjustment (%)", step=1)
214
- tts_pitch = gr.Slider(
215
- minimum=-20, maximum=20, value=0, label="Pitch Adjustment (Hz)", step=1)
216
-
217
- # Change input visibility based on selection
218
- def update_visibility(input_type):
219
- if input_type == "Audio":
220
- return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)
221
- else:
222
- return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
223
 
224
- input_type.change(update_visibility, inputs=input_type, outputs=[
225
- audio_input, sentiment_option, image_input])
226
 
227
- # Submit button
228
  submit_btn = gr.Button("Submit")
229
 
230
- # Bot Response Side
231
  with gr.Column():
232
  gr.Markdown("### Bot Response")
 
 
 
 
 
233
 
234
- lang_str = gr.Textbox(
235
- label="Detected Language", interactive=False)
236
- text = gr.Textbox(
237
- label="Transcription or Prediction", interactive=False)
238
- sentiment_output = gr.Textbox(
239
- label="Sentiment Analysis Results", interactive=False)
240
- search_results = gr.Textbox(
241
- label="Explanation or Search Results", interactive=False)
242
- audio_output = gr.Audio(
243
- label="Generated Explanation Audio", type="filepath", interactive=False)
244
-
245
- # Submit button action
246
- submit_btn.click(
247
- process_input,
248
- inputs=[input_type, audio_input, image_input, sentiment_option,
249
- api_key_input, tts_voice, tts_rate, tts_pitch],
250
- outputs=[lang_str, text, sentiment_output,
251
- search_results, audio_output]
252
- )
253
 
254
  demo.launch(share=True)
255
 
256
- asyncio.run(main())
 
13
  # Load models
14
  whisper_model = whisper.load_model("base")
15
  sentiment_analysis = pipeline(
16
+ "sentiment-analysis", framework="pt", model="SamLowe/roberta-base-go_emotions"
17
+ )
18
 
19
  def load_sign_language_model():
20
+ return tf.keras.models.load_model("best_model.h5")
 
21
 
22
  sign_language_model = load_sign_language_model()
23
 
24
+ # Get available voices asynchronously
 
 
25
  async def get_voices():
26
  voices = await edge_tts.list_voices()
27
+ return {
28
+ f"{v['ShortName']} - {v['Locale']} ({v['Gender']})": v["ShortName"]
29
+ for v in voices
30
+ }
31
 
32
  # Audio-based functions
 
 
33
  def analyze_sentiment(text):
34
  results = sentiment_analysis(text)
35
+ sentiment_results = {result["label"]: result["score"] for result in results}
 
36
  return sentiment_results
37
 
 
38
  def display_sentiment_results(sentiment_results, option):
39
  sentiment_text = ""
40
  for sentiment, score in sentiment_results.items():
41
  if option == "Sentiment Only":
42
  sentiment_text += f"{sentiment}\n"
43
  elif option == "Sentiment + Score":
44
+ sentiment_text += f"{sentiment}: {score:.2f}\n"
45
  return sentiment_text
46
 
 
47
  def search_text(text, api_key):
48
  api_endpoint = "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-latest:generateContent"
49
  headers = {"Content-Type": "application/json"}
50
  payload = {"contents": [{"parts": [{"text": text}]}]}
51
 
52
  try:
53
+ response = requests.post(api_endpoint, headers=headers, json=payload, params={"key": api_key})
 
54
  response.raise_for_status()
55
  response_json = response.json()
56
+ if "candidates" in response_json and response_json["candidates"]:
57
+ content_parts = response_json["candidates"][0]["content"]["parts"]
58
+ if content_parts:
59
+ return content_parts[0]["text"].strip()
60
  return "No relevant content found."
61
  except requests.exceptions.RequestException as e:
62
  return {"error": str(e)}
63
 
 
64
  async def text_to_speech(text, voice, rate, pitch):
65
  if not text.strip():
66
  return None, gr.Warning("Please enter text to convert.")
 
70
  voice_short_name = voice.split(" - ")[0]
71
  rate_str = f"{rate:+d}%"
72
  pitch_str = f"{pitch:+d}Hz"
73
+ communicate = edge_tts.Communicate(text, voice_short_name, rate=rate_str, pitch=pitch_str)
74
+
75
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
76
  tmp_path = tmp_file.name
77
  await communicate.save(tmp_path)
78
+
79
  return tmp_path, None
80
 
 
81
  async def tts_interface(text, voice, rate, pitch):
82
+ return await text_to_speech(text, voice, rate, pitch)
 
 
83
 
84
+ async def inference_audio(audio, sentiment_option, api_key, tts_voice, tts_rate, tts_pitch):
85
  if audio is None:
86
  return "No audio file provided.", "", "", "", None
87
 
 
96
  result = whisper.decode(whisper_model, mel, options)
97
 
98
  sentiment_results = analyze_sentiment(result.text)
99
+ sentiment_output = display_sentiment_results(sentiment_results, sentiment_option)
 
100
 
101
  search_results = search_text(result.text, api_key)
102
 
103
+ explanation_audio, _ = await tts_interface(search_results, tts_voice, tts_rate, tts_pitch)
 
 
104
 
105
  return lang.upper(), result.text, sentiment_output, search_results, explanation_audio
106
 
107
+ async def classify_sign_language(image, api_key):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  img = np.array(image)
109
  gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
110
  gray_img = cv2.resize(gray_img, (28, 28))
 
117
  output = output + 1 if output > 7 else output
118
  pred = uppercase_alphabet[output]
119
 
120
+ explanation = search_text(f"Explain the American Sign Language letter '{pred}'.", api_key)
121
+ explanation_audio, _ = await tts_interface(explanation, None, 0, 0)
 
 
 
122
 
123
+ return pred, explanation, explanation_audio
124
 
125
+ async def process_input(input_type, audio=None, image=None, sentiment_option=None, api_key=None, tts_voice=None, tts_rate=0, tts_pitch=0):
126
  if input_type == "Audio":
127
+ return await inference_audio(audio, sentiment_option, api_key, tts_voice, tts_rate, tts_pitch)
128
  elif input_type == "Image":
129
+ return await classify_sign_language(image, api_key)
 
 
 
 
130
 
131
  async def main():
132
  voices = await get_voices()
 
134
  with gr.Blocks() as demo:
135
  gr.Markdown("# Speak & Sign AI Assistant")
136
 
 
137
  with gr.Row():
 
138
  with gr.Column():
139
  gr.Markdown("### User Input")
140
+ input_type = gr.Radio(label="Choose Input Type", choices=["Audio", "Image"], value="Audio")
141
+ api_key_input = gr.Textbox(label="API Key", placeholder="Your API key here", type="password")
142
+ audio_input = gr.Audio(label="Upload or Record Audio", type="filepath")
143
+ sentiment_option = gr.Radio(choices=["Sentiment Only", "Sentiment + Score"], label="Sentiment Output", value="Sentiment Only")
144
+ image_input = gr.Image(label="Upload Image", type="pil", visible=False)
145
+ tts_voice = gr.Dropdown(label="Select Voice", choices=[""] + list(voices.keys()), value="")
146
+ tts_rate = gr.Slider(minimum=-50, maximum=50, value=0, label="Speech Rate Adjustment (%)", step=1)
147
+ tts_pitch = gr.Slider(minimum=-20, maximum=20, value=0, label="Pitch Adjustment (Hz)", step=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
+ def update_visibility(input_type):
150
+ return gr.update(visible=input_type == "Audio"), gr.update(visible=input_type == "Image")
151
 
152
+ input_type.change(update_visibility, inputs=[input_type], outputs=[audio_input, image_input])
153
  submit_btn = gr.Button("Submit")
154
 
 
155
  with gr.Column():
156
  gr.Markdown("### Bot Response")
157
+ lang_str = gr.Textbox(label="Detected Language", interactive=False)
158
+ text = gr.Textbox(label="Transcription or Prediction", interactive=False)
159
+ sentiment_output = gr.Textbox(label="Sentiment Analysis Results", interactive=False)
160
+ search_results = gr.Textbox(label="Explanation", interactive=False)
161
+ audio_output = gr.Audio(label="Generated Explanation Audio", type="filepath", interactive=False)
162
 
163
+ submit_btn.click(process_input, inputs=[input_type, audio_input, image_input, sentiment_option, api_key_input, tts_voice, tts_rate, tts_pitch], outputs=[lang_str, text, sentiment_output, search_results, audio_output])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  demo.launch(share=True)
166
 
167
+ asyncio.create_task(main())