ALeLacheur commited on
Commit
5a9b731
·
1 Parent(s): 98a3a53

uploading audio diffusion attacks

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. audio_diffusion_attacks +0 -1
  3. audio_diffusion_attacks_forhf/.DS_Store +0 -0
  4. audio_diffusion_attacks_forhf/README.md +37 -0
  5. audio_diffusion_attacks_forhf/assets/.DS_Store +0 -0
  6. audio_diffusion_attacks_forhf/assets/audios/.DS_Store +0 -0
  7. audio_diffusion_attacks_forhf/assets/audios/hyperpop.wav +0 -0
  8. audio_diffusion_attacks_forhf/assets/example_MAS.png +0 -0
  9. audio_diffusion_attacks_forhf/assets/example_duration.png +0 -0
  10. audio_diffusion_attacks_forhf/assets/example_mel.png +0 -0
  11. audio_diffusion_attacks_forhf/assets/example_untrained_phone_encoding.png +0 -0
  12. audio_diffusion_attacks_forhf/assets/gradtts_system.png +0 -0
  13. audio_diffusion_attacks_forhf/audio_ethics.yml +0 -0
  14. audio_diffusion_attacks_forhf/config.yml +2 -0
  15. audio_diffusion_attacks_forhf/gen_audio_ethics_3.10.yml +8 -0
  16. audio_diffusion_attacks_forhf/models/.DS_Store +0 -0
  17. audio_diffusion_attacks_forhf/models/__init__.py +0 -0
  18. audio_diffusion_attacks_forhf/models/__pycache__/__init__.cpython-310.pyc +0 -0
  19. audio_diffusion_attacks_forhf/models/__pycache__/phoneme_encoder.cpython-310.pyc +0 -0
  20. audio_diffusion_attacks_forhf/models/__pycache__/style_diffusion.cpython-310.pyc +0 -0
  21. audio_diffusion_attacks_forhf/models/__pycache__/utils.cpython-310.pyc +0 -0
  22. audio_diffusion_attacks_forhf/models/datasets/__pycache__/music_datasets.cpython-310.pyc +0 -0
  23. audio_diffusion_attacks_forhf/models/datasets/music_datasets.py +65 -0
  24. audio_diffusion_attacks_forhf/models/monotonic_align/.DS_Store +0 -0
  25. audio_diffusion_attacks_forhf/models/monotonic_align/__init__.py +23 -0
  26. audio_diffusion_attacks_forhf/models/monotonic_align/__pycache__/__init__.cpython-310.pyc +0 -0
  27. audio_diffusion_attacks_forhf/models/monotonic_align/build/temp.linux-x86_64-cpython-310/core.o +0 -0
  28. audio_diffusion_attacks_forhf/models/monotonic_align/core.c +0 -0
  29. audio_diffusion_attacks_forhf/models/monotonic_align/core.cpython-310-x86_64-linux-gnu.so +0 -0
  30. audio_diffusion_attacks_forhf/models/monotonic_align/core.pyx +45 -0
  31. audio_diffusion_attacks_forhf/models/monotonic_align/setup.py +11 -0
  32. audio_diffusion_attacks_forhf/models/phoneme_encoder.py +363 -0
  33. audio_diffusion_attacks_forhf/models/style_diffusion.py +111 -0
  34. audio_diffusion_attacks_forhf/models/utils.py +77 -0
  35. audio_diffusion_attacks_forhf/notebooks/data_exploration/00_fma_exploration.ipynb +0 -0
  36. audio_diffusion_attacks_forhf/resources/cmu_dictionary +0 -0
  37. audio_diffusion_attacks_forhf/scripts/.DS_Store +0 -0
  38. audio_diffusion_attacks_forhf/scripts/data_processing/process_music_mels.py +106 -0
  39. audio_diffusion_attacks_forhf/scripts/data_processing/process_music_numpy.py +74 -0
  40. audio_diffusion_attacks_forhf/scripts/train/music_models/train_music_completion.py +243 -0
  41. audio_diffusion_attacks_forhf/scripts/train/train_tts.py +430 -0
  42. audio_diffusion_attacks_forhf/src/.DS_Store +0 -0
  43. audio_diffusion_attacks_forhf/src/__pycache__/losses.cpython-310.pyc +0 -0
  44. audio_diffusion_attacks_forhf/src/__pycache__/music_gen.cpython-310.pyc +0 -0
  45. audio_diffusion_attacks_forhf/src/__pycache__/test_encoder_attack.cpython-310.pyc +0 -0
  46. audio_diffusion_attacks_forhf/src/balancer.py +137 -0
  47. audio_diffusion_attacks_forhf/src/losses.py +329 -0
  48. audio_diffusion_attacks_forhf/src/music_gen.py +100 -0
  49. audio_diffusion_attacks_forhf/src/speech_inference.py +94 -0
  50. audio_diffusion_attacks_forhf/src/test_audio/.Il Sogno Del Marinaio - Nanos' Waltz.mp3.icloud +0 -0
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
audio_diffusion_attacks DELETED
@@ -1 +0,0 @@
1
- Subproject commit 1aaf4563762c407f31436ad452a72dd5af929443
 
 
audio_diffusion_attacks_forhf/.DS_Store ADDED
Binary file (6.15 kB). View file
 
audio_diffusion_attacks_forhf/README.md ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Audio Data Ownership
2
+
3
+ ## Installation
4
+
5
+ conda env create -n audio_ethics --file gen_audio_ethics_3.10.yml
6
+
7
+ To set up wandb, please check out this following link: [https://docs.wandb.ai/quickstart](https://docs.wandb.ai/quickstart)
8
+
9
+ ## Run Encoder Attack
10
+
11
+ cd src
12
+
13
+ python test_encoder_attack.py
14
+
15
+ ## Overview
16
+
17
+ ## Task 1: Audio Completion with Diffusion Models
18
+
19
+ For this task, we use the [Free Music Archive (FMA)](https://github.com/mdeff/fma), which is a collection of royalty-free music. You can use any version of the model you wish, but we'll use the `fma_large` partition for training an initial system.
20
+
21
+ Note: If librosa version is too high, have to edit line in audioldm to be `fft_window = pad_center(fft_window, size=filter_length)`
22
+
23
+ To preprocess FMA, configure the file with your corresponding path and run the correct preprocessing script to convert the `.mp3` files to numpy (Loading in audio files during training is prohibitively slow).
24
+ - Proceprocessing for ArchiSound encoders: `nohup python -u scripts/data_processing/process_music_numpy.py > logs/process_48k_music.out &`
25
+
26
+ ## Task 2: TTS with Diffusion Models
27
+
28
+ TTS with Diffusion (or flow) models is one approach of many that folks have been taking for SOTA TTS performance right now. In this repo, we have a model similar to
29
+ [Grad-TTS](https://grad-tts.github.io/), with the example inference for Grad-TTS below:
30
+
31
+ ![Inference Figure for Grad-TTS](./assets/gradtts_system.png)
32
+
33
+ To run, first you need to build the `monotonic_align` code:
34
+
35
+ `cd model/monotonic_align; python setup.py build_ext --inplace; cd ../..`
36
+
37
+ You possibly might have to move the generated .so file to the `monotonic_align/` directory if it is generated in `montonic_align/build/`.
audio_diffusion_attacks_forhf/assets/.DS_Store ADDED
Binary file (6.15 kB). View file
 
audio_diffusion_attacks_forhf/assets/audios/.DS_Store ADDED
Binary file (6.15 kB). View file
 
audio_diffusion_attacks_forhf/assets/audios/hyperpop.wav ADDED
Binary file (640 kB). View file
 
audio_diffusion_attacks_forhf/assets/example_MAS.png ADDED
audio_diffusion_attacks_forhf/assets/example_duration.png ADDED
audio_diffusion_attacks_forhf/assets/example_mel.png ADDED
audio_diffusion_attacks_forhf/assets/example_untrained_phone_encoding.png ADDED
audio_diffusion_attacks_forhf/assets/gradtts_system.png ADDED
audio_diffusion_attacks_forhf/audio_ethics.yml ADDED
File without changes
audio_diffusion_attacks_forhf/config.yml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ wandb_settings:
2
+ project_name: audio_attacks
audio_diffusion_attacks_forhf/gen_audio_ethics_3.10.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ name: gen_audio_ethics_3.10
2
+ channels:
3
+ - defaults
4
+ dependencies:
5
+ - python=3.10
6
+ - conda-forge::libsndfile
7
+ - librosa
8
+ prefix: /home/willie/anaconda3/envs/gen_audio_ethics_3.10
audio_diffusion_attacks_forhf/models/.DS_Store ADDED
Binary file (6.15 kB). View file
 
audio_diffusion_attacks_forhf/models/__init__.py ADDED
File without changes
audio_diffusion_attacks_forhf/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (173 Bytes). View file
 
audio_diffusion_attacks_forhf/models/__pycache__/phoneme_encoder.cpython-310.pyc ADDED
Binary file (12 kB). View file
 
audio_diffusion_attacks_forhf/models/__pycache__/style_diffusion.cpython-310.pyc ADDED
Binary file (4.38 kB). View file
 
audio_diffusion_attacks_forhf/models/__pycache__/utils.cpython-310.pyc ADDED
Binary file (2.54 kB). View file
 
audio_diffusion_attacks_forhf/models/datasets/__pycache__/music_datasets.cpython-310.pyc ADDED
Binary file (1.89 kB). View file
 
audio_diffusion_attacks_forhf/models/datasets/music_datasets.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ music_datasets.py
3
+ Desc: Contains the code for the music datasets.
4
+ """
5
+
6
+ import torch
7
+ from torch.utils.data import Dataset
8
+ import torchaudio
9
+ import numpy as np
10
+ import pandas as pd
11
+
12
+
13
+ """
14
+ MusicMelDataset:
15
+ Given pre-processed mel-spectrograms, return a chunk of audio from the mel, with a masked version of a defined length
16
+ Args:
17
+ audio_files: List of .npy files consisting of mel-specs
18
+ audio_len: length in seconds (roughly) of audio to be return
19
+ mask_ratio: Size of mask as a ration of audio_len
20
+ mask_start: Where the mask starts for learning
21
+ "midpoint": always mask out the second half of the mel-spec
22
+ crop_start: Where the starting point for the sample of audio is taken
23
+ "random": Random valid starting point from audio is taken
24
+
25
+ """
26
+ class MusicMelDataset(Dataset):
27
+ def __init__(self, audio_files, audio_len = 6, mask_ratio = 0.5, mask_start = "midpoint", crop_start = "random"):
28
+ self.audio_files = audio_files
29
+
30
+ # Convert length to number of frames
31
+ self.audio_len = int(audio_len * 100) # 100 is heuristic conversion made
32
+ self.mask_ratio = mask_ratio
33
+ self.mask_len = int(np.floor(self.audio_len * mask_ratio))
34
+ self.mask_start = mask_start
35
+ self.crop_start = crop_start
36
+
37
+ def __len__(self):
38
+ return len(self.audio_files)
39
+
40
+ # Get a random crop using audio_length
41
+ def get_random_crop(self, mel):
42
+ crop_start = torch.randint(0, mel.shape[0] - self.audio_len - 1, (1,))
43
+ return mel[crop_start:crop_start + self.audio_len, :]
44
+
45
+ def __getitem__(self, idx):
46
+ mel = torch.Tensor(np.load(self.audio_files[idx]))
47
+
48
+
49
+ if self.crop_start == "random":
50
+ mel = self.get_random_crop(mel)
51
+ else:
52
+ raise NotImplementedError(f"{self.crop_start} is not an implemented parameter for crop_start")
53
+
54
+ mask = torch.ones_like(mel)
55
+ if self.mask_start == "midpoint":
56
+ if self.mask_ratio == 0.5:
57
+ mask[self.mask_len:, :] = 0
58
+ else:
59
+ mask[self.audio_len // 2 + self.mask_len, :] = 0
60
+ else:
61
+ raise NotImplementedError(f"{self.mask_start} is not an implemented parameter for mask_start")
62
+
63
+ mel_mask = mel*mask
64
+
65
+ return mel, mel_mask
audio_diffusion_attacks_forhf/models/monotonic_align/.DS_Store ADDED
Binary file (6.15 kB). View file
 
audio_diffusion_attacks_forhf/models/monotonic_align/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/jaywalnut310/glow-tts """
2
+
3
+ import numpy as np
4
+ import torch
5
+ from .core import maximum_path_c
6
+
7
+
8
+ def maximum_path(value, mask):
9
+ """ Cython optimised version.
10
+ value: [b, t_x, t_y]
11
+ mask: [b, t_x, t_y]
12
+ """
13
+ value = value * mask
14
+ device = value.device
15
+ dtype = value.dtype
16
+ value = value.data.cpu().numpy().astype(np.float32)
17
+ path = np.zeros_like(value).astype(np.int32)
18
+ mask = mask.data.cpu().numpy()
19
+
20
+ t_x_max = mask.sum(1)[:, 0].astype(np.int32)
21
+ t_y_max = mask.sum(2)[:, 0].astype(np.int32)
22
+ maximum_path_c(path, value, t_x_max, t_y_max)
23
+ return torch.from_numpy(path).to(device=device, dtype=dtype)
audio_diffusion_attacks_forhf/models/monotonic_align/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (881 Bytes). View file
 
audio_diffusion_attacks_forhf/models/monotonic_align/build/temp.linux-x86_64-cpython-310/core.o ADDED
Binary file (236 kB). View file
 
audio_diffusion_attacks_forhf/models/monotonic_align/core.c ADDED
The diff for this file is too large to render. See raw diff
 
audio_diffusion_attacks_forhf/models/monotonic_align/core.cpython-310-x86_64-linux-gnu.so ADDED
Binary file (178 kB). View file
 
audio_diffusion_attacks_forhf/models/monotonic_align/core.pyx ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ cimport numpy as np
3
+ cimport cython
4
+ from cython.parallel import prange
5
+
6
+
7
+ @cython.boundscheck(False)
8
+ @cython.wraparound(False)
9
+ cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil:
10
+ cdef int x
11
+ cdef int y
12
+ cdef float v_prev
13
+ cdef float v_cur
14
+ cdef float tmp
15
+ cdef int index = t_x - 1
16
+
17
+ for y in range(t_y):
18
+ for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
19
+ if x == y:
20
+ v_cur = max_neg_val
21
+ else:
22
+ v_cur = value[x, y-1]
23
+ if x == 0:
24
+ if y == 0:
25
+ v_prev = 0.
26
+ else:
27
+ v_prev = max_neg_val
28
+ else:
29
+ v_prev = value[x-1, y-1]
30
+ value[x, y] = max(v_cur, v_prev) + value[x, y]
31
+
32
+ for y in range(t_y - 1, -1, -1):
33
+ path[index, y] = 1
34
+ if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]):
35
+ index = index - 1
36
+
37
+
38
+ @cython.boundscheck(False)
39
+ @cython.wraparound(False)
40
+ cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil:
41
+ cdef int b = values.shape[0]
42
+
43
+ cdef int i
44
+ for i in prange(b, nogil=True):
45
+ maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val)
audio_diffusion_attacks_forhf/models/monotonic_align/setup.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/jaywalnut310/glow-tts """
2
+
3
+ from distutils.core import setup
4
+ from Cython.Build import cythonize
5
+ import numpy
6
+
7
+ setup(
8
+ name = 'monotonic_align',
9
+ ext_modules = cythonize("core.pyx"),
10
+ include_dirs=[numpy.get_include()]
11
+ )
audio_diffusion_attacks_forhf/models/phoneme_encoder.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/jaywalnut310/glow-tts and https://github.com/huawei-noah/Speech-Backbones/blob/main/Grad-TTS/"""
2
+
3
+ import math
4
+
5
+ import torch
6
+
7
+
8
+ from models.utils import sequence_mask, convert_pad_shape
9
+
10
+
11
+ # def sequence_mask(length, max_length=None):
12
+ # if max_length is None:
13
+ # max_length = length.max()
14
+ # x = torch.arange(int(max_length), dtype=length.dtype, device=length.device)
15
+ # return x.unsqueeze(0) < length.unsqueeze(1)
16
+
17
+ # def convert_pad_shape(pad_shape):
18
+ # l = pad_shape[::-1]
19
+ # pad_shape = [item for sublist in l for item in sublist]
20
+ # return pad_shape
21
+
22
+ class BaseModule(torch.nn.Module):
23
+ def __init__(self):
24
+ super(BaseModule, self).__init__()
25
+
26
+ @property
27
+ def nparams(self):
28
+ """
29
+ Returns number of trainable parameters of the module.
30
+ """
31
+ num_params = 0
32
+ for name, param in self.named_parameters():
33
+ if param.requires_grad:
34
+ num_params += np.prod(param.detach().cpu().numpy().shape)
35
+ return num_params
36
+
37
+
38
+ def relocate_input(self, x: list):
39
+ """
40
+ Relocates provided tensors to the same device set for the module.
41
+ """
42
+ device = next(self.parameters()).device
43
+ for i in range(len(x)):
44
+ if isinstance(x[i], torch.Tensor) and x[i].device != device:
45
+ x[i] = x[i].to(device)
46
+ return x
47
+
48
+ class LayerNorm(BaseModule):
49
+ def __init__(self, channels, eps=1e-4):
50
+ super(LayerNorm, self).__init__()
51
+ self.channels = channels
52
+ self.eps = eps
53
+
54
+ self.gamma = torch.nn.Parameter(torch.ones(channels))
55
+ self.beta = torch.nn.Parameter(torch.zeros(channels))
56
+
57
+ def forward(self, x):
58
+ n_dims = len(x.shape)
59
+ mean = torch.mean(x, 1, keepdim=True)
60
+ variance = torch.mean((x - mean)**2, 1, keepdim=True)
61
+
62
+ x = (x - mean) * torch.rsqrt(variance + self.eps)
63
+
64
+ shape = [1, -1] + [1] * (n_dims - 2)
65
+ x = x * self.gamma.view(*shape) + self.beta.view(*shape)
66
+ return x
67
+
68
+
69
+ class ConvReluNorm(BaseModule):
70
+ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size,
71
+ n_layers, p_dropout):
72
+ super(ConvReluNorm, self).__init__()
73
+ self.in_channels = in_channels
74
+ self.hidden_channels = hidden_channels
75
+ self.out_channels = out_channels
76
+ self.kernel_size = kernel_size
77
+ self.n_layers = n_layers
78
+ self.p_dropout = p_dropout
79
+
80
+ self.conv_layers = torch.nn.ModuleList()
81
+ self.norm_layers = torch.nn.ModuleList()
82
+ self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels,
83
+ kernel_size, padding=kernel_size//2))
84
+ self.norm_layers.append(LayerNorm(hidden_channels))
85
+ self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout))
86
+ for _ in range(n_layers - 1):
87
+ self.conv_layers.append(torch.nn.Conv1d(hidden_channels, hidden_channels,
88
+ kernel_size, padding=kernel_size//2))
89
+ self.norm_layers.append(LayerNorm(hidden_channels))
90
+ self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
91
+ self.proj.weight.data.zero_()
92
+ self.proj.bias.data.zero_()
93
+
94
+ def forward(self, x, x_mask):
95
+ x_org = x
96
+ for i in range(self.n_layers):
97
+ x = self.conv_layers[i](x * x_mask)
98
+ x = self.norm_layers[i](x)
99
+ x = self.relu_drop(x)
100
+ x = x_org + self.proj(x)
101
+ return x * x_mask
102
+
103
+
104
+ class DurationPredictor(BaseModule):
105
+ def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
106
+ super(DurationPredictor, self).__init__()
107
+ self.in_channels = in_channels
108
+ self.filter_channels = filter_channels
109
+ self.p_dropout = p_dropout
110
+
111
+ self.drop = torch.nn.Dropout(p_dropout)
112
+ self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels,
113
+ kernel_size, padding=kernel_size//2)
114
+ self.norm_1 = LayerNorm(filter_channels)
115
+ self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels,
116
+ kernel_size, padding=kernel_size//2)
117
+ self.norm_2 = LayerNorm(filter_channels)
118
+ self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
119
+
120
+ def forward(self, x, x_mask):
121
+ x = self.conv_1(x * x_mask)
122
+ x = torch.relu(x)
123
+ x = self.norm_1(x)
124
+ x = self.drop(x)
125
+ x = self.conv_2(x * x_mask)
126
+ x = torch.relu(x)
127
+ x = self.norm_2(x)
128
+ x = self.drop(x)
129
+ x = self.proj(x * x_mask)
130
+ return x * x_mask
131
+
132
+
133
+ class MultiHeadAttention(BaseModule):
134
+ def __init__(self, channels, out_channels, n_heads, window_size=None,
135
+ heads_share=True, p_dropout=0.0, proximal_bias=False,
136
+ proximal_init=False):
137
+ super(MultiHeadAttention, self).__init__()
138
+ assert channels % n_heads == 0
139
+
140
+ self.channels = channels
141
+ self.out_channels = out_channels
142
+ self.n_heads = n_heads
143
+ self.window_size = window_size
144
+ self.heads_share = heads_share
145
+ self.proximal_bias = proximal_bias
146
+ self.p_dropout = p_dropout
147
+ self.attn = None
148
+
149
+ self.k_channels = channels // n_heads
150
+ self.conv_q = torch.nn.Conv1d(channels, channels, 1)
151
+ self.conv_k = torch.nn.Conv1d(channels, channels, 1)
152
+ self.conv_v = torch.nn.Conv1d(channels, channels, 1)
153
+ if window_size is not None:
154
+ n_heads_rel = 1 if heads_share else n_heads
155
+ rel_stddev = self.k_channels**-0.5
156
+ self.emb_rel_k = torch.nn.Parameter(torch.randn(n_heads_rel,
157
+ window_size * 2 + 1, self.k_channels) * rel_stddev)
158
+ self.emb_rel_v = torch.nn.Parameter(torch.randn(n_heads_rel,
159
+ window_size * 2 + 1, self.k_channels) * rel_stddev)
160
+ self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
161
+ self.drop = torch.nn.Dropout(p_dropout)
162
+
163
+ torch.nn.init.xavier_uniform_(self.conv_q.weight)
164
+ torch.nn.init.xavier_uniform_(self.conv_k.weight)
165
+ if proximal_init:
166
+ self.conv_k.weight.data.copy_(self.conv_q.weight.data)
167
+ self.conv_k.bias.data.copy_(self.conv_q.bias.data)
168
+ torch.nn.init.xavier_uniform_(self.conv_v.weight)
169
+
170
+ def forward(self, x, c, attn_mask=None):
171
+ q = self.conv_q(x)
172
+ k = self.conv_k(c)
173
+ v = self.conv_v(c)
174
+
175
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
176
+
177
+ x = self.conv_o(x)
178
+ return x
179
+
180
+ def attention(self, query, key, value, mask=None):
181
+ b, d, t_s, t_t = (*key.size(), query.size(2))
182
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
183
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
184
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
185
+
186
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
187
+ if self.window_size is not None:
188
+ assert t_s == t_t, "Relative attention is only available for self-attention."
189
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
190
+ rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
191
+ rel_logits = self._relative_position_to_absolute_position(rel_logits)
192
+ scores_local = rel_logits / math.sqrt(self.k_channels)
193
+ scores = scores + scores_local
194
+ if self.proximal_bias:
195
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
196
+ scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device,
197
+ dtype=scores.dtype)
198
+ if mask is not None:
199
+ scores = scores.masked_fill(mask == 0, -1e4)
200
+ p_attn = torch.nn.functional.softmax(scores, dim=-1)
201
+ p_attn = self.drop(p_attn)
202
+ output = torch.matmul(p_attn, value)
203
+ if self.window_size is not None:
204
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
205
+ value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
206
+ output = output + self._matmul_with_relative_values(relative_weights,
207
+ value_relative_embeddings)
208
+ output = output.transpose(2, 3).contiguous().view(b, d, t_t)
209
+ return output, p_attn
210
+
211
+ def _matmul_with_relative_values(self, x, y):
212
+ ret = torch.matmul(x, y.unsqueeze(0))
213
+ return ret
214
+
215
+ def _matmul_with_relative_keys(self, x, y):
216
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
217
+ return ret
218
+
219
+ def _get_relative_embeddings(self, relative_embeddings, length):
220
+ pad_length = max(length - (self.window_size + 1), 0)
221
+ slice_start_position = max((self.window_size + 1) - length, 0)
222
+ slice_end_position = slice_start_position + 2 * length - 1
223
+ if pad_length > 0:
224
+ padded_relative_embeddings = torch.nn.functional.pad(
225
+ relative_embeddings, convert_pad_shape([[0, 0],
226
+ [pad_length, pad_length], [0, 0]]))
227
+ else:
228
+ padded_relative_embeddings = relative_embeddings
229
+ used_relative_embeddings = padded_relative_embeddings[:,
230
+ slice_start_position:slice_end_position]
231
+ return used_relative_embeddings
232
+
233
+ def _relative_position_to_absolute_position(self, x):
234
+ batch, heads, length, _ = x.size()
235
+ x = torch.nn.functional.pad(x, convert_pad_shape([[0,0],[0,0],[0,0],[0,1]]))
236
+ x_flat = x.view([batch, heads, length * 2 * length])
237
+ x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0,0],[0,0],[0,length-1]]))
238
+ x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:]
239
+ return x_final
240
+
241
+ def _absolute_position_to_relative_position(self, x):
242
+ batch, heads, length, _ = x.size()
243
+ x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]]))
244
+ x_flat = x.view([batch, heads, length**2 + length*(length - 1)])
245
+ x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
246
+ x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:]
247
+ return x_final
248
+
249
+ def _attention_bias_proximal(self, length):
250
+ r = torch.arange(length, dtype=torch.float32)
251
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
252
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
253
+
254
+
255
+ class FFN(BaseModule):
256
+ def __init__(self, in_channels, out_channels, filter_channels, kernel_size,
257
+ p_dropout=0.0):
258
+ super(FFN, self).__init__()
259
+ self.in_channels = in_channels
260
+ self.out_channels = out_channels
261
+ self.filter_channels = filter_channels
262
+ self.kernel_size = kernel_size
263
+ self.p_dropout = p_dropout
264
+
265
+ self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size,
266
+ padding=kernel_size//2)
267
+ self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size,
268
+ padding=kernel_size//2)
269
+ self.drop = torch.nn.Dropout(p_dropout)
270
+
271
+ def forward(self, x, x_mask):
272
+ x = self.conv_1(x * x_mask)
273
+ x = torch.relu(x)
274
+ x = self.drop(x)
275
+ x = self.conv_2(x * x_mask)
276
+ return x * x_mask
277
+
278
+
279
+ class Encoder(BaseModule):
280
+ def __init__(self, hidden_channels, filter_channels, n_heads, n_layers,
281
+ kernel_size=1, p_dropout=0.0, window_size=None, **kwargs):
282
+ super(Encoder, self).__init__()
283
+ self.hidden_channels = hidden_channels
284
+ self.filter_channels = filter_channels
285
+ self.n_heads = n_heads
286
+ self.n_layers = n_layers
287
+ self.kernel_size = kernel_size
288
+ self.p_dropout = p_dropout
289
+ self.window_size = window_size
290
+
291
+ self.drop = torch.nn.Dropout(p_dropout)
292
+ self.attn_layers = torch.nn.ModuleList()
293
+ self.norm_layers_1 = torch.nn.ModuleList()
294
+ self.ffn_layers = torch.nn.ModuleList()
295
+ self.norm_layers_2 = torch.nn.ModuleList()
296
+ for _ in range(self.n_layers):
297
+ self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels,
298
+ n_heads, window_size=window_size, p_dropout=p_dropout))
299
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
300
+ self.ffn_layers.append(FFN(hidden_channels, hidden_channels,
301
+ filter_channels, kernel_size, p_dropout=p_dropout))
302
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
303
+
304
+ def forward(self, x, x_mask):
305
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
306
+ for i in range(self.n_layers):
307
+ x = x * x_mask
308
+ y = self.attn_layers[i](x, x, attn_mask)
309
+ y = self.drop(y)
310
+ x = self.norm_layers_1[i](x + y)
311
+ y = self.ffn_layers[i](x, x_mask)
312
+ y = self.drop(y)
313
+ x = self.norm_layers_2[i](x + y)
314
+ x = x * x_mask
315
+ return x
316
+
317
+
318
+ class TextEncoder(BaseModule):
319
+ def __init__(self, n_vocab, n_feats, n_channels, filter_channels,
320
+ filter_channels_dp, n_heads, n_layers, kernel_size,
321
+ p_dropout, window_size=None, spk_emb_dim=64, n_spks=1):
322
+ super(TextEncoder, self).__init__()
323
+ self.n_vocab = n_vocab
324
+ self.n_feats = n_feats
325
+ self.n_channels = n_channels
326
+ self.filter_channels = filter_channels
327
+ self.filter_channels_dp = filter_channels_dp
328
+ self.n_heads = n_heads
329
+ self.n_layers = n_layers
330
+ self.kernel_size = kernel_size
331
+ self.p_dropout = p_dropout
332
+ self.window_size = window_size
333
+ self.spk_emb_dim = spk_emb_dim
334
+ self.n_spks = n_spks
335
+
336
+ self.emb = torch.nn.Embedding(n_vocab, n_channels)
337
+ torch.nn.init.normal_(self.emb.weight, 0.0, n_channels**-0.5)
338
+
339
+ self.prenet = ConvReluNorm(n_channels, n_channels, n_channels,
340
+ kernel_size=5, n_layers=3, p_dropout=0.5)
341
+
342
+ self.encoder = Encoder(n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels, n_heads, n_layers,
343
+ kernel_size, p_dropout, window_size=window_size)
344
+
345
+ self.proj_m = torch.nn.Conv1d(n_channels + (spk_emb_dim if n_spks > 1 else 0), n_feats, 1)
346
+ self.proj_w = DurationPredictor(n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels_dp,
347
+ kernel_size, p_dropout)
348
+
349
+ def forward(self, x, x_lengths, spk=None):
350
+ x = self.emb(x) * math.sqrt(self.n_channels)
351
+ x = torch.transpose(x, 1, -1)
352
+ x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
353
+
354
+ x = self.prenet(x, x_mask)
355
+ if self.n_spks > 1:
356
+ x = torch.cat([x, spk.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
357
+ x = self.encoder(x, x_mask)
358
+ mu = self.proj_m(x) * x_mask
359
+
360
+ x_dp = torch.detach(x)
361
+ logw = self.proj_w(x_dp, x_mask)
362
+
363
+ return mu, logw, x_mask
audio_diffusion_attacks_forhf/models/style_diffusion.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ style_diffusion.py
3
+ Desc: Contains StyleVDiffusion models for training style transfer/editing models. These are essentially slight modifications of the original VDiffusion classes.
4
+ """
5
+
6
+ from math import pi
7
+ from typing import Any, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from einops import rearrange, repeat
13
+ from torch import Tensor
14
+ from tqdm import tqdm
15
+
16
+ from audio_diffusion_pytorch.utils import default
17
+ from audio_diffusion_pytorch import Diffusion, Sampler, VDiffusion, VSampler, LinearSchedule, Schedule, Distribution, UniformDistribution
18
+
19
+ def pad_dims(x: Tensor, ndim: int) -> Tensor:
20
+ # Pads additional ndims to the right of the tensor
21
+ return x.view(*x.shape, *((1,) * ndim))
22
+
23
+
24
+ def clip(x: Tensor, dynamic_threshold: float = 0.0):
25
+ if dynamic_threshold == 0.0:
26
+ return x.clamp(-1.0, 1.0)
27
+ else:
28
+ # Dynamic thresholding
29
+ # Find dynamic threshold quantile for each batch
30
+ x_flat = rearrange(x, "b ... -> b (...)")
31
+ scale = torch.quantile(x_flat.abs(), dynamic_threshold, dim=-1)
32
+ # Clamp to a min of 1.0
33
+ scale.clamp_(min=1.0)
34
+ # Clamp all values and scale
35
+ scale = pad_dims(scale, ndim=x.ndim - scale.ndim)
36
+ x = x.clamp(-scale, scale) / scale
37
+ return x
38
+
39
+
40
+ def extend_dim(x: Tensor, dim: int):
41
+ # e.g. if dim = 4: shape [b] => [b, 1, 1, 1],
42
+ return x.view(*x.shape + (1,) * (dim - x.ndim))
43
+
44
+ class StyleVDiffusion(Diffusion):
45
+ def __init__(
46
+ self, net: nn.Module, sigma_distribution: Distribution = UniformDistribution()
47
+ ):
48
+ super().__init__()
49
+ self.net = net
50
+ self.sigma_distribution = sigma_distribution
51
+
52
+ def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
53
+ angle = sigmas * pi / 2
54
+ alpha, beta = torch.cos(angle), torch.sin(angle)
55
+ return alpha, beta
56
+
57
+ def forward(self, x: Tensor, y: Tensor, **kwargs) -> Tensor: # type: ignore
58
+ batch_size, device = x.shape[0], x.device
59
+ # Sample amount of noise to add for each batch element
60
+ sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
61
+ sigmas_batch = extend_dim(sigmas, dim=y.ndim)
62
+ # Get noise
63
+ noise = torch.randn_like(y)
64
+ # Combine input and noise weighted by half-circle
65
+ alphas, betas = self.get_alpha_beta(sigmas_batch)
66
+ y_noisy = alphas * y + betas * noise
67
+ y_noisy = torch.concat((y_noisy, x), dim=1)
68
+ v_target = alphas * noise - betas * y
69
+ # Predict velocity and return loss
70
+ v_pred = self.net(y_noisy, sigmas, **kwargs)
71
+ return F.mse_loss(v_pred, v_target)
72
+
73
+
74
+ class StyleVSampler(Sampler):
75
+
76
+ diffusion_types = [VDiffusion]
77
+
78
+ def __init__(self, net: nn.Module, schedule: Schedule = LinearSchedule()):
79
+ super().__init__()
80
+ self.net = net
81
+ self.schedule = schedule
82
+
83
+ def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
84
+ angle = sigmas * pi / 2
85
+ alpha, beta = torch.cos(angle), torch.sin(angle)
86
+ return alpha, beta
87
+
88
+ @torch.no_grad()
89
+ def forward( # type: ignore
90
+ self, x:Tensor, x_noisy: Tensor, num_steps: int, show_progress: bool = False, **kwargs
91
+ ) -> Tensor:
92
+ b = x_noisy.shape[0]
93
+ x = x[None, ...]
94
+ sigmas = self.schedule(num_steps + 1, device=x_noisy.device)
95
+ sigmas = repeat(sigmas, "i -> i b", b=b)
96
+ sigmas_batch = extend_dim(sigmas, dim=x_noisy.ndim + 1)
97
+ alphas, betas = self.get_alpha_beta(sigmas_batch)
98
+ progress_bar = tqdm(range(num_steps), disable=not show_progress)
99
+
100
+ for i in progress_bar:
101
+ x_mix = torch.cat((x_noisy, x), dim=1)
102
+ v_pred = self.net(x_mix, sigmas[i], **kwargs)
103
+ x_pred = alphas[i] * x_noisy - betas[i] * v_pred
104
+ noise_pred = betas[i] * x_noisy + alphas[i] * v_pred
105
+ x_noisy = alphas[i + 1] * x_pred + betas[i + 1] * noise_pred
106
+ progress_bar.set_description(f"Sampling (noise={sigmas[i+1,0]:.2f})")
107
+
108
+ return x_noisy
109
+
110
+ if __name__ == "__main__":
111
+ print("Loaded dependencies correctly.")
audio_diffusion_attacks_forhf/models/utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ utils.py
3
+ Desc: A file for miscellaneous util functions
4
+ """
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+
10
+
11
+ # MonoTransform, this does not exist in PyTorch anymore since it is a simple mean calculation. We provide an implementation here
12
+ class MonoTransform(object):
13
+ """
14
+ Convert audio sample to mono channel
15
+
16
+ Args for __call__:
17
+ audio_sample with shape (C, T) or (B, C, T), where C is the number of channels.
18
+
19
+ TODO: IMPLEMENT __call__
20
+ """
21
+ def __init__(self):
22
+ pass
23
+
24
+ def __call__(self, sample):
25
+ pass
26
+
27
+ """
28
+ Below: Helper functions for Grad-TTS
29
+ """
30
+
31
+ ## Duration Loss
32
+ ## Desc: A function for computing the duration loss for the duration predictor
33
+ def duration_loss(logw, logw_, lengths):
34
+ loss = torch.sum((logw - logw_)**2) / torch.sum(lengths)
35
+ return loss
36
+
37
+ def intersperse(lst, item):
38
+ # Adds blank symbol
39
+ result = [item] * (len(lst) * 2 + 1)
40
+ result[1::2] = lst
41
+ return result
42
+
43
+
44
+ def sequence_mask(length, max_length=None):
45
+ if max_length is None:
46
+ max_length = length.max()
47
+ x = torch.arange(int(max_length), dtype=length.dtype, device=length.device)
48
+ return x.unsqueeze(0) < length.unsqueeze(1)
49
+
50
+
51
+ def fix_len_compatibility(length, num_downsamplings_in_unet=2):
52
+ while True:
53
+ if length % (2**num_downsamplings_in_unet) == 0:
54
+ return length
55
+ length += 1
56
+
57
+
58
+ def convert_pad_shape(pad_shape):
59
+ l = pad_shape[::-1]
60
+ pad_shape = [item for sublist in l for item in sublist]
61
+ return pad_shape
62
+
63
+
64
+ def generate_path(duration, mask):
65
+ device = duration.device
66
+
67
+ b, t_x, t_y = mask.shape
68
+ cum_duration = torch.cumsum(duration, 1)
69
+ path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
70
+
71
+ cum_duration_flat = cum_duration.view(b * t_x)
72
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
73
+ path = path.view(b, t_x, t_y)
74
+ path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0],
75
+ [1, 0], [0, 0]]))[:, :-1]
76
+ path = path * mask
77
+ return path
audio_diffusion_attacks_forhf/notebooks/data_exploration/00_fma_exploration.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
audio_diffusion_attacks_forhf/resources/cmu_dictionary ADDED
The diff for this file is too large to render. See raw diff
 
audio_diffusion_attacks_forhf/scripts/.DS_Store ADDED
Binary file (6.15 kB). View file
 
audio_diffusion_attacks_forhf/scripts/data_processing/process_music_mels.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ process_audio_mels.py
3
+ Desc: Run this script with the appropriate data paths to extract mels
4
+ Command: `python -u scripts/data/processing/process_music_mels.py`
5
+ """
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import matplotlib.pyplot as plt
10
+ import seaborn as sns
11
+ import IPython
12
+
13
+ import scipy
14
+ import torch
15
+ import torchaudio
16
+ import os
17
+ import ast
18
+ import soundfile as sf
19
+ import glob
20
+
21
+ # Old Code for Importing AudioLDM
22
+ # from audioldm.pipeline import build_model
23
+
24
+ # HF Code for AudioLDM2
25
+ # from diffusers import AudioLDM2Pipeline
26
+ from audioldm.audio import wav_to_fbank, TacotronSTFT
27
+ try:
28
+ from audioldm2 import build_model
29
+ except:
30
+ from audioldm2 import build_model
31
+
32
+ # TODO: Replace these with args
33
+ audio_path = "/data/robbizorg/music_datasets/fma/data/fma_large/"
34
+ target_audio_path = "/data/robbizorg/music_datasets/fma/data/fma_processed/"
35
+ remake = False
36
+
37
+ ## AudioLDM Mel Spec
38
+ default_mel_config = {
39
+ "preprocessing": {
40
+ "audio": {
41
+ "sampling_rate": 16000,
42
+ "max_wav_value": 32768,
43
+ "duration": 10.24,
44
+ },
45
+ "stft": {"filter_length": 1024, "hop_length": 160, "win_length": 1024},
46
+ "mel": {"n_mel_channels": 64, "mel_fmin": 0, "mel_fmax": 8000},
47
+ }}
48
+
49
+ fn_STFT = TacotronSTFT(
50
+ default_mel_config["preprocessing"]["stft"]["filter_length"],
51
+ default_mel_config["preprocessing"]["stft"]["hop_length"],
52
+ default_mel_config["preprocessing"]["stft"]["win_length"],
53
+ default_mel_config["preprocessing"]["mel"]["n_mel_channels"],
54
+ default_mel_config["preprocessing"]["audio"]["sampling_rate"],
55
+ default_mel_config["preprocessing"]["mel"]["mel_fmin"],
56
+ default_mel_config["preprocessing"]["mel"]["mel_fmax"],
57
+ )
58
+
59
+ if __name__ == "__main__":
60
+ audio_files = glob.glob(os.path.join(audio_path, "*/*.mp3"))
61
+ failed_files = []
62
+
63
+ # Preprocess all mel_specs
64
+ for i, f in enumerate(audio_files):
65
+ if i % 1000 == 0:
66
+ print(f"{i} of {len(audio_files)} files have been processed.")
67
+ dir_info = f.split("/")
68
+ filename = dir_info[-1].split(".")[0]
69
+ parent_dir = dir_info[-2]
70
+
71
+ # Skip the file if it's already generated
72
+ if not remake and os.path.exists(os.path.join(target_audio_path, parent_dir, filename + '.npy')):
73
+ continue
74
+
75
+ try:
76
+ audio, sr = torchaudio.load(f)
77
+ except:
78
+ failed_files.append(f)
79
+ print(f"Failed on File {f}")
80
+ continue
81
+
82
+
83
+ if audio.shape[0] == 2:
84
+ mono_audio = torch.mean(audio, axis = 0) # Convert to Mono
85
+ else:
86
+ mono_audio = audio[0, :] # remove channel info
87
+
88
+ # Resample Audio
89
+ resamp_16k = torchaudio.functional.resample(mono_audio, sr, 16000)
90
+
91
+ duration = resamp_16k.shape[0]/16000
92
+ target_length = int(duration * 100) # int(duration * 102.4)
93
+
94
+ mel, _, _ = wav_to_fbank(resamp_16k.cpu(), target_length=target_length, fn_STFT=fn_STFT)
95
+
96
+ # Make parent dir
97
+ if not os.path.exists(os.path.join(target_audio_path, parent_dir)):
98
+ os.mkdir(os.path.join(target_audio_path, parent_dir))
99
+
100
+ with open(os.path.join(target_audio_path, parent_dir, filename + '.npy'), 'wb') as numpy_f:
101
+ np.save(numpy_f, mel.numpy())
102
+
103
+ print("Failed_Files:", len(failed_files))
104
+ for x in failed_files:
105
+ print(x)
106
+
audio_diffusion_attacks_forhf/scripts/data_processing/process_music_numpy.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ process_music_numpy.py
3
+ Desc: Run this script with the appropriate data paths to preprocess audio files and convert to 48k for the ArchiSound encoders
4
+ Command: `python -u scripts/data/processing/process_music_mels.py`
5
+ """
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import matplotlib.pyplot as plt
10
+ import seaborn as sns
11
+ import IPython
12
+
13
+ import scipy
14
+ import torch
15
+ import torchaudio
16
+ import os
17
+ import ast
18
+ import soundfile as sf
19
+ import glob
20
+
21
+ # Old Code for Importing AudioLDM
22
+ # from audioldm.pipeline import build_model
23
+
24
+ # HF Code for AudioLDM2
25
+ # from diffusers import AudioLDM2Pipeline
26
+ from audioldm.audio import wav_to_fbank, TacotronSTFT
27
+ try:
28
+ from audioldm2 import build_model
29
+ except:
30
+ from audioldm2 import build_model
31
+
32
+ # TODO: Replace these with args
33
+ audio_path = "/data/robbizorg/music_datasets/fma/data/fma_large/"
34
+ target_audio_path = "/data/robbizorg/music_datasets/fma/data/fma_processed_48k/"
35
+ remake = False
36
+
37
+ if __name__ == "__main__":
38
+ audio_files = glob.glob(os.path.join(audio_path, "*/*.mp3"))
39
+ failed_files = []
40
+
41
+ # Preprocess all mel_specs
42
+ for i, f in enumerate(audio_files):
43
+ if i % 1000 == 0:
44
+ print(f"{i} of {len(audio_files)} files have been processed.")
45
+ dir_info = f.split("/")
46
+ filename = dir_info[-1].split(".")[0]
47
+ parent_dir = dir_info[-2]
48
+
49
+ # Skip the file if it's already generated
50
+ if not remake and os.path.exists(os.path.join(target_audio_path, parent_dir, filename + '.npy')):
51
+ continue
52
+
53
+ try:
54
+ audio, sr = torchaudio.load(f)
55
+ except:
56
+ failed_files.append(f)
57
+ print(f"Failed on File {f}")
58
+ continue
59
+
60
+
61
+ # Resample Audio--Don't need to make mono since archisound encoders take in stereo
62
+ resamp_48k = torchaudio.functional.resample(audio, sr, 48000)
63
+
64
+ # Make parent dir
65
+ if not os.path.exists(os.path.join(target_audio_path, parent_dir)):
66
+ os.mkdir(os.path.join(target_audio_path, parent_dir))
67
+
68
+ with open(os.path.join(target_audio_path, parent_dir, filename + '.npy'), 'wb') as numpy_f:
69
+ np.save(numpy_f, resamp_48k.numpy())
70
+
71
+ print("Failed_Files:", len(failed_files))
72
+ for x in failed_files:
73
+ print(x)
74
+
audio_diffusion_attacks_forhf/scripts/train/music_models/train_music_completion.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ train_music_completion.py
3
+ Desc: Train a model for completing a 3 seconds of audio given 3 seconds of music as input.
4
+ Note: There are two possible approaches for this task
5
+ 1. Perform masking and try to get the model to fill in the blank with StyleVDiffusion
6
+ 2. Condition on the mel-spec with VDiffusion
7
+ """
8
+
9
+ import sys
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torchaudio
15
+ import gc
16
+ import argparse
17
+ import os
18
+ from tqdm import tqdm
19
+ import wandb
20
+ from audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion, VSampler
21
+ import soundfile as sf
22
+
23
+ sys.path.append(".")
24
+ from models.style_diffusion import StyleVDiffusion, StyleVSampler
25
+ from models.datasets.music_datasets import MusicMelDataset
26
+
27
+ import logging
28
+
29
+ from audioldm.audio import wav_to_fbank, TacotronSTFT
30
+ from audioldm2 import build_model
31
+
32
+
33
+ # Uncomment out below if wanting to supress
34
+ import warnings
35
+ warnings.filterwarnings("ignore")
36
+
37
+ # Set Sample Rate if like so if desired
38
+ SAMPLE_RATE = 16000
39
+ BATCH_SIZE = 16
40
+
41
+
42
+ # Function for creating a model that acts on mel-specs
43
+ def create_mel_model():
44
+ return DiffusionModel(
45
+ net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)
46
+ # dim=2, # for spectrogram we can use 2D-CNN, but not going to for now
47
+ in_channels=600, # U-Net: number of input (time) channels
48
+ out_channels=300, # U-Net: number of output (time) channels
49
+ channels=[8, 32, 64, 128, 256, 512], # U-Net: channels at each layer
50
+ factors=[2, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
51
+ items=[2, 2, 2, 2, 2, 2], # U-Net: number of repeating items at each layer
52
+ attentions=[0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer
53
+ attention_heads=8, # U-Net: number of attention heads per attention item
54
+ attention_features=64, # U-Net: number of attention features per attention item
55
+ diffusion_t=StyleVDiffusion, # The diffusion method used
56
+ sampler_t=StyleVSampler, # The diffusion sampler used
57
+ # embedding_features = 7, # Embedding Features for when conditioned
58
+ # cross_attentions=[0, 0, 0, 0, 1, 1, 1, 1]
59
+ )
60
+
61
+ def create_2Dmel_model():
62
+ return DiffusionModel(
63
+ net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)
64
+ dim=2, # for spectrogram we can use 2D-CNN
65
+ in_channels=2, # U-Net: number of input (time) channels
66
+ out_channels=1, # U-Net: number of output (time) channels
67
+ channels=[8, 32, 64, 128, 256, 512], # U-Net: channels at each layer
68
+ factors=[2, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
69
+ items=[2, 2, 2, 2, 2, 2], # U-Net: number of repeating items at each layer
70
+ attentions=[0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer
71
+ attention_heads=8, # U-Net: number of attention heads per attention item
72
+ attention_features=64, # U-Net: number of attention features per attention item
73
+ diffusion_t=StyleVDiffusion, # The diffusion method used
74
+ sampler_t=StyleVSampler, # The diffusion sampler used
75
+ # embedding_features = 7, # Embedding Features for when conditioned
76
+ # cross_attentions=[0, 0, 0, 0, 1, 1, 1, 1]
77
+ )
78
+
79
+ def parse_args():
80
+ parser = argparse.ArgumentParser()
81
+ parser.add_argument("--checkpoint", type=str, default='/data/robbizorg/attacksanddefenses/checkpoints/')
82
+ parser.add_argument("--resume", action="store_true")
83
+ parser.add_argument("--run_id", type=str, default='')
84
+ parser.add_argument("--debug", action="store_true")
85
+ parser.add_argument("--data_path", type = str, default = "./data/fma_valid_files.npy")
86
+ parser.add_argument("--epoch_num", type = int, default = 101)
87
+ args = parser.parse_args()
88
+ return vars(args)
89
+
90
+ if __name__ == "__main__":
91
+
92
+ args = parse_args()
93
+
94
+ if args['run_id'] == '':
95
+ raise ValueError(f"Please provide a run_id for this training session.")
96
+
97
+ cuda_ids = [phy_id for phy_id in range(len(os.environ["CUDA_VISIBLE_DEVICES"].split(",")))]
98
+
99
+ if len(cuda_ids) > 1:
100
+ raise NotImplementedError("Currently training is only allowed on a single GPU")
101
+
102
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
103
+
104
+ logging.basicConfig(
105
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
106
+ datefmt="%Y-%m-%d %H:%M:%S",
107
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
108
+ stream=sys.stdout,
109
+ filemode='w',
110
+ )
111
+ logger = logging.getLogger("")
112
+
113
+ audio_files = list(np.load(args['data_path'], allow_pickle = True).item())
114
+
115
+ dataset = MusicMelDataset(audio_files, audio_len=5.12)
116
+
117
+ print(f"Dataset length: {len(dataset)}")
118
+
119
+ dataloader = torch.utils.data.DataLoader(
120
+ dataset,
121
+ batch_size=BATCH_SIZE,
122
+ shuffle=True,
123
+ num_workers=16,
124
+ pin_memory=False,
125
+ )
126
+
127
+ # Use this model for vocoder
128
+ vae_model = build_model().to(device)
129
+
130
+ diff_model = create_2Dmel_model().to(device)
131
+
132
+
133
+ optimizer = torch.optim.AdamW(params=list(diff_model.parameters()), lr=1e-4, betas= (0.95, 0.999), eps=1e-6, weight_decay=1e-3)
134
+
135
+ print(f"Number of parameters: {sum(p.numel() for p in diff_model.parameters() if p.requires_grad)}")
136
+
137
+ if not args['debug']:
138
+ run_id = wandb.util.generate_id()
139
+ if args["run_id"] is not None:
140
+ run_id = args["run_id"]
141
+ print(f"Run ID: {run_id}")
142
+
143
+ wandb.init(project="music-completion", resume=args["resume"], id=run_id)
144
+
145
+ epoch = 0
146
+ step = 0
147
+
148
+ checkpoint_path = os.path.join(args["checkpoint"], args["run_id"])
149
+
150
+ if not os.path.exists(checkpoint_path):
151
+ os.makedirs(checkpoint_path)
152
+ os.makedirs(os.path.join(checkpoint_path, "mels"))
153
+ os.makedirs(os.path.join(checkpoint_path, "wavs"))
154
+
155
+
156
+ if not args['debug'] and wandb.run.resumed:
157
+ if os.path.exists(checkpoint_path):
158
+ checkpoint = torch.load(checkpoint_path)
159
+ else:
160
+ checkpoint = torch.load(wandb.restore(checkpoint_path))
161
+ diff_model.load_state_dict(checkpoint['model_state_dict'])
162
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
163
+ epoch = checkpoint['epoch']
164
+ step = epoch * len(dataloader)
165
+
166
+ scaler = torch.cuda.amp.GradScaler()
167
+
168
+ diff_model.train()
169
+
170
+ while epoch < args['epoch_num']:
171
+ avg_loss = 0
172
+ avg_loss_step = 0
173
+ progress = tqdm(dataloader)
174
+ for i, (audio, masked_audio) in enumerate(progress):
175
+ optimizer.zero_grad()
176
+ # audio = torch.swapaxes(audio.to(device), 1, 2)
177
+ # masked_audio = torch.swapaxes(masked_audio.to(device), 1, 2)
178
+ # audio = audio.to(device)
179
+ # masked_audio = masked_audio.to(device)
180
+ audio = audio.to(device).unsqueeze(1)
181
+ masked_audio = masked_audio.to(device).unsqueeze(1)
182
+
183
+ with torch.cuda.amp.autocast():
184
+ loss = diff_model(masked_audio, audio)
185
+ avg_loss += loss.item()
186
+ avg_loss_step += 1
187
+ scaler.scale(loss).backward()
188
+ scaler.step(optimizer)
189
+ scaler.update()
190
+ progress.set_postfix(
191
+ # loss=loss.item(),
192
+ loss=avg_loss / avg_loss_step,
193
+ epoch=epoch + i / len(dataloader),
194
+ )
195
+
196
+ if step % 500 == 0:
197
+ # if step % 1 == 0:
198
+ # Turn noise into new audio sample with diffusion
199
+ # noise = torch.randn(1, 300, 64, device=device)
200
+ # noise = torch.randn(1, 64, 300, device=device)
201
+ noise = torch.randn(1, 1, 512, 64, device=device) # 2D example
202
+
203
+
204
+ with torch.cuda.amp.autocast():
205
+ sample = diff_model.sample(masked_audio[0], noise, num_steps=200)
206
+
207
+ orig_wav = vae_model.mel_spectrogram_to_waveform(audio[0].unsqueeze(0), save = False)[0][0].astype(np.float32) # 1, 1, len
208
+ gen_wav = vae_model.mel_spectrogram_to_waveform(sample, save = False)[0][0].astype(np.float32) # 1 1 len
209
+
210
+ orig_dir = os.path.join(checkpoint_path, 'wavs', f'target_{step}0.wav')
211
+ gen_dir = os.path.join(checkpoint_path, 'wavs', f'gen_{step}0.wav')
212
+ sf.write(orig_dir, orig_wav, samplerate = 16000)
213
+ sf.write(gen_dir, gen_wav, samplerate = 16000)
214
+
215
+ if not args['debug']:
216
+ wandb.log({
217
+ "step": step,
218
+ "epoch": epoch + i / len(dataloader),
219
+ "loss": avg_loss / avg_loss_step,
220
+ "target_audio": wandb.Audio(orig_dir, caption="Target audio", sample_rate=SAMPLE_RATE),
221
+ "generated_audio": wandb.Audio(gen_dir, caption="Generated audio", sample_rate=SAMPLE_RATE)
222
+ })
223
+
224
+ if not args['debug'] and step % 100 == 0:
225
+ wandb.log({
226
+ "step": step,
227
+ "epoch": epoch + i / len(dataloader),
228
+ "loss": avg_loss / avg_loss_step,
229
+ })
230
+ avg_loss = 0
231
+ avg_loss_step = 0
232
+
233
+ step += 1
234
+
235
+ epoch += 1
236
+
237
+ if not args['debug'] and epoch % 100 == 0:
238
+ torch.save({
239
+ 'epoch': epoch,
240
+ 'model_state_dict': diff_model.state_dict(),
241
+ 'optimizer_state_dict': optimizer.state_dict(),
242
+ }, os.path.join(checkpoint_path, f"epoch-{epoch}.pt"))
243
+ wandb.save(checkpoint_path, base_path=args["checkpoint"])
audio_diffusion_attacks_forhf/scripts/train/train_tts.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ train_tts.py
3
+ Desc: An example script for training a Diffusion-based TTS model with a speaker encoder.
4
+ """
5
+
6
+ import sys
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torchaudio
11
+ import gc
12
+ import argparse
13
+ import os
14
+ from tqdm import tqdm
15
+ import wandb
16
+ from audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion, VSampler
17
+
18
+ sys.path.append(".")
19
+ from models.style_diffusion import StyleVDiffusion, StyleVSampler
20
+ # from models.utils import MonoTransform
21
+
22
+ # from util import calculate_codebook_bitrate, extract_melspectrogram, get_audio_file_bitrate, get_duration, load_neural_audio_codec
23
+ from audioldm.pipeline import build_model
24
+ import torch.multiprocessing as mp
25
+
26
+ # Needed for Instruction/Prompt Models
27
+ # from transformers import AutoTokenizer, T5EncoderModel
28
+
29
+ import logging
30
+
31
+ # Uncomment out below if wanting to supress
32
+ # import warnings
33
+ # warnings.filterwarnings("ignore")
34
+
35
+ # Set Sample Rate if like so if desired
36
+ SAMPLE_RATE = 16000
37
+ BATCH_SIZE = 16
38
+ NUM_SAMPLES = int(2.56 * SAMPLE_RATE)
39
+ # NUM_SAMPLES = 2 ** 15
40
+
41
+
42
+ def create_model():
43
+ return DiffusionModel(
44
+ net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)
45
+ # dim=2, # for spectrogram we use 2D-CNN
46
+ in_channels=314, # U-Net: number of input (audio) channels
47
+ out_channels=157, # U-Net: number of output (audio) channels
48
+ channels=[256, 256, 512, 512, 768, 768, 1280, 1280], # U-Net: channels at each layer
49
+ factors=[2, 2, 2, 2, 2, 2, 2, 1], # U-Net: downsampling and upsampling factors at each layer
50
+ items=[2, 2, 2, 2, 2, 2, 2, 2], # U-Net: number of repeating items at each layer
51
+ attentions=[0, 0, 0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer
52
+ attention_heads=8, # U-Net: number of attention heads per attention item
53
+ attention_features=64, # U-Net: number of attention features per attention item
54
+ diffusion_t=StyleVDiffusion, # The diffusion method used
55
+ sampler_t=StyleVSampler, # The diffusion sampler used
56
+ # embedding_features = 8,
57
+ # embedding_features = 2, # Embedding for when it's just res and weight
58
+ embedding_features = 7, # Embedding Features for when Severity is Dropped
59
+ cross_attentions=[0, 0, 0, 0, 1, 1, 1, 1]
60
+ )
61
+
62
+ def main():
63
+ pass
64
+ # args = parse_args()
65
+
66
+ # os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'
67
+ # os.environ["CUDA_VISIBLE_DEVICES"] = args['cuda_ids']
68
+ # cuda_ids = [phy_id for phy_id in range(len(args['cuda_ids'].split(",")))]
69
+
70
+ # logging.basicConfig(
71
+ # format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
72
+ # datefmt="%Y-%m-%d %H:%M:%S",
73
+ # level=os.environ.get("LOGLEVEL", "INFO").upper(),
74
+ # stream=sys.stdout,
75
+ # filemode='w',
76
+ # )
77
+ # logger = logging.getLogger("")
78
+
79
+ # # mp.set_start_method('spawn')
80
+ # # mp.set_sharing_strategy('file_system')
81
+
82
+ # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
83
+
84
+ # # Load in text model
85
+ # # tokenizer = AutoTokenizer.from_pretrained("t5-small")
86
+ # # text_model = T5EncoderModel.from_pretrained("t5-small")
87
+ # # text_model.eval() # Don't want to train it!
88
+
89
+
90
+ # dataset = DSVAE_CondStyleWAVDataset(
91
+ # path="/data/robbizorg/pqvd_gen_w_conditioning/speech_non_speech_timesteps_VCTK.json",
92
+ # random_crop_size=NUM_SAMPLES,
93
+ # sample_rate=SAMPLE_RATE,
94
+ # transforms=AllTransform(
95
+ # mono=True,
96
+ # ),
97
+ # reconstructive = False, # Make this true to just train a reconstructive model
98
+ # identity_limit = 1 # Affects how often we learn identity mapping
99
+ # )
100
+
101
+ # print(f"Dataset length: {len(dataset)}")
102
+
103
+ # dataloader = torch.utils.data.DataLoader(
104
+ # dataset,
105
+ # batch_size=BATCH_SIZE,
106
+ # shuffle=True,
107
+ # num_workers=16,
108
+ # pin_memory=False,
109
+ # )
110
+
111
+ # vae_model = DSVAE(logger, **args).cuda()
112
+
113
+ # if not os.path.exists(args['model_path']):
114
+ # logger.warning("model not exist and we just create the new model......")
115
+ # else:
116
+ # logger.info("Model Exists")
117
+ # logger.info("Model Path is " + args['model_path'])
118
+ # vae_model.loadParameters(args['model_path'])
119
+ # vae_model = torch.nn.DataParallel(vae_model, device_ids = cuda_ids, output_device=cuda_ids[0])
120
+ # vae_model = vae_model.cuda()
121
+ # vae_model.eval()
122
+ # vae_model.module.eer = True
123
+
124
+ # diff_model = create_model().to(device)
125
+ # # audio_codec = build_model().to(device)
126
+ # # audio_codec.latent_t_size = 157
127
+ # # config, audio_codec, vocoder = load_neural_audio_codec('2021-05-19T22-16-54_vggsound_codebook', './logs', device)
128
+
129
+
130
+ # # optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
131
+ # optimizer = torch.optim.AdamW(params=list(diff_model.parameters()), lr=1e-4, betas= (0.95, 0.999), eps=1e-6, weight_decay=1e-3)
132
+
133
+ # print(f"Number of parameters: {sum(p.numel() for p in diff_model.parameters() if p.requires_grad)}")
134
+
135
+ # run_id = wandb.util.generate_id()
136
+ # if args["run_id"] is not None:
137
+ # run_id = args["run_id"]
138
+ # print(f"Run ID: {run_id}")
139
+
140
+ # wandb.init(project="audio-diffusion-no-condition", resume=args["resume"], id=run_id)
141
+
142
+ # epoch = 0
143
+ # step = 0
144
+
145
+ # checkpoint_path = os.path.join(args["checkpoint"], args["run_id"])
146
+
147
+ # if not os.path.exists(checkpoint_path):
148
+ # os.makedirs(checkpoint_path)
149
+ # os.makedirs(os.path.join(checkpoint_path, "mels"))
150
+ # os.makedirs(os.path.join(checkpoint_path, "wavs"))
151
+
152
+
153
+ # if wandb.run.resumed:
154
+ # if os.path.exists(checkpoint_path):
155
+ # checkpoint = torch.load(checkpoint_path)
156
+ # else:
157
+ # checkpoint = torch.load(wandb.restore(checkpoint_path))
158
+ # diff_model.load_state_dict(checkpoint['model_state_dict'])
159
+ # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
160
+ # epoch = checkpoint['epoch']
161
+ # step = epoch * len(dataloader)
162
+
163
+ # scaler = torch.cuda.amp.GradScaler()
164
+
165
+ # diff_model.train()
166
+ # while epoch < 101:
167
+ # avg_loss = 0
168
+ # avg_loss_step = 0
169
+ # progress = tqdm(dataloader)
170
+ # for i, (audio, target, embedding) in enumerate(progress):
171
+ # optimizer.zero_grad()
172
+ # audio = audio.to(device)
173
+ # target = target.to(device)
174
+ # embedding = embedding.to(device)
175
+
176
+ # with torch.no_grad():
177
+ # embedding = embedding.float() # Make it float like the others
178
+
179
+ # speaker_embed_source, content_embed_source = vae_model(audio)
180
+ # speaker_embed_source = speaker_embed_source.unsqueeze(1).expand(-1, 157, -1)
181
+
182
+ # audio_embed = torch.cat((speaker_embed_source, content_embed_source), axis = -1)
183
+
184
+ # # zeroes = torch.zeros(16, 3, 128, dtype=audio_embed.dtype, device = audio_embed.device)
185
+ # # audio_embed = torch.cat((audio_embed, zeroes), dim=1)
186
+
187
+ # speaker_embed, content_embed = vae_model(target)
188
+ # speaker_embed = speaker_embed.unsqueeze(1).expand(-1, 157, -1)
189
+
190
+ # # in order to simulate paired data, do (naive) voice conversion first
191
+ # target_embed = torch.cat((speaker_embed, content_embed_source), axis = -1)
192
+ # # target_embed = torch.cat((target_embed, zeroes), dim = 1)
193
+
194
+ # with torch.cuda.amp.autocast():
195
+ # loss = diff_model(audio_embed, target_embed, embedding=embedding)
196
+ # avg_loss += loss.item()
197
+ # avg_loss_step += 1
198
+ # scaler.scale(loss).backward()
199
+ # scaler.step(optimizer)
200
+ # scaler.update()
201
+ # progress.set_postfix(
202
+ # # loss=loss.item(),
203
+ # loss=avg_loss / avg_loss_step,
204
+ # epoch=epoch + i / len(dataloader),
205
+ # )
206
+
207
+ # if step % 500 == 0:
208
+ # # if step % 1 == 0:
209
+ # # Turn noise into new audio sample with diffusion
210
+ # noise = torch.randn(1, 157, 128, device=device)
211
+
212
+
213
+ # with torch.cuda.amp.autocast():
214
+ # sample = diff_model.sample(audio_embed[0], noise, embedding=embedding[0][None, :], num_steps=200)
215
+
216
+
217
+
218
+ # # Save the melspecs
219
+ # audio_sub = torch.swapaxes(audio[0].unsqueeze(0), 1, 2)
220
+ # # target_sub = torch.swapaxes(target[0].unsqueeze(0), 1, 2) # This is the original target audio, not what we want
221
+ # target_sub = vae_model.module.share_decoder(target_embed).loc
222
+ # gen_mel = vae_model.module.share_decoder(sample).loc
223
+
224
+ # vae_model.module.draw_mel(audio_sub, mode=f"source_{step}", file_path = os.path.join(checkpoint_path, "mels"))
225
+ # vae_model.module.draw_mel(target_sub, mode=f"target_{step}", file_path = os.path.join(checkpoint_path, "mels"))
226
+ # vae_model.module.draw_mel(gen_mel, mode=f"gen_{step}", file_path = os.path.join(checkpoint_path, "mels"))
227
+
228
+ # vae_model.module.mel2wav(audio_sub, mode=f"source_{step}", task="vc", file_path = os.path.join(checkpoint_path, "wavs"))
229
+ # vae_model.module.mel2wav(target_sub, mode=f"target_{step}", task="vc", file_path = os.path.join(checkpoint_path, "wavs"))
230
+ # vae_model.module.mel2wav(gen_mel, mode=f"gen_{step}", task="vc", file_path = os.path.join(checkpoint_path, "wavs"))
231
+
232
+ # # torchaudio.save(os.path.join(checkpoint_path, 'wavs', f'test_input_sound_{step}.wav'), torch.from_numpy(audio_codec.mel_spectrogram_to_waveform(audio_codec.decode_first_stage(audio[0].unsqueeze(0))))[0], SAMPLE_RATE)
233
+ # # torchaudio.save(os.path.join(checkpoint_path, 'wavs', f'test_generated_sound_{step}.wav'), torch.from_numpy(audio_codec.mel_spectrogram_to_waveform(audio_codec.decode_first_stage(sample[0].unsqueeze(0))))[0], SAMPLE_RATE)
234
+ # # torchaudio.save(os.path.join(checkpoint_path, 'wavs', f'test_target_sound_{step}.wav'), torch.from_numpy(audio_codec.mel_spectrogram_to_waveform(audio_codec.decode_first_stage(target[0].unsqueeze(0))))[0], SAMPLE_RATE)
235
+
236
+
237
+
238
+ # wandb.log({
239
+ # "step": step,
240
+ # "epoch": epoch + i / len(dataloader),
241
+ # "loss": avg_loss / avg_loss_step,
242
+ # "input_mel": wandb.Image(os.path.join(checkpoint_path, "mels", f"source_{step}_mel_0.png"), caption="Input Mel"),
243
+ # "target_mel": wandb.Image(os.path.join(checkpoint_path, "mels", f"target_{step}_mel_0.png"), caption="Target Mel"),
244
+ # "gen_mel": wandb.Image(os.path.join(checkpoint_path, "mels", f"gen_{step}_mel_0.png"), caption="Gen Mel"),
245
+ # "input_audio": wandb.Audio(os.path.join(checkpoint_path, 'wavs', f'source_{step}0.wav'), caption="Input audio", sample_rate=SAMPLE_RATE),
246
+ # "target_audio": wandb.Audio(os.path.join(checkpoint_path, 'wavs', f'target_{step}0.wav'), caption="Target audio", sample_rate=SAMPLE_RATE),
247
+ # "generated_audio": wandb.Audio(os.path.join(checkpoint_path, 'wavs', f'gen_{step}0.wav'), caption="Generated audio", sample_rate=SAMPLE_RATE)
248
+ # })
249
+
250
+ # if step % 100 == 0:
251
+ # wandb.log({
252
+ # "step": step,
253
+ # "epoch": epoch + i / len(dataloader),
254
+ # "loss": avg_loss / avg_loss_step,
255
+ # })
256
+ # avg_loss = 0
257
+ # avg_loss_step = 0
258
+
259
+ # step += 1
260
+
261
+ # epoch += 1
262
+
263
+ # if epoch % 100 == 0:
264
+ # torch.save({
265
+ # 'epoch': epoch,
266
+ # 'model_state_dict': diff_model.state_dict(),
267
+ # 'optimizer_state_dict': optimizer.state_dict(),
268
+ # }, os.path.join(checkpoint_path, f"epoch-{epoch}.pt"))
269
+ # wandb.save(checkpoint_path, base_path=args["checkpoint"])
270
+
271
+
272
+ # def parse_args():
273
+ # parser = argparse.ArgumentParser()
274
+ # parser.add_argument("--checkpoint", type=str, default='/data/robbizorg/pqvd_gen_w_dsvae/checkpoints/')
275
+ # parser.add_argument("--resume", action="store_true")
276
+ # parser.add_argument("--run_id", type=str, default='condition_ldm')
277
+
278
+ # ## Params from DSVAE
279
+ # parser.add_argument('--dataset', type=str, default="VCTK", help='VCTK, LibriTTS')
280
+ # parser.add_argument('--encoder', type=str, default='dsvae', help='dsvae. tdnn')
281
+ # parser.add_argument('--vocoder', type=str, default='hifigan', help='wavenet, hifigan')
282
+ # parser.add_argument('--save_tsne', dest='save_tsne', action='store_true', help='save_tsne')
283
+ # parser.add_argument('--mel_tsne', dest='mel_tsne', action='store_true', help='mel_tsne')
284
+ # parser.add_argument('--feature', type=str, default='mel_spec', help='stft, mel_spec, mfcc')
285
+ # parser.add_argument('--model_path', type=str, default='/home/robbizorg/research/dsvae/save_models/dsvae/best699.pth')
286
+ # # parser.add_argument('--model_path', type=str, default='/data/andreaguz/save_models/dsvae_003_03/best699.pth') # Using the fine-tuned dsvae
287
+ # # parser.add_argument('--model_path', type=str, default='/data/andreaguz/save_models/dsvae_0001_0005/best.pth') # Using the fine-tuned dsvae
288
+ # parser.add_argument('--save_path', type=str, default='save_models/dsvae')
289
+ # parser.add_argument('--cuda_ids', type=str, default='0')
290
+ # parser.add_argument('--tsne_mode', type=str, default='test')
291
+ # parser.add_argument("--optimizer", type=str, default='adam', help='sgd, adam')
292
+ # parser.add_argument("--path_vc_1", type=str, default='', help='')
293
+ # parser.add_argument("--path_vc_2", type=str, default='', help='')
294
+ # parser.add_argument('--max_frames', type=int, default=100, help='1frame~10ms')
295
+ # parser.add_argument("--hop_size", type=int, default=256, help='hop_size')
296
+ # parser.add_argument("--win_length", type=int, default=1024, help='win_length')
297
+ # parser.add_argument("--spk_dim", type=int, default=64, help='spk_embed')
298
+ # parser.add_argument("--ecapa_spk_dim", type=int, default=128, help='ecapa spk_embed')
299
+ # parser.add_argument("--content_dim", type=int, default=64, help="content_embed")
300
+ # parser.add_argument("--conformer_hidden_dim", type=int, default=256, help="content_embed")
301
+ # parser.add_argument('--n_epochs', type=int, default=700, help='n_epochs')
302
+ # parser.add_argument('--eval_epoch', type=int, default=5, help='eval_epoch')
303
+ # parser.add_argument('--step_size', type=int, default=5, help='step_size')
304
+ # parser.add_argument('--num_workers', type=int, default=16, help='num_workers')
305
+ # parser.add_argument('--lr_decay_rate',type=float, default=0.95, help='lr_decay_rate')
306
+ # parser.add_argument('--lr',type=float, default=3e-4, help='lr_rate')
307
+ # # parser.add_argument('--klf_factor', type=float, default=3e-3, help='klf_factor')
308
+ # # parser.add_argument('--klt_factor', type=float, default=5, help='klt_factor')
309
+ # parser.add_argument('--klf_factor', type=float, default=3e-4, help='klf_factor') # Changed for the Fine-tuned Version
310
+ # parser.add_argument('--klt_factor', type=float, default=3e-3, help='klt_factor')
311
+ # parser.add_argument('--rec_factor', type=float, default=1, help='rec_factor')
312
+ # parser.add_argument('--vq_factor', type=float, default=1000, help='vq_factor')
313
+ # parser.add_argument('--zf_vq_factor', type=float, default=1000, help='vq_factor')
314
+ # parser.add_argument('--klf_std', type=float, default=0.5, help='klf_std')
315
+ # parser.add_argument('--rec_std', type=float, default=0.04, help='rec_std')
316
+ # parser.add_argument('--clip', type=float, default=1, help='rec_std')
317
+ # parser.add_argument('--phoneme_factor', type=float, default=1, help='phoneme_factor')
318
+ # parser.add_argument('--r_vq_factor', type=float, default=10, help='r_vq_factor')
319
+ # parser.add_argument('--compute_speaker_eer', dest='compute_speaker_eer', action='store_true', help='ASV EER')
320
+ # parser.add_argument('--eval_phoneme', dest='eval_phoneme', action='store_true', help='ASV EER')
321
+ # parser.add_argument('--num_eval', type=int, default=20, help='num of segments for eval')
322
+ # parser.add_argument('--batch_size', type=int, default=256, help='batch_size')
323
+ # parser.add_argument('--num_phonemes', type=int, default=100, help='num_phonemes')
324
+ # parser.add_argument('--with_phoneme', dest='with_phoneme', action='store_true', help='')
325
+ # parser.add_argument("--conversion", action='store_true', help='for conversion text')
326
+ # parser.add_argument("--conversion2", action='store_true', help='for conversion text')
327
+ # parser.add_argument("--conversion3", action='store_true', help='for conversion text')
328
+ # parser.add_argument("--mel2npy", action='store_true', help='mel2npy')
329
+ # parser.add_argument("--unconditional", action='store_true', help='unconditional')
330
+ # parser.add_argument('--zt_norm_mean', action='store_true', help='instancenorm1d on zt prior and post')
331
+ # parser.add_argument('--zf_norm_mean', action='store_true', help='instancenorm1d on zf prior and post')
332
+ # parser.add_argument('--freeze_encoder', action='store_true', help='if or not to freeze encoder')
333
+ # parser.add_argument('--freeze_decoder', action='store_true', help='if or not to freeze decoder')
334
+ # parser.add_argument("--sample_rate",type=int, default=16000, help='16000 or 48000')
335
+ # parser.add_argument('--noise_path', type=str, default='datasets/noise_list.scp', help='nosie invariant')
336
+ # parser.add_argument('--wav_aug_train', action='store_true', help='with data augmentation')
337
+ # parser.add_argument('--spec_aug_train', action='store_true', help='with data augmentation')
338
+ # parser.add_argument('--noise_train', action='store_true', help='noise')
339
+ # parser.add_argument('--triphn', action='store_true', help='with triphn')
340
+ # parser.add_argument('--train_hifigan', action='store_true', help='train hifigan')
341
+ # parser.add_argument("--prior_alignment", action='store_true', help='')
342
+ # parser.add_argument("--zf_vq", action='store_true', help='')
343
+ # parser.add_argument("--vq_prior_independent", action='store_true', help='')
344
+ # parser.add_argument("--vq_prior_regressive", action='store_true', help='')
345
+ # parser.add_argument("--vq_prior_pseudo", action='store_true', help='')
346
+ # parser.add_argument("--vq_size_zt",type=int, default=200, help='')
347
+ # parser.add_argument("--vq_size_zf",type=int, default=200, help='')
348
+ # parser.add_argument("--ignore_index",type=int, default=0, help='')
349
+ # parser.add_argument("--hidden_dim",type=int, default=256, help='')
350
+
351
+ # parser.add_argument("--share_encoder", type=str, default='cnn', help='')
352
+ # parser.add_argument("--share_decoder", type=str, default='cnn_lstm', help='cnn_lstm, cnn_transformer')
353
+ # parser.add_argument("--zt_encoder", type=str, default='lstm', help='lstm, conformer_encoder, transformer_encoder')
354
+ # parser.add_argument("--zf_encoder", type=str, default='lstm', help='lstm, transformer_encoder, ecapa_tdnn')
355
+ # parser.add_argument("--zt_prior_model", type=str, default='lstm', help='lstm, vqvae, transformer')
356
+ # parser.add_argument("--prior_signal", type=str, default='None', help='alignment_triphn, alignment_mono, melspec_pseudo, wavlm_pseudo, vq_embeds, vq_pseudo')
357
+ # parser.add_argument("--multi_scale_add", action='store_true', help='')
358
+ # parser.add_argument("--multi_scale_cat", action='store_true', help='')
359
+ # parser.add_argument("--num_scales",type=int, default=1, help='')
360
+
361
+ # parser.add_argument("--kmeans_num_clusters",type=int, default=50, help='')
362
+ # parser.add_argument("--wavlm_dim", type=int, default=768, help='')
363
+
364
+ # parser.add_argument("--ema_zt", action='store_true', help='')
365
+ # parser.add_argument("--ema_zf", action='store_true', help='')
366
+
367
+ # parser.add_argument("--r_vqvae", action='store_true', help='')
368
+ # parser.add_argument("--masked_mel", action='store_true', help='')
369
+
370
+ # parser.add_argument("--rec_noise", action='store_true', help='')
371
+ # parser.add_argument("--rec_mask", action='store_true', help='')
372
+
373
+ # parser.add_argument("--mel_classification", action='store_true', help='')
374
+ # parser.add_argument("--test_script", action='store_true', help='')
375
+
376
+ # parser.add_argument("--no_klt", action='store_true', help='')
377
+
378
+ # parser.add_argument("--zt_prior_ce_r_vq", action='store_true', help='')
379
+ # parser.add_argument('--zt_prior_ce_r_vq_factor', type=float, default=1000, help='factor')
380
+
381
+ # parser.add_argument("--zt_post_ce_r_vq", action='store_true', help='')
382
+
383
+
384
+ # parser.add_argument("--zt_prior_ce_kmeans", action='store_true', help='')
385
+ # parser.add_argument('--zt_prior_ce_kmeans_factor', type=float, default=1000, help='factor')
386
+
387
+ # parser.add_argument("--zt_post_ce_kmeans", action='store_true', help='')
388
+ # parser.add_argument('--zt_post_ce_kmeans_factor', type=float, default=10, help='factor')
389
+
390
+
391
+ # parser.add_argument("--zt_prior_ce_alignment", action='store_true', help='')
392
+ # parser.add_argument('--zt_prior_ce_alignment_factor', type=float, default=1000, help='factor')
393
+
394
+ # parser.add_argument("--prior_type", type=str, default='None', help='normal, condition, lm')
395
+ # parser.add_argument("--prior_embedding", type=str, default='one-hot', help='one-hot, embedding')
396
+ # parser.add_argument("--prior_mask", action='store_true', help='')
397
+
398
+ # parser.add_argument("--wavlm", action='store_true', help='')
399
+ # parser.add_argument("--wavlm_type", type=str, default='base', help='')
400
+
401
+
402
+ # parser.add_argument("--tts_phn_wav_path", type=str, default='', help='')
403
+
404
+ # parser.add_argument("--sr", type=str, default="16000", help='')
405
+
406
+ # parser.add_argument("--text", type=str, default="your tts", help='')
407
+ # parser.add_argument("--tts_align", action='store_true', help='')
408
+ # parser.add_argument("--tts_wavlm", action='store_true', help='')
409
+ # parser.add_argument("--tts", action='store_true', help='')
410
+ # parser.add_argument("--tts_config", type=str, default="conf/LibriTTS/preprocess.yaml", help='')
411
+ # parser.add_argument("--tts_target_wav_path", type=str, default='', help='')
412
+ # parser.add_argument("--speed", type=float, default='1.0', help='')
413
+
414
+ # parser.add_argument("--train_mapping", action='store_true', help='')
415
+ # parser.add_argument("--mapping_encoder", type=str, default='lstm', help='')
416
+ # parser.add_argument("--mapping_model_path", type=str, default='lstm', help='')
417
+ # parser.add_argument("--mask_mapping", action='store_true', help='')
418
+ # parser.add_argument("--mask_mapping_factor", type=float, default=1, help='')
419
+ # parser.add_argument("--l1_mapping_factor", type=float, default=1, help='')
420
+ # parser.add_argument("--mapping_ratio", type=float, default=1.0, help='')
421
+
422
+ # parser.add_argument("--condition2", action='store_true', help='')
423
+
424
+ # args = parser.parse_args()
425
+ # return update_args(**vars(args))
426
+
427
+
428
+ # if __name__ == "__main__":
429
+ # # torch.cuda.empty_cache()
430
+ # main()
audio_diffusion_attacks_forhf/src/.DS_Store ADDED
Binary file (6.15 kB). View file
 
audio_diffusion_attacks_forhf/src/__pycache__/losses.cpython-310.pyc ADDED
Binary file (10.2 kB). View file
 
audio_diffusion_attacks_forhf/src/__pycache__/music_gen.cpython-310.pyc ADDED
Binary file (3.49 kB). View file
 
audio_diffusion_attacks_forhf/src/__pycache__/test_encoder_attack.cpython-310.pyc ADDED
Binary file (5.47 kB). View file
 
audio_diffusion_attacks_forhf/src/balancer.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import typing as tp
8
+
9
+ import flashy
10
+ import torch
11
+ from torch import autograd
12
+
13
+
14
+ class Balancer:
15
+ """Loss balancer.
16
+
17
+ The loss balancer combines losses together to compute gradients for the backward.
18
+ Given `y = f(...)`, and a number of losses `l1(y, ...)`, `l2(y, ...)`, with `...`
19
+ not having any dependence on `f`, the balancer can efficiently normalize the partial gradients
20
+ `d l1 / d y`, `d l2 / dy` before summing them in order to achieve a desired ratio between
21
+ the losses. For instance if `weights = {'l1': 2, 'l2': 1}`, 66% of the gradient
22
+ going into `f(...)` will come from `l1` on average, and 33% from `l2`. This allows for an easy
23
+ interpration of the weights even if the intrisic scale of `l1`, `l2` ... is unknown.
24
+
25
+ Noting `g1 = d l1 / dy`, etc., the balanced gradient `G` will be
26
+ (with `avg` an exponential moving average over the updates),
27
+
28
+ G = sum_i total_norm * g_i / avg(||g_i||) * w_i / sum(w_i)
29
+
30
+ If `balance_grads` is False, this is deactivated, and instead the gradient will just be the
31
+ standard sum of the partial gradients with the given weights.
32
+
33
+ A call to the backward method of the balancer will compute the the partial gradients,
34
+ combining all the losses and potentially rescaling the gradients,
35
+ which can help stabilize the training and reason about multiple losses with varying scales.
36
+ The obtained gradient with respect to `y` is then back-propagated to `f(...)`.
37
+
38
+ Expected usage:
39
+
40
+ weights = {'loss_a': 1, 'loss_b': 4}
41
+ balancer = Balancer(weights, ...)
42
+ losses: dict = {}
43
+ losses['loss_a'] = compute_loss_a(x, y)
44
+ losses['loss_b'] = compute_loss_b(x, y)
45
+ if model.training():
46
+ effective_loss = balancer.backward(losses, x)
47
+
48
+ Args:
49
+ weights (dict[str, float]): Weight coefficient for each loss. The balancer expect the losses keys
50
+ from the backward method to match the weights keys to assign weight to each of the provided loss.
51
+ balance_grads (bool): Whether to rescale gradients so that weights reflect the fraction of the
52
+ overall gradient, rather than a constant multiplier.
53
+ total_norm (float): Reference norm when rescaling gradients, ignored otherwise.
54
+ emay_decay (float): EMA decay for averaging the norms.
55
+ per_batch_item (bool): Whether to compute the averaged norm per batch item or not. This only holds
56
+ when rescaling the gradients.
57
+ epsilon (float): Epsilon value for numerical stability.
58
+ monitor (bool): If True, stores in `self.metrics` the relative ratio between the norm of the gradients
59
+ coming from each loss, when calling `backward()`.
60
+ """
61
+ def __init__(self, weights: tp.Dict[str, float], balance_grads: bool = True, total_norm: float = 1.,
62
+ ema_decay: float = 0.999, per_batch_item: bool = True, epsilon: float = 1e-12,
63
+ monitor: bool = False):
64
+ self.weights = weights
65
+ self.per_batch_item = per_batch_item
66
+ self.total_norm = total_norm or 1.
67
+ self.averager = flashy.averager(ema_decay or 1.)
68
+ self.epsilon = epsilon
69
+ self.monitor = monitor
70
+ self.balance_grads = balance_grads
71
+ self._metrics: tp.Dict[str, tp.Any] = {}
72
+
73
+ @property
74
+ def metrics(self):
75
+ return self._metrics
76
+
77
+ def backward(self, losses: tp.Dict[str, torch.Tensor], input: torch.Tensor) -> torch.Tensor:
78
+ """Compute the backward and return the effective train loss, e.g. the loss obtained from
79
+ computing the effective weights. If `balance_grads` is True, the effective weights
80
+ are the one that needs to be applied to each gradient to respect the desired relative
81
+ scale of gradients coming from each loss.
82
+
83
+ Args:
84
+ losses (Dict[str, torch.Tensor]): dictionary with the same keys as `self.weights`.
85
+ input (torch.Tensor): the input of the losses, typically the output of the model.
86
+ This should be the single point of dependence between the losses
87
+ and the model being trained.
88
+ """
89
+ norms = {}
90
+ grads = {}
91
+ for name, loss in losses.items():
92
+ # Compute partial derivative of the less with respect to the input.
93
+ grad, = autograd.grad(loss, [input], retain_graph=True)
94
+ if self.per_batch_item:
95
+ # We do not average the gradient over the batch dimension.
96
+ dims = tuple(range(1, grad.dim()))
97
+ norm = grad.norm(dim=dims, p=2).mean()
98
+ else:
99
+ norm = grad.norm(p=2)
100
+ norms[name] = norm
101
+ grads[name] = grad
102
+
103
+ count = 1
104
+ if self.per_batch_item:
105
+ count = len(grad)
106
+ # Average norms across workers. Theoretically we should average the
107
+ # squared norm, then take the sqrt, but it worked fine like that.
108
+ avg_norms = flashy.distrib.average_metrics(self.averager(norms), count)
109
+ # We approximate the total norm of the gradient as the sums of the norms.
110
+ # Obviously this can be very incorrect if all gradients are aligned, but it works fine.
111
+ total = sum(avg_norms.values())
112
+
113
+ self._metrics = {}
114
+ if self.monitor:
115
+ # Store the ratio of the total gradient represented by each loss.
116
+ for k, v in avg_norms.items():
117
+ self._metrics[f'ratio_{k}'] = v / total
118
+
119
+ total_weights = sum([self.weights[k] for k in avg_norms])
120
+ assert total_weights > 0.
121
+ desired_ratios = {k: w / total_weights for k, w in self.weights.items()}
122
+
123
+ out_grad = torch.zeros_like(input)
124
+ effective_loss = torch.tensor(0., device=input.device, dtype=input.dtype)
125
+ for name, avg_norm in avg_norms.items():
126
+ if self.balance_grads:
127
+ # g_balanced = g / avg(||g||) * total_norm * desired_ratio
128
+ scale = desired_ratios[name] * self.total_norm / (self.epsilon + avg_norm)
129
+ else:
130
+ # We just do regular weighted sum of the gradients.
131
+ scale = self.weights[name]
132
+ out_grad.add_(grads[name], alpha=scale)
133
+ effective_loss += scale * losses[name].detach()
134
+ # Send the computed partial derivative with respect to the output of the model to the model.
135
+ input.backward(out_grad)
136
+ return effective_loss
137
+
audio_diffusion_attacks_forhf/src/losses.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/loss.py
2
+
3
+ import typing
4
+ from typing import List
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from audiotools import AudioSignal
9
+ from audiotools import STFTParams
10
+ from torch import nn
11
+
12
+
13
+ class L1Loss(nn.L1Loss):
14
+ """L1 Loss between AudioSignals. Defaults
15
+ to comparing ``audio_data``, but any
16
+ attribute of an AudioSignal can be used.
17
+
18
+ Parameters
19
+ ----------
20
+ attribute : str, optional
21
+ Attribute of signal to compare, defaults to ``audio_data``.
22
+ weight : float, optional
23
+ Weight of this loss, defaults to 1.0.
24
+
25
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
26
+ """
27
+
28
+ def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
29
+ self.attribute = attribute
30
+ self.weight = weight
31
+ super().__init__(**kwargs)
32
+
33
+ def forward(self, x: AudioSignal, y: AudioSignal):
34
+ """
35
+ Parameters
36
+ ----------
37
+ x : AudioSignal
38
+ Estimate AudioSignal
39
+ y : AudioSignal
40
+ Reference AudioSignal
41
+
42
+ Returns
43
+ -------
44
+ torch.Tensor
45
+ L1 loss between AudioSignal attributes.
46
+ """
47
+ if isinstance(x, AudioSignal):
48
+ x = getattr(x, self.attribute)
49
+ y = getattr(y, self.attribute)
50
+ return super().forward(x, y)
51
+
52
+
53
+ class SISDRLoss(nn.Module):
54
+ """
55
+ Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
56
+ of estimated and reference audio signals or aligned features.
57
+
58
+ Parameters
59
+ ----------
60
+ scaling : int, optional
61
+ Whether to use scale-invariant (True) or
62
+ signal-to-noise ratio (False), by default True
63
+ reduction : str, optional
64
+ How to reduce across the batch (either 'mean',
65
+ 'sum', or none).], by default ' mean'
66
+ zero_mean : int, optional
67
+ Zero mean the references and estimates before
68
+ computing the loss, by default True
69
+ clip_min : int, optional
70
+ The minimum possible loss value. Helps network
71
+ to not focus on making already good examples better, by default None
72
+ weight : float, optional
73
+ Weight of this loss, defaults to 1.0.
74
+
75
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
76
+ """
77
+
78
+ def __init__(
79
+ self,
80
+ scaling: int = True,
81
+ reduction: str = "mean",
82
+ zero_mean: int = True,
83
+ clip_min: int = None,
84
+ weight: float = 1.0,
85
+ ):
86
+ self.scaling = scaling
87
+ self.reduction = reduction
88
+ self.zero_mean = zero_mean
89
+ self.clip_min = clip_min
90
+ self.weight = weight
91
+ super().__init__()
92
+
93
+ def forward(self, x: AudioSignal, y: AudioSignal):
94
+ eps = 1e-8
95
+ # nb, nc, nt
96
+ if isinstance(x, AudioSignal):
97
+ references = x.audio_data
98
+ estimates = y.audio_data
99
+ else:
100
+ references = x
101
+ estimates = y
102
+
103
+ nb = references.shape[0]
104
+ references = references.reshape(nb, 1, -1).permute(0, 2, 1)
105
+ estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
106
+
107
+ # samples now on axis 1
108
+ if self.zero_mean:
109
+ mean_reference = references.mean(dim=1, keepdim=True)
110
+ mean_estimate = estimates.mean(dim=1, keepdim=True)
111
+ else:
112
+ mean_reference = 0
113
+ mean_estimate = 0
114
+
115
+ _references = references - mean_reference
116
+ _estimates = estimates - mean_estimate
117
+
118
+ references_projection = (_references**2).sum(dim=-2) + eps
119
+ references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
120
+
121
+ scale = (
122
+ (references_on_estimates / references_projection).unsqueeze(1)
123
+ if self.scaling
124
+ else 1
125
+ )
126
+
127
+ e_true = scale * _references
128
+ e_res = _estimates - e_true
129
+
130
+ signal = (e_true**2).sum(dim=1)
131
+ noise = (e_res**2).sum(dim=1)
132
+ sdr = -10 * torch.log10(signal / noise + eps)
133
+
134
+ if self.clip_min is not None:
135
+ sdr = torch.clamp(sdr, min=self.clip_min)
136
+
137
+ if self.reduction == "mean":
138
+ sdr = sdr.mean()
139
+ elif self.reduction == "sum":
140
+ sdr = sdr.sum()
141
+ return sdr
142
+
143
+
144
+ class MultiScaleSTFTLoss(nn.Module):
145
+ """Computes the multi-scale STFT loss from [1].
146
+
147
+ Parameters
148
+ ----------
149
+ window_lengths : List[int], optional
150
+ Length of each window of each STFT, by default [2048, 512]
151
+ loss_fn : typing.Callable, optional
152
+ How to compare each loss, by default nn.L1Loss()
153
+ clamp_eps : float, optional
154
+ Clamp on the log magnitude, below, by default 1e-5
155
+ mag_weight : float, optional
156
+ Weight of raw magnitude portion of loss, by default 1.0
157
+ log_weight : float, optional
158
+ Weight of log magnitude portion of loss, by default 1.0
159
+ pow : float, optional
160
+ Power to raise magnitude to before taking log, by default 2.0
161
+ weight : float, optional
162
+ Weight of this loss, by default 1.0
163
+ match_stride : bool, optional
164
+ Whether to match the stride of convolutional layers, by default False
165
+
166
+ References
167
+ ----------
168
+
169
+ 1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
170
+ "DDSP: Differentiable Digital Signal Processing."
171
+ International Conference on Learning Representations. 2019.
172
+
173
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
174
+ """
175
+
176
+ def __init__(
177
+ self,
178
+ window_lengths: List[int] = [2048, 512],
179
+ loss_fn: typing.Callable = nn.L1Loss(),
180
+ clamp_eps: float = 1e-5,
181
+ mag_weight: float = 1.0,
182
+ log_weight: float = 1.0,
183
+ pow: float = 2.0,
184
+ weight: float = 1.0,
185
+ match_stride: bool = False,
186
+ window_type: str = None,
187
+ ):
188
+ super().__init__()
189
+ self.stft_params = [
190
+ STFTParams(
191
+ window_length=w,
192
+ hop_length=w // 4,
193
+ match_stride=match_stride,
194
+ window_type=window_type,
195
+ )
196
+ for w in window_lengths
197
+ ]
198
+ self.loss_fn = loss_fn
199
+ self.log_weight = log_weight
200
+ self.mag_weight = mag_weight
201
+ self.clamp_eps = clamp_eps
202
+ self.weight = weight
203
+ self.pow = pow
204
+
205
+ def forward(self, x: AudioSignal, y: AudioSignal):
206
+ """Computes multi-scale STFT between an estimate and a reference
207
+ signal.
208
+
209
+ Parameters
210
+ ----------
211
+ x : AudioSignal
212
+ Estimate signal
213
+ y : AudioSignal
214
+ Reference signal
215
+
216
+ Returns
217
+ -------
218
+ torch.Tensor
219
+ Multi-scale STFT loss.
220
+ """
221
+ loss = 0.0
222
+ for s in self.stft_params:
223
+ x.stft(s.window_length, s.hop_length, s.window_type)
224
+ y.stft(s.window_length, s.hop_length, s.window_type)
225
+ loss += self.log_weight * self.loss_fn(
226
+ x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
227
+ y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
228
+ )
229
+ loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
230
+ return loss
231
+
232
+
233
+ class MelSpectrogramLoss(nn.Module):
234
+ """Compute distance between mel spectrograms. Can be used
235
+ in a multi-scale way.
236
+
237
+ Parameters
238
+ ----------
239
+ n_mels : List[int]
240
+ Number of mels per STFT, by default [150, 80],
241
+ window_lengths : List[int], optional
242
+ Length of each window of each STFT, by default [2048, 512]
243
+ loss_fn : typing.Callable, optional
244
+ How to compare each loss, by default nn.L1Loss()
245
+ clamp_eps : float, optional
246
+ Clamp on the log magnitude, below, by default 1e-5
247
+ mag_weight : float, optional
248
+ Weight of raw magnitude portion of loss, by default 1.0
249
+ log_weight : float, optional
250
+ Weight of log magnitude portion of loss, by default 1.0
251
+ pow : float, optional
252
+ Power to raise magnitude to before taking log, by default 2.0
253
+ weight : float, optional
254
+ Weight of this loss, by default 1.0
255
+ match_stride : bool, optional
256
+ Whether to match the stride of convolutional layers, by default False
257
+
258
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
259
+ """
260
+
261
+ def __init__(
262
+ self,
263
+ n_mels: List[int] = [150, 80],
264
+ window_lengths: List[int] = [2048, 512],
265
+ loss_fn: typing.Callable = nn.MSELoss(),
266
+ clamp_eps: float = 1e-5,
267
+ mag_weight: float = 1.0,
268
+ log_weight: float = 1.0,
269
+ pow: float = 2.0,
270
+ weight: float = 1.0,
271
+ match_stride: bool = False,
272
+ mel_fmin: List[float] = [0.0, 0.0],
273
+ mel_fmax: List[float] = [None, None],
274
+ window_type: str = None,
275
+ ):
276
+ super().__init__()
277
+ self.stft_params = [
278
+ STFTParams(
279
+ window_length=w,
280
+ hop_length=w // 4,
281
+ match_stride=match_stride,
282
+ window_type=window_type,
283
+ )
284
+ for w in window_lengths
285
+ ]
286
+ self.n_mels = n_mels
287
+ self.loss_fn = loss_fn
288
+ self.clamp_eps = clamp_eps
289
+ self.log_weight = log_weight
290
+ self.mag_weight = mag_weight
291
+ self.weight = weight
292
+ self.mel_fmin = mel_fmin
293
+ self.mel_fmax = mel_fmax
294
+ self.pow = pow
295
+
296
+ def forward(self, x: AudioSignal, y: AudioSignal):
297
+ """Computes mel loss between an estimate and a reference
298
+ signal.
299
+
300
+ Parameters
301
+ ----------
302
+ x : AudioSignal
303
+ Estimate signal
304
+ y : AudioSignal
305
+ Reference signal
306
+
307
+ Returns
308
+ -------
309
+ torch.Tensor
310
+ Mel loss.
311
+ """
312
+ loss = 0.0
313
+ for n_mels, fmin, fmax, s in zip(
314
+ self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
315
+ ):
316
+ kwargs = {
317
+ "window_length": s.window_length,
318
+ "hop_length": s.hop_length,
319
+ "window_type": s.window_type,
320
+ }
321
+ x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
322
+ y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
323
+
324
+ loss += self.log_weight * self.loss_fn(
325
+ x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
326
+ y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
327
+ )
328
+ loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
329
+ return loss
audio_diffusion_attacks_forhf/src/music_gen.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoProcessor, MusicgenForConditionalGeneration
2
+ #Andy removed: from datasets import load_dataset
3
+ import torchaudio
4
+ import torch
5
+ #Andy edited: import losses
6
+ import audio_diffusion_attacks_forhf.src.losses
7
+ from audiotools import AudioSignal
8
+
9
+ class MusicGenEval:
10
+
11
+ def __init__(self, input_sample_rate, audio_steps):
12
+ model_name="facebook/musicgen-stereo-small"
13
+ self.processor = AutoProcessor.from_pretrained(model_name)
14
+ self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
15
+ self.model=self.model.to(device='cuda')
16
+ self.input_sample_rate=input_sample_rate
17
+ self.audio_steps=audio_steps
18
+ self.mel_loss = losses.MelSpectrogramLoss(n_mels=[5, 10, 20, 40, 80, 160, 320],
19
+ window_lengths=[32, 64, 128, 256, 512, 1024, 2048],
20
+ mel_fmin=[0, 0, 0, 0, 0, 0, 0],
21
+ pow=1.0,
22
+ clamp_eps=1.0e-5,
23
+ mag_weight=0.0)
24
+
25
+ def eval(self, original_audio, protected_audio):
26
+ original_audio=original_audio[:, :, :self.audio_steps]
27
+ protected_audio=protected_audio[:, :, :self.audio_steps]
28
+ input_len=original_audio.shape[-1]
29
+
30
+ unprotected_gen=self.generate_audio(original_audio)[0].to(device='cuda')
31
+ protected_gen=self.generate_audio(protected_audio)[0].to(device='cuda')
32
+
33
+ eval_dict={}
34
+ # Difference between original and unprotected gen
35
+ eval_dict["original_unprotectedgen_l1"]=torch.mean(torch.abs(original_audio-unprotected_gen[:, :input_len]))
36
+ eval_dict["original_unprotectedgen_mel"]=self.mel_loss(AudioSignal(original_audio, self.input_sample_rate), AudioSignal(unprotected_gen[:, :input_len], self.input_sample_rate))
37
+ # Difference between original and protected gen
38
+ eval_dict["original_protectedgen_l1"]=torch.mean(torch.abs(original_audio-protected_gen[:, :input_len]))
39
+ eval_dict["original_protectedgen_mel"]=self.mel_loss(AudioSignal(original_audio, self.input_sample_rate), AudioSignal(protected_gen[:, :input_len], self.input_sample_rate))
40
+ # Difference between protected and protected gen
41
+ eval_dict["protected_protectedgen_l1"]=torch.mean(torch.abs(protected_audio-protected_gen[:, :input_len]))
42
+ eval_dict["protected_protectedgen_mel"]=self.mel_loss(AudioSignal(protected_audio, self.input_sample_rate), AudioSignal(protected_gen[:, :input_len], self.input_sample_rate))
43
+ # Difference between unprotected gen and protected gen
44
+ eval_dict["protectedgen_unprotectedgen_l1"]=torch.mean(torch.abs(protected_gen-unprotected_gen))
45
+ eval_dict["protectedgen_unprotectedgen_mel"]=self.mel_loss(AudioSignal(protected_gen, self.input_sample_rate), AudioSignal(unprotected_gen, self.input_sample_rate))
46
+ return eval_dict, unprotected_gen, protected_gen
47
+
48
+ def generate_audio(self, audio):
49
+ torch.manual_seed(0)
50
+
51
+ transform = torchaudio.transforms.Resample(self.input_sample_rate, 32000).to(device='cuda')
52
+ waveform=transform(audio[0]).detach().cpu()
53
+ # waveform.clamp_(0,1)
54
+ a=torch.min(waveform)
55
+ b=torch.max(waveform)
56
+ c=waveform.isnan().any()
57
+ # sample = processor(raw_audio=waveform, sampling_rate=48000, return_tensors="pt")
58
+
59
+ inputs = self.processor(
60
+ audio=waveform,
61
+ sampling_rate=32000,
62
+ text=["music"],
63
+ padding=True,
64
+ return_tensors="pt",
65
+ )
66
+ for d in inputs.data:
67
+ inputs.data[d]=inputs.data[d].to(device='cuda')
68
+ audio_values = self.model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=1024)
69
+
70
+ transform = torchaudio.transforms.Resample(32000, self.input_sample_rate).to(device='cuda')
71
+ audio_values=transform(audio_values)
72
+ return audio_values
73
+
74
+ model_name="facebook/musicgen-stereo-small"
75
+ processor = AutoProcessor.from_pretrained(model_name)
76
+ model = MusicgenForConditionalGeneration.from_pretrained(model_name).to(device='cuda')
77
+
78
+ '''Andy commented:
79
+ song_name="Texas Sun"
80
+ waveform, sample_rate = torchaudio.load(f"test_audio/{song_name}.mp3")
81
+ waveform=waveform[:, :500000]
82
+ torch.manual_seed(0)
83
+ transform = torchaudio.transforms.Resample(sample_rate, 32000)
84
+ waveform=transform(waveform)
85
+ # sample = processor(raw_audio=waveform, sampling_rate=48000, return_tensors="pt")
86
+
87
+ inputs = processor(
88
+ audio=waveform,
89
+ sampling_rate=32000,
90
+ text=["music"],
91
+ padding=True,
92
+ return_tensors="pt",
93
+ )
94
+ for d in inputs.data:
95
+ inputs.data[d]=inputs.data[d].to(device='cuda')
96
+ audio_values = model.generate(**inputs, do_sample=False, guidance_scale=3, max_new_tokens=512, top_k=0, top_p=250)
97
+ torchaudio.save(f"test_audio/perturbed/{model_name[9:]}_{song_name}.mp3", audio_values.detach().cpu()[0], 32000)
98
+
99
+ u=0
100
+ '''
audio_diffusion_attacks_forhf/src/speech_inference.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from TTS.api import TTS
3
+ #Andy edited: import losses
4
+ import audio_diffusion_attacks_forhf.src.losses
5
+ from audiotools import AudioSignal
6
+ import numpy as np
7
+ import torchaudio
8
+ import random
9
+ import string
10
+ import os
11
+
12
+ class XTTS_Eval:
13
+
14
+ def __init__(self, input_sample_rate, text="The quick brown fox jumps over the lazy dog."):
15
+ self.model = TTS("tts_models/multilingual/multi-dataset/xtts_v2")
16
+ self.model=self.model.to(device='cuda')
17
+ self.text=text
18
+ self.input_sample_rate=input_sample_rate
19
+ self.mel_loss = losses.MelSpectrogramLoss(n_mels=[5, 10, 20, 40, 80, 160, 320],
20
+ window_lengths=[32, 64, 128, 256, 512, 1024, 2048],
21
+ mel_fmin=[0, 0, 0, 0, 0, 0, 0],
22
+ pow=1.0,
23
+ clamp_eps=1.0e-5,
24
+ mag_weight=0.0)
25
+
26
+ def eval(self, original_audio, protected_audio):
27
+
28
+ original_audio=original_audio[0]
29
+ protected_audio=protected_audio[0]
30
+
31
+ unprotected_gen=self.generate_audio(original_audio).to(device='cuda')
32
+ protected_gen=self.generate_audio(protected_audio).to(device='cuda')
33
+
34
+ match_len=min(original_audio.shape[1], unprotected_gen.shape[1])
35
+ if original_audio.shape[1]<unprotected_gen.shape[1]:
36
+ s_unprotected_gen=unprotected_gen[:, :match_len]
37
+ s_protected_gen=unprotected_gen[:, :match_len]
38
+ s_original_audio=original_audio
39
+ s_protected_audio=protected_audio
40
+ else:
41
+ s_unprotected_gen=unprotected_gen
42
+ s_protected_gen=unprotected_gen
43
+ s_original_audio=original_audio[:, :match_len]
44
+ s_protected_audio=protected_audio[:, :match_len]
45
+
46
+ match_len=min(protected_gen.shape[1], unprotected_gen.shape[1])
47
+ protected_gen=protected_gen[:,:match_len]
48
+ unprotected_gen=unprotected_gen[:,:match_len]
49
+
50
+
51
+ eval_dict={}
52
+ # Difference between original and unprotected gen
53
+ eval_dict["original_unprotectedgen_l1"]=torch.mean(torch.abs(s_original_audio-s_unprotected_gen))
54
+ eval_dict["original_unprotectedgen_mel"]=self.mel_loss(AudioSignal(s_original_audio, self.input_sample_rate), AudioSignal(s_unprotected_gen, self.input_sample_rate))
55
+ # Difference between original and protected gen
56
+ eval_dict["original_protectedgen_l1"]=torch.mean(torch.abs(s_original_audio-s_protected_gen))
57
+ eval_dict["original_protectedgen_mel"]=self.mel_loss(AudioSignal(s_original_audio, self.input_sample_rate), AudioSignal(s_protected_gen, self.input_sample_rate))
58
+ # Difference between protected and protected gen
59
+ eval_dict["protected_protectedgen_l1"]=torch.mean(torch.abs(s_protected_audio-s_protected_gen))
60
+ eval_dict["protected_protectedgen_mel"]=self.mel_loss(AudioSignal(s_protected_audio, self.input_sample_rate), AudioSignal(s_protected_gen, self.input_sample_rate))
61
+ # Difference between unprotected gen and protected gen
62
+ eval_dict["protectedgen_unprotectedgen_l1"]=torch.mean(torch.abs(protected_gen-unprotected_gen))
63
+ eval_dict["protectedgen_unprotectedgen_mel"]=self.mel_loss(AudioSignal(protected_gen, self.input_sample_rate), AudioSignal(unprotected_gen, self.input_sample_rate))
64
+ return eval_dict, unprotected_gen, protected_gen
65
+
66
+ def generate_audio(self, audio):
67
+ random_str=''.join(random.choices(string.ascii_uppercase + string.digits, k=50))
68
+ torchaudio.save(f"test_audio/{random_str}.wav", torch.reshape(audio.detach().cpu(), (2, audio.shape[1])), self.input_sample_rate, format="wav")
69
+ torch.manual_seed(0)
70
+
71
+ wav = self.model.tts(text=self.text,
72
+ speaker_wav=f"test_audio/{random_str}.wav",
73
+ language="en")
74
+ os.remove(f"test_audio/{random_str}.wav")
75
+ wav=torch.from_numpy(np.array(wav))
76
+ stereo_wave=torch.zeros((2, wav.shape[0]))
77
+ stereo_wave[:,:]=wav
78
+
79
+ transform = torchaudio.transforms.Resample(24000, self.input_sample_rate)
80
+ stereo_wave=transform(stereo_wave)
81
+ return stereo_wave
82
+
83
+ # # Init TTS
84
+ # tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)
85
+ #
86
+ # # Run TTS
87
+ # # ❗ Since this model is multi-lingual voice cloning model, we must set the target speaker_wav and language
88
+ # # Text to speech list of amplitude values as output
89
+ # # wav = tts.tts(text="Hello world!", speaker_wav=, language="en")
90
+ # # Text to speech to a file
91
+ # tts.tts_to_file(text="Hello world!",
92
+ # speaker_wav="/media/willie/1caf5422-4135-4f2c-9619-c44041b51146/audio_data/DS_10283_3443/VCTK-Corpus-0.92/wav48_silence_trimmed/p227/p227_023_mic1.flac",
93
+ # language="en",
94
+ # file_path="/home/willie/eclipse-workspace/audio_diffusion_attacks/src/test_audio/speech/output.wav")
audio_diffusion_attacks_forhf/src/test_audio/.Il Sogno Del Marinaio - Nanos' Waltz.mp3.icloud ADDED
Binary file (192 Bytes). View file