Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from functions import * | |
from unet import UNet | |
from custom_scaler import min_max_scaler | |
if torch.cuda.is_available(): | |
device = "cuda" | |
else: | |
device = "cpu" | |
# Check for CUDA availability | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load model | |
model = UNet().to(device) # Move the model to the selected device | |
model_state_dict = torch.load(r"model.pth", map_location=device, weights_only=True) | |
model.load_state_dict(model_state_dict["model_state_dict"]) | |
scaler = min_max_scaler() | |
scaler.fit() | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
gr.Markdown( | |
""" | |
# Speech enhancement demonstration | |
Hello! | |
This is a demo for a speech enhancement model trained to reduce background noice to ensure inteligibility of a single speaker. | |
Feel free to upload your own audio file or try one of our example files to see how it works! | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
audio_path = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Upload your song here", format="wav") | |
with gr.Column(): | |
enhanced_audio = gr.Audio(sources=None, label="Enhanced audio will be found here", format="wav") | |
with gr.Row(): | |
files = gr.FileExplorer(label="Example files", file_count="single", root_dir=r"examples", interactive=True) | |
files.change(fn=return_input, inputs=files, outputs=audio_path) | |
files.change(fn=return_input, inputs=None, outputs=enhanced_audio) | |
with gr.Row(): | |
submit_audio = gr.Button(value="Submit audio for enhancement") | |
submit_audio.click(fn=lambda x: predict(x, model, scaler), inputs=audio_path, outputs=enhanced_audio, trigger_mode="once") | |
demo.launch() | |