CoSTA / ST /inference /demo.py
bhavanishankarpullela's picture
Upload 360 files
b817ab5 verified
raw
history blame
4.2 kB
import requests
import gradio as gr
import soundfile as sf
import time
def speech_translation(audio, language):
if audio is None:
return "No audio input provided!", "No audio input provided!"
# Convert audio to .wav format if not already
if not audio.endswith(".wav"):
wav_data, samplerate = sf.read(audio)
sf.write("temp_audio.wav", wav_data, samplerate)
audio_file = "temp_audio.wav"
else:
audio_file = audio
# ASR processing
files = {
'file': open(audio_file, "rb"),
'language': (None, language),
'vtt': (None, 'true'),
}
response = requests.post('https://asr.iitm.ac.in/ssl_asr/decode', files=files)
print(response.json())
try:
asr_output = response.json()['transcript']
except:
asr_output = "Error in ASR processing"
asr_output = asr_output.replace("।", "")
asr_output = asr_output.replace(".", "")
time.sleep(1)
if language == "telugu":
lang = "te"
elif language == "hindi":
lang = "hi"
elif language == "marathi":
lang = "mr"
elif language == "bengali":
lang = "bn"
payload = {
"pipelineTasks": [
{
"taskType": "translation",
"config": {
"language": {
"sourceLanguage": lang,
"targetLanguage": "en",
},
},
}
],
"pipelineRequestConfig": {
"pipelineId" : "64392f96daac500b55c543cd"
}
}
headers = {
"Content-Type": "application/json",
"userID": "2aeef589f4584eb08aa0b9c49761aeb8",
"ulcaApiKey": "02ed10445a-66b0-4061-9030-9b0b8b37a4f1"
}
response = requests.post('https://meity-auth.ulcacontrib.org/ulca/apis/v0/model/getModelsPipeline', json=payload, headers=headers)
if response.status_code == 200:
response_data = response.json()
print(response_data)
service_id = response_data["pipelineResponseConfig"][0]["config"][0]["serviceId"]
# if lang=="te":
# service_id = "bhashini/iitm/asr-dravidian--gpu--t4"
# else:
# service_id = "bhashini/iitm/asr-indoaryan--gpu--t4"
# print("halfway")
compute_payload = {
"pipelineTasks": [
{
"taskType": "translation",
"config": {
"language": {
"sourceLanguage": lang,
"targetLanguage": "en",
},
},
}
],
"inputData": {"input": [{"source": asr_output}]},
}
callback_url = response_data["pipelineInferenceAPIEndPoint"]["callbackUrl"]
headers2 = {
"Content-Type": "application/json",
response_data["pipelineInferenceAPIEndPoint"]["inferenceApiKey"]["name"]:
response_data["pipelineInferenceAPIEndPoint"]["inferenceApiKey"]["value"]
}
compute_response = requests.post(callback_url, json=compute_payload, headers=headers2)
# print(compute_response.json())
if compute_response.status_code == 200:
compute_response_data = compute_response.json()
print(compute_response_data)
translated_content = compute_response_data["pipelineResponse"][0]["output"][0]["target"]
print(
"Translation successful",
translated_content
)
else:
print (
"status_code", compute_response.status_code)
return translated_content
iface = gr.Interface(
fn=speech_translation,
inputs=[
gr.Audio(type="filepath", label="Record your speech"),
gr.Dropdown(["telugu", "hindi", "marathi", "bengali"], label="Select Language")
],
outputs=["text"],
title="Speech Translation",
description="Record your speech and get the English translation.",
)
iface.launch()