import torchaudio from torch import mean as _mean from torch import hamming_window, log10, no_grad, exp def return_input(user_input): if user_input is None: return None return user_input def stereo_to_mono_convertion(waveform): if waveform.shape[0] > 1: waveform = _mean(waveform, dim=0, keepdims=True) return waveform else: return waveform def load_audio(audio_path): audio_tensor, sr = torchaudio.load(audio_path) audio_tensor = stereo_to_mono_convertion(audio_tensor) audio_tensor = torchaudio.functional.resample(audio_tensor, sr, 16000) return audio_tensor def load_audio_numpy(audio_path): audio_tensor, sr = torchaudio.load(audio_path) audio_tensor = torchaudio.functional.resample(audio_tensor, sr, 16000) audio_array = audio_tensor.numpy() return (16000, audio_array.ravel()) def audio_to_spectrogram(audio): transform_fn = torchaudio.transforms.Spectrogram(n_fft=512, hop_length=512//4, power=None, window_fn=hamming_window) spectrogram = transform_fn(audio) return spectrogram def extract_magnitude_and_phase(spectrogram): magnitude, phase = spectrogram.abs(), spectrogram.angle() return magnitude, phase def amplitude_to_db(magnitude_spec): max_amplitude = magnitude_spec.max() db_spectrogram = torchaudio.functional.amplitude_to_DB(magnitude_spec, 20, 10e-10, log10(max_amplitude), 100.0) return db_spectrogram, max_amplitude def min_max_scaling(spectrogram, scaler): # Min-Max scaling (soundness of the math is questionable due to the use of each spectrograms' max value during decibel-scaling) spectrogram = scaler.transform(spectrogram) return spectrogram def inverse_min_max(spectrogram, scaler): spectrogram = scaler.inverse_transform(spectrogram) return spectrogram def db_to_amplitude(db_spectrogram, max_amplitude): return max_amplitude * 10**(db_spectrogram/20) def reconstruct_complex_spectrogram(magnitude, phase): return magnitude * exp(1j*phase) def inverse_fft(spectrogram): inverse_fn = torchaudio.transforms.InverseSpectrogram(n_fft=512, hop_length=512//4, window_fn=hamming_window) return inverse_fn(spectrogram) def transform_audio(audio, scaler): spectrogram = audio_to_spectrogram(audio) magnitude, phase = extract_magnitude_and_phase(spectrogram) db_spectrogram, max_amplitude = amplitude_to_db(magnitude) db_spectrogram = min_max_scaling(db_spectrogram, scaler) return db_spectrogram.unsqueeze(0), phase, max_amplitude def spectrogram_to_audio(db_spectrogram, scaler, phase, max_amplitude): db_spectrogram = db_spectrogram.squeeze(0) db_spectrogram = inverse_min_max(db_spectrogram, scaler) spectrogram = db_to_amplitude(db_spectrogram, max_amplitude) complex_spec = reconstruct_complex_spectrogram(spectrogram, phase) audio = inverse_fft(complex_spec) return audio def save_audio(audio): torchaudio.save(r"enhanced_audio.wav", audio, 16000) return r"enhanced_audio.wav" def predict(user_input, model, scaler): audio = load_audio(user_input) spectrogram, phase, max_amplitude = transform_audio(audio, scaler) with no_grad(): enhanced_spectrogram = model.forward(spectrogram) enhanced_audio = spectrogram_to_audio(enhanced_spectrogram, scaler, phase, max_amplitude) enhanced_audio_path = save_audio(enhanced_audio) return enhanced_audio_path