Speech_Enhancement_Demo / functions.py
DurreSudoku's picture
Added stereo to mono conversion
89b26ec verified
import torchaudio
from torch import mean as _mean
from torch import hamming_window, log10, no_grad, exp
def return_input(user_input):
if user_input is None:
return None
return user_input
def stereo_to_mono_convertion(waveform):
if waveform.shape[0] > 1:
waveform = _mean(waveform, dim=0, keepdims=True)
return waveform
else:
return waveform
def load_audio(audio_path):
audio_tensor, sr = torchaudio.load(audio_path)
audio_tensor = stereo_to_mono_convertion(audio_tensor)
audio_tensor = torchaudio.functional.resample(audio_tensor, sr, 16000)
return audio_tensor
def load_audio_numpy(audio_path):
audio_tensor, sr = torchaudio.load(audio_path)
audio_tensor = torchaudio.functional.resample(audio_tensor, sr, 16000)
audio_array = audio_tensor.numpy()
return (16000, audio_array.ravel())
def audio_to_spectrogram(audio):
transform_fn = torchaudio.transforms.Spectrogram(n_fft=512, hop_length=512//4, power=None, window_fn=hamming_window)
spectrogram = transform_fn(audio)
return spectrogram
def extract_magnitude_and_phase(spectrogram):
magnitude, phase = spectrogram.abs(), spectrogram.angle()
return magnitude, phase
def amplitude_to_db(magnitude_spec):
max_amplitude = magnitude_spec.max()
db_spectrogram = torchaudio.functional.amplitude_to_DB(magnitude_spec, 20, 10e-10, log10(max_amplitude), 100.0)
return db_spectrogram, max_amplitude
def min_max_scaling(spectrogram, scaler):
# Min-Max scaling (soundness of the math is questionable due to the use of each spectrograms' max value during decibel-scaling)
spectrogram = scaler.transform(spectrogram)
return spectrogram
def inverse_min_max(spectrogram, scaler):
spectrogram = scaler.inverse_transform(spectrogram)
return spectrogram
def db_to_amplitude(db_spectrogram, max_amplitude):
return max_amplitude * 10**(db_spectrogram/20)
def reconstruct_complex_spectrogram(magnitude, phase):
return magnitude * exp(1j*phase)
def inverse_fft(spectrogram):
inverse_fn = torchaudio.transforms.InverseSpectrogram(n_fft=512, hop_length=512//4, window_fn=hamming_window)
return inverse_fn(spectrogram)
def transform_audio(audio, scaler):
spectrogram = audio_to_spectrogram(audio)
magnitude, phase = extract_magnitude_and_phase(spectrogram)
db_spectrogram, max_amplitude = amplitude_to_db(magnitude)
db_spectrogram = min_max_scaling(db_spectrogram, scaler)
return db_spectrogram.unsqueeze(0), phase, max_amplitude
def spectrogram_to_audio(db_spectrogram, scaler, phase, max_amplitude):
db_spectrogram = db_spectrogram.squeeze(0)
db_spectrogram = inverse_min_max(db_spectrogram, scaler)
spectrogram = db_to_amplitude(db_spectrogram, max_amplitude)
complex_spec = reconstruct_complex_spectrogram(spectrogram, phase)
audio = inverse_fft(complex_spec)
return audio
def save_audio(audio):
torchaudio.save(r"enhanced_audio.wav", audio, 16000)
return r"enhanced_audio.wav"
def predict(user_input, model, scaler):
audio = load_audio(user_input)
spectrogram, phase, max_amplitude = transform_audio(audio, scaler)
with no_grad():
enhanced_spectrogram = model.forward(spectrogram)
enhanced_audio = spectrogram_to_audio(enhanced_spectrogram, scaler, phase, max_amplitude)
enhanced_audio_path = save_audio(enhanced_audio)
return enhanced_audio_path