Update RingFormer/meldataset.py
Browse files- RingFormer/meldataset.py +33 -16
RingFormer/meldataset.py
CHANGED
@@ -17,11 +17,24 @@ import soundfile as sf
|
|
17 |
def normalize_audio(wav):
|
18 |
return wav / torch.max(torch.abs(torch.from_numpy(wav))) # Correct peak normalization
|
19 |
|
20 |
-
def
|
21 |
-
data, sampling_rate = librosa.load(full_path, sr=
|
22 |
return data, sampling_rate
|
23 |
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
26 |
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
27 |
|
@@ -54,12 +67,16 @@ hann_window = {}
|
|
54 |
|
55 |
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
56 |
|
57 |
-
y = torch.clamp(y, min=-1, max=1)
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
global mel_basis, hann_window
|
65 |
if fmax not in mel_basis:
|
@@ -87,12 +104,10 @@ to_mel = torchaudio.transforms.MelSpectrogram(
|
|
87 |
sample_rate=44_100, n_mels=128, n_fft=2048, win_length=2048, hop_length=512)
|
88 |
|
89 |
|
90 |
-
|
91 |
-
# to_mel = torchaudio.transforms.MelSpectrogram(
|
92 |
-
# sample_rate=24000, n_mels=80, n_fft=2048, win_length=1200, hop_length=300, center='center')
|
93 |
|
94 |
mean, std = -4, 4
|
95 |
-
|
96 |
def preproces(wave,to_mel=to_mel, device='cpu'):
|
97 |
|
98 |
to_mel = to_mel.to(device)
|
@@ -142,17 +157,19 @@ class MelDataset(torch.utils.data.Dataset):
|
|
142 |
filename = self.audio_files[index]
|
143 |
if self._cache_ref_count == 0:
|
144 |
audio, sampling_rate = load_wav(filename)
|
145 |
-
|
146 |
-
|
|
|
147 |
self.cached_wav = audio
|
148 |
if sampling_rate != self.sampling_rate:
|
149 |
-
|
150 |
-
|
|
|
151 |
self._cache_ref_count = self.n_cache_reuse
|
152 |
else:
|
153 |
audio = self.cached_wav
|
154 |
self._cache_ref_count -= 1
|
155 |
-
|
156 |
audio = torch.FloatTensor(audio)
|
157 |
audio = audio.unsqueeze(0)
|
158 |
|
|
|
17 |
def normalize_audio(wav):
|
18 |
return wav / torch.max(torch.abs(torch.from_numpy(wav))) # Correct peak normalization
|
19 |
|
20 |
+
def load_wav_librosa(full_path):
|
21 |
+
data, sampling_rate = librosa.load(full_path, sr=44100)
|
22 |
return data, sampling_rate
|
23 |
|
24 |
|
25 |
+
|
26 |
+
def load_wav_scipy(full_path):
|
27 |
+
sampling_rate, data = read(full_path)
|
28 |
+
return data, sampling_rate
|
29 |
+
|
30 |
+
def load_wav(full_path):
|
31 |
+
try:
|
32 |
+
return load_wav_scipy(full_path)
|
33 |
+
except:
|
34 |
+
# print('using librosa...')
|
35 |
+
return load_wav_librosa(full_path)
|
36 |
+
|
37 |
+
|
38 |
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
39 |
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
40 |
|
|
|
67 |
|
68 |
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
69 |
|
|
|
70 |
|
71 |
+
# y = torch.clamp(y, -1, 1)
|
72 |
+
|
73 |
+
|
74 |
+
# if torch.min(y) < -1.:
|
75 |
+
# # y = torch.clamp(y, min = -1)
|
76 |
+
# # print('min value is ', torch.min(y))
|
77 |
+
# if torch.max(y) > 1.:
|
78 |
+
# y = torch.clamp(y, max = -1)
|
79 |
+
# print('max value is ', torch.max(y))
|
80 |
|
81 |
global mel_basis, hann_window
|
82 |
if fmax not in mel_basis:
|
|
|
104 |
sample_rate=44_100, n_mels=128, n_fft=2048, win_length=2048, hop_length=512)
|
105 |
|
106 |
|
107 |
+
# to_mel = torchaudio.transforms.MelSpectrogram(n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
|
|
|
|
|
108 |
|
109 |
mean, std = -4, 4
|
110 |
+
5
|
111 |
def preproces(wave,to_mel=to_mel, device='cpu'):
|
112 |
|
113 |
to_mel = to_mel.to(device)
|
|
|
157 |
filename = self.audio_files[index]
|
158 |
if self._cache_ref_count == 0:
|
159 |
audio, sampling_rate = load_wav(filename)
|
160 |
+
audio = audio / MAX_WAV_VALUE
|
161 |
+
if not self.fine_tuning:
|
162 |
+
audio = normalize(audio) * 0.95
|
163 |
self.cached_wav = audio
|
164 |
if sampling_rate != self.sampling_rate:
|
165 |
+
audio = librosa.resample(audio, orig_sr= sampling_rate, target_sr= self.sampling_rate)
|
166 |
+
# raise ValueError("{} SR doesn't match target {} SR, {}".format(
|
167 |
+
# sampling_rate, self.sampling_rate, filename))
|
168 |
self._cache_ref_count = self.n_cache_reuse
|
169 |
else:
|
170 |
audio = self.cached_wav
|
171 |
self._cache_ref_count -= 1
|
172 |
+
|
173 |
audio = torch.FloatTensor(audio)
|
174 |
audio = audio.unsqueeze(0)
|
175 |
|