DurreSudoku commited on
Commit
68d7781
·
verified ·
1 Parent(s): d3ebde8

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
- title: Speech Enhancement Demo
3
- emoji: 🌍
4
- colorFrom: blue
5
- colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 5.5.0
8
- app_file: app.py
9
- pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Speech_Enhancement_Demo
3
+ app_file: app.py
 
 
4
  sdk: gradio
5
  sdk_version: 5.5.0
 
 
6
  ---
 
 
__pycache__/custom_scaler.cpython-312.pyc ADDED
Binary file (2.67 kB). View file
 
__pycache__/functions.cpython-312.pyc ADDED
Binary file (4.85 kB). View file
 
__pycache__/unet.cpython-312.pyc ADDED
Binary file (5.6 kB). View file
 
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from functions import *
4
+ from unet import UNet
5
+ from custom_scaler import min_max_scaler
6
+
7
+ model = UNet()
8
+ model_state_dict = torch.load("huggingface/model.pth", map_location="cpu")
9
+ model.load_state_dict(model_state_dict["model_state_dict"])
10
+
11
+ scaler = min_max_scaler()
12
+ scaler.fit()
13
+
14
+ with gr.Blocks() as demo:
15
+ with gr.Row():
16
+ gr.Markdown(
17
+ """
18
+ # Speech enhancement demonstration
19
+
20
+ Hello!
21
+
22
+ This is a demo for a speech enhancement model trained to reduce background noice to ensure inteligibility of a single speaker.
23
+
24
+ Feel free to upload your own audio file or try one of our example files to see how it works!
25
+
26
+ """
27
+ )
28
+ with gr.Row():
29
+ with gr.Column():
30
+ audio_path = gr.Audio(sources="upload", type="filepath", label="Upload your song here", format="wav")
31
+ with gr.Column():
32
+ enhanced_audio = gr.Audio(sources=None, label="Enhanced audio will be found here", format="wav")
33
+ with gr.Row():
34
+ files = gr.FileExplorer(label="Example files", file_count="single", root_dir="huggingface/examples", interactive=True)
35
+ files.change(fn=return_input, inputs=files, outputs=audio_path)
36
+ files.change(fn=return_input, inputs=None, outputs=enhanced_audio)
37
+ with gr.Row():
38
+ submit_audio = gr.Button(value="Submit audio for enhancement")
39
+ submit_audio.click(fn=lambda x: predict(x, model, scaler), inputs=audio_path, outputs=enhanced_audio, trigger_mode="once")
40
+
41
+ demo.launch(share=True)
42
+
custom_scaler.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class min_max_scaler():
4
+ def __init__(self, upper_bound=1, lower_bound=0):
5
+
6
+ self.upper = upper_bound
7
+ self.lower = lower_bound
8
+ self.minimum = torch.ones(1) * torch.inf
9
+ self.maximum = - torch.ones(1) *torch.inf
10
+
11
+ def fit(self, set_maximum=0.0, set_minimum=-100.0):
12
+ """Find min and max of given subset OR set min and max manually.
13
+ Since dB-spectrograms are on the scale [-100, 0] by default, default values are set to those values.
14
+
15
+ Args:
16
+ set_maximum (float, optional): set maximum value manually. Defaults to 0.0.
17
+ set_minimum (float, optional): set minimum value manually. Defaults to -100.0.
18
+
19
+ Returns:
20
+ None: None
21
+ """
22
+ if set_minimum is not None and set_maximum is not None:
23
+ self.minimum = set_minimum
24
+ self.maximum = set_maximum
25
+ return None
26
+
27
+ def transform(self, spectrogram):
28
+ if self.minimum == torch.inf:
29
+ raise ValueError("Cannot transform before scaler is fitted with min-max-values")
30
+ return (self.upper - self.lower) * (spectrogram - self.minimum) / (self.maximum - self.minimum) + self.lower
31
+
32
+ def inverse_transform(self, spectrogram):
33
+ if self.minimum == torch.inf:
34
+ raise ValueError("Cannot inverse transform before scaler is fitted with min-max-values")
35
+ return (spectrogram - self.lower) * (self.maximum - self.minimum) / (self.upper - self.lower) + self.minimum
examples/VoiceBank+DEMAND_test_sample_male_1.wav ADDED
Binary file (112 kB). View file
 
examples/durim_test_sample_1.wav ADDED
Binary file (300 kB). View file
 
examples/durim_test_sample_2.wav ADDED
Binary file (347 kB). View file
 
functions.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchaudio
2
+ from torch import hamming_window, log10, no_grad, exp
3
+
4
+
5
+ def return_input(user_input):
6
+ if user_input is None:
7
+ return None
8
+ return user_input
9
+
10
+ def load_audio(audio_path):
11
+
12
+ audio_tensor, sr = torchaudio.load(audio_path)
13
+ audio_tensor = torchaudio.functional.resample(audio_tensor, sr, 16000)
14
+ audio_tensor.type()
15
+ return audio_tensor
16
+
17
+ def load_audio_numpy(audio_path):
18
+ audio_tensor, sr = torchaudio.load(audio_path)
19
+ audio_tensor = torchaudio.functional.resample(audio_tensor, sr, 16000)
20
+ audio_array = audio_tensor.numpy()
21
+ return (16000, audio_array.ravel())
22
+
23
+ def audio_to_spectrogram(audio):
24
+ transform_fn = torchaudio.transforms.Spectrogram(n_fft=512, hop_length=512//4, power=None, window_fn=hamming_window)
25
+ spectrogram = transform_fn(audio)
26
+ return spectrogram
27
+
28
+ def extract_magnitude_and_phase(spectrogram):
29
+ magnitude, phase = spectrogram.abs(), spectrogram.angle()
30
+ return magnitude, phase
31
+
32
+ def amplitude_to_db(magnitude_spec):
33
+ max_amplitude = magnitude_spec.max()
34
+ db_spectrogram = torchaudio.functional.amplitude_to_DB(magnitude_spec, 20, 10e-10, log10(max_amplitude), 100.0)
35
+ return db_spectrogram, max_amplitude
36
+
37
+ def min_max_scaling(spectrogram, scaler):
38
+ # Min-Max scaling (soundness of the math is questionable due to the use of each spectrograms' max value during decibel-scaling)
39
+ spectrogram = scaler.transform(spectrogram)
40
+ return spectrogram
41
+
42
+ def inverse_min_max(spectrogram, scaler):
43
+ spectrogram = scaler.inverse_transform(spectrogram)
44
+ return spectrogram
45
+
46
+ def db_to_amplitude(db_spectrogram, max_amplitude):
47
+ return max_amplitude * 10**(db_spectrogram/20)
48
+
49
+ def reconstruct_complex_spectrogram(magnitude, phase):
50
+ return magnitude * exp(1j*phase)
51
+
52
+ def inverse_fft(spectrogram):
53
+ inverse_fn = torchaudio.transforms.InverseSpectrogram(n_fft=512, hop_length=512//4, window_fn=hamming_window)
54
+ return inverse_fn(spectrogram)
55
+
56
+ def transform_audio(audio, scaler):
57
+ spectrogram = audio_to_spectrogram(audio)
58
+ magnitude, phase = extract_magnitude_and_phase(spectrogram)
59
+ db_spectrogram, max_amplitude = amplitude_to_db(magnitude)
60
+ db_spectrogram = min_max_scaling(db_spectrogram, scaler)
61
+ return db_spectrogram.unsqueeze(0), phase, max_amplitude
62
+
63
+ def spectrogram_to_audio(db_spectrogram, scaler, phase, max_amplitude):
64
+ db_spectrogram = db_spectrogram.squeeze(0)
65
+ db_spectrogram = inverse_min_max(db_spectrogram, scaler)
66
+ spectrogram = db_to_amplitude(db_spectrogram, max_amplitude)
67
+ complex_spec = reconstruct_complex_spectrogram(spectrogram, phase)
68
+ audio = inverse_fft(complex_spec)
69
+ return audio
70
+
71
+ def save_audio(audio):
72
+ torchaudio.save(r"enhanced_audio.wav", audio, 16000)
73
+ return r"enhanced_audio.wav"
74
+
75
+ def predict(user_input, model, scaler):
76
+ audio = load_audio(user_input)
77
+ spectrogram, phase, max_amplitude = transform_audio(audio, scaler)
78
+
79
+ with no_grad():
80
+ enhanced_spectrogram = model.forward(spectrogram)
81
+ enhanced_audio = spectrogram_to_audio(enhanced_spectrogram, scaler, phase, max_amplitude)
82
+ enhanced_audio_path = save_audio(enhanced_audio)
83
+ return enhanced_audio_path
model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1b7f1b84b4e0520ab3d07e3dad89f1dee2f8e74a845a1c6c5ced5e482227e4b1
3
+ size 137041542
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch==2.3.0
2
+ torchaudio==2.3.0
3
+ gradio==5.5.0
test.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchaudio
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from functions import *
6
+ from unet import UNet
7
+
8
+ """
9
+ model = UNet()
10
+ model_state_dict = torch.load("huggingface/model.pth", map_location="cpu")
11
+ model.load_state_dict(model_state_dict["model_state_dict"])
12
+ print("# of trainable parameters =", sum(p.numel() for p in model.parameters() if p.requires_grad))
13
+
14
+ audio = load_audio("huggingface/p232_001.wav")
15
+ enhanced = predict("huggingface/p232_001.wav", model)
16
+ print(enhanced.shape)"""
17
+ string = "C:/Users/durim/Documents/KTH/Master_År2/DT2119-SSR/project_feature_extraction.ipynb"
18
+ print("/".join(string.split("/")[:-1]))
unet.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class EncodingBlock(nn.Module):
5
+ def __init__(self, in_channels, out_channels):
6
+ super(EncodingBlock, self).__init__()
7
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
8
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
9
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
10
+ self.activation = nn.ReLU(inplace=True)
11
+
12
+ def forward(self, x):
13
+ x = self.conv1(x)
14
+ x = self.activation(x)
15
+ x = self.conv2(x)
16
+ x = self.activation(x)
17
+ skip_connection = x
18
+ x = self.pool(x)
19
+ return x, skip_connection
20
+
21
+ class DecodingBlock(nn.Module):
22
+ def __init__(self, in_channels, out_channels):
23
+ super(DecodingBlock, self).__init__()
24
+ self.conv_transpose = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=5, stride=2, padding=2)
25
+ self.conv1 = nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, padding=1)
26
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
27
+ self.activation = nn.ReLU(inplace=True)
28
+
29
+ def forward(self, x, skip_connection):
30
+ x = self.conv_transpose(x)
31
+ pd = (0, skip_connection.size(-1) - x.size(-1), 0, skip_connection.size(-2) - x.size(-2))
32
+ x = nn.functional.pad(x, pd, mode='constant', value=0)
33
+ x = torch.cat((x, skip_connection), dim=1)
34
+ x = self.conv1(x)
35
+ x = self.activation(x)
36
+ x = self.conv2(x)
37
+ x = self.activation(x)
38
+ return x
39
+
40
+ class UNet(nn.Module):
41
+ def __init__(self, init_features=32, bottleneck_size=512):
42
+ super(UNet, self).__init__()
43
+ self.encoding_block1 = EncodingBlock(1, init_features)
44
+ self.encoding_block2 = EncodingBlock(init_features, init_features*2)
45
+ self.encoding_block3 = EncodingBlock(init_features*2, init_features*4)
46
+ self.encoding_block4 = EncodingBlock(init_features*4, init_features*8)
47
+
48
+ self.bottleneck_conv1 = nn.Conv2d(init_features*8, bottleneck_size, kernel_size=3, padding=1)
49
+ self.bottleneck_conv2 = nn.Conv2d(bottleneck_size, bottleneck_size, kernel_size=3, padding=1)
50
+
51
+ self.decoding_block4 = DecodingBlock(bottleneck_size, init_features*8)
52
+ self.decoding_block3 = DecodingBlock(init_features*8, init_features*4)
53
+ self.decoding_block2 = DecodingBlock(init_features*4, init_features*2)
54
+ self.decoding_block1 = DecodingBlock(init_features*2, init_features)
55
+
56
+ self.final_conv = nn.Conv2d(init_features, 1, kernel_size=1)
57
+
58
+ def forward(self, x):
59
+ x, skip1 = self.encoding_block1(x)
60
+ x, skip2 = self.encoding_block2(x)
61
+ x, skip3 = self.encoding_block3(x)
62
+ x, skip4 = self.encoding_block4(x)
63
+
64
+ x = self.bottleneck_conv1(x)
65
+ x = nn.ReLU(inplace=True)(x)
66
+ x = self.bottleneck_conv2(x)
67
+ x = nn.ReLU(inplace=True)(x)
68
+
69
+ x = self.decoding_block4(x, skip4)
70
+ x = self.decoding_block3(x, skip3)
71
+ x = self.decoding_block2(x, skip2)
72
+ x = self.decoding_block1(x, skip1)
73
+
74
+ x = self.final_conv(x)
75
+ return x