aiegoo's picture
Update app.py (#10)
83ce982
import gradio as gr
import base64
import requests
import secrets
import os
import argparse
from io import BytesIO
from pydub import AudioSegment
LOCAL_API_ENDPOINT = "http://localhost:5000"
PUBLIC_API_ENDPOINT = "http://45.76.97.217:5000"
API_ENDPOINT = PUBLIC_API_ENDPOINT
session_id = ""
chat_history = []
css = """
#audio_input {
margin-top: -30px; !important;
margin-left: -15px; !important;
width: 100% !important;
border-style: None !important;
background-color: transparent !important;
}
#audio_input button {
height:50px !important;
font-size: 0px !important;
width: 110% !important;
}
#audio_input button:after {
content: '🎤' !important;
font-size: 16px !important;
}
audio {
min-width: 200px !important;
}
@media (max-width : 480px) {
#audio_input {
width: 120% !important;
}
#audio_input button:after {
content: '' !important;
}
#txt_input_container {
flex-grow: 70% !important;
}
#audio_input_container {
flex-grow: 30% !important;
}
}
"""
js_audio_auto_play = """
() => {
// select last audio element
const audio = document.getElementsByTagName('audio');
const last_audio = audio[audio.length - 1];
// set autoplay attribute
last_audio.setAttribute('autoplay', true);
}
"""
def create_chat_session():
r = requests.post(API_ENDPOINT + "/create")
if (r.status_code != 201):
raise Exception("Failed to create chat session")
# create temp audio folder
session_id = r.json()["id"]
os.makedirs(f"./temp_audio/{session_id}")
return session_id
def create_new_or_change_session(history, id):
global session_id
global chat_history
if id == "":
session_id = create_chat_session()
history = []
else:
history, _ = change_session(history, id)
chat_history = history
return history, gr.update(value="", interactive=False)
def add_text(history, text):
history = history + [(text, None)]
return history, gr.update(value="", interactive=False)
def add_audio(history, audio):
audio_bytes = base64.b64decode(audio['data'].split(',')[-1].encode('utf-8'))
audio_file = BytesIO(audio_bytes)
AudioSegment.from_file(audio_file).export(audio_file, format="mp3")
# save audio file temporary to disk
audio_id = secrets.token_hex(8)
AudioSegment.from_file(audio_file).export(f"temp_audio/{session_id}/audio_input_{audio_id}.mp3", format="mp3")
history = history + [((f"temp_audio/{session_id}/audio_input_{audio_id}.mp3",), None)]
response = requests.post(
API_ENDPOINT + "/transcribe",
files={'audio': audio_file.getvalue()}
)
if (response.status_code != 200):
raise Exception(response.text)
text = response.json()['text']
history = history + [(text, None)]
return history, gr.update(value="", interactive=False)
def reset_chat_session(history):
global session_id
global chat_history
response = requests.post(
API_ENDPOINT + f"/reset/{session_id}"
)
if (response.status_code != 200):
raise Exception(response.text)
history = []
chat_history = []
return history
def bot(history):
if type(history[-1][0]) == str:
message = history[-1][0]
else:
message = history[-2][0]
response = requests.post(
API_ENDPOINT + f"/send/text/{session_id}",
headers={'Content-type': 'application/json'},
json={
'message': message,
'role': 'user'
}
)
if (response.status_code != 200):
raise Exception(f"Failed to send message, {response.text}")
response = response.json()
text, audio = response['text'], response['audio']
audio_bytes = base64.b64decode(audio.encode('utf-8'))
audio_file = BytesIO(audio_bytes)
audio_id = secrets.token_hex(8)
AudioSegment.from_file(audio_file).export(f"temp_audio/{session_id}/audio_input_{audio_id}.mp3", format="mp3")
history = history + [(None, (f"temp_audio/{session_id}/audio_input_{audio_id}.mp3",))]
history = history + [(None, text)]
global chat_history
chat_history = history.copy()
return history
def change_session(history, id):
global session_id
global chat_history
response = requests.get(
API_ENDPOINT + f"/{id}"
)
if (response.status_code != 200):
raise Exception(response.text)
response = response.json()
session_id = id
history = []
try:
for chat in response:
if chat['role'] == 'user':
if chat['audio'] != "":
audio_bytes = base64.b64decode(chat['audio'].encode('utf-8'))
audio_file = BytesIO(audio_bytes)
audio_id = secrets.token_hex(8)
AudioSegment.from_file(audio_file).export(f"temp_audio/{id}/audio_input_{audio_id}.mp3", format="mp3")
history = history + [((f"temp_audio/{id}/audio_input_{audio_id}.mp3",), None)]
history = history + [(chat['message'], None)]
elif chat['role'] == 'assistant':
audio_bytes = base64.b64decode(chat['audio'].encode('utf-8'))
audio_file = BytesIO(audio_bytes)
audio_id = secrets.token_hex(8)
AudioSegment.from_file(audio_file).export(f"temp_audio/{id}/audio_input_{audio_id}.mp3", format="mp3")
history = history + [(None, (f"temp_audio/{id}/audio_input_{audio_id}.mp3",))]
history = history + [(None, chat['message'])]
else:
raise Exception("Invalid chat role")
except Exception as e:
raise Exception(f"Response: {response}")
chat_history = history.copy()
print(f"len(chat_history): {len(chat_history)}\nlen(history): {len(history)}\nlen(response): {len(response)}")
return history, gr.update(value="", interactive=False)
def load_chat_history(history):
global chat_history
if len(chat_history) > len(history):
history = chat_history
return history
def main():
global session_id
global chat_history
session_id = create_chat_session()
chat_history = []
with gr.Blocks(css=css) as demo:
with gr.Row():
# change session id
change_session_txt = gr.Textbox(
show_label=False,
placeholder=session_id,
).style(container=False)
with gr.Row():
# button to create new or change session id
change_session_button = gr.Button(
"Create new or change session", type='success', size="sm"
).style(container=False)
chatbot = gr.Chatbot([], elem_id="chatbot").style(height=750)
demo.load(load_chat_history, [chatbot], [chatbot], queue=False)
with gr.Row():
with gr.Column(scale=0.85, min_width=0, elem_id="txt_input_container"):
txt = gr.Textbox(
show_label=False,
placeholder="Enter text and press enter, or record audio",
elem_id="txt_input"
).style(container=False)
with gr.Column(scale=0.15, min_width=0, elem_id="audio_input_container"):
audio = gr.Audio(
source="microphone", type="numpy", show_label=False, format="mp3", min_width=0, container=False, elem_id="audio_input"
)
with gr.Row():
reset_button = gr.Button(
"Reset Chat Session", type='stop', size="sm"
).style(container=False)
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
bot, chatbot, chatbot
)
txt_msg.then(lambda: gr.update(interactive=True), None, [txt], queue=False).then(
None, [], [], queue=False, _js=js_audio_auto_play
)
audio_msg = audio.change(add_audio, [chatbot, audio], [chatbot, audio], queue=False, preprocess=False, postprocess=False).then(
bot, chatbot, chatbot
)
audio_msg.then(lambda: gr.update(interactive=True, value=None), None, [audio], queue=False).then(
None, [], [], queue=False, _js=js_audio_auto_play
)
reset_button.click(reset_chat_session, [chatbot], [chatbot], queue=False)
chgn_msg = change_session_txt.submit(change_session, [chatbot, change_session_txt], [chatbot, change_session_txt], queue=False)
chgn_msg.then(lambda: gr.update(interactive=True, placeholder=session_id), None, [change_session_txt], queue=False)
create_new_or_change_session_btn = change_session_button.click(create_new_or_change_session, [chatbot, change_session_txt], [chatbot, change_session_txt], queue=False)
create_new_or_change_session_btn.then(lambda: gr.update(interactive=True, placeholder=session_id), None, [change_session_txt], queue=False)
return demo
if __name__ == "__main__":
# arguments --local
parser = argparse.ArgumentParser()
parser.add_argument("--local", action="store_true", help="Use local API endpoint")
args = parser.parse_args()
if args.local:
API_ENDPOINT = LOCAL_API_ENDPOINT
demo = main()
demo.launch(show_error=True, server_name="0.0.0.0")