|
import warnings |
|
warnings.filterwarnings("ignore") |
|
|
|
import os |
|
import numpy as np |
|
import pandas as pd |
|
from typing import Iterable |
|
|
|
import gradio as gr |
|
from gradio.themes.base import Base |
|
from gradio.themes.utils import colors, fonts, sizes |
|
|
|
import torch |
|
import librosa |
|
import torch.nn.functional as F |
|
|
|
|
|
from audio_class_predictor import predict_class |
|
from bird_ast_model import birdast_preprocess, birdast_inference |
|
from bird_ast_seq_model import birdast_seq_preprocess, birdast_seq_inference |
|
|
|
from utils import plot_wave, plot_mel, download_model, bandpass_filter |
|
|
|
|
|
ASSET_DIR = "./assets" |
|
DEFUALT_SR = 16_000 |
|
DEFUALT_HIGH_CUT = 8_000 |
|
DEFUALT_LOW_CUT = 1_000 |
|
DEVICE = "cpu" |
|
|
|
print(f"Device: {DEVICE}") |
|
|
|
if not os.path.exists(ASSET_DIR): |
|
os.makedirs(ASSET_DIR) |
|
|
|
|
|
|
|
birdast_assets = { |
|
"model_weights": [ |
|
f"https://huggingface.co/shiyi-li/BirdAST/resolve/main/BirdAST_Baseline_GroupKFold_fold_{i}.pth" |
|
for i in range(5) |
|
], |
|
"label_mapping": "https://huggingface.co/shiyi-li/BirdAST/resolve/main/BirdAST_Baseline_GroupKFold_label_map.csv", |
|
"preprocess_fn": birdast_preprocess, |
|
"inference_fn": birdast_inference, |
|
} |
|
|
|
birdast_seq_assets = { |
|
"model_weights": [ |
|
f"https://huggingface.co/shiyi-li/BirdAST_Seq/resolve/main/BirdAST_SeqPool_GroupKFold_fold_{i}.pth" |
|
for i in range(5) |
|
], |
|
"label_mapping": "https://huggingface.co/shiyi-li/BirdAST_Seq/resolve/main/BirdAST_SeqPool_GroupKFold_label_map.csv", |
|
"preprocess_fn": birdast_seq_preprocess, |
|
"inference_fn": birdast_seq_inference, |
|
} |
|
|
|
|
|
ASSET_DICT = { |
|
"BirdAST": birdast_assets, |
|
"BirdAST_Seq": birdast_seq_assets, |
|
} |
|
|
|
|
|
def run_inference_with_model(audio_clip, sr, model_name): |
|
|
|
|
|
assets = ASSET_DICT[model_name] |
|
model_weights_url = assets["model_weights"] |
|
label_map_url = assets["label_mapping"] |
|
preprocess_fn = assets["preprocess_fn"] |
|
inference_fn = assets["inference_fn"] |
|
|
|
|
|
model_weights = [] |
|
for model_weight in model_weights_url: |
|
weight_file = os.path.join(ASSET_DIR, model_weight.split("/")[-1]) |
|
if not os.path.exists(weight_file): |
|
download_model(model_weight, weight_file) |
|
model_weights.append(weight_file) |
|
|
|
|
|
label_map_csv = os.path.join(ASSET_DIR, label_map_url.split("/")[-1]) |
|
if not os.path.exists(label_map_csv): |
|
download_model(label_map_url, label_map_csv) |
|
|
|
|
|
label_mapping = pd.read_csv(label_map_csv) |
|
species_id_to_name = {row["species_id"]: row["scientific_name"] for _, row in label_mapping.iterrows()} |
|
|
|
|
|
spectrogram = preprocess_fn(audio_clip, sr=sr) |
|
|
|
|
|
predictions = inference_fn(model_weights, spectrogram, device=DEVICE) |
|
|
|
|
|
final_predicts = predictions.mean(axis=0) |
|
topk_values, topk_indices = torch.topk(torch.from_numpy(final_predicts), 10) |
|
|
|
results = [] |
|
for idx, scores in zip(topk_indices, topk_values): |
|
species_name = species_id_to_name[idx.item()] |
|
probability = scores.item() * 100 |
|
results.append([species_name, probability]) |
|
|
|
return results |
|
|
|
|
|
def predict(audio, start, end, model_name="BirdAST_Seq"): |
|
|
|
raw_sr, audio_array = audio |
|
|
|
if audio_array.ndim > 1: |
|
audio_array = audio_array.mean(axis=1) |
|
|
|
print(f"Audio shape raw: {audio_array.shape}, sr: {raw_sr}") |
|
|
|
|
|
len_audio = audio_array.shape[0] / raw_sr |
|
if start >= end: |
|
raise gr.Error(f"`start` ({start}) must be smaller than end ({end}s)") |
|
|
|
if audio_array.shape[0] < start * raw_sr: |
|
raise gr.Error(f"`start` ({start}) must be smaller than audio duration ({len_audio:.0f}s)") |
|
|
|
if audio_array.shape[0] > end * raw_sr: |
|
end = audio_array.shape[0] / (1.0*raw_sr) |
|
|
|
audio_array = np.array(audio_array, dtype=np.float32) / 32768.0 |
|
audio_array = audio_array[int(start*raw_sr) : int(end*raw_sr)] |
|
|
|
if raw_sr != DEFUALT_SR: |
|
|
|
audio_array = bandpass_filter(audio_array, DEFUALT_LOW_CUT, DEFUALT_HIGH_CUT, raw_sr) |
|
audio_array = librosa.resample(audio_array, orig_sr=raw_sr, target_sr=DEFUALT_SR) |
|
print(f"Resampled Audio shape: {audio_array.shape}") |
|
|
|
audio_array = audio_array.astype(np.float32) |
|
|
|
|
|
audio_class = predict_class(audio_array) |
|
|
|
fig_spectrogram = plot_mel(DEFUALT_SR, audio_array) |
|
fig_waveform = plot_wave(DEFUALT_SR, audio_array) |
|
|
|
|
|
print(f"Running inference with model: {model_name}") |
|
species_class = run_inference_with_model(audio_array, DEFUALT_SR, model_name) |
|
|
|
return audio_class, species_class, fig_waveform, fig_spectrogram |
|
|
|
|
|
DESCRIPTION = """ |
|
# Introduction |
|
|
|
It is esimated that 50% of the global economy is threatened by biodiversity loss [2]. As such, intensive efforts have been concerted into estimating bird biodiversity, as birds are a top indicator of biodiversity in the region. One of these efforts is |
|
finding the bird species in a region using bird species audio classification. |
|
|
|
# Solution |
|
|
|
To tackle this problem, we propose VOJ. It first preprocesses an audio signal using a bandpass filter (1K - 8K) and then applies downsampling to 16K Hz. Afterwards, we input the signal into AudioMAE (Audio Masked AutoEncoder by Meta [1]) which extracts relevant features even in the presence of corruptions to the signal spectrogram. |
|
The AudioMAE is also trained on 527 types of audio that comprise bird, silence, environmental noise, and other types. The purpose of this initial inference stage is to provide an initial sense of the audio. If the AudioMAE outputs silence, we can expect low species prediction confidence, or if the output is insect, it may not be worth labelling. |
|
Next, we train BirdAST, which has Audio Spectrogram Transformer (AST) as backbone, followed by an attention pooling and dense layer. We also train EfficientB0 on the melspectrogram, and finally, we train a model using Wav2Vec pretrained on 50 bird species [3]. |
|
""" |
|
|
|
|
|
css = """ |
|
#gradio-animation { |
|
font-size: 2em; |
|
font-weight: bold; |
|
text-align: center; |
|
margin-bottom: 20px; |
|
} |
|
|
|
.logo-container img { |
|
width: 14%; /* Adjust width as necessary */ |
|
display: block; |
|
margin: auto; |
|
} |
|
|
|
.number-input { |
|
height: 100%; |
|
padding-bottom: 60px; /* Adust the value as needed for more or less space */ |
|
} |
|
.full-height { |
|
height: 100%; |
|
} |
|
.column-container { |
|
height: 100%; |
|
} |
|
""" |
|
|
|
|
|
|
|
class Seafoam(Base): |
|
def __init__( |
|
self, |
|
*, |
|
primary_hue: colors.Color | str = colors.emerald, |
|
secondary_hue: colors.Color | str = colors.blue, |
|
neutral_hue: colors.Color | str = colors.gray, |
|
spacing_size: sizes.Size | str = sizes.spacing_md, |
|
radius_size: sizes.Size | str = sizes.radius_md, |
|
text_size: sizes.Size | str = sizes.text_lg, |
|
font: fonts.Font |
|
| str |
|
| Iterable[fonts.Font | str] = ( |
|
fonts.GoogleFont("Quicksand"), |
|
"ui-sans-serif", |
|
"sans-serif", |
|
), |
|
font_mono: fonts.Font |
|
| str |
|
| Iterable[fonts.Font | str] = ( |
|
fonts.GoogleFont("IBM Plex Mono"), |
|
"ui-monospace", |
|
"monospace", |
|
), |
|
): |
|
super().__init__( |
|
primary_hue=primary_hue, |
|
secondary_hue=secondary_hue, |
|
neutral_hue=neutral_hue, |
|
spacing_size=spacing_size, |
|
radius_size=radius_size, |
|
text_size=text_size, |
|
font=font, |
|
font_mono=font_mono, |
|
) |
|
|
|
|
|
seafoam = Seafoam() |
|
|
|
|
|
js = """ |
|
function createGradioAnimation() { |
|
var container = document.getElementById('gradio-animation'); |
|
var text = 'Voice of Jungle'; |
|
for (var i = 0; i < text.length; i++) { |
|
(function(i){ |
|
setTimeout(function(){ |
|
var letter = document.createElement('span'); |
|
letter.style.opacity = '0'; |
|
letter.style.transition = 'opacity 0.5s'; |
|
letter.innerText = text[i]; |
|
container.appendChild(letter); |
|
setTimeout(function() { |
|
letter.style.opacity = '1'; |
|
}, 50); |
|
}, i * 250); |
|
})(i); |
|
} |
|
} |
|
""" |
|
|
|
REFERENCES = """ |
|
References |
|
|
|
[1] Huang, P.-Y., Xu, H., Li, J., Baevski, A., Auli, M., Galuba, W., Metze, F., & Feichtenhofer, C. (2022). Masked Autoencoders that Listen. In NeurIPS. |
|
|
|
[2] Torkington, S. (2023, February 7). 50% of the global economy is under threat from biodiversity loss. World Economic Forum. Retrieved from https://www.weforum.org/agenda/2023/02/biodiversity-nature-loss-cop15/. |
|
|
|
[3] https://www.kaggle.com/code/dima806/bird-species-by-sound-detection |
|
""" |
|
|
|
|
|
def handle_model_selection(model_name, download_status): |
|
|
|
|
|
print(f"Downloading model weights for {model_name}...") |
|
assets = ASSET_DICT[model_name] |
|
model_weights_url = assets["model_weights"] |
|
download_flag = True |
|
try: |
|
total_files = len(model_weights_url) |
|
for idx, model_weight in enumerate(model_weights_url): |
|
weight_file = os.path.join(ASSET_DIR, model_weight.split("/")[-1]) |
|
print(weight_file) |
|
if not os.path.exists(weight_file): |
|
download_status = f"Downloading {idx + 1} of {total_files}" |
|
download_model(model_weight, weight_file) |
|
|
|
if not os.path.exists(weight_file): |
|
download_flag = False |
|
break |
|
|
|
if download_flag: |
|
download_status = f"Model {model_name} is ready for prediction!" |
|
else: |
|
download_status = f"An error occurred while downloading model weights." |
|
|
|
except Exception as e: |
|
download_status = f"An error occurred while downloading model weights." |
|
|
|
return download_status |
|
|
|
|
|
with gr.Blocks(theme = seafoam, css = css, js = js) as demo: |
|
|
|
gr.Markdown('<div class="logo-container"><img src="https://i.ibb.co/vcG9kr0/vojlogo.jpg" width="50px" alt="vojlogo"></div>') |
|
gr.Markdown('<div id="gradio-animation"></div>') |
|
gr.Markdown(DESCRIPTION) |
|
|
|
|
|
model_names = ['BirdAST', 'BirdAST_Seq', 'EfficientNet'] |
|
model_dropdown = gr.Dropdown(label="Choose a model", choices=model_names) |
|
download_status = gr.Textbox(label="Model Status", lines=3, value='', interactive=False) |
|
|
|
model_dropdown.change(handle_model_selection, inputs=[model_dropdown, download_status], outputs=download_status) |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(elem_classes="column-container"): |
|
start_time_input = gr.Number(label="Start Time", value=0, elem_classes="number-input full-height") |
|
end_time_input = gr.Number(label="End Time", value=10, elem_classes="number-input full-height") |
|
with gr.Column(): |
|
audio_input = gr.Audio(label="Input Audio", elem_classes="full-height") |
|
|
|
with gr.Row(): |
|
raw_class_output = gr.Dataframe(headers=["Class", "Score [%]"], row_count=10, label="Class Prediction") |
|
species_output = gr.Dataframe(headers=["Class", "Score [%]"], row_count=10, label="Species Prediction") |
|
|
|
with gr.Row(): |
|
waveform_output = gr.Plot(label="Waveform") |
|
spectrogram_output = gr.Plot(label="Spectrogram") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gr.Button("Predict").click(predict, [audio_input, start_time_input, end_time_input, model_dropdown], [raw_class_output, species_output, waveform_output, spectrogram_output]) |
|
|
|
gr.Markdown(REFERENCES) |
|
|
|
demo.launch(share = True) |
|
|
|
|
|
|