|
import torchaudio |
|
from torch import mean as _mean |
|
from torch import hamming_window, log10, no_grad, exp |
|
|
|
|
|
def return_input(user_input): |
|
if user_input is None: |
|
return None |
|
return user_input |
|
|
|
|
|
def stereo_to_mono_convertion(waveform): |
|
if waveform.shape[0] > 1: |
|
waveform = _mean(waveform, dim=0, keepdims=True) |
|
return waveform |
|
else: |
|
return waveform |
|
|
|
def load_audio(audio_path): |
|
|
|
audio_tensor, sr = torchaudio.load(audio_path) |
|
audio_tensor = stereo_to_mono_convertion(audio_tensor) |
|
audio_tensor = torchaudio.functional.resample(audio_tensor, sr, 16000) |
|
return audio_tensor |
|
|
|
def load_audio_numpy(audio_path): |
|
audio_tensor, sr = torchaudio.load(audio_path) |
|
audio_tensor = torchaudio.functional.resample(audio_tensor, sr, 16000) |
|
audio_array = audio_tensor.numpy() |
|
return (16000, audio_array.ravel()) |
|
|
|
def audio_to_spectrogram(audio): |
|
transform_fn = torchaudio.transforms.Spectrogram(n_fft=512, hop_length=512//4, power=None, window_fn=hamming_window) |
|
spectrogram = transform_fn(audio) |
|
return spectrogram |
|
|
|
def extract_magnitude_and_phase(spectrogram): |
|
magnitude, phase = spectrogram.abs(), spectrogram.angle() |
|
return magnitude, phase |
|
|
|
def amplitude_to_db(magnitude_spec): |
|
max_amplitude = magnitude_spec.max() |
|
db_spectrogram = torchaudio.functional.amplitude_to_DB(magnitude_spec, 20, 10e-10, log10(max_amplitude), 100.0) |
|
return db_spectrogram, max_amplitude |
|
|
|
def min_max_scaling(spectrogram, scaler): |
|
|
|
spectrogram = scaler.transform(spectrogram) |
|
return spectrogram |
|
|
|
def inverse_min_max(spectrogram, scaler): |
|
spectrogram = scaler.inverse_transform(spectrogram) |
|
return spectrogram |
|
|
|
def db_to_amplitude(db_spectrogram, max_amplitude): |
|
return max_amplitude * 10**(db_spectrogram/20) |
|
|
|
def reconstruct_complex_spectrogram(magnitude, phase): |
|
return magnitude * exp(1j*phase) |
|
|
|
def inverse_fft(spectrogram): |
|
inverse_fn = torchaudio.transforms.InverseSpectrogram(n_fft=512, hop_length=512//4, window_fn=hamming_window) |
|
return inverse_fn(spectrogram) |
|
|
|
def transform_audio(audio, scaler): |
|
spectrogram = audio_to_spectrogram(audio) |
|
magnitude, phase = extract_magnitude_and_phase(spectrogram) |
|
db_spectrogram, max_amplitude = amplitude_to_db(magnitude) |
|
db_spectrogram = min_max_scaling(db_spectrogram, scaler) |
|
return db_spectrogram.unsqueeze(0), phase, max_amplitude |
|
|
|
def spectrogram_to_audio(db_spectrogram, scaler, phase, max_amplitude): |
|
db_spectrogram = db_spectrogram.squeeze(0) |
|
db_spectrogram = inverse_min_max(db_spectrogram, scaler) |
|
spectrogram = db_to_amplitude(db_spectrogram, max_amplitude) |
|
complex_spec = reconstruct_complex_spectrogram(spectrogram, phase) |
|
audio = inverse_fft(complex_spec) |
|
return audio |
|
|
|
def save_audio(audio): |
|
torchaudio.save(r"enhanced_audio.wav", audio, 16000) |
|
return r"enhanced_audio.wav" |
|
|
|
def predict(user_input, model, scaler): |
|
audio = load_audio(user_input) |
|
spectrogram, phase, max_amplitude = transform_audio(audio, scaler) |
|
|
|
with no_grad(): |
|
enhanced_spectrogram = model.forward(spectrogram) |
|
enhanced_audio = spectrogram_to_audio(enhanced_spectrogram, scaler, phase, max_amplitude) |
|
enhanced_audio_path = save_audio(enhanced_audio) |
|
return enhanced_audio_path |