File size: 3,444 Bytes
89b26ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68d7781
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torchaudio
from torch import mean as _mean
from torch import hamming_window, log10, no_grad, exp


def return_input(user_input):
    if user_input is None:
        return None
    return user_input


def stereo_to_mono_convertion(waveform):
    if waveform.shape[0] > 1:
        waveform = _mean(waveform, dim=0, keepdims=True)
        return waveform
    else:
        return waveform
        
def load_audio(audio_path):

    audio_tensor, sr = torchaudio.load(audio_path)
    audio_tensor = stereo_to_mono_convertion(audio_tensor)
    audio_tensor = torchaudio.functional.resample(audio_tensor, sr, 16000)
    return audio_tensor

def load_audio_numpy(audio_path):
    audio_tensor, sr = torchaudio.load(audio_path)
    audio_tensor = torchaudio.functional.resample(audio_tensor, sr, 16000)
    audio_array = audio_tensor.numpy()
    return (16000, audio_array.ravel())

def audio_to_spectrogram(audio):
    transform_fn = torchaudio.transforms.Spectrogram(n_fft=512, hop_length=512//4, power=None, window_fn=hamming_window)
    spectrogram = transform_fn(audio)
    return spectrogram

def extract_magnitude_and_phase(spectrogram):
    magnitude, phase = spectrogram.abs(), spectrogram.angle()
    return magnitude, phase

def amplitude_to_db(magnitude_spec):
    max_amplitude = magnitude_spec.max()
    db_spectrogram = torchaudio.functional.amplitude_to_DB(magnitude_spec, 20, 10e-10, log10(max_amplitude), 100.0)
    return db_spectrogram, max_amplitude

def min_max_scaling(spectrogram, scaler):
    # Min-Max scaling (soundness of the math is questionable due to the use of each spectrograms' max value during decibel-scaling)
    spectrogram = scaler.transform(spectrogram)
    return spectrogram

def inverse_min_max(spectrogram, scaler):
    spectrogram = scaler.inverse_transform(spectrogram)
    return spectrogram

def db_to_amplitude(db_spectrogram, max_amplitude):
    return max_amplitude * 10**(db_spectrogram/20)

def reconstruct_complex_spectrogram(magnitude, phase):
    return magnitude * exp(1j*phase)
     
def inverse_fft(spectrogram):
    inverse_fn = torchaudio.transforms.InverseSpectrogram(n_fft=512, hop_length=512//4, window_fn=hamming_window)
    return inverse_fn(spectrogram)

def transform_audio(audio, scaler):
    spectrogram = audio_to_spectrogram(audio)
    magnitude, phase = extract_magnitude_and_phase(spectrogram)
    db_spectrogram, max_amplitude = amplitude_to_db(magnitude)
    db_spectrogram = min_max_scaling(db_spectrogram, scaler)
    return db_spectrogram.unsqueeze(0), phase, max_amplitude

def spectrogram_to_audio(db_spectrogram, scaler, phase, max_amplitude):
    db_spectrogram = db_spectrogram.squeeze(0)
    db_spectrogram = inverse_min_max(db_spectrogram, scaler)
    spectrogram = db_to_amplitude(db_spectrogram, max_amplitude)
    complex_spec = reconstruct_complex_spectrogram(spectrogram, phase)
    audio = inverse_fft(complex_spec)
    return audio

def save_audio(audio):
    torchaudio.save(r"enhanced_audio.wav", audio, 16000)
    return r"enhanced_audio.wav"

def predict(user_input, model, scaler):
    audio = load_audio(user_input)
    spectrogram, phase, max_amplitude = transform_audio(audio, scaler)
    
    with no_grad():
        enhanced_spectrogram = model.forward(spectrogram)
    enhanced_audio = spectrogram_to_audio(enhanced_spectrogram, scaler, phase, max_amplitude)
    enhanced_audio_path = save_audio(enhanced_audio)
    return enhanced_audio_path