Spaces:
Runtime error
Runtime error
Commit
·
5a9b731
1
Parent(s):
98a3a53
uploading audio diffusion attacks
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .DS_Store +0 -0
- audio_diffusion_attacks +0 -1
- audio_diffusion_attacks_forhf/.DS_Store +0 -0
- audio_diffusion_attacks_forhf/README.md +37 -0
- audio_diffusion_attacks_forhf/assets/.DS_Store +0 -0
- audio_diffusion_attacks_forhf/assets/audios/.DS_Store +0 -0
- audio_diffusion_attacks_forhf/assets/audios/hyperpop.wav +0 -0
- audio_diffusion_attacks_forhf/assets/example_MAS.png +0 -0
- audio_diffusion_attacks_forhf/assets/example_duration.png +0 -0
- audio_diffusion_attacks_forhf/assets/example_mel.png +0 -0
- audio_diffusion_attacks_forhf/assets/example_untrained_phone_encoding.png +0 -0
- audio_diffusion_attacks_forhf/assets/gradtts_system.png +0 -0
- audio_diffusion_attacks_forhf/audio_ethics.yml +0 -0
- audio_diffusion_attacks_forhf/config.yml +2 -0
- audio_diffusion_attacks_forhf/gen_audio_ethics_3.10.yml +8 -0
- audio_diffusion_attacks_forhf/models/.DS_Store +0 -0
- audio_diffusion_attacks_forhf/models/__init__.py +0 -0
- audio_diffusion_attacks_forhf/models/__pycache__/__init__.cpython-310.pyc +0 -0
- audio_diffusion_attacks_forhf/models/__pycache__/phoneme_encoder.cpython-310.pyc +0 -0
- audio_diffusion_attacks_forhf/models/__pycache__/style_diffusion.cpython-310.pyc +0 -0
- audio_diffusion_attacks_forhf/models/__pycache__/utils.cpython-310.pyc +0 -0
- audio_diffusion_attacks_forhf/models/datasets/__pycache__/music_datasets.cpython-310.pyc +0 -0
- audio_diffusion_attacks_forhf/models/datasets/music_datasets.py +65 -0
- audio_diffusion_attacks_forhf/models/monotonic_align/.DS_Store +0 -0
- audio_diffusion_attacks_forhf/models/monotonic_align/__init__.py +23 -0
- audio_diffusion_attacks_forhf/models/monotonic_align/__pycache__/__init__.cpython-310.pyc +0 -0
- audio_diffusion_attacks_forhf/models/monotonic_align/build/temp.linux-x86_64-cpython-310/core.o +0 -0
- audio_diffusion_attacks_forhf/models/monotonic_align/core.c +0 -0
- audio_diffusion_attacks_forhf/models/monotonic_align/core.cpython-310-x86_64-linux-gnu.so +0 -0
- audio_diffusion_attacks_forhf/models/monotonic_align/core.pyx +45 -0
- audio_diffusion_attacks_forhf/models/monotonic_align/setup.py +11 -0
- audio_diffusion_attacks_forhf/models/phoneme_encoder.py +363 -0
- audio_diffusion_attacks_forhf/models/style_diffusion.py +111 -0
- audio_diffusion_attacks_forhf/models/utils.py +77 -0
- audio_diffusion_attacks_forhf/notebooks/data_exploration/00_fma_exploration.ipynb +0 -0
- audio_diffusion_attacks_forhf/resources/cmu_dictionary +0 -0
- audio_diffusion_attacks_forhf/scripts/.DS_Store +0 -0
- audio_diffusion_attacks_forhf/scripts/data_processing/process_music_mels.py +106 -0
- audio_diffusion_attacks_forhf/scripts/data_processing/process_music_numpy.py +74 -0
- audio_diffusion_attacks_forhf/scripts/train/music_models/train_music_completion.py +243 -0
- audio_diffusion_attacks_forhf/scripts/train/train_tts.py +430 -0
- audio_diffusion_attacks_forhf/src/.DS_Store +0 -0
- audio_diffusion_attacks_forhf/src/__pycache__/losses.cpython-310.pyc +0 -0
- audio_diffusion_attacks_forhf/src/__pycache__/music_gen.cpython-310.pyc +0 -0
- audio_diffusion_attacks_forhf/src/__pycache__/test_encoder_attack.cpython-310.pyc +0 -0
- audio_diffusion_attacks_forhf/src/balancer.py +137 -0
- audio_diffusion_attacks_forhf/src/losses.py +329 -0
- audio_diffusion_attacks_forhf/src/music_gen.py +100 -0
- audio_diffusion_attacks_forhf/src/speech_inference.py +94 -0
- audio_diffusion_attacks_forhf/src/test_audio/.Il Sogno Del Marinaio - Nanos' Waltz.mp3.icloud +0 -0
.DS_Store
CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
|
|
audio_diffusion_attacks
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
Subproject commit 1aaf4563762c407f31436ad452a72dd5af929443
|
|
|
|
audio_diffusion_attacks_forhf/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
audio_diffusion_attacks_forhf/README.md
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Audio Data Ownership
|
2 |
+
|
3 |
+
## Installation
|
4 |
+
|
5 |
+
conda env create -n audio_ethics --file gen_audio_ethics_3.10.yml
|
6 |
+
|
7 |
+
To set up wandb, please check out this following link: [https://docs.wandb.ai/quickstart](https://docs.wandb.ai/quickstart)
|
8 |
+
|
9 |
+
## Run Encoder Attack
|
10 |
+
|
11 |
+
cd src
|
12 |
+
|
13 |
+
python test_encoder_attack.py
|
14 |
+
|
15 |
+
## Overview
|
16 |
+
|
17 |
+
## Task 1: Audio Completion with Diffusion Models
|
18 |
+
|
19 |
+
For this task, we use the [Free Music Archive (FMA)](https://github.com/mdeff/fma), which is a collection of royalty-free music. You can use any version of the model you wish, but we'll use the `fma_large` partition for training an initial system.
|
20 |
+
|
21 |
+
Note: If librosa version is too high, have to edit line in audioldm to be `fft_window = pad_center(fft_window, size=filter_length)`
|
22 |
+
|
23 |
+
To preprocess FMA, configure the file with your corresponding path and run the correct preprocessing script to convert the `.mp3` files to numpy (Loading in audio files during training is prohibitively slow).
|
24 |
+
- Proceprocessing for ArchiSound encoders: `nohup python -u scripts/data_processing/process_music_numpy.py > logs/process_48k_music.out &`
|
25 |
+
|
26 |
+
## Task 2: TTS with Diffusion Models
|
27 |
+
|
28 |
+
TTS with Diffusion (or flow) models is one approach of many that folks have been taking for SOTA TTS performance right now. In this repo, we have a model similar to
|
29 |
+
[Grad-TTS](https://grad-tts.github.io/), with the example inference for Grad-TTS below:
|
30 |
+
|
31 |
+
![Inference Figure for Grad-TTS](./assets/gradtts_system.png)
|
32 |
+
|
33 |
+
To run, first you need to build the `monotonic_align` code:
|
34 |
+
|
35 |
+
`cd model/monotonic_align; python setup.py build_ext --inplace; cd ../..`
|
36 |
+
|
37 |
+
You possibly might have to move the generated .so file to the `monotonic_align/` directory if it is generated in `montonic_align/build/`.
|
audio_diffusion_attacks_forhf/assets/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
audio_diffusion_attacks_forhf/assets/audios/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
audio_diffusion_attacks_forhf/assets/audios/hyperpop.wav
ADDED
Binary file (640 kB). View file
|
|
audio_diffusion_attacks_forhf/assets/example_MAS.png
ADDED
![]() |
audio_diffusion_attacks_forhf/assets/example_duration.png
ADDED
![]() |
audio_diffusion_attacks_forhf/assets/example_mel.png
ADDED
![]() |
audio_diffusion_attacks_forhf/assets/example_untrained_phone_encoding.png
ADDED
![]() |
audio_diffusion_attacks_forhf/assets/gradtts_system.png
ADDED
![]() |
audio_diffusion_attacks_forhf/audio_ethics.yml
ADDED
File without changes
|
audio_diffusion_attacks_forhf/config.yml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
wandb_settings:
|
2 |
+
project_name: audio_attacks
|
audio_diffusion_attacks_forhf/gen_audio_ethics_3.10.yml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: gen_audio_ethics_3.10
|
2 |
+
channels:
|
3 |
+
- defaults
|
4 |
+
dependencies:
|
5 |
+
- python=3.10
|
6 |
+
- conda-forge::libsndfile
|
7 |
+
- librosa
|
8 |
+
prefix: /home/willie/anaconda3/envs/gen_audio_ethics_3.10
|
audio_diffusion_attacks_forhf/models/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
audio_diffusion_attacks_forhf/models/__init__.py
ADDED
File without changes
|
audio_diffusion_attacks_forhf/models/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (173 Bytes). View file
|
|
audio_diffusion_attacks_forhf/models/__pycache__/phoneme_encoder.cpython-310.pyc
ADDED
Binary file (12 kB). View file
|
|
audio_diffusion_attacks_forhf/models/__pycache__/style_diffusion.cpython-310.pyc
ADDED
Binary file (4.38 kB). View file
|
|
audio_diffusion_attacks_forhf/models/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (2.54 kB). View file
|
|
audio_diffusion_attacks_forhf/models/datasets/__pycache__/music_datasets.cpython-310.pyc
ADDED
Binary file (1.89 kB). View file
|
|
audio_diffusion_attacks_forhf/models/datasets/music_datasets.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
music_datasets.py
|
3 |
+
Desc: Contains the code for the music datasets.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
import torchaudio
|
9 |
+
import numpy as np
|
10 |
+
import pandas as pd
|
11 |
+
|
12 |
+
|
13 |
+
"""
|
14 |
+
MusicMelDataset:
|
15 |
+
Given pre-processed mel-spectrograms, return a chunk of audio from the mel, with a masked version of a defined length
|
16 |
+
Args:
|
17 |
+
audio_files: List of .npy files consisting of mel-specs
|
18 |
+
audio_len: length in seconds (roughly) of audio to be return
|
19 |
+
mask_ratio: Size of mask as a ration of audio_len
|
20 |
+
mask_start: Where the mask starts for learning
|
21 |
+
"midpoint": always mask out the second half of the mel-spec
|
22 |
+
crop_start: Where the starting point for the sample of audio is taken
|
23 |
+
"random": Random valid starting point from audio is taken
|
24 |
+
|
25 |
+
"""
|
26 |
+
class MusicMelDataset(Dataset):
|
27 |
+
def __init__(self, audio_files, audio_len = 6, mask_ratio = 0.5, mask_start = "midpoint", crop_start = "random"):
|
28 |
+
self.audio_files = audio_files
|
29 |
+
|
30 |
+
# Convert length to number of frames
|
31 |
+
self.audio_len = int(audio_len * 100) # 100 is heuristic conversion made
|
32 |
+
self.mask_ratio = mask_ratio
|
33 |
+
self.mask_len = int(np.floor(self.audio_len * mask_ratio))
|
34 |
+
self.mask_start = mask_start
|
35 |
+
self.crop_start = crop_start
|
36 |
+
|
37 |
+
def __len__(self):
|
38 |
+
return len(self.audio_files)
|
39 |
+
|
40 |
+
# Get a random crop using audio_length
|
41 |
+
def get_random_crop(self, mel):
|
42 |
+
crop_start = torch.randint(0, mel.shape[0] - self.audio_len - 1, (1,))
|
43 |
+
return mel[crop_start:crop_start + self.audio_len, :]
|
44 |
+
|
45 |
+
def __getitem__(self, idx):
|
46 |
+
mel = torch.Tensor(np.load(self.audio_files[idx]))
|
47 |
+
|
48 |
+
|
49 |
+
if self.crop_start == "random":
|
50 |
+
mel = self.get_random_crop(mel)
|
51 |
+
else:
|
52 |
+
raise NotImplementedError(f"{self.crop_start} is not an implemented parameter for crop_start")
|
53 |
+
|
54 |
+
mask = torch.ones_like(mel)
|
55 |
+
if self.mask_start == "midpoint":
|
56 |
+
if self.mask_ratio == 0.5:
|
57 |
+
mask[self.mask_len:, :] = 0
|
58 |
+
else:
|
59 |
+
mask[self.audio_len // 2 + self.mask_len, :] = 0
|
60 |
+
else:
|
61 |
+
raise NotImplementedError(f"{self.mask_start} is not an implemented parameter for mask_start")
|
62 |
+
|
63 |
+
mel_mask = mel*mask
|
64 |
+
|
65 |
+
return mel, mel_mask
|
audio_diffusion_attacks_forhf/models/monotonic_align/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
audio_diffusion_attacks_forhf/models/monotonic_align/__init__.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" from https://github.com/jaywalnut310/glow-tts """
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from .core import maximum_path_c
|
6 |
+
|
7 |
+
|
8 |
+
def maximum_path(value, mask):
|
9 |
+
""" Cython optimised version.
|
10 |
+
value: [b, t_x, t_y]
|
11 |
+
mask: [b, t_x, t_y]
|
12 |
+
"""
|
13 |
+
value = value * mask
|
14 |
+
device = value.device
|
15 |
+
dtype = value.dtype
|
16 |
+
value = value.data.cpu().numpy().astype(np.float32)
|
17 |
+
path = np.zeros_like(value).astype(np.int32)
|
18 |
+
mask = mask.data.cpu().numpy()
|
19 |
+
|
20 |
+
t_x_max = mask.sum(1)[:, 0].astype(np.int32)
|
21 |
+
t_y_max = mask.sum(2)[:, 0].astype(np.int32)
|
22 |
+
maximum_path_c(path, value, t_x_max, t_y_max)
|
23 |
+
return torch.from_numpy(path).to(device=device, dtype=dtype)
|
audio_diffusion_attacks_forhf/models/monotonic_align/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (881 Bytes). View file
|
|
audio_diffusion_attacks_forhf/models/monotonic_align/build/temp.linux-x86_64-cpython-310/core.o
ADDED
Binary file (236 kB). View file
|
|
audio_diffusion_attacks_forhf/models/monotonic_align/core.c
ADDED
The diff for this file is too large to render.
See raw diff
|
|
audio_diffusion_attacks_forhf/models/monotonic_align/core.cpython-310-x86_64-linux-gnu.so
ADDED
Binary file (178 kB). View file
|
|
audio_diffusion_attacks_forhf/models/monotonic_align/core.pyx
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
cimport numpy as np
|
3 |
+
cimport cython
|
4 |
+
from cython.parallel import prange
|
5 |
+
|
6 |
+
|
7 |
+
@cython.boundscheck(False)
|
8 |
+
@cython.wraparound(False)
|
9 |
+
cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil:
|
10 |
+
cdef int x
|
11 |
+
cdef int y
|
12 |
+
cdef float v_prev
|
13 |
+
cdef float v_cur
|
14 |
+
cdef float tmp
|
15 |
+
cdef int index = t_x - 1
|
16 |
+
|
17 |
+
for y in range(t_y):
|
18 |
+
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
|
19 |
+
if x == y:
|
20 |
+
v_cur = max_neg_val
|
21 |
+
else:
|
22 |
+
v_cur = value[x, y-1]
|
23 |
+
if x == 0:
|
24 |
+
if y == 0:
|
25 |
+
v_prev = 0.
|
26 |
+
else:
|
27 |
+
v_prev = max_neg_val
|
28 |
+
else:
|
29 |
+
v_prev = value[x-1, y-1]
|
30 |
+
value[x, y] = max(v_cur, v_prev) + value[x, y]
|
31 |
+
|
32 |
+
for y in range(t_y - 1, -1, -1):
|
33 |
+
path[index, y] = 1
|
34 |
+
if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]):
|
35 |
+
index = index - 1
|
36 |
+
|
37 |
+
|
38 |
+
@cython.boundscheck(False)
|
39 |
+
@cython.wraparound(False)
|
40 |
+
cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil:
|
41 |
+
cdef int b = values.shape[0]
|
42 |
+
|
43 |
+
cdef int i
|
44 |
+
for i in prange(b, nogil=True):
|
45 |
+
maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val)
|
audio_diffusion_attacks_forhf/models/monotonic_align/setup.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" from https://github.com/jaywalnut310/glow-tts """
|
2 |
+
|
3 |
+
from distutils.core import setup
|
4 |
+
from Cython.Build import cythonize
|
5 |
+
import numpy
|
6 |
+
|
7 |
+
setup(
|
8 |
+
name = 'monotonic_align',
|
9 |
+
ext_modules = cythonize("core.pyx"),
|
10 |
+
include_dirs=[numpy.get_include()]
|
11 |
+
)
|
audio_diffusion_attacks_forhf/models/phoneme_encoder.py
ADDED
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" from https://github.com/jaywalnut310/glow-tts and https://github.com/huawei-noah/Speech-Backbones/blob/main/Grad-TTS/"""
|
2 |
+
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
from models.utils import sequence_mask, convert_pad_shape
|
9 |
+
|
10 |
+
|
11 |
+
# def sequence_mask(length, max_length=None):
|
12 |
+
# if max_length is None:
|
13 |
+
# max_length = length.max()
|
14 |
+
# x = torch.arange(int(max_length), dtype=length.dtype, device=length.device)
|
15 |
+
# return x.unsqueeze(0) < length.unsqueeze(1)
|
16 |
+
|
17 |
+
# def convert_pad_shape(pad_shape):
|
18 |
+
# l = pad_shape[::-1]
|
19 |
+
# pad_shape = [item for sublist in l for item in sublist]
|
20 |
+
# return pad_shape
|
21 |
+
|
22 |
+
class BaseModule(torch.nn.Module):
|
23 |
+
def __init__(self):
|
24 |
+
super(BaseModule, self).__init__()
|
25 |
+
|
26 |
+
@property
|
27 |
+
def nparams(self):
|
28 |
+
"""
|
29 |
+
Returns number of trainable parameters of the module.
|
30 |
+
"""
|
31 |
+
num_params = 0
|
32 |
+
for name, param in self.named_parameters():
|
33 |
+
if param.requires_grad:
|
34 |
+
num_params += np.prod(param.detach().cpu().numpy().shape)
|
35 |
+
return num_params
|
36 |
+
|
37 |
+
|
38 |
+
def relocate_input(self, x: list):
|
39 |
+
"""
|
40 |
+
Relocates provided tensors to the same device set for the module.
|
41 |
+
"""
|
42 |
+
device = next(self.parameters()).device
|
43 |
+
for i in range(len(x)):
|
44 |
+
if isinstance(x[i], torch.Tensor) and x[i].device != device:
|
45 |
+
x[i] = x[i].to(device)
|
46 |
+
return x
|
47 |
+
|
48 |
+
class LayerNorm(BaseModule):
|
49 |
+
def __init__(self, channels, eps=1e-4):
|
50 |
+
super(LayerNorm, self).__init__()
|
51 |
+
self.channels = channels
|
52 |
+
self.eps = eps
|
53 |
+
|
54 |
+
self.gamma = torch.nn.Parameter(torch.ones(channels))
|
55 |
+
self.beta = torch.nn.Parameter(torch.zeros(channels))
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
n_dims = len(x.shape)
|
59 |
+
mean = torch.mean(x, 1, keepdim=True)
|
60 |
+
variance = torch.mean((x - mean)**2, 1, keepdim=True)
|
61 |
+
|
62 |
+
x = (x - mean) * torch.rsqrt(variance + self.eps)
|
63 |
+
|
64 |
+
shape = [1, -1] + [1] * (n_dims - 2)
|
65 |
+
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
|
66 |
+
return x
|
67 |
+
|
68 |
+
|
69 |
+
class ConvReluNorm(BaseModule):
|
70 |
+
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size,
|
71 |
+
n_layers, p_dropout):
|
72 |
+
super(ConvReluNorm, self).__init__()
|
73 |
+
self.in_channels = in_channels
|
74 |
+
self.hidden_channels = hidden_channels
|
75 |
+
self.out_channels = out_channels
|
76 |
+
self.kernel_size = kernel_size
|
77 |
+
self.n_layers = n_layers
|
78 |
+
self.p_dropout = p_dropout
|
79 |
+
|
80 |
+
self.conv_layers = torch.nn.ModuleList()
|
81 |
+
self.norm_layers = torch.nn.ModuleList()
|
82 |
+
self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels,
|
83 |
+
kernel_size, padding=kernel_size//2))
|
84 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
85 |
+
self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout))
|
86 |
+
for _ in range(n_layers - 1):
|
87 |
+
self.conv_layers.append(torch.nn.Conv1d(hidden_channels, hidden_channels,
|
88 |
+
kernel_size, padding=kernel_size//2))
|
89 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
90 |
+
self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
|
91 |
+
self.proj.weight.data.zero_()
|
92 |
+
self.proj.bias.data.zero_()
|
93 |
+
|
94 |
+
def forward(self, x, x_mask):
|
95 |
+
x_org = x
|
96 |
+
for i in range(self.n_layers):
|
97 |
+
x = self.conv_layers[i](x * x_mask)
|
98 |
+
x = self.norm_layers[i](x)
|
99 |
+
x = self.relu_drop(x)
|
100 |
+
x = x_org + self.proj(x)
|
101 |
+
return x * x_mask
|
102 |
+
|
103 |
+
|
104 |
+
class DurationPredictor(BaseModule):
|
105 |
+
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
|
106 |
+
super(DurationPredictor, self).__init__()
|
107 |
+
self.in_channels = in_channels
|
108 |
+
self.filter_channels = filter_channels
|
109 |
+
self.p_dropout = p_dropout
|
110 |
+
|
111 |
+
self.drop = torch.nn.Dropout(p_dropout)
|
112 |
+
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels,
|
113 |
+
kernel_size, padding=kernel_size//2)
|
114 |
+
self.norm_1 = LayerNorm(filter_channels)
|
115 |
+
self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels,
|
116 |
+
kernel_size, padding=kernel_size//2)
|
117 |
+
self.norm_2 = LayerNorm(filter_channels)
|
118 |
+
self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
|
119 |
+
|
120 |
+
def forward(self, x, x_mask):
|
121 |
+
x = self.conv_1(x * x_mask)
|
122 |
+
x = torch.relu(x)
|
123 |
+
x = self.norm_1(x)
|
124 |
+
x = self.drop(x)
|
125 |
+
x = self.conv_2(x * x_mask)
|
126 |
+
x = torch.relu(x)
|
127 |
+
x = self.norm_2(x)
|
128 |
+
x = self.drop(x)
|
129 |
+
x = self.proj(x * x_mask)
|
130 |
+
return x * x_mask
|
131 |
+
|
132 |
+
|
133 |
+
class MultiHeadAttention(BaseModule):
|
134 |
+
def __init__(self, channels, out_channels, n_heads, window_size=None,
|
135 |
+
heads_share=True, p_dropout=0.0, proximal_bias=False,
|
136 |
+
proximal_init=False):
|
137 |
+
super(MultiHeadAttention, self).__init__()
|
138 |
+
assert channels % n_heads == 0
|
139 |
+
|
140 |
+
self.channels = channels
|
141 |
+
self.out_channels = out_channels
|
142 |
+
self.n_heads = n_heads
|
143 |
+
self.window_size = window_size
|
144 |
+
self.heads_share = heads_share
|
145 |
+
self.proximal_bias = proximal_bias
|
146 |
+
self.p_dropout = p_dropout
|
147 |
+
self.attn = None
|
148 |
+
|
149 |
+
self.k_channels = channels // n_heads
|
150 |
+
self.conv_q = torch.nn.Conv1d(channels, channels, 1)
|
151 |
+
self.conv_k = torch.nn.Conv1d(channels, channels, 1)
|
152 |
+
self.conv_v = torch.nn.Conv1d(channels, channels, 1)
|
153 |
+
if window_size is not None:
|
154 |
+
n_heads_rel = 1 if heads_share else n_heads
|
155 |
+
rel_stddev = self.k_channels**-0.5
|
156 |
+
self.emb_rel_k = torch.nn.Parameter(torch.randn(n_heads_rel,
|
157 |
+
window_size * 2 + 1, self.k_channels) * rel_stddev)
|
158 |
+
self.emb_rel_v = torch.nn.Parameter(torch.randn(n_heads_rel,
|
159 |
+
window_size * 2 + 1, self.k_channels) * rel_stddev)
|
160 |
+
self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
|
161 |
+
self.drop = torch.nn.Dropout(p_dropout)
|
162 |
+
|
163 |
+
torch.nn.init.xavier_uniform_(self.conv_q.weight)
|
164 |
+
torch.nn.init.xavier_uniform_(self.conv_k.weight)
|
165 |
+
if proximal_init:
|
166 |
+
self.conv_k.weight.data.copy_(self.conv_q.weight.data)
|
167 |
+
self.conv_k.bias.data.copy_(self.conv_q.bias.data)
|
168 |
+
torch.nn.init.xavier_uniform_(self.conv_v.weight)
|
169 |
+
|
170 |
+
def forward(self, x, c, attn_mask=None):
|
171 |
+
q = self.conv_q(x)
|
172 |
+
k = self.conv_k(c)
|
173 |
+
v = self.conv_v(c)
|
174 |
+
|
175 |
+
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
176 |
+
|
177 |
+
x = self.conv_o(x)
|
178 |
+
return x
|
179 |
+
|
180 |
+
def attention(self, query, key, value, mask=None):
|
181 |
+
b, d, t_s, t_t = (*key.size(), query.size(2))
|
182 |
+
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
183 |
+
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
184 |
+
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
185 |
+
|
186 |
+
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
|
187 |
+
if self.window_size is not None:
|
188 |
+
assert t_s == t_t, "Relative attention is only available for self-attention."
|
189 |
+
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
190 |
+
rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
|
191 |
+
rel_logits = self._relative_position_to_absolute_position(rel_logits)
|
192 |
+
scores_local = rel_logits / math.sqrt(self.k_channels)
|
193 |
+
scores = scores + scores_local
|
194 |
+
if self.proximal_bias:
|
195 |
+
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
196 |
+
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device,
|
197 |
+
dtype=scores.dtype)
|
198 |
+
if mask is not None:
|
199 |
+
scores = scores.masked_fill(mask == 0, -1e4)
|
200 |
+
p_attn = torch.nn.functional.softmax(scores, dim=-1)
|
201 |
+
p_attn = self.drop(p_attn)
|
202 |
+
output = torch.matmul(p_attn, value)
|
203 |
+
if self.window_size is not None:
|
204 |
+
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
205 |
+
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
|
206 |
+
output = output + self._matmul_with_relative_values(relative_weights,
|
207 |
+
value_relative_embeddings)
|
208 |
+
output = output.transpose(2, 3).contiguous().view(b, d, t_t)
|
209 |
+
return output, p_attn
|
210 |
+
|
211 |
+
def _matmul_with_relative_values(self, x, y):
|
212 |
+
ret = torch.matmul(x, y.unsqueeze(0))
|
213 |
+
return ret
|
214 |
+
|
215 |
+
def _matmul_with_relative_keys(self, x, y):
|
216 |
+
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
217 |
+
return ret
|
218 |
+
|
219 |
+
def _get_relative_embeddings(self, relative_embeddings, length):
|
220 |
+
pad_length = max(length - (self.window_size + 1), 0)
|
221 |
+
slice_start_position = max((self.window_size + 1) - length, 0)
|
222 |
+
slice_end_position = slice_start_position + 2 * length - 1
|
223 |
+
if pad_length > 0:
|
224 |
+
padded_relative_embeddings = torch.nn.functional.pad(
|
225 |
+
relative_embeddings, convert_pad_shape([[0, 0],
|
226 |
+
[pad_length, pad_length], [0, 0]]))
|
227 |
+
else:
|
228 |
+
padded_relative_embeddings = relative_embeddings
|
229 |
+
used_relative_embeddings = padded_relative_embeddings[:,
|
230 |
+
slice_start_position:slice_end_position]
|
231 |
+
return used_relative_embeddings
|
232 |
+
|
233 |
+
def _relative_position_to_absolute_position(self, x):
|
234 |
+
batch, heads, length, _ = x.size()
|
235 |
+
x = torch.nn.functional.pad(x, convert_pad_shape([[0,0],[0,0],[0,0],[0,1]]))
|
236 |
+
x_flat = x.view([batch, heads, length * 2 * length])
|
237 |
+
x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0,0],[0,0],[0,length-1]]))
|
238 |
+
x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:]
|
239 |
+
return x_final
|
240 |
+
|
241 |
+
def _absolute_position_to_relative_position(self, x):
|
242 |
+
batch, heads, length, _ = x.size()
|
243 |
+
x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]]))
|
244 |
+
x_flat = x.view([batch, heads, length**2 + length*(length - 1)])
|
245 |
+
x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
246 |
+
x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:]
|
247 |
+
return x_final
|
248 |
+
|
249 |
+
def _attention_bias_proximal(self, length):
|
250 |
+
r = torch.arange(length, dtype=torch.float32)
|
251 |
+
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
252 |
+
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
253 |
+
|
254 |
+
|
255 |
+
class FFN(BaseModule):
|
256 |
+
def __init__(self, in_channels, out_channels, filter_channels, kernel_size,
|
257 |
+
p_dropout=0.0):
|
258 |
+
super(FFN, self).__init__()
|
259 |
+
self.in_channels = in_channels
|
260 |
+
self.out_channels = out_channels
|
261 |
+
self.filter_channels = filter_channels
|
262 |
+
self.kernel_size = kernel_size
|
263 |
+
self.p_dropout = p_dropout
|
264 |
+
|
265 |
+
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size,
|
266 |
+
padding=kernel_size//2)
|
267 |
+
self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size,
|
268 |
+
padding=kernel_size//2)
|
269 |
+
self.drop = torch.nn.Dropout(p_dropout)
|
270 |
+
|
271 |
+
def forward(self, x, x_mask):
|
272 |
+
x = self.conv_1(x * x_mask)
|
273 |
+
x = torch.relu(x)
|
274 |
+
x = self.drop(x)
|
275 |
+
x = self.conv_2(x * x_mask)
|
276 |
+
return x * x_mask
|
277 |
+
|
278 |
+
|
279 |
+
class Encoder(BaseModule):
|
280 |
+
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers,
|
281 |
+
kernel_size=1, p_dropout=0.0, window_size=None, **kwargs):
|
282 |
+
super(Encoder, self).__init__()
|
283 |
+
self.hidden_channels = hidden_channels
|
284 |
+
self.filter_channels = filter_channels
|
285 |
+
self.n_heads = n_heads
|
286 |
+
self.n_layers = n_layers
|
287 |
+
self.kernel_size = kernel_size
|
288 |
+
self.p_dropout = p_dropout
|
289 |
+
self.window_size = window_size
|
290 |
+
|
291 |
+
self.drop = torch.nn.Dropout(p_dropout)
|
292 |
+
self.attn_layers = torch.nn.ModuleList()
|
293 |
+
self.norm_layers_1 = torch.nn.ModuleList()
|
294 |
+
self.ffn_layers = torch.nn.ModuleList()
|
295 |
+
self.norm_layers_2 = torch.nn.ModuleList()
|
296 |
+
for _ in range(self.n_layers):
|
297 |
+
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels,
|
298 |
+
n_heads, window_size=window_size, p_dropout=p_dropout))
|
299 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
300 |
+
self.ffn_layers.append(FFN(hidden_channels, hidden_channels,
|
301 |
+
filter_channels, kernel_size, p_dropout=p_dropout))
|
302 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
303 |
+
|
304 |
+
def forward(self, x, x_mask):
|
305 |
+
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
306 |
+
for i in range(self.n_layers):
|
307 |
+
x = x * x_mask
|
308 |
+
y = self.attn_layers[i](x, x, attn_mask)
|
309 |
+
y = self.drop(y)
|
310 |
+
x = self.norm_layers_1[i](x + y)
|
311 |
+
y = self.ffn_layers[i](x, x_mask)
|
312 |
+
y = self.drop(y)
|
313 |
+
x = self.norm_layers_2[i](x + y)
|
314 |
+
x = x * x_mask
|
315 |
+
return x
|
316 |
+
|
317 |
+
|
318 |
+
class TextEncoder(BaseModule):
|
319 |
+
def __init__(self, n_vocab, n_feats, n_channels, filter_channels,
|
320 |
+
filter_channels_dp, n_heads, n_layers, kernel_size,
|
321 |
+
p_dropout, window_size=None, spk_emb_dim=64, n_spks=1):
|
322 |
+
super(TextEncoder, self).__init__()
|
323 |
+
self.n_vocab = n_vocab
|
324 |
+
self.n_feats = n_feats
|
325 |
+
self.n_channels = n_channels
|
326 |
+
self.filter_channels = filter_channels
|
327 |
+
self.filter_channels_dp = filter_channels_dp
|
328 |
+
self.n_heads = n_heads
|
329 |
+
self.n_layers = n_layers
|
330 |
+
self.kernel_size = kernel_size
|
331 |
+
self.p_dropout = p_dropout
|
332 |
+
self.window_size = window_size
|
333 |
+
self.spk_emb_dim = spk_emb_dim
|
334 |
+
self.n_spks = n_spks
|
335 |
+
|
336 |
+
self.emb = torch.nn.Embedding(n_vocab, n_channels)
|
337 |
+
torch.nn.init.normal_(self.emb.weight, 0.0, n_channels**-0.5)
|
338 |
+
|
339 |
+
self.prenet = ConvReluNorm(n_channels, n_channels, n_channels,
|
340 |
+
kernel_size=5, n_layers=3, p_dropout=0.5)
|
341 |
+
|
342 |
+
self.encoder = Encoder(n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels, n_heads, n_layers,
|
343 |
+
kernel_size, p_dropout, window_size=window_size)
|
344 |
+
|
345 |
+
self.proj_m = torch.nn.Conv1d(n_channels + (spk_emb_dim if n_spks > 1 else 0), n_feats, 1)
|
346 |
+
self.proj_w = DurationPredictor(n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels_dp,
|
347 |
+
kernel_size, p_dropout)
|
348 |
+
|
349 |
+
def forward(self, x, x_lengths, spk=None):
|
350 |
+
x = self.emb(x) * math.sqrt(self.n_channels)
|
351 |
+
x = torch.transpose(x, 1, -1)
|
352 |
+
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
353 |
+
|
354 |
+
x = self.prenet(x, x_mask)
|
355 |
+
if self.n_spks > 1:
|
356 |
+
x = torch.cat([x, spk.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
|
357 |
+
x = self.encoder(x, x_mask)
|
358 |
+
mu = self.proj_m(x) * x_mask
|
359 |
+
|
360 |
+
x_dp = torch.detach(x)
|
361 |
+
logw = self.proj_w(x_dp, x_mask)
|
362 |
+
|
363 |
+
return mu, logw, x_mask
|
audio_diffusion_attacks_forhf/models/style_diffusion.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
style_diffusion.py
|
3 |
+
Desc: Contains StyleVDiffusion models for training style transfer/editing models. These are essentially slight modifications of the original VDiffusion classes.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from math import pi
|
7 |
+
from typing import Any, Optional, Tuple
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from einops import rearrange, repeat
|
13 |
+
from torch import Tensor
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
from audio_diffusion_pytorch.utils import default
|
17 |
+
from audio_diffusion_pytorch import Diffusion, Sampler, VDiffusion, VSampler, LinearSchedule, Schedule, Distribution, UniformDistribution
|
18 |
+
|
19 |
+
def pad_dims(x: Tensor, ndim: int) -> Tensor:
|
20 |
+
# Pads additional ndims to the right of the tensor
|
21 |
+
return x.view(*x.shape, *((1,) * ndim))
|
22 |
+
|
23 |
+
|
24 |
+
def clip(x: Tensor, dynamic_threshold: float = 0.0):
|
25 |
+
if dynamic_threshold == 0.0:
|
26 |
+
return x.clamp(-1.0, 1.0)
|
27 |
+
else:
|
28 |
+
# Dynamic thresholding
|
29 |
+
# Find dynamic threshold quantile for each batch
|
30 |
+
x_flat = rearrange(x, "b ... -> b (...)")
|
31 |
+
scale = torch.quantile(x_flat.abs(), dynamic_threshold, dim=-1)
|
32 |
+
# Clamp to a min of 1.0
|
33 |
+
scale.clamp_(min=1.0)
|
34 |
+
# Clamp all values and scale
|
35 |
+
scale = pad_dims(scale, ndim=x.ndim - scale.ndim)
|
36 |
+
x = x.clamp(-scale, scale) / scale
|
37 |
+
return x
|
38 |
+
|
39 |
+
|
40 |
+
def extend_dim(x: Tensor, dim: int):
|
41 |
+
# e.g. if dim = 4: shape [b] => [b, 1, 1, 1],
|
42 |
+
return x.view(*x.shape + (1,) * (dim - x.ndim))
|
43 |
+
|
44 |
+
class StyleVDiffusion(Diffusion):
|
45 |
+
def __init__(
|
46 |
+
self, net: nn.Module, sigma_distribution: Distribution = UniformDistribution()
|
47 |
+
):
|
48 |
+
super().__init__()
|
49 |
+
self.net = net
|
50 |
+
self.sigma_distribution = sigma_distribution
|
51 |
+
|
52 |
+
def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
|
53 |
+
angle = sigmas * pi / 2
|
54 |
+
alpha, beta = torch.cos(angle), torch.sin(angle)
|
55 |
+
return alpha, beta
|
56 |
+
|
57 |
+
def forward(self, x: Tensor, y: Tensor, **kwargs) -> Tensor: # type: ignore
|
58 |
+
batch_size, device = x.shape[0], x.device
|
59 |
+
# Sample amount of noise to add for each batch element
|
60 |
+
sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
|
61 |
+
sigmas_batch = extend_dim(sigmas, dim=y.ndim)
|
62 |
+
# Get noise
|
63 |
+
noise = torch.randn_like(y)
|
64 |
+
# Combine input and noise weighted by half-circle
|
65 |
+
alphas, betas = self.get_alpha_beta(sigmas_batch)
|
66 |
+
y_noisy = alphas * y + betas * noise
|
67 |
+
y_noisy = torch.concat((y_noisy, x), dim=1)
|
68 |
+
v_target = alphas * noise - betas * y
|
69 |
+
# Predict velocity and return loss
|
70 |
+
v_pred = self.net(y_noisy, sigmas, **kwargs)
|
71 |
+
return F.mse_loss(v_pred, v_target)
|
72 |
+
|
73 |
+
|
74 |
+
class StyleVSampler(Sampler):
|
75 |
+
|
76 |
+
diffusion_types = [VDiffusion]
|
77 |
+
|
78 |
+
def __init__(self, net: nn.Module, schedule: Schedule = LinearSchedule()):
|
79 |
+
super().__init__()
|
80 |
+
self.net = net
|
81 |
+
self.schedule = schedule
|
82 |
+
|
83 |
+
def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
|
84 |
+
angle = sigmas * pi / 2
|
85 |
+
alpha, beta = torch.cos(angle), torch.sin(angle)
|
86 |
+
return alpha, beta
|
87 |
+
|
88 |
+
@torch.no_grad()
|
89 |
+
def forward( # type: ignore
|
90 |
+
self, x:Tensor, x_noisy: Tensor, num_steps: int, show_progress: bool = False, **kwargs
|
91 |
+
) -> Tensor:
|
92 |
+
b = x_noisy.shape[0]
|
93 |
+
x = x[None, ...]
|
94 |
+
sigmas = self.schedule(num_steps + 1, device=x_noisy.device)
|
95 |
+
sigmas = repeat(sigmas, "i -> i b", b=b)
|
96 |
+
sigmas_batch = extend_dim(sigmas, dim=x_noisy.ndim + 1)
|
97 |
+
alphas, betas = self.get_alpha_beta(sigmas_batch)
|
98 |
+
progress_bar = tqdm(range(num_steps), disable=not show_progress)
|
99 |
+
|
100 |
+
for i in progress_bar:
|
101 |
+
x_mix = torch.cat((x_noisy, x), dim=1)
|
102 |
+
v_pred = self.net(x_mix, sigmas[i], **kwargs)
|
103 |
+
x_pred = alphas[i] * x_noisy - betas[i] * v_pred
|
104 |
+
noise_pred = betas[i] * x_noisy + alphas[i] * v_pred
|
105 |
+
x_noisy = alphas[i + 1] * x_pred + betas[i + 1] * noise_pred
|
106 |
+
progress_bar.set_description(f"Sampling (noise={sigmas[i+1,0]:.2f})")
|
107 |
+
|
108 |
+
return x_noisy
|
109 |
+
|
110 |
+
if __name__ == "__main__":
|
111 |
+
print("Loaded dependencies correctly.")
|
audio_diffusion_attacks_forhf/models/utils.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
utils.py
|
3 |
+
Desc: A file for miscellaneous util functions
|
4 |
+
"""
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
# MonoTransform, this does not exist in PyTorch anymore since it is a simple mean calculation. We provide an implementation here
|
12 |
+
class MonoTransform(object):
|
13 |
+
"""
|
14 |
+
Convert audio sample to mono channel
|
15 |
+
|
16 |
+
Args for __call__:
|
17 |
+
audio_sample with shape (C, T) or (B, C, T), where C is the number of channels.
|
18 |
+
|
19 |
+
TODO: IMPLEMENT __call__
|
20 |
+
"""
|
21 |
+
def __init__(self):
|
22 |
+
pass
|
23 |
+
|
24 |
+
def __call__(self, sample):
|
25 |
+
pass
|
26 |
+
|
27 |
+
"""
|
28 |
+
Below: Helper functions for Grad-TTS
|
29 |
+
"""
|
30 |
+
|
31 |
+
## Duration Loss
|
32 |
+
## Desc: A function for computing the duration loss for the duration predictor
|
33 |
+
def duration_loss(logw, logw_, lengths):
|
34 |
+
loss = torch.sum((logw - logw_)**2) / torch.sum(lengths)
|
35 |
+
return loss
|
36 |
+
|
37 |
+
def intersperse(lst, item):
|
38 |
+
# Adds blank symbol
|
39 |
+
result = [item] * (len(lst) * 2 + 1)
|
40 |
+
result[1::2] = lst
|
41 |
+
return result
|
42 |
+
|
43 |
+
|
44 |
+
def sequence_mask(length, max_length=None):
|
45 |
+
if max_length is None:
|
46 |
+
max_length = length.max()
|
47 |
+
x = torch.arange(int(max_length), dtype=length.dtype, device=length.device)
|
48 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
49 |
+
|
50 |
+
|
51 |
+
def fix_len_compatibility(length, num_downsamplings_in_unet=2):
|
52 |
+
while True:
|
53 |
+
if length % (2**num_downsamplings_in_unet) == 0:
|
54 |
+
return length
|
55 |
+
length += 1
|
56 |
+
|
57 |
+
|
58 |
+
def convert_pad_shape(pad_shape):
|
59 |
+
l = pad_shape[::-1]
|
60 |
+
pad_shape = [item for sublist in l for item in sublist]
|
61 |
+
return pad_shape
|
62 |
+
|
63 |
+
|
64 |
+
def generate_path(duration, mask):
|
65 |
+
device = duration.device
|
66 |
+
|
67 |
+
b, t_x, t_y = mask.shape
|
68 |
+
cum_duration = torch.cumsum(duration, 1)
|
69 |
+
path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
|
70 |
+
|
71 |
+
cum_duration_flat = cum_duration.view(b * t_x)
|
72 |
+
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
73 |
+
path = path.view(b, t_x, t_y)
|
74 |
+
path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0],
|
75 |
+
[1, 0], [0, 0]]))[:, :-1]
|
76 |
+
path = path * mask
|
77 |
+
return path
|
audio_diffusion_attacks_forhf/notebooks/data_exploration/00_fma_exploration.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
audio_diffusion_attacks_forhf/resources/cmu_dictionary
ADDED
The diff for this file is too large to render.
See raw diff
|
|
audio_diffusion_attacks_forhf/scripts/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
audio_diffusion_attacks_forhf/scripts/data_processing/process_music_mels.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
process_audio_mels.py
|
3 |
+
Desc: Run this script with the appropriate data paths to extract mels
|
4 |
+
Command: `python -u scripts/data/processing/process_music_mels.py`
|
5 |
+
"""
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import seaborn as sns
|
11 |
+
import IPython
|
12 |
+
|
13 |
+
import scipy
|
14 |
+
import torch
|
15 |
+
import torchaudio
|
16 |
+
import os
|
17 |
+
import ast
|
18 |
+
import soundfile as sf
|
19 |
+
import glob
|
20 |
+
|
21 |
+
# Old Code for Importing AudioLDM
|
22 |
+
# from audioldm.pipeline import build_model
|
23 |
+
|
24 |
+
# HF Code for AudioLDM2
|
25 |
+
# from diffusers import AudioLDM2Pipeline
|
26 |
+
from audioldm.audio import wav_to_fbank, TacotronSTFT
|
27 |
+
try:
|
28 |
+
from audioldm2 import build_model
|
29 |
+
except:
|
30 |
+
from audioldm2 import build_model
|
31 |
+
|
32 |
+
# TODO: Replace these with args
|
33 |
+
audio_path = "/data/robbizorg/music_datasets/fma/data/fma_large/"
|
34 |
+
target_audio_path = "/data/robbizorg/music_datasets/fma/data/fma_processed/"
|
35 |
+
remake = False
|
36 |
+
|
37 |
+
## AudioLDM Mel Spec
|
38 |
+
default_mel_config = {
|
39 |
+
"preprocessing": {
|
40 |
+
"audio": {
|
41 |
+
"sampling_rate": 16000,
|
42 |
+
"max_wav_value": 32768,
|
43 |
+
"duration": 10.24,
|
44 |
+
},
|
45 |
+
"stft": {"filter_length": 1024, "hop_length": 160, "win_length": 1024},
|
46 |
+
"mel": {"n_mel_channels": 64, "mel_fmin": 0, "mel_fmax": 8000},
|
47 |
+
}}
|
48 |
+
|
49 |
+
fn_STFT = TacotronSTFT(
|
50 |
+
default_mel_config["preprocessing"]["stft"]["filter_length"],
|
51 |
+
default_mel_config["preprocessing"]["stft"]["hop_length"],
|
52 |
+
default_mel_config["preprocessing"]["stft"]["win_length"],
|
53 |
+
default_mel_config["preprocessing"]["mel"]["n_mel_channels"],
|
54 |
+
default_mel_config["preprocessing"]["audio"]["sampling_rate"],
|
55 |
+
default_mel_config["preprocessing"]["mel"]["mel_fmin"],
|
56 |
+
default_mel_config["preprocessing"]["mel"]["mel_fmax"],
|
57 |
+
)
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
audio_files = glob.glob(os.path.join(audio_path, "*/*.mp3"))
|
61 |
+
failed_files = []
|
62 |
+
|
63 |
+
# Preprocess all mel_specs
|
64 |
+
for i, f in enumerate(audio_files):
|
65 |
+
if i % 1000 == 0:
|
66 |
+
print(f"{i} of {len(audio_files)} files have been processed.")
|
67 |
+
dir_info = f.split("/")
|
68 |
+
filename = dir_info[-1].split(".")[0]
|
69 |
+
parent_dir = dir_info[-2]
|
70 |
+
|
71 |
+
# Skip the file if it's already generated
|
72 |
+
if not remake and os.path.exists(os.path.join(target_audio_path, parent_dir, filename + '.npy')):
|
73 |
+
continue
|
74 |
+
|
75 |
+
try:
|
76 |
+
audio, sr = torchaudio.load(f)
|
77 |
+
except:
|
78 |
+
failed_files.append(f)
|
79 |
+
print(f"Failed on File {f}")
|
80 |
+
continue
|
81 |
+
|
82 |
+
|
83 |
+
if audio.shape[0] == 2:
|
84 |
+
mono_audio = torch.mean(audio, axis = 0) # Convert to Mono
|
85 |
+
else:
|
86 |
+
mono_audio = audio[0, :] # remove channel info
|
87 |
+
|
88 |
+
# Resample Audio
|
89 |
+
resamp_16k = torchaudio.functional.resample(mono_audio, sr, 16000)
|
90 |
+
|
91 |
+
duration = resamp_16k.shape[0]/16000
|
92 |
+
target_length = int(duration * 100) # int(duration * 102.4)
|
93 |
+
|
94 |
+
mel, _, _ = wav_to_fbank(resamp_16k.cpu(), target_length=target_length, fn_STFT=fn_STFT)
|
95 |
+
|
96 |
+
# Make parent dir
|
97 |
+
if not os.path.exists(os.path.join(target_audio_path, parent_dir)):
|
98 |
+
os.mkdir(os.path.join(target_audio_path, parent_dir))
|
99 |
+
|
100 |
+
with open(os.path.join(target_audio_path, parent_dir, filename + '.npy'), 'wb') as numpy_f:
|
101 |
+
np.save(numpy_f, mel.numpy())
|
102 |
+
|
103 |
+
print("Failed_Files:", len(failed_files))
|
104 |
+
for x in failed_files:
|
105 |
+
print(x)
|
106 |
+
|
audio_diffusion_attacks_forhf/scripts/data_processing/process_music_numpy.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
process_music_numpy.py
|
3 |
+
Desc: Run this script with the appropriate data paths to preprocess audio files and convert to 48k for the ArchiSound encoders
|
4 |
+
Command: `python -u scripts/data/processing/process_music_mels.py`
|
5 |
+
"""
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import seaborn as sns
|
11 |
+
import IPython
|
12 |
+
|
13 |
+
import scipy
|
14 |
+
import torch
|
15 |
+
import torchaudio
|
16 |
+
import os
|
17 |
+
import ast
|
18 |
+
import soundfile as sf
|
19 |
+
import glob
|
20 |
+
|
21 |
+
# Old Code for Importing AudioLDM
|
22 |
+
# from audioldm.pipeline import build_model
|
23 |
+
|
24 |
+
# HF Code for AudioLDM2
|
25 |
+
# from diffusers import AudioLDM2Pipeline
|
26 |
+
from audioldm.audio import wav_to_fbank, TacotronSTFT
|
27 |
+
try:
|
28 |
+
from audioldm2 import build_model
|
29 |
+
except:
|
30 |
+
from audioldm2 import build_model
|
31 |
+
|
32 |
+
# TODO: Replace these with args
|
33 |
+
audio_path = "/data/robbizorg/music_datasets/fma/data/fma_large/"
|
34 |
+
target_audio_path = "/data/robbizorg/music_datasets/fma/data/fma_processed_48k/"
|
35 |
+
remake = False
|
36 |
+
|
37 |
+
if __name__ == "__main__":
|
38 |
+
audio_files = glob.glob(os.path.join(audio_path, "*/*.mp3"))
|
39 |
+
failed_files = []
|
40 |
+
|
41 |
+
# Preprocess all mel_specs
|
42 |
+
for i, f in enumerate(audio_files):
|
43 |
+
if i % 1000 == 0:
|
44 |
+
print(f"{i} of {len(audio_files)} files have been processed.")
|
45 |
+
dir_info = f.split("/")
|
46 |
+
filename = dir_info[-1].split(".")[0]
|
47 |
+
parent_dir = dir_info[-2]
|
48 |
+
|
49 |
+
# Skip the file if it's already generated
|
50 |
+
if not remake and os.path.exists(os.path.join(target_audio_path, parent_dir, filename + '.npy')):
|
51 |
+
continue
|
52 |
+
|
53 |
+
try:
|
54 |
+
audio, sr = torchaudio.load(f)
|
55 |
+
except:
|
56 |
+
failed_files.append(f)
|
57 |
+
print(f"Failed on File {f}")
|
58 |
+
continue
|
59 |
+
|
60 |
+
|
61 |
+
# Resample Audio--Don't need to make mono since archisound encoders take in stereo
|
62 |
+
resamp_48k = torchaudio.functional.resample(audio, sr, 48000)
|
63 |
+
|
64 |
+
# Make parent dir
|
65 |
+
if not os.path.exists(os.path.join(target_audio_path, parent_dir)):
|
66 |
+
os.mkdir(os.path.join(target_audio_path, parent_dir))
|
67 |
+
|
68 |
+
with open(os.path.join(target_audio_path, parent_dir, filename + '.npy'), 'wb') as numpy_f:
|
69 |
+
np.save(numpy_f, resamp_48k.numpy())
|
70 |
+
|
71 |
+
print("Failed_Files:", len(failed_files))
|
72 |
+
for x in failed_files:
|
73 |
+
print(x)
|
74 |
+
|
audio_diffusion_attacks_forhf/scripts/train/music_models/train_music_completion.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
train_music_completion.py
|
3 |
+
Desc: Train a model for completing a 3 seconds of audio given 3 seconds of music as input.
|
4 |
+
Note: There are two possible approaches for this task
|
5 |
+
1. Perform masking and try to get the model to fill in the blank with StyleVDiffusion
|
6 |
+
2. Condition on the mel-spec with VDiffusion
|
7 |
+
"""
|
8 |
+
|
9 |
+
import sys
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torchaudio
|
15 |
+
import gc
|
16 |
+
import argparse
|
17 |
+
import os
|
18 |
+
from tqdm import tqdm
|
19 |
+
import wandb
|
20 |
+
from audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion, VSampler
|
21 |
+
import soundfile as sf
|
22 |
+
|
23 |
+
sys.path.append(".")
|
24 |
+
from models.style_diffusion import StyleVDiffusion, StyleVSampler
|
25 |
+
from models.datasets.music_datasets import MusicMelDataset
|
26 |
+
|
27 |
+
import logging
|
28 |
+
|
29 |
+
from audioldm.audio import wav_to_fbank, TacotronSTFT
|
30 |
+
from audioldm2 import build_model
|
31 |
+
|
32 |
+
|
33 |
+
# Uncomment out below if wanting to supress
|
34 |
+
import warnings
|
35 |
+
warnings.filterwarnings("ignore")
|
36 |
+
|
37 |
+
# Set Sample Rate if like so if desired
|
38 |
+
SAMPLE_RATE = 16000
|
39 |
+
BATCH_SIZE = 16
|
40 |
+
|
41 |
+
|
42 |
+
# Function for creating a model that acts on mel-specs
|
43 |
+
def create_mel_model():
|
44 |
+
return DiffusionModel(
|
45 |
+
net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)
|
46 |
+
# dim=2, # for spectrogram we can use 2D-CNN, but not going to for now
|
47 |
+
in_channels=600, # U-Net: number of input (time) channels
|
48 |
+
out_channels=300, # U-Net: number of output (time) channels
|
49 |
+
channels=[8, 32, 64, 128, 256, 512], # U-Net: channels at each layer
|
50 |
+
factors=[2, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
|
51 |
+
items=[2, 2, 2, 2, 2, 2], # U-Net: number of repeating items at each layer
|
52 |
+
attentions=[0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer
|
53 |
+
attention_heads=8, # U-Net: number of attention heads per attention item
|
54 |
+
attention_features=64, # U-Net: number of attention features per attention item
|
55 |
+
diffusion_t=StyleVDiffusion, # The diffusion method used
|
56 |
+
sampler_t=StyleVSampler, # The diffusion sampler used
|
57 |
+
# embedding_features = 7, # Embedding Features for when conditioned
|
58 |
+
# cross_attentions=[0, 0, 0, 0, 1, 1, 1, 1]
|
59 |
+
)
|
60 |
+
|
61 |
+
def create_2Dmel_model():
|
62 |
+
return DiffusionModel(
|
63 |
+
net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)
|
64 |
+
dim=2, # for spectrogram we can use 2D-CNN
|
65 |
+
in_channels=2, # U-Net: number of input (time) channels
|
66 |
+
out_channels=1, # U-Net: number of output (time) channels
|
67 |
+
channels=[8, 32, 64, 128, 256, 512], # U-Net: channels at each layer
|
68 |
+
factors=[2, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
|
69 |
+
items=[2, 2, 2, 2, 2, 2], # U-Net: number of repeating items at each layer
|
70 |
+
attentions=[0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer
|
71 |
+
attention_heads=8, # U-Net: number of attention heads per attention item
|
72 |
+
attention_features=64, # U-Net: number of attention features per attention item
|
73 |
+
diffusion_t=StyleVDiffusion, # The diffusion method used
|
74 |
+
sampler_t=StyleVSampler, # The diffusion sampler used
|
75 |
+
# embedding_features = 7, # Embedding Features for when conditioned
|
76 |
+
# cross_attentions=[0, 0, 0, 0, 1, 1, 1, 1]
|
77 |
+
)
|
78 |
+
|
79 |
+
def parse_args():
|
80 |
+
parser = argparse.ArgumentParser()
|
81 |
+
parser.add_argument("--checkpoint", type=str, default='/data/robbizorg/attacksanddefenses/checkpoints/')
|
82 |
+
parser.add_argument("--resume", action="store_true")
|
83 |
+
parser.add_argument("--run_id", type=str, default='')
|
84 |
+
parser.add_argument("--debug", action="store_true")
|
85 |
+
parser.add_argument("--data_path", type = str, default = "./data/fma_valid_files.npy")
|
86 |
+
parser.add_argument("--epoch_num", type = int, default = 101)
|
87 |
+
args = parser.parse_args()
|
88 |
+
return vars(args)
|
89 |
+
|
90 |
+
if __name__ == "__main__":
|
91 |
+
|
92 |
+
args = parse_args()
|
93 |
+
|
94 |
+
if args['run_id'] == '':
|
95 |
+
raise ValueError(f"Please provide a run_id for this training session.")
|
96 |
+
|
97 |
+
cuda_ids = [phy_id for phy_id in range(len(os.environ["CUDA_VISIBLE_DEVICES"].split(",")))]
|
98 |
+
|
99 |
+
if len(cuda_ids) > 1:
|
100 |
+
raise NotImplementedError("Currently training is only allowed on a single GPU")
|
101 |
+
|
102 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
103 |
+
|
104 |
+
logging.basicConfig(
|
105 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
106 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
107 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
108 |
+
stream=sys.stdout,
|
109 |
+
filemode='w',
|
110 |
+
)
|
111 |
+
logger = logging.getLogger("")
|
112 |
+
|
113 |
+
audio_files = list(np.load(args['data_path'], allow_pickle = True).item())
|
114 |
+
|
115 |
+
dataset = MusicMelDataset(audio_files, audio_len=5.12)
|
116 |
+
|
117 |
+
print(f"Dataset length: {len(dataset)}")
|
118 |
+
|
119 |
+
dataloader = torch.utils.data.DataLoader(
|
120 |
+
dataset,
|
121 |
+
batch_size=BATCH_SIZE,
|
122 |
+
shuffle=True,
|
123 |
+
num_workers=16,
|
124 |
+
pin_memory=False,
|
125 |
+
)
|
126 |
+
|
127 |
+
# Use this model for vocoder
|
128 |
+
vae_model = build_model().to(device)
|
129 |
+
|
130 |
+
diff_model = create_2Dmel_model().to(device)
|
131 |
+
|
132 |
+
|
133 |
+
optimizer = torch.optim.AdamW(params=list(diff_model.parameters()), lr=1e-4, betas= (0.95, 0.999), eps=1e-6, weight_decay=1e-3)
|
134 |
+
|
135 |
+
print(f"Number of parameters: {sum(p.numel() for p in diff_model.parameters() if p.requires_grad)}")
|
136 |
+
|
137 |
+
if not args['debug']:
|
138 |
+
run_id = wandb.util.generate_id()
|
139 |
+
if args["run_id"] is not None:
|
140 |
+
run_id = args["run_id"]
|
141 |
+
print(f"Run ID: {run_id}")
|
142 |
+
|
143 |
+
wandb.init(project="music-completion", resume=args["resume"], id=run_id)
|
144 |
+
|
145 |
+
epoch = 0
|
146 |
+
step = 0
|
147 |
+
|
148 |
+
checkpoint_path = os.path.join(args["checkpoint"], args["run_id"])
|
149 |
+
|
150 |
+
if not os.path.exists(checkpoint_path):
|
151 |
+
os.makedirs(checkpoint_path)
|
152 |
+
os.makedirs(os.path.join(checkpoint_path, "mels"))
|
153 |
+
os.makedirs(os.path.join(checkpoint_path, "wavs"))
|
154 |
+
|
155 |
+
|
156 |
+
if not args['debug'] and wandb.run.resumed:
|
157 |
+
if os.path.exists(checkpoint_path):
|
158 |
+
checkpoint = torch.load(checkpoint_path)
|
159 |
+
else:
|
160 |
+
checkpoint = torch.load(wandb.restore(checkpoint_path))
|
161 |
+
diff_model.load_state_dict(checkpoint['model_state_dict'])
|
162 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
163 |
+
epoch = checkpoint['epoch']
|
164 |
+
step = epoch * len(dataloader)
|
165 |
+
|
166 |
+
scaler = torch.cuda.amp.GradScaler()
|
167 |
+
|
168 |
+
diff_model.train()
|
169 |
+
|
170 |
+
while epoch < args['epoch_num']:
|
171 |
+
avg_loss = 0
|
172 |
+
avg_loss_step = 0
|
173 |
+
progress = tqdm(dataloader)
|
174 |
+
for i, (audio, masked_audio) in enumerate(progress):
|
175 |
+
optimizer.zero_grad()
|
176 |
+
# audio = torch.swapaxes(audio.to(device), 1, 2)
|
177 |
+
# masked_audio = torch.swapaxes(masked_audio.to(device), 1, 2)
|
178 |
+
# audio = audio.to(device)
|
179 |
+
# masked_audio = masked_audio.to(device)
|
180 |
+
audio = audio.to(device).unsqueeze(1)
|
181 |
+
masked_audio = masked_audio.to(device).unsqueeze(1)
|
182 |
+
|
183 |
+
with torch.cuda.amp.autocast():
|
184 |
+
loss = diff_model(masked_audio, audio)
|
185 |
+
avg_loss += loss.item()
|
186 |
+
avg_loss_step += 1
|
187 |
+
scaler.scale(loss).backward()
|
188 |
+
scaler.step(optimizer)
|
189 |
+
scaler.update()
|
190 |
+
progress.set_postfix(
|
191 |
+
# loss=loss.item(),
|
192 |
+
loss=avg_loss / avg_loss_step,
|
193 |
+
epoch=epoch + i / len(dataloader),
|
194 |
+
)
|
195 |
+
|
196 |
+
if step % 500 == 0:
|
197 |
+
# if step % 1 == 0:
|
198 |
+
# Turn noise into new audio sample with diffusion
|
199 |
+
# noise = torch.randn(1, 300, 64, device=device)
|
200 |
+
# noise = torch.randn(1, 64, 300, device=device)
|
201 |
+
noise = torch.randn(1, 1, 512, 64, device=device) # 2D example
|
202 |
+
|
203 |
+
|
204 |
+
with torch.cuda.amp.autocast():
|
205 |
+
sample = diff_model.sample(masked_audio[0], noise, num_steps=200)
|
206 |
+
|
207 |
+
orig_wav = vae_model.mel_spectrogram_to_waveform(audio[0].unsqueeze(0), save = False)[0][0].astype(np.float32) # 1, 1, len
|
208 |
+
gen_wav = vae_model.mel_spectrogram_to_waveform(sample, save = False)[0][0].astype(np.float32) # 1 1 len
|
209 |
+
|
210 |
+
orig_dir = os.path.join(checkpoint_path, 'wavs', f'target_{step}0.wav')
|
211 |
+
gen_dir = os.path.join(checkpoint_path, 'wavs', f'gen_{step}0.wav')
|
212 |
+
sf.write(orig_dir, orig_wav, samplerate = 16000)
|
213 |
+
sf.write(gen_dir, gen_wav, samplerate = 16000)
|
214 |
+
|
215 |
+
if not args['debug']:
|
216 |
+
wandb.log({
|
217 |
+
"step": step,
|
218 |
+
"epoch": epoch + i / len(dataloader),
|
219 |
+
"loss": avg_loss / avg_loss_step,
|
220 |
+
"target_audio": wandb.Audio(orig_dir, caption="Target audio", sample_rate=SAMPLE_RATE),
|
221 |
+
"generated_audio": wandb.Audio(gen_dir, caption="Generated audio", sample_rate=SAMPLE_RATE)
|
222 |
+
})
|
223 |
+
|
224 |
+
if not args['debug'] and step % 100 == 0:
|
225 |
+
wandb.log({
|
226 |
+
"step": step,
|
227 |
+
"epoch": epoch + i / len(dataloader),
|
228 |
+
"loss": avg_loss / avg_loss_step,
|
229 |
+
})
|
230 |
+
avg_loss = 0
|
231 |
+
avg_loss_step = 0
|
232 |
+
|
233 |
+
step += 1
|
234 |
+
|
235 |
+
epoch += 1
|
236 |
+
|
237 |
+
if not args['debug'] and epoch % 100 == 0:
|
238 |
+
torch.save({
|
239 |
+
'epoch': epoch,
|
240 |
+
'model_state_dict': diff_model.state_dict(),
|
241 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
242 |
+
}, os.path.join(checkpoint_path, f"epoch-{epoch}.pt"))
|
243 |
+
wandb.save(checkpoint_path, base_path=args["checkpoint"])
|
audio_diffusion_attacks_forhf/scripts/train/train_tts.py
ADDED
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
train_tts.py
|
3 |
+
Desc: An example script for training a Diffusion-based TTS model with a speaker encoder.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import sys
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torchaudio
|
11 |
+
import gc
|
12 |
+
import argparse
|
13 |
+
import os
|
14 |
+
from tqdm import tqdm
|
15 |
+
import wandb
|
16 |
+
from audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion, VSampler
|
17 |
+
|
18 |
+
sys.path.append(".")
|
19 |
+
from models.style_diffusion import StyleVDiffusion, StyleVSampler
|
20 |
+
# from models.utils import MonoTransform
|
21 |
+
|
22 |
+
# from util import calculate_codebook_bitrate, extract_melspectrogram, get_audio_file_bitrate, get_duration, load_neural_audio_codec
|
23 |
+
from audioldm.pipeline import build_model
|
24 |
+
import torch.multiprocessing as mp
|
25 |
+
|
26 |
+
# Needed for Instruction/Prompt Models
|
27 |
+
# from transformers import AutoTokenizer, T5EncoderModel
|
28 |
+
|
29 |
+
import logging
|
30 |
+
|
31 |
+
# Uncomment out below if wanting to supress
|
32 |
+
# import warnings
|
33 |
+
# warnings.filterwarnings("ignore")
|
34 |
+
|
35 |
+
# Set Sample Rate if like so if desired
|
36 |
+
SAMPLE_RATE = 16000
|
37 |
+
BATCH_SIZE = 16
|
38 |
+
NUM_SAMPLES = int(2.56 * SAMPLE_RATE)
|
39 |
+
# NUM_SAMPLES = 2 ** 15
|
40 |
+
|
41 |
+
|
42 |
+
def create_model():
|
43 |
+
return DiffusionModel(
|
44 |
+
net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)
|
45 |
+
# dim=2, # for spectrogram we use 2D-CNN
|
46 |
+
in_channels=314, # U-Net: number of input (audio) channels
|
47 |
+
out_channels=157, # U-Net: number of output (audio) channels
|
48 |
+
channels=[256, 256, 512, 512, 768, 768, 1280, 1280], # U-Net: channels at each layer
|
49 |
+
factors=[2, 2, 2, 2, 2, 2, 2, 1], # U-Net: downsampling and upsampling factors at each layer
|
50 |
+
items=[2, 2, 2, 2, 2, 2, 2, 2], # U-Net: number of repeating items at each layer
|
51 |
+
attentions=[0, 0, 0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer
|
52 |
+
attention_heads=8, # U-Net: number of attention heads per attention item
|
53 |
+
attention_features=64, # U-Net: number of attention features per attention item
|
54 |
+
diffusion_t=StyleVDiffusion, # The diffusion method used
|
55 |
+
sampler_t=StyleVSampler, # The diffusion sampler used
|
56 |
+
# embedding_features = 8,
|
57 |
+
# embedding_features = 2, # Embedding for when it's just res and weight
|
58 |
+
embedding_features = 7, # Embedding Features for when Severity is Dropped
|
59 |
+
cross_attentions=[0, 0, 0, 0, 1, 1, 1, 1]
|
60 |
+
)
|
61 |
+
|
62 |
+
def main():
|
63 |
+
pass
|
64 |
+
# args = parse_args()
|
65 |
+
|
66 |
+
# os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'
|
67 |
+
# os.environ["CUDA_VISIBLE_DEVICES"] = args['cuda_ids']
|
68 |
+
# cuda_ids = [phy_id for phy_id in range(len(args['cuda_ids'].split(",")))]
|
69 |
+
|
70 |
+
# logging.basicConfig(
|
71 |
+
# format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
72 |
+
# datefmt="%Y-%m-%d %H:%M:%S",
|
73 |
+
# level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
74 |
+
# stream=sys.stdout,
|
75 |
+
# filemode='w',
|
76 |
+
# )
|
77 |
+
# logger = logging.getLogger("")
|
78 |
+
|
79 |
+
# # mp.set_start_method('spawn')
|
80 |
+
# # mp.set_sharing_strategy('file_system')
|
81 |
+
|
82 |
+
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
83 |
+
|
84 |
+
# # Load in text model
|
85 |
+
# # tokenizer = AutoTokenizer.from_pretrained("t5-small")
|
86 |
+
# # text_model = T5EncoderModel.from_pretrained("t5-small")
|
87 |
+
# # text_model.eval() # Don't want to train it!
|
88 |
+
|
89 |
+
|
90 |
+
# dataset = DSVAE_CondStyleWAVDataset(
|
91 |
+
# path="/data/robbizorg/pqvd_gen_w_conditioning/speech_non_speech_timesteps_VCTK.json",
|
92 |
+
# random_crop_size=NUM_SAMPLES,
|
93 |
+
# sample_rate=SAMPLE_RATE,
|
94 |
+
# transforms=AllTransform(
|
95 |
+
# mono=True,
|
96 |
+
# ),
|
97 |
+
# reconstructive = False, # Make this true to just train a reconstructive model
|
98 |
+
# identity_limit = 1 # Affects how often we learn identity mapping
|
99 |
+
# )
|
100 |
+
|
101 |
+
# print(f"Dataset length: {len(dataset)}")
|
102 |
+
|
103 |
+
# dataloader = torch.utils.data.DataLoader(
|
104 |
+
# dataset,
|
105 |
+
# batch_size=BATCH_SIZE,
|
106 |
+
# shuffle=True,
|
107 |
+
# num_workers=16,
|
108 |
+
# pin_memory=False,
|
109 |
+
# )
|
110 |
+
|
111 |
+
# vae_model = DSVAE(logger, **args).cuda()
|
112 |
+
|
113 |
+
# if not os.path.exists(args['model_path']):
|
114 |
+
# logger.warning("model not exist and we just create the new model......")
|
115 |
+
# else:
|
116 |
+
# logger.info("Model Exists")
|
117 |
+
# logger.info("Model Path is " + args['model_path'])
|
118 |
+
# vae_model.loadParameters(args['model_path'])
|
119 |
+
# vae_model = torch.nn.DataParallel(vae_model, device_ids = cuda_ids, output_device=cuda_ids[0])
|
120 |
+
# vae_model = vae_model.cuda()
|
121 |
+
# vae_model.eval()
|
122 |
+
# vae_model.module.eer = True
|
123 |
+
|
124 |
+
# diff_model = create_model().to(device)
|
125 |
+
# # audio_codec = build_model().to(device)
|
126 |
+
# # audio_codec.latent_t_size = 157
|
127 |
+
# # config, audio_codec, vocoder = load_neural_audio_codec('2021-05-19T22-16-54_vggsound_codebook', './logs', device)
|
128 |
+
|
129 |
+
|
130 |
+
# # optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
131 |
+
# optimizer = torch.optim.AdamW(params=list(diff_model.parameters()), lr=1e-4, betas= (0.95, 0.999), eps=1e-6, weight_decay=1e-3)
|
132 |
+
|
133 |
+
# print(f"Number of parameters: {sum(p.numel() for p in diff_model.parameters() if p.requires_grad)}")
|
134 |
+
|
135 |
+
# run_id = wandb.util.generate_id()
|
136 |
+
# if args["run_id"] is not None:
|
137 |
+
# run_id = args["run_id"]
|
138 |
+
# print(f"Run ID: {run_id}")
|
139 |
+
|
140 |
+
# wandb.init(project="audio-diffusion-no-condition", resume=args["resume"], id=run_id)
|
141 |
+
|
142 |
+
# epoch = 0
|
143 |
+
# step = 0
|
144 |
+
|
145 |
+
# checkpoint_path = os.path.join(args["checkpoint"], args["run_id"])
|
146 |
+
|
147 |
+
# if not os.path.exists(checkpoint_path):
|
148 |
+
# os.makedirs(checkpoint_path)
|
149 |
+
# os.makedirs(os.path.join(checkpoint_path, "mels"))
|
150 |
+
# os.makedirs(os.path.join(checkpoint_path, "wavs"))
|
151 |
+
|
152 |
+
|
153 |
+
# if wandb.run.resumed:
|
154 |
+
# if os.path.exists(checkpoint_path):
|
155 |
+
# checkpoint = torch.load(checkpoint_path)
|
156 |
+
# else:
|
157 |
+
# checkpoint = torch.load(wandb.restore(checkpoint_path))
|
158 |
+
# diff_model.load_state_dict(checkpoint['model_state_dict'])
|
159 |
+
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
160 |
+
# epoch = checkpoint['epoch']
|
161 |
+
# step = epoch * len(dataloader)
|
162 |
+
|
163 |
+
# scaler = torch.cuda.amp.GradScaler()
|
164 |
+
|
165 |
+
# diff_model.train()
|
166 |
+
# while epoch < 101:
|
167 |
+
# avg_loss = 0
|
168 |
+
# avg_loss_step = 0
|
169 |
+
# progress = tqdm(dataloader)
|
170 |
+
# for i, (audio, target, embedding) in enumerate(progress):
|
171 |
+
# optimizer.zero_grad()
|
172 |
+
# audio = audio.to(device)
|
173 |
+
# target = target.to(device)
|
174 |
+
# embedding = embedding.to(device)
|
175 |
+
|
176 |
+
# with torch.no_grad():
|
177 |
+
# embedding = embedding.float() # Make it float like the others
|
178 |
+
|
179 |
+
# speaker_embed_source, content_embed_source = vae_model(audio)
|
180 |
+
# speaker_embed_source = speaker_embed_source.unsqueeze(1).expand(-1, 157, -1)
|
181 |
+
|
182 |
+
# audio_embed = torch.cat((speaker_embed_source, content_embed_source), axis = -1)
|
183 |
+
|
184 |
+
# # zeroes = torch.zeros(16, 3, 128, dtype=audio_embed.dtype, device = audio_embed.device)
|
185 |
+
# # audio_embed = torch.cat((audio_embed, zeroes), dim=1)
|
186 |
+
|
187 |
+
# speaker_embed, content_embed = vae_model(target)
|
188 |
+
# speaker_embed = speaker_embed.unsqueeze(1).expand(-1, 157, -1)
|
189 |
+
|
190 |
+
# # in order to simulate paired data, do (naive) voice conversion first
|
191 |
+
# target_embed = torch.cat((speaker_embed, content_embed_source), axis = -1)
|
192 |
+
# # target_embed = torch.cat((target_embed, zeroes), dim = 1)
|
193 |
+
|
194 |
+
# with torch.cuda.amp.autocast():
|
195 |
+
# loss = diff_model(audio_embed, target_embed, embedding=embedding)
|
196 |
+
# avg_loss += loss.item()
|
197 |
+
# avg_loss_step += 1
|
198 |
+
# scaler.scale(loss).backward()
|
199 |
+
# scaler.step(optimizer)
|
200 |
+
# scaler.update()
|
201 |
+
# progress.set_postfix(
|
202 |
+
# # loss=loss.item(),
|
203 |
+
# loss=avg_loss / avg_loss_step,
|
204 |
+
# epoch=epoch + i / len(dataloader),
|
205 |
+
# )
|
206 |
+
|
207 |
+
# if step % 500 == 0:
|
208 |
+
# # if step % 1 == 0:
|
209 |
+
# # Turn noise into new audio sample with diffusion
|
210 |
+
# noise = torch.randn(1, 157, 128, device=device)
|
211 |
+
|
212 |
+
|
213 |
+
# with torch.cuda.amp.autocast():
|
214 |
+
# sample = diff_model.sample(audio_embed[0], noise, embedding=embedding[0][None, :], num_steps=200)
|
215 |
+
|
216 |
+
|
217 |
+
|
218 |
+
# # Save the melspecs
|
219 |
+
# audio_sub = torch.swapaxes(audio[0].unsqueeze(0), 1, 2)
|
220 |
+
# # target_sub = torch.swapaxes(target[0].unsqueeze(0), 1, 2) # This is the original target audio, not what we want
|
221 |
+
# target_sub = vae_model.module.share_decoder(target_embed).loc
|
222 |
+
# gen_mel = vae_model.module.share_decoder(sample).loc
|
223 |
+
|
224 |
+
# vae_model.module.draw_mel(audio_sub, mode=f"source_{step}", file_path = os.path.join(checkpoint_path, "mels"))
|
225 |
+
# vae_model.module.draw_mel(target_sub, mode=f"target_{step}", file_path = os.path.join(checkpoint_path, "mels"))
|
226 |
+
# vae_model.module.draw_mel(gen_mel, mode=f"gen_{step}", file_path = os.path.join(checkpoint_path, "mels"))
|
227 |
+
|
228 |
+
# vae_model.module.mel2wav(audio_sub, mode=f"source_{step}", task="vc", file_path = os.path.join(checkpoint_path, "wavs"))
|
229 |
+
# vae_model.module.mel2wav(target_sub, mode=f"target_{step}", task="vc", file_path = os.path.join(checkpoint_path, "wavs"))
|
230 |
+
# vae_model.module.mel2wav(gen_mel, mode=f"gen_{step}", task="vc", file_path = os.path.join(checkpoint_path, "wavs"))
|
231 |
+
|
232 |
+
# # torchaudio.save(os.path.join(checkpoint_path, 'wavs', f'test_input_sound_{step}.wav'), torch.from_numpy(audio_codec.mel_spectrogram_to_waveform(audio_codec.decode_first_stage(audio[0].unsqueeze(0))))[0], SAMPLE_RATE)
|
233 |
+
# # torchaudio.save(os.path.join(checkpoint_path, 'wavs', f'test_generated_sound_{step}.wav'), torch.from_numpy(audio_codec.mel_spectrogram_to_waveform(audio_codec.decode_first_stage(sample[0].unsqueeze(0))))[0], SAMPLE_RATE)
|
234 |
+
# # torchaudio.save(os.path.join(checkpoint_path, 'wavs', f'test_target_sound_{step}.wav'), torch.from_numpy(audio_codec.mel_spectrogram_to_waveform(audio_codec.decode_first_stage(target[0].unsqueeze(0))))[0], SAMPLE_RATE)
|
235 |
+
|
236 |
+
|
237 |
+
|
238 |
+
# wandb.log({
|
239 |
+
# "step": step,
|
240 |
+
# "epoch": epoch + i / len(dataloader),
|
241 |
+
# "loss": avg_loss / avg_loss_step,
|
242 |
+
# "input_mel": wandb.Image(os.path.join(checkpoint_path, "mels", f"source_{step}_mel_0.png"), caption="Input Mel"),
|
243 |
+
# "target_mel": wandb.Image(os.path.join(checkpoint_path, "mels", f"target_{step}_mel_0.png"), caption="Target Mel"),
|
244 |
+
# "gen_mel": wandb.Image(os.path.join(checkpoint_path, "mels", f"gen_{step}_mel_0.png"), caption="Gen Mel"),
|
245 |
+
# "input_audio": wandb.Audio(os.path.join(checkpoint_path, 'wavs', f'source_{step}0.wav'), caption="Input audio", sample_rate=SAMPLE_RATE),
|
246 |
+
# "target_audio": wandb.Audio(os.path.join(checkpoint_path, 'wavs', f'target_{step}0.wav'), caption="Target audio", sample_rate=SAMPLE_RATE),
|
247 |
+
# "generated_audio": wandb.Audio(os.path.join(checkpoint_path, 'wavs', f'gen_{step}0.wav'), caption="Generated audio", sample_rate=SAMPLE_RATE)
|
248 |
+
# })
|
249 |
+
|
250 |
+
# if step % 100 == 0:
|
251 |
+
# wandb.log({
|
252 |
+
# "step": step,
|
253 |
+
# "epoch": epoch + i / len(dataloader),
|
254 |
+
# "loss": avg_loss / avg_loss_step,
|
255 |
+
# })
|
256 |
+
# avg_loss = 0
|
257 |
+
# avg_loss_step = 0
|
258 |
+
|
259 |
+
# step += 1
|
260 |
+
|
261 |
+
# epoch += 1
|
262 |
+
|
263 |
+
# if epoch % 100 == 0:
|
264 |
+
# torch.save({
|
265 |
+
# 'epoch': epoch,
|
266 |
+
# 'model_state_dict': diff_model.state_dict(),
|
267 |
+
# 'optimizer_state_dict': optimizer.state_dict(),
|
268 |
+
# }, os.path.join(checkpoint_path, f"epoch-{epoch}.pt"))
|
269 |
+
# wandb.save(checkpoint_path, base_path=args["checkpoint"])
|
270 |
+
|
271 |
+
|
272 |
+
# def parse_args():
|
273 |
+
# parser = argparse.ArgumentParser()
|
274 |
+
# parser.add_argument("--checkpoint", type=str, default='/data/robbizorg/pqvd_gen_w_dsvae/checkpoints/')
|
275 |
+
# parser.add_argument("--resume", action="store_true")
|
276 |
+
# parser.add_argument("--run_id", type=str, default='condition_ldm')
|
277 |
+
|
278 |
+
# ## Params from DSVAE
|
279 |
+
# parser.add_argument('--dataset', type=str, default="VCTK", help='VCTK, LibriTTS')
|
280 |
+
# parser.add_argument('--encoder', type=str, default='dsvae', help='dsvae. tdnn')
|
281 |
+
# parser.add_argument('--vocoder', type=str, default='hifigan', help='wavenet, hifigan')
|
282 |
+
# parser.add_argument('--save_tsne', dest='save_tsne', action='store_true', help='save_tsne')
|
283 |
+
# parser.add_argument('--mel_tsne', dest='mel_tsne', action='store_true', help='mel_tsne')
|
284 |
+
# parser.add_argument('--feature', type=str, default='mel_spec', help='stft, mel_spec, mfcc')
|
285 |
+
# parser.add_argument('--model_path', type=str, default='/home/robbizorg/research/dsvae/save_models/dsvae/best699.pth')
|
286 |
+
# # parser.add_argument('--model_path', type=str, default='/data/andreaguz/save_models/dsvae_003_03/best699.pth') # Using the fine-tuned dsvae
|
287 |
+
# # parser.add_argument('--model_path', type=str, default='/data/andreaguz/save_models/dsvae_0001_0005/best.pth') # Using the fine-tuned dsvae
|
288 |
+
# parser.add_argument('--save_path', type=str, default='save_models/dsvae')
|
289 |
+
# parser.add_argument('--cuda_ids', type=str, default='0')
|
290 |
+
# parser.add_argument('--tsne_mode', type=str, default='test')
|
291 |
+
# parser.add_argument("--optimizer", type=str, default='adam', help='sgd, adam')
|
292 |
+
# parser.add_argument("--path_vc_1", type=str, default='', help='')
|
293 |
+
# parser.add_argument("--path_vc_2", type=str, default='', help='')
|
294 |
+
# parser.add_argument('--max_frames', type=int, default=100, help='1frame~10ms')
|
295 |
+
# parser.add_argument("--hop_size", type=int, default=256, help='hop_size')
|
296 |
+
# parser.add_argument("--win_length", type=int, default=1024, help='win_length')
|
297 |
+
# parser.add_argument("--spk_dim", type=int, default=64, help='spk_embed')
|
298 |
+
# parser.add_argument("--ecapa_spk_dim", type=int, default=128, help='ecapa spk_embed')
|
299 |
+
# parser.add_argument("--content_dim", type=int, default=64, help="content_embed")
|
300 |
+
# parser.add_argument("--conformer_hidden_dim", type=int, default=256, help="content_embed")
|
301 |
+
# parser.add_argument('--n_epochs', type=int, default=700, help='n_epochs')
|
302 |
+
# parser.add_argument('--eval_epoch', type=int, default=5, help='eval_epoch')
|
303 |
+
# parser.add_argument('--step_size', type=int, default=5, help='step_size')
|
304 |
+
# parser.add_argument('--num_workers', type=int, default=16, help='num_workers')
|
305 |
+
# parser.add_argument('--lr_decay_rate',type=float, default=0.95, help='lr_decay_rate')
|
306 |
+
# parser.add_argument('--lr',type=float, default=3e-4, help='lr_rate')
|
307 |
+
# # parser.add_argument('--klf_factor', type=float, default=3e-3, help='klf_factor')
|
308 |
+
# # parser.add_argument('--klt_factor', type=float, default=5, help='klt_factor')
|
309 |
+
# parser.add_argument('--klf_factor', type=float, default=3e-4, help='klf_factor') # Changed for the Fine-tuned Version
|
310 |
+
# parser.add_argument('--klt_factor', type=float, default=3e-3, help='klt_factor')
|
311 |
+
# parser.add_argument('--rec_factor', type=float, default=1, help='rec_factor')
|
312 |
+
# parser.add_argument('--vq_factor', type=float, default=1000, help='vq_factor')
|
313 |
+
# parser.add_argument('--zf_vq_factor', type=float, default=1000, help='vq_factor')
|
314 |
+
# parser.add_argument('--klf_std', type=float, default=0.5, help='klf_std')
|
315 |
+
# parser.add_argument('--rec_std', type=float, default=0.04, help='rec_std')
|
316 |
+
# parser.add_argument('--clip', type=float, default=1, help='rec_std')
|
317 |
+
# parser.add_argument('--phoneme_factor', type=float, default=1, help='phoneme_factor')
|
318 |
+
# parser.add_argument('--r_vq_factor', type=float, default=10, help='r_vq_factor')
|
319 |
+
# parser.add_argument('--compute_speaker_eer', dest='compute_speaker_eer', action='store_true', help='ASV EER')
|
320 |
+
# parser.add_argument('--eval_phoneme', dest='eval_phoneme', action='store_true', help='ASV EER')
|
321 |
+
# parser.add_argument('--num_eval', type=int, default=20, help='num of segments for eval')
|
322 |
+
# parser.add_argument('--batch_size', type=int, default=256, help='batch_size')
|
323 |
+
# parser.add_argument('--num_phonemes', type=int, default=100, help='num_phonemes')
|
324 |
+
# parser.add_argument('--with_phoneme', dest='with_phoneme', action='store_true', help='')
|
325 |
+
# parser.add_argument("--conversion", action='store_true', help='for conversion text')
|
326 |
+
# parser.add_argument("--conversion2", action='store_true', help='for conversion text')
|
327 |
+
# parser.add_argument("--conversion3", action='store_true', help='for conversion text')
|
328 |
+
# parser.add_argument("--mel2npy", action='store_true', help='mel2npy')
|
329 |
+
# parser.add_argument("--unconditional", action='store_true', help='unconditional')
|
330 |
+
# parser.add_argument('--zt_norm_mean', action='store_true', help='instancenorm1d on zt prior and post')
|
331 |
+
# parser.add_argument('--zf_norm_mean', action='store_true', help='instancenorm1d on zf prior and post')
|
332 |
+
# parser.add_argument('--freeze_encoder', action='store_true', help='if or not to freeze encoder')
|
333 |
+
# parser.add_argument('--freeze_decoder', action='store_true', help='if or not to freeze decoder')
|
334 |
+
# parser.add_argument("--sample_rate",type=int, default=16000, help='16000 or 48000')
|
335 |
+
# parser.add_argument('--noise_path', type=str, default='datasets/noise_list.scp', help='nosie invariant')
|
336 |
+
# parser.add_argument('--wav_aug_train', action='store_true', help='with data augmentation')
|
337 |
+
# parser.add_argument('--spec_aug_train', action='store_true', help='with data augmentation')
|
338 |
+
# parser.add_argument('--noise_train', action='store_true', help='noise')
|
339 |
+
# parser.add_argument('--triphn', action='store_true', help='with triphn')
|
340 |
+
# parser.add_argument('--train_hifigan', action='store_true', help='train hifigan')
|
341 |
+
# parser.add_argument("--prior_alignment", action='store_true', help='')
|
342 |
+
# parser.add_argument("--zf_vq", action='store_true', help='')
|
343 |
+
# parser.add_argument("--vq_prior_independent", action='store_true', help='')
|
344 |
+
# parser.add_argument("--vq_prior_regressive", action='store_true', help='')
|
345 |
+
# parser.add_argument("--vq_prior_pseudo", action='store_true', help='')
|
346 |
+
# parser.add_argument("--vq_size_zt",type=int, default=200, help='')
|
347 |
+
# parser.add_argument("--vq_size_zf",type=int, default=200, help='')
|
348 |
+
# parser.add_argument("--ignore_index",type=int, default=0, help='')
|
349 |
+
# parser.add_argument("--hidden_dim",type=int, default=256, help='')
|
350 |
+
|
351 |
+
# parser.add_argument("--share_encoder", type=str, default='cnn', help='')
|
352 |
+
# parser.add_argument("--share_decoder", type=str, default='cnn_lstm', help='cnn_lstm, cnn_transformer')
|
353 |
+
# parser.add_argument("--zt_encoder", type=str, default='lstm', help='lstm, conformer_encoder, transformer_encoder')
|
354 |
+
# parser.add_argument("--zf_encoder", type=str, default='lstm', help='lstm, transformer_encoder, ecapa_tdnn')
|
355 |
+
# parser.add_argument("--zt_prior_model", type=str, default='lstm', help='lstm, vqvae, transformer')
|
356 |
+
# parser.add_argument("--prior_signal", type=str, default='None', help='alignment_triphn, alignment_mono, melspec_pseudo, wavlm_pseudo, vq_embeds, vq_pseudo')
|
357 |
+
# parser.add_argument("--multi_scale_add", action='store_true', help='')
|
358 |
+
# parser.add_argument("--multi_scale_cat", action='store_true', help='')
|
359 |
+
# parser.add_argument("--num_scales",type=int, default=1, help='')
|
360 |
+
|
361 |
+
# parser.add_argument("--kmeans_num_clusters",type=int, default=50, help='')
|
362 |
+
# parser.add_argument("--wavlm_dim", type=int, default=768, help='')
|
363 |
+
|
364 |
+
# parser.add_argument("--ema_zt", action='store_true', help='')
|
365 |
+
# parser.add_argument("--ema_zf", action='store_true', help='')
|
366 |
+
|
367 |
+
# parser.add_argument("--r_vqvae", action='store_true', help='')
|
368 |
+
# parser.add_argument("--masked_mel", action='store_true', help='')
|
369 |
+
|
370 |
+
# parser.add_argument("--rec_noise", action='store_true', help='')
|
371 |
+
# parser.add_argument("--rec_mask", action='store_true', help='')
|
372 |
+
|
373 |
+
# parser.add_argument("--mel_classification", action='store_true', help='')
|
374 |
+
# parser.add_argument("--test_script", action='store_true', help='')
|
375 |
+
|
376 |
+
# parser.add_argument("--no_klt", action='store_true', help='')
|
377 |
+
|
378 |
+
# parser.add_argument("--zt_prior_ce_r_vq", action='store_true', help='')
|
379 |
+
# parser.add_argument('--zt_prior_ce_r_vq_factor', type=float, default=1000, help='factor')
|
380 |
+
|
381 |
+
# parser.add_argument("--zt_post_ce_r_vq", action='store_true', help='')
|
382 |
+
|
383 |
+
|
384 |
+
# parser.add_argument("--zt_prior_ce_kmeans", action='store_true', help='')
|
385 |
+
# parser.add_argument('--zt_prior_ce_kmeans_factor', type=float, default=1000, help='factor')
|
386 |
+
|
387 |
+
# parser.add_argument("--zt_post_ce_kmeans", action='store_true', help='')
|
388 |
+
# parser.add_argument('--zt_post_ce_kmeans_factor', type=float, default=10, help='factor')
|
389 |
+
|
390 |
+
|
391 |
+
# parser.add_argument("--zt_prior_ce_alignment", action='store_true', help='')
|
392 |
+
# parser.add_argument('--zt_prior_ce_alignment_factor', type=float, default=1000, help='factor')
|
393 |
+
|
394 |
+
# parser.add_argument("--prior_type", type=str, default='None', help='normal, condition, lm')
|
395 |
+
# parser.add_argument("--prior_embedding", type=str, default='one-hot', help='one-hot, embedding')
|
396 |
+
# parser.add_argument("--prior_mask", action='store_true', help='')
|
397 |
+
|
398 |
+
# parser.add_argument("--wavlm", action='store_true', help='')
|
399 |
+
# parser.add_argument("--wavlm_type", type=str, default='base', help='')
|
400 |
+
|
401 |
+
|
402 |
+
# parser.add_argument("--tts_phn_wav_path", type=str, default='', help='')
|
403 |
+
|
404 |
+
# parser.add_argument("--sr", type=str, default="16000", help='')
|
405 |
+
|
406 |
+
# parser.add_argument("--text", type=str, default="your tts", help='')
|
407 |
+
# parser.add_argument("--tts_align", action='store_true', help='')
|
408 |
+
# parser.add_argument("--tts_wavlm", action='store_true', help='')
|
409 |
+
# parser.add_argument("--tts", action='store_true', help='')
|
410 |
+
# parser.add_argument("--tts_config", type=str, default="conf/LibriTTS/preprocess.yaml", help='')
|
411 |
+
# parser.add_argument("--tts_target_wav_path", type=str, default='', help='')
|
412 |
+
# parser.add_argument("--speed", type=float, default='1.0', help='')
|
413 |
+
|
414 |
+
# parser.add_argument("--train_mapping", action='store_true', help='')
|
415 |
+
# parser.add_argument("--mapping_encoder", type=str, default='lstm', help='')
|
416 |
+
# parser.add_argument("--mapping_model_path", type=str, default='lstm', help='')
|
417 |
+
# parser.add_argument("--mask_mapping", action='store_true', help='')
|
418 |
+
# parser.add_argument("--mask_mapping_factor", type=float, default=1, help='')
|
419 |
+
# parser.add_argument("--l1_mapping_factor", type=float, default=1, help='')
|
420 |
+
# parser.add_argument("--mapping_ratio", type=float, default=1.0, help='')
|
421 |
+
|
422 |
+
# parser.add_argument("--condition2", action='store_true', help='')
|
423 |
+
|
424 |
+
# args = parser.parse_args()
|
425 |
+
# return update_args(**vars(args))
|
426 |
+
|
427 |
+
|
428 |
+
# if __name__ == "__main__":
|
429 |
+
# # torch.cuda.empty_cache()
|
430 |
+
# main()
|
audio_diffusion_attacks_forhf/src/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
audio_diffusion_attacks_forhf/src/__pycache__/losses.cpython-310.pyc
ADDED
Binary file (10.2 kB). View file
|
|
audio_diffusion_attacks_forhf/src/__pycache__/music_gen.cpython-310.pyc
ADDED
Binary file (3.49 kB). View file
|
|
audio_diffusion_attacks_forhf/src/__pycache__/test_encoder_attack.cpython-310.pyc
ADDED
Binary file (5.47 kB). View file
|
|
audio_diffusion_attacks_forhf/src/balancer.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import typing as tp
|
8 |
+
|
9 |
+
import flashy
|
10 |
+
import torch
|
11 |
+
from torch import autograd
|
12 |
+
|
13 |
+
|
14 |
+
class Balancer:
|
15 |
+
"""Loss balancer.
|
16 |
+
|
17 |
+
The loss balancer combines losses together to compute gradients for the backward.
|
18 |
+
Given `y = f(...)`, and a number of losses `l1(y, ...)`, `l2(y, ...)`, with `...`
|
19 |
+
not having any dependence on `f`, the balancer can efficiently normalize the partial gradients
|
20 |
+
`d l1 / d y`, `d l2 / dy` before summing them in order to achieve a desired ratio between
|
21 |
+
the losses. For instance if `weights = {'l1': 2, 'l2': 1}`, 66% of the gradient
|
22 |
+
going into `f(...)` will come from `l1` on average, and 33% from `l2`. This allows for an easy
|
23 |
+
interpration of the weights even if the intrisic scale of `l1`, `l2` ... is unknown.
|
24 |
+
|
25 |
+
Noting `g1 = d l1 / dy`, etc., the balanced gradient `G` will be
|
26 |
+
(with `avg` an exponential moving average over the updates),
|
27 |
+
|
28 |
+
G = sum_i total_norm * g_i / avg(||g_i||) * w_i / sum(w_i)
|
29 |
+
|
30 |
+
If `balance_grads` is False, this is deactivated, and instead the gradient will just be the
|
31 |
+
standard sum of the partial gradients with the given weights.
|
32 |
+
|
33 |
+
A call to the backward method of the balancer will compute the the partial gradients,
|
34 |
+
combining all the losses and potentially rescaling the gradients,
|
35 |
+
which can help stabilize the training and reason about multiple losses with varying scales.
|
36 |
+
The obtained gradient with respect to `y` is then back-propagated to `f(...)`.
|
37 |
+
|
38 |
+
Expected usage:
|
39 |
+
|
40 |
+
weights = {'loss_a': 1, 'loss_b': 4}
|
41 |
+
balancer = Balancer(weights, ...)
|
42 |
+
losses: dict = {}
|
43 |
+
losses['loss_a'] = compute_loss_a(x, y)
|
44 |
+
losses['loss_b'] = compute_loss_b(x, y)
|
45 |
+
if model.training():
|
46 |
+
effective_loss = balancer.backward(losses, x)
|
47 |
+
|
48 |
+
Args:
|
49 |
+
weights (dict[str, float]): Weight coefficient for each loss. The balancer expect the losses keys
|
50 |
+
from the backward method to match the weights keys to assign weight to each of the provided loss.
|
51 |
+
balance_grads (bool): Whether to rescale gradients so that weights reflect the fraction of the
|
52 |
+
overall gradient, rather than a constant multiplier.
|
53 |
+
total_norm (float): Reference norm when rescaling gradients, ignored otherwise.
|
54 |
+
emay_decay (float): EMA decay for averaging the norms.
|
55 |
+
per_batch_item (bool): Whether to compute the averaged norm per batch item or not. This only holds
|
56 |
+
when rescaling the gradients.
|
57 |
+
epsilon (float): Epsilon value for numerical stability.
|
58 |
+
monitor (bool): If True, stores in `self.metrics` the relative ratio between the norm of the gradients
|
59 |
+
coming from each loss, when calling `backward()`.
|
60 |
+
"""
|
61 |
+
def __init__(self, weights: tp.Dict[str, float], balance_grads: bool = True, total_norm: float = 1.,
|
62 |
+
ema_decay: float = 0.999, per_batch_item: bool = True, epsilon: float = 1e-12,
|
63 |
+
monitor: bool = False):
|
64 |
+
self.weights = weights
|
65 |
+
self.per_batch_item = per_batch_item
|
66 |
+
self.total_norm = total_norm or 1.
|
67 |
+
self.averager = flashy.averager(ema_decay or 1.)
|
68 |
+
self.epsilon = epsilon
|
69 |
+
self.monitor = monitor
|
70 |
+
self.balance_grads = balance_grads
|
71 |
+
self._metrics: tp.Dict[str, tp.Any] = {}
|
72 |
+
|
73 |
+
@property
|
74 |
+
def metrics(self):
|
75 |
+
return self._metrics
|
76 |
+
|
77 |
+
def backward(self, losses: tp.Dict[str, torch.Tensor], input: torch.Tensor) -> torch.Tensor:
|
78 |
+
"""Compute the backward and return the effective train loss, e.g. the loss obtained from
|
79 |
+
computing the effective weights. If `balance_grads` is True, the effective weights
|
80 |
+
are the one that needs to be applied to each gradient to respect the desired relative
|
81 |
+
scale of gradients coming from each loss.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
losses (Dict[str, torch.Tensor]): dictionary with the same keys as `self.weights`.
|
85 |
+
input (torch.Tensor): the input of the losses, typically the output of the model.
|
86 |
+
This should be the single point of dependence between the losses
|
87 |
+
and the model being trained.
|
88 |
+
"""
|
89 |
+
norms = {}
|
90 |
+
grads = {}
|
91 |
+
for name, loss in losses.items():
|
92 |
+
# Compute partial derivative of the less with respect to the input.
|
93 |
+
grad, = autograd.grad(loss, [input], retain_graph=True)
|
94 |
+
if self.per_batch_item:
|
95 |
+
# We do not average the gradient over the batch dimension.
|
96 |
+
dims = tuple(range(1, grad.dim()))
|
97 |
+
norm = grad.norm(dim=dims, p=2).mean()
|
98 |
+
else:
|
99 |
+
norm = grad.norm(p=2)
|
100 |
+
norms[name] = norm
|
101 |
+
grads[name] = grad
|
102 |
+
|
103 |
+
count = 1
|
104 |
+
if self.per_batch_item:
|
105 |
+
count = len(grad)
|
106 |
+
# Average norms across workers. Theoretically we should average the
|
107 |
+
# squared norm, then take the sqrt, but it worked fine like that.
|
108 |
+
avg_norms = flashy.distrib.average_metrics(self.averager(norms), count)
|
109 |
+
# We approximate the total norm of the gradient as the sums of the norms.
|
110 |
+
# Obviously this can be very incorrect if all gradients are aligned, but it works fine.
|
111 |
+
total = sum(avg_norms.values())
|
112 |
+
|
113 |
+
self._metrics = {}
|
114 |
+
if self.monitor:
|
115 |
+
# Store the ratio of the total gradient represented by each loss.
|
116 |
+
for k, v in avg_norms.items():
|
117 |
+
self._metrics[f'ratio_{k}'] = v / total
|
118 |
+
|
119 |
+
total_weights = sum([self.weights[k] for k in avg_norms])
|
120 |
+
assert total_weights > 0.
|
121 |
+
desired_ratios = {k: w / total_weights for k, w in self.weights.items()}
|
122 |
+
|
123 |
+
out_grad = torch.zeros_like(input)
|
124 |
+
effective_loss = torch.tensor(0., device=input.device, dtype=input.dtype)
|
125 |
+
for name, avg_norm in avg_norms.items():
|
126 |
+
if self.balance_grads:
|
127 |
+
# g_balanced = g / avg(||g||) * total_norm * desired_ratio
|
128 |
+
scale = desired_ratios[name] * self.total_norm / (self.epsilon + avg_norm)
|
129 |
+
else:
|
130 |
+
# We just do regular weighted sum of the gradients.
|
131 |
+
scale = self.weights[name]
|
132 |
+
out_grad.add_(grads[name], alpha=scale)
|
133 |
+
effective_loss += scale * losses[name].detach()
|
134 |
+
# Send the computed partial derivative with respect to the output of the model to the model.
|
135 |
+
input.backward(out_grad)
|
136 |
+
return effective_loss
|
137 |
+
|
audio_diffusion_attacks_forhf/src/losses.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/loss.py
|
2 |
+
|
3 |
+
import typing
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from audiotools import AudioSignal
|
9 |
+
from audiotools import STFTParams
|
10 |
+
from torch import nn
|
11 |
+
|
12 |
+
|
13 |
+
class L1Loss(nn.L1Loss):
|
14 |
+
"""L1 Loss between AudioSignals. Defaults
|
15 |
+
to comparing ``audio_data``, but any
|
16 |
+
attribute of an AudioSignal can be used.
|
17 |
+
|
18 |
+
Parameters
|
19 |
+
----------
|
20 |
+
attribute : str, optional
|
21 |
+
Attribute of signal to compare, defaults to ``audio_data``.
|
22 |
+
weight : float, optional
|
23 |
+
Weight of this loss, defaults to 1.0.
|
24 |
+
|
25 |
+
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
|
29 |
+
self.attribute = attribute
|
30 |
+
self.weight = weight
|
31 |
+
super().__init__(**kwargs)
|
32 |
+
|
33 |
+
def forward(self, x: AudioSignal, y: AudioSignal):
|
34 |
+
"""
|
35 |
+
Parameters
|
36 |
+
----------
|
37 |
+
x : AudioSignal
|
38 |
+
Estimate AudioSignal
|
39 |
+
y : AudioSignal
|
40 |
+
Reference AudioSignal
|
41 |
+
|
42 |
+
Returns
|
43 |
+
-------
|
44 |
+
torch.Tensor
|
45 |
+
L1 loss between AudioSignal attributes.
|
46 |
+
"""
|
47 |
+
if isinstance(x, AudioSignal):
|
48 |
+
x = getattr(x, self.attribute)
|
49 |
+
y = getattr(y, self.attribute)
|
50 |
+
return super().forward(x, y)
|
51 |
+
|
52 |
+
|
53 |
+
class SISDRLoss(nn.Module):
|
54 |
+
"""
|
55 |
+
Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
|
56 |
+
of estimated and reference audio signals or aligned features.
|
57 |
+
|
58 |
+
Parameters
|
59 |
+
----------
|
60 |
+
scaling : int, optional
|
61 |
+
Whether to use scale-invariant (True) or
|
62 |
+
signal-to-noise ratio (False), by default True
|
63 |
+
reduction : str, optional
|
64 |
+
How to reduce across the batch (either 'mean',
|
65 |
+
'sum', or none).], by default ' mean'
|
66 |
+
zero_mean : int, optional
|
67 |
+
Zero mean the references and estimates before
|
68 |
+
computing the loss, by default True
|
69 |
+
clip_min : int, optional
|
70 |
+
The minimum possible loss value. Helps network
|
71 |
+
to not focus on making already good examples better, by default None
|
72 |
+
weight : float, optional
|
73 |
+
Weight of this loss, defaults to 1.0.
|
74 |
+
|
75 |
+
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
|
76 |
+
"""
|
77 |
+
|
78 |
+
def __init__(
|
79 |
+
self,
|
80 |
+
scaling: int = True,
|
81 |
+
reduction: str = "mean",
|
82 |
+
zero_mean: int = True,
|
83 |
+
clip_min: int = None,
|
84 |
+
weight: float = 1.0,
|
85 |
+
):
|
86 |
+
self.scaling = scaling
|
87 |
+
self.reduction = reduction
|
88 |
+
self.zero_mean = zero_mean
|
89 |
+
self.clip_min = clip_min
|
90 |
+
self.weight = weight
|
91 |
+
super().__init__()
|
92 |
+
|
93 |
+
def forward(self, x: AudioSignal, y: AudioSignal):
|
94 |
+
eps = 1e-8
|
95 |
+
# nb, nc, nt
|
96 |
+
if isinstance(x, AudioSignal):
|
97 |
+
references = x.audio_data
|
98 |
+
estimates = y.audio_data
|
99 |
+
else:
|
100 |
+
references = x
|
101 |
+
estimates = y
|
102 |
+
|
103 |
+
nb = references.shape[0]
|
104 |
+
references = references.reshape(nb, 1, -1).permute(0, 2, 1)
|
105 |
+
estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
|
106 |
+
|
107 |
+
# samples now on axis 1
|
108 |
+
if self.zero_mean:
|
109 |
+
mean_reference = references.mean(dim=1, keepdim=True)
|
110 |
+
mean_estimate = estimates.mean(dim=1, keepdim=True)
|
111 |
+
else:
|
112 |
+
mean_reference = 0
|
113 |
+
mean_estimate = 0
|
114 |
+
|
115 |
+
_references = references - mean_reference
|
116 |
+
_estimates = estimates - mean_estimate
|
117 |
+
|
118 |
+
references_projection = (_references**2).sum(dim=-2) + eps
|
119 |
+
references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
|
120 |
+
|
121 |
+
scale = (
|
122 |
+
(references_on_estimates / references_projection).unsqueeze(1)
|
123 |
+
if self.scaling
|
124 |
+
else 1
|
125 |
+
)
|
126 |
+
|
127 |
+
e_true = scale * _references
|
128 |
+
e_res = _estimates - e_true
|
129 |
+
|
130 |
+
signal = (e_true**2).sum(dim=1)
|
131 |
+
noise = (e_res**2).sum(dim=1)
|
132 |
+
sdr = -10 * torch.log10(signal / noise + eps)
|
133 |
+
|
134 |
+
if self.clip_min is not None:
|
135 |
+
sdr = torch.clamp(sdr, min=self.clip_min)
|
136 |
+
|
137 |
+
if self.reduction == "mean":
|
138 |
+
sdr = sdr.mean()
|
139 |
+
elif self.reduction == "sum":
|
140 |
+
sdr = sdr.sum()
|
141 |
+
return sdr
|
142 |
+
|
143 |
+
|
144 |
+
class MultiScaleSTFTLoss(nn.Module):
|
145 |
+
"""Computes the multi-scale STFT loss from [1].
|
146 |
+
|
147 |
+
Parameters
|
148 |
+
----------
|
149 |
+
window_lengths : List[int], optional
|
150 |
+
Length of each window of each STFT, by default [2048, 512]
|
151 |
+
loss_fn : typing.Callable, optional
|
152 |
+
How to compare each loss, by default nn.L1Loss()
|
153 |
+
clamp_eps : float, optional
|
154 |
+
Clamp on the log magnitude, below, by default 1e-5
|
155 |
+
mag_weight : float, optional
|
156 |
+
Weight of raw magnitude portion of loss, by default 1.0
|
157 |
+
log_weight : float, optional
|
158 |
+
Weight of log magnitude portion of loss, by default 1.0
|
159 |
+
pow : float, optional
|
160 |
+
Power to raise magnitude to before taking log, by default 2.0
|
161 |
+
weight : float, optional
|
162 |
+
Weight of this loss, by default 1.0
|
163 |
+
match_stride : bool, optional
|
164 |
+
Whether to match the stride of convolutional layers, by default False
|
165 |
+
|
166 |
+
References
|
167 |
+
----------
|
168 |
+
|
169 |
+
1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
|
170 |
+
"DDSP: Differentiable Digital Signal Processing."
|
171 |
+
International Conference on Learning Representations. 2019.
|
172 |
+
|
173 |
+
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
|
174 |
+
"""
|
175 |
+
|
176 |
+
def __init__(
|
177 |
+
self,
|
178 |
+
window_lengths: List[int] = [2048, 512],
|
179 |
+
loss_fn: typing.Callable = nn.L1Loss(),
|
180 |
+
clamp_eps: float = 1e-5,
|
181 |
+
mag_weight: float = 1.0,
|
182 |
+
log_weight: float = 1.0,
|
183 |
+
pow: float = 2.0,
|
184 |
+
weight: float = 1.0,
|
185 |
+
match_stride: bool = False,
|
186 |
+
window_type: str = None,
|
187 |
+
):
|
188 |
+
super().__init__()
|
189 |
+
self.stft_params = [
|
190 |
+
STFTParams(
|
191 |
+
window_length=w,
|
192 |
+
hop_length=w // 4,
|
193 |
+
match_stride=match_stride,
|
194 |
+
window_type=window_type,
|
195 |
+
)
|
196 |
+
for w in window_lengths
|
197 |
+
]
|
198 |
+
self.loss_fn = loss_fn
|
199 |
+
self.log_weight = log_weight
|
200 |
+
self.mag_weight = mag_weight
|
201 |
+
self.clamp_eps = clamp_eps
|
202 |
+
self.weight = weight
|
203 |
+
self.pow = pow
|
204 |
+
|
205 |
+
def forward(self, x: AudioSignal, y: AudioSignal):
|
206 |
+
"""Computes multi-scale STFT between an estimate and a reference
|
207 |
+
signal.
|
208 |
+
|
209 |
+
Parameters
|
210 |
+
----------
|
211 |
+
x : AudioSignal
|
212 |
+
Estimate signal
|
213 |
+
y : AudioSignal
|
214 |
+
Reference signal
|
215 |
+
|
216 |
+
Returns
|
217 |
+
-------
|
218 |
+
torch.Tensor
|
219 |
+
Multi-scale STFT loss.
|
220 |
+
"""
|
221 |
+
loss = 0.0
|
222 |
+
for s in self.stft_params:
|
223 |
+
x.stft(s.window_length, s.hop_length, s.window_type)
|
224 |
+
y.stft(s.window_length, s.hop_length, s.window_type)
|
225 |
+
loss += self.log_weight * self.loss_fn(
|
226 |
+
x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
|
227 |
+
y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
|
228 |
+
)
|
229 |
+
loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
|
230 |
+
return loss
|
231 |
+
|
232 |
+
|
233 |
+
class MelSpectrogramLoss(nn.Module):
|
234 |
+
"""Compute distance between mel spectrograms. Can be used
|
235 |
+
in a multi-scale way.
|
236 |
+
|
237 |
+
Parameters
|
238 |
+
----------
|
239 |
+
n_mels : List[int]
|
240 |
+
Number of mels per STFT, by default [150, 80],
|
241 |
+
window_lengths : List[int], optional
|
242 |
+
Length of each window of each STFT, by default [2048, 512]
|
243 |
+
loss_fn : typing.Callable, optional
|
244 |
+
How to compare each loss, by default nn.L1Loss()
|
245 |
+
clamp_eps : float, optional
|
246 |
+
Clamp on the log magnitude, below, by default 1e-5
|
247 |
+
mag_weight : float, optional
|
248 |
+
Weight of raw magnitude portion of loss, by default 1.0
|
249 |
+
log_weight : float, optional
|
250 |
+
Weight of log magnitude portion of loss, by default 1.0
|
251 |
+
pow : float, optional
|
252 |
+
Power to raise magnitude to before taking log, by default 2.0
|
253 |
+
weight : float, optional
|
254 |
+
Weight of this loss, by default 1.0
|
255 |
+
match_stride : bool, optional
|
256 |
+
Whether to match the stride of convolutional layers, by default False
|
257 |
+
|
258 |
+
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
|
259 |
+
"""
|
260 |
+
|
261 |
+
def __init__(
|
262 |
+
self,
|
263 |
+
n_mels: List[int] = [150, 80],
|
264 |
+
window_lengths: List[int] = [2048, 512],
|
265 |
+
loss_fn: typing.Callable = nn.MSELoss(),
|
266 |
+
clamp_eps: float = 1e-5,
|
267 |
+
mag_weight: float = 1.0,
|
268 |
+
log_weight: float = 1.0,
|
269 |
+
pow: float = 2.0,
|
270 |
+
weight: float = 1.0,
|
271 |
+
match_stride: bool = False,
|
272 |
+
mel_fmin: List[float] = [0.0, 0.0],
|
273 |
+
mel_fmax: List[float] = [None, None],
|
274 |
+
window_type: str = None,
|
275 |
+
):
|
276 |
+
super().__init__()
|
277 |
+
self.stft_params = [
|
278 |
+
STFTParams(
|
279 |
+
window_length=w,
|
280 |
+
hop_length=w // 4,
|
281 |
+
match_stride=match_stride,
|
282 |
+
window_type=window_type,
|
283 |
+
)
|
284 |
+
for w in window_lengths
|
285 |
+
]
|
286 |
+
self.n_mels = n_mels
|
287 |
+
self.loss_fn = loss_fn
|
288 |
+
self.clamp_eps = clamp_eps
|
289 |
+
self.log_weight = log_weight
|
290 |
+
self.mag_weight = mag_weight
|
291 |
+
self.weight = weight
|
292 |
+
self.mel_fmin = mel_fmin
|
293 |
+
self.mel_fmax = mel_fmax
|
294 |
+
self.pow = pow
|
295 |
+
|
296 |
+
def forward(self, x: AudioSignal, y: AudioSignal):
|
297 |
+
"""Computes mel loss between an estimate and a reference
|
298 |
+
signal.
|
299 |
+
|
300 |
+
Parameters
|
301 |
+
----------
|
302 |
+
x : AudioSignal
|
303 |
+
Estimate signal
|
304 |
+
y : AudioSignal
|
305 |
+
Reference signal
|
306 |
+
|
307 |
+
Returns
|
308 |
+
-------
|
309 |
+
torch.Tensor
|
310 |
+
Mel loss.
|
311 |
+
"""
|
312 |
+
loss = 0.0
|
313 |
+
for n_mels, fmin, fmax, s in zip(
|
314 |
+
self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
|
315 |
+
):
|
316 |
+
kwargs = {
|
317 |
+
"window_length": s.window_length,
|
318 |
+
"hop_length": s.hop_length,
|
319 |
+
"window_type": s.window_type,
|
320 |
+
}
|
321 |
+
x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
|
322 |
+
y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
|
323 |
+
|
324 |
+
loss += self.log_weight * self.loss_fn(
|
325 |
+
x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
|
326 |
+
y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
|
327 |
+
)
|
328 |
+
loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
|
329 |
+
return loss
|
audio_diffusion_attacks_forhf/src/music_gen.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
2 |
+
#Andy removed: from datasets import load_dataset
|
3 |
+
import torchaudio
|
4 |
+
import torch
|
5 |
+
#Andy edited: import losses
|
6 |
+
import audio_diffusion_attacks_forhf.src.losses
|
7 |
+
from audiotools import AudioSignal
|
8 |
+
|
9 |
+
class MusicGenEval:
|
10 |
+
|
11 |
+
def __init__(self, input_sample_rate, audio_steps):
|
12 |
+
model_name="facebook/musicgen-stereo-small"
|
13 |
+
self.processor = AutoProcessor.from_pretrained(model_name)
|
14 |
+
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
|
15 |
+
self.model=self.model.to(device='cuda')
|
16 |
+
self.input_sample_rate=input_sample_rate
|
17 |
+
self.audio_steps=audio_steps
|
18 |
+
self.mel_loss = losses.MelSpectrogramLoss(n_mels=[5, 10, 20, 40, 80, 160, 320],
|
19 |
+
window_lengths=[32, 64, 128, 256, 512, 1024, 2048],
|
20 |
+
mel_fmin=[0, 0, 0, 0, 0, 0, 0],
|
21 |
+
pow=1.0,
|
22 |
+
clamp_eps=1.0e-5,
|
23 |
+
mag_weight=0.0)
|
24 |
+
|
25 |
+
def eval(self, original_audio, protected_audio):
|
26 |
+
original_audio=original_audio[:, :, :self.audio_steps]
|
27 |
+
protected_audio=protected_audio[:, :, :self.audio_steps]
|
28 |
+
input_len=original_audio.shape[-1]
|
29 |
+
|
30 |
+
unprotected_gen=self.generate_audio(original_audio)[0].to(device='cuda')
|
31 |
+
protected_gen=self.generate_audio(protected_audio)[0].to(device='cuda')
|
32 |
+
|
33 |
+
eval_dict={}
|
34 |
+
# Difference between original and unprotected gen
|
35 |
+
eval_dict["original_unprotectedgen_l1"]=torch.mean(torch.abs(original_audio-unprotected_gen[:, :input_len]))
|
36 |
+
eval_dict["original_unprotectedgen_mel"]=self.mel_loss(AudioSignal(original_audio, self.input_sample_rate), AudioSignal(unprotected_gen[:, :input_len], self.input_sample_rate))
|
37 |
+
# Difference between original and protected gen
|
38 |
+
eval_dict["original_protectedgen_l1"]=torch.mean(torch.abs(original_audio-protected_gen[:, :input_len]))
|
39 |
+
eval_dict["original_protectedgen_mel"]=self.mel_loss(AudioSignal(original_audio, self.input_sample_rate), AudioSignal(protected_gen[:, :input_len], self.input_sample_rate))
|
40 |
+
# Difference between protected and protected gen
|
41 |
+
eval_dict["protected_protectedgen_l1"]=torch.mean(torch.abs(protected_audio-protected_gen[:, :input_len]))
|
42 |
+
eval_dict["protected_protectedgen_mel"]=self.mel_loss(AudioSignal(protected_audio, self.input_sample_rate), AudioSignal(protected_gen[:, :input_len], self.input_sample_rate))
|
43 |
+
# Difference between unprotected gen and protected gen
|
44 |
+
eval_dict["protectedgen_unprotectedgen_l1"]=torch.mean(torch.abs(protected_gen-unprotected_gen))
|
45 |
+
eval_dict["protectedgen_unprotectedgen_mel"]=self.mel_loss(AudioSignal(protected_gen, self.input_sample_rate), AudioSignal(unprotected_gen, self.input_sample_rate))
|
46 |
+
return eval_dict, unprotected_gen, protected_gen
|
47 |
+
|
48 |
+
def generate_audio(self, audio):
|
49 |
+
torch.manual_seed(0)
|
50 |
+
|
51 |
+
transform = torchaudio.transforms.Resample(self.input_sample_rate, 32000).to(device='cuda')
|
52 |
+
waveform=transform(audio[0]).detach().cpu()
|
53 |
+
# waveform.clamp_(0,1)
|
54 |
+
a=torch.min(waveform)
|
55 |
+
b=torch.max(waveform)
|
56 |
+
c=waveform.isnan().any()
|
57 |
+
# sample = processor(raw_audio=waveform, sampling_rate=48000, return_tensors="pt")
|
58 |
+
|
59 |
+
inputs = self.processor(
|
60 |
+
audio=waveform,
|
61 |
+
sampling_rate=32000,
|
62 |
+
text=["music"],
|
63 |
+
padding=True,
|
64 |
+
return_tensors="pt",
|
65 |
+
)
|
66 |
+
for d in inputs.data:
|
67 |
+
inputs.data[d]=inputs.data[d].to(device='cuda')
|
68 |
+
audio_values = self.model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=1024)
|
69 |
+
|
70 |
+
transform = torchaudio.transforms.Resample(32000, self.input_sample_rate).to(device='cuda')
|
71 |
+
audio_values=transform(audio_values)
|
72 |
+
return audio_values
|
73 |
+
|
74 |
+
model_name="facebook/musicgen-stereo-small"
|
75 |
+
processor = AutoProcessor.from_pretrained(model_name)
|
76 |
+
model = MusicgenForConditionalGeneration.from_pretrained(model_name).to(device='cuda')
|
77 |
+
|
78 |
+
'''Andy commented:
|
79 |
+
song_name="Texas Sun"
|
80 |
+
waveform, sample_rate = torchaudio.load(f"test_audio/{song_name}.mp3")
|
81 |
+
waveform=waveform[:, :500000]
|
82 |
+
torch.manual_seed(0)
|
83 |
+
transform = torchaudio.transforms.Resample(sample_rate, 32000)
|
84 |
+
waveform=transform(waveform)
|
85 |
+
# sample = processor(raw_audio=waveform, sampling_rate=48000, return_tensors="pt")
|
86 |
+
|
87 |
+
inputs = processor(
|
88 |
+
audio=waveform,
|
89 |
+
sampling_rate=32000,
|
90 |
+
text=["music"],
|
91 |
+
padding=True,
|
92 |
+
return_tensors="pt",
|
93 |
+
)
|
94 |
+
for d in inputs.data:
|
95 |
+
inputs.data[d]=inputs.data[d].to(device='cuda')
|
96 |
+
audio_values = model.generate(**inputs, do_sample=False, guidance_scale=3, max_new_tokens=512, top_k=0, top_p=250)
|
97 |
+
torchaudio.save(f"test_audio/perturbed/{model_name[9:]}_{song_name}.mp3", audio_values.detach().cpu()[0], 32000)
|
98 |
+
|
99 |
+
u=0
|
100 |
+
'''
|
audio_diffusion_attacks_forhf/src/speech_inference.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from TTS.api import TTS
|
3 |
+
#Andy edited: import losses
|
4 |
+
import audio_diffusion_attacks_forhf.src.losses
|
5 |
+
from audiotools import AudioSignal
|
6 |
+
import numpy as np
|
7 |
+
import torchaudio
|
8 |
+
import random
|
9 |
+
import string
|
10 |
+
import os
|
11 |
+
|
12 |
+
class XTTS_Eval:
|
13 |
+
|
14 |
+
def __init__(self, input_sample_rate, text="The quick brown fox jumps over the lazy dog."):
|
15 |
+
self.model = TTS("tts_models/multilingual/multi-dataset/xtts_v2")
|
16 |
+
self.model=self.model.to(device='cuda')
|
17 |
+
self.text=text
|
18 |
+
self.input_sample_rate=input_sample_rate
|
19 |
+
self.mel_loss = losses.MelSpectrogramLoss(n_mels=[5, 10, 20, 40, 80, 160, 320],
|
20 |
+
window_lengths=[32, 64, 128, 256, 512, 1024, 2048],
|
21 |
+
mel_fmin=[0, 0, 0, 0, 0, 0, 0],
|
22 |
+
pow=1.0,
|
23 |
+
clamp_eps=1.0e-5,
|
24 |
+
mag_weight=0.0)
|
25 |
+
|
26 |
+
def eval(self, original_audio, protected_audio):
|
27 |
+
|
28 |
+
original_audio=original_audio[0]
|
29 |
+
protected_audio=protected_audio[0]
|
30 |
+
|
31 |
+
unprotected_gen=self.generate_audio(original_audio).to(device='cuda')
|
32 |
+
protected_gen=self.generate_audio(protected_audio).to(device='cuda')
|
33 |
+
|
34 |
+
match_len=min(original_audio.shape[1], unprotected_gen.shape[1])
|
35 |
+
if original_audio.shape[1]<unprotected_gen.shape[1]:
|
36 |
+
s_unprotected_gen=unprotected_gen[:, :match_len]
|
37 |
+
s_protected_gen=unprotected_gen[:, :match_len]
|
38 |
+
s_original_audio=original_audio
|
39 |
+
s_protected_audio=protected_audio
|
40 |
+
else:
|
41 |
+
s_unprotected_gen=unprotected_gen
|
42 |
+
s_protected_gen=unprotected_gen
|
43 |
+
s_original_audio=original_audio[:, :match_len]
|
44 |
+
s_protected_audio=protected_audio[:, :match_len]
|
45 |
+
|
46 |
+
match_len=min(protected_gen.shape[1], unprotected_gen.shape[1])
|
47 |
+
protected_gen=protected_gen[:,:match_len]
|
48 |
+
unprotected_gen=unprotected_gen[:,:match_len]
|
49 |
+
|
50 |
+
|
51 |
+
eval_dict={}
|
52 |
+
# Difference between original and unprotected gen
|
53 |
+
eval_dict["original_unprotectedgen_l1"]=torch.mean(torch.abs(s_original_audio-s_unprotected_gen))
|
54 |
+
eval_dict["original_unprotectedgen_mel"]=self.mel_loss(AudioSignal(s_original_audio, self.input_sample_rate), AudioSignal(s_unprotected_gen, self.input_sample_rate))
|
55 |
+
# Difference between original and protected gen
|
56 |
+
eval_dict["original_protectedgen_l1"]=torch.mean(torch.abs(s_original_audio-s_protected_gen))
|
57 |
+
eval_dict["original_protectedgen_mel"]=self.mel_loss(AudioSignal(s_original_audio, self.input_sample_rate), AudioSignal(s_protected_gen, self.input_sample_rate))
|
58 |
+
# Difference between protected and protected gen
|
59 |
+
eval_dict["protected_protectedgen_l1"]=torch.mean(torch.abs(s_protected_audio-s_protected_gen))
|
60 |
+
eval_dict["protected_protectedgen_mel"]=self.mel_loss(AudioSignal(s_protected_audio, self.input_sample_rate), AudioSignal(s_protected_gen, self.input_sample_rate))
|
61 |
+
# Difference between unprotected gen and protected gen
|
62 |
+
eval_dict["protectedgen_unprotectedgen_l1"]=torch.mean(torch.abs(protected_gen-unprotected_gen))
|
63 |
+
eval_dict["protectedgen_unprotectedgen_mel"]=self.mel_loss(AudioSignal(protected_gen, self.input_sample_rate), AudioSignal(unprotected_gen, self.input_sample_rate))
|
64 |
+
return eval_dict, unprotected_gen, protected_gen
|
65 |
+
|
66 |
+
def generate_audio(self, audio):
|
67 |
+
random_str=''.join(random.choices(string.ascii_uppercase + string.digits, k=50))
|
68 |
+
torchaudio.save(f"test_audio/{random_str}.wav", torch.reshape(audio.detach().cpu(), (2, audio.shape[1])), self.input_sample_rate, format="wav")
|
69 |
+
torch.manual_seed(0)
|
70 |
+
|
71 |
+
wav = self.model.tts(text=self.text,
|
72 |
+
speaker_wav=f"test_audio/{random_str}.wav",
|
73 |
+
language="en")
|
74 |
+
os.remove(f"test_audio/{random_str}.wav")
|
75 |
+
wav=torch.from_numpy(np.array(wav))
|
76 |
+
stereo_wave=torch.zeros((2, wav.shape[0]))
|
77 |
+
stereo_wave[:,:]=wav
|
78 |
+
|
79 |
+
transform = torchaudio.transforms.Resample(24000, self.input_sample_rate)
|
80 |
+
stereo_wave=transform(stereo_wave)
|
81 |
+
return stereo_wave
|
82 |
+
|
83 |
+
# # Init TTS
|
84 |
+
# tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)
|
85 |
+
#
|
86 |
+
# # Run TTS
|
87 |
+
# # ❗ Since this model is multi-lingual voice cloning model, we must set the target speaker_wav and language
|
88 |
+
# # Text to speech list of amplitude values as output
|
89 |
+
# # wav = tts.tts(text="Hello world!", speaker_wav=, language="en")
|
90 |
+
# # Text to speech to a file
|
91 |
+
# tts.tts_to_file(text="Hello world!",
|
92 |
+
# speaker_wav="/media/willie/1caf5422-4135-4f2c-9619-c44041b51146/audio_data/DS_10283_3443/VCTK-Corpus-0.92/wav48_silence_trimmed/p227/p227_023_mic1.flac",
|
93 |
+
# language="en",
|
94 |
+
# file_path="/home/willie/eclipse-workspace/audio_diffusion_attacks/src/test_audio/speech/output.wav")
|
audio_diffusion_attacks_forhf/src/test_audio/.Il Sogno Del Marinaio - Nanos' Waltz.mp3.icloud
ADDED
Binary file (192 Bytes). View file
|
|