cleanup2
Browse files- audiocraft/builders.py +1 -32
- audiocraft/conditioners.py +12 -226
- demo.py +15 -0
audiocraft/builders.py
CHANGED
@@ -137,25 +137,9 @@ def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> Cond
|
|
137 |
model_args = cond_cfg[model_type]
|
138 |
if model_type == 't5':
|
139 |
conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args)
|
140 |
-
elif model_type == 'lut':
|
141 |
-
conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args)
|
142 |
-
# elif model_type == 'chroma_stem':
|
143 |
-
# conditioners[str(cond)] = ChromaStemConditioner(
|
144 |
-
# output_dim=output_dim,
|
145 |
-
# duration=duration,
|
146 |
-
# device=device,
|
147 |
-
# **model_args
|
148 |
-
# )
|
149 |
-
elif model_type == 'clap':
|
150 |
-
conditioners[str(cond)] = CLAPEmbeddingConditioner(
|
151 |
-
output_dim=output_dim,
|
152 |
-
device=device,
|
153 |
-
**model_args
|
154 |
-
)
|
155 |
else:
|
156 |
raise ValueError(f"Unrecognized conditioning model: {model_type}")
|
157 |
conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args)
|
158 |
-
print(' COND\n',conditioner)
|
159 |
return conditioner
|
160 |
|
161 |
|
@@ -229,22 +213,7 @@ def get_processor(cfg, sample_rate: int = 24000):
|
|
229 |
return sample_processor
|
230 |
|
231 |
|
232 |
-
|
233 |
-
"""Instantiate a debug LM to be used for unit tests."""
|
234 |
-
pattern = DelayedPatternProvider(n_q=4)
|
235 |
-
dim = 16
|
236 |
-
providers = {
|
237 |
-
'description': LUTConditioner(n_bins=128, dim=dim, output_dim=dim, tokenizer="whitespace"),
|
238 |
-
}
|
239 |
-
condition_provider = ConditioningProvider(providers)
|
240 |
-
fuser = ConditionFuser(
|
241 |
-
{'cross': ['description'], 'prepend': [],
|
242 |
-
'sum': [], 'input_interpolate': []})
|
243 |
-
lm = LMModel(
|
244 |
-
pattern, condition_provider, fuser,
|
245 |
-
n_q=4, card=400, dim=dim, num_heads=4, custom=True, num_layers=2,
|
246 |
-
cross_attention=True, causal=True)
|
247 |
-
return lm.to(device).eval()
|
248 |
|
249 |
|
250 |
def get_wrapped_compression_model(
|
|
|
137 |
model_args = cond_cfg[model_type]
|
138 |
if model_type == 't5':
|
139 |
conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
else:
|
141 |
raise ValueError(f"Unrecognized conditioning model: {model_type}")
|
142 |
conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args)
|
|
|
143 |
return conditioner
|
144 |
|
145 |
|
|
|
213 |
return sample_processor
|
214 |
|
215 |
|
216 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
|
218 |
|
219 |
def get_wrapped_compression_model(
|
audiocraft/conditioners.py
CHANGED
@@ -1,9 +1,3 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
from collections import defaultdict
|
8 |
from copy import deepcopy
|
9 |
from dataclasses import dataclass, field
|
@@ -16,7 +10,6 @@ import re
|
|
16 |
import typing as tp
|
17 |
import warnings
|
18 |
import soundfile
|
19 |
-
import einops
|
20 |
from num2words import num2words
|
21 |
import spacy
|
22 |
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
|
@@ -42,12 +35,7 @@ TextCondition = tp.Optional[str] # a text condition can be a string or None (if
|
|
42 |
ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
|
43 |
|
44 |
|
45 |
-
|
46 |
-
wav: torch.Tensor
|
47 |
-
length: torch.Tensor
|
48 |
-
sample_rate: tp.List[int]
|
49 |
-
path: tp.List[tp.Optional[str]] = []
|
50 |
-
seek_time: tp.List[tp.Optional[float]] = []
|
51 |
|
52 |
|
53 |
class JointEmbedCondition(tp.NamedTuple):
|
@@ -62,7 +50,7 @@ class JointEmbedCondition(tp.NamedTuple):
|
|
62 |
@dataclass
|
63 |
class ConditioningAttributes:
|
64 |
text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
|
65 |
-
wav: tp.Dict[str,
|
66 |
joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
|
67 |
|
68 |
def __getitem__(self, item):
|
@@ -107,67 +95,13 @@ class ConditioningAttributes:
|
|
107 |
|
108 |
|
109 |
|
110 |
-
def nullify_condition(condition: ConditionType, dim: int = 1):
|
111 |
-
"""Transform an input condition to a null condition.
|
112 |
-
The way it is done by converting it to a single zero vector similarly
|
113 |
-
to how it is done inside WhiteSpaceTokenizer and NoopTokenizer.
|
114 |
|
115 |
-
Args:
|
116 |
-
condition (ConditionType): A tuple of condition and mask (tuple[torch.Tensor, torch.Tensor])
|
117 |
-
dim (int): The dimension that will be truncated (should be the time dimension)
|
118 |
-
WARNING!: dim should not be the batch dimension!
|
119 |
-
Returns:
|
120 |
-
ConditionType: A tuple of null condition and mask
|
121 |
-
"""
|
122 |
-
assert dim != 0, "dim cannot be the batch dimension!"
|
123 |
-
assert isinstance(condition, tuple) and \
|
124 |
-
isinstance(condition[0], torch.Tensor) and \
|
125 |
-
isinstance(condition[1], torch.Tensor), "'nullify_condition' got an unexpected input type!"
|
126 |
-
cond, mask = condition
|
127 |
-
B = cond.shape[0]
|
128 |
-
last_dim = cond.dim() - 1
|
129 |
-
out = cond.transpose(dim, last_dim)
|
130 |
-
out = 0. * out[..., :1]
|
131 |
-
out = out.transpose(dim, last_dim)
|
132 |
-
mask = torch.zeros((B, 1), device=out.device).int()
|
133 |
-
assert cond.dim() == out.dim()
|
134 |
-
return out, mask
|
135 |
-
|
136 |
-
|
137 |
-
def nullify_wav(cond: WavCondition) -> WavCondition:
|
138 |
-
"""Transform a WavCondition to a nullified WavCondition.
|
139 |
-
It replaces the wav by a null tensor, forces its length to 0, and replaces metadata by dummy attributes.
|
140 |
|
141 |
-
Args:
|
142 |
-
cond (WavCondition): Wav condition with wav, tensor of shape [B, T].
|
143 |
-
Returns:
|
144 |
-
WavCondition: Nullified wav condition.
|
145 |
-
"""
|
146 |
-
null_wav, _ = nullify_condition((cond.wav, torch.zeros_like(cond.wav)), dim=cond.wav.dim() - 1)
|
147 |
-
return WavCondition(
|
148 |
-
wav=null_wav,
|
149 |
-
length=torch.tensor([0] * cond.wav.shape[0], device=cond.wav.device),
|
150 |
-
sample_rate=cond.sample_rate,
|
151 |
-
path=[None] * cond.wav.shape[0],
|
152 |
-
seek_time=[None] * cond.wav.shape[0],
|
153 |
-
)
|
154 |
|
155 |
|
156 |
-
def nullify_joint_embed(embed: JointEmbedCondition) -> JointEmbedCondition:
|
157 |
-
"""Nullify the joint embedding condition by replacing it by a null tensor, forcing its length to 0,
|
158 |
-
and replacing metadata by dummy attributes.
|
159 |
|
160 |
-
|
161 |
-
|
162 |
-
"""
|
163 |
-
null_wav, _ = nullify_condition((embed.wav, torch.zeros_like(embed.wav)), dim=embed.wav.dim() - 1)
|
164 |
-
return JointEmbedCondition(
|
165 |
-
wav=null_wav, text=[None] * len(embed.text),
|
166 |
-
length=torch.LongTensor([0]).to(embed.wav.device),
|
167 |
-
sample_rate=embed.sample_rate,
|
168 |
-
path=[None] * embed.wav.shape[0],
|
169 |
-
seek_time=[0] * embed.wav.shape[0],
|
170 |
-
)
|
171 |
|
172 |
|
173 |
class Tokenizer:
|
@@ -419,129 +353,7 @@ class T5Conditioner(TextConditioner):
|
|
419 |
|
420 |
|
421 |
|
422 |
-
def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str) -> ConditioningAttributes:
|
423 |
-
"""Utility function for nullifying an attribute inside an ConditioningAttributes object.
|
424 |
-
If the condition is of type "wav", then nullify it using `nullify_condition` function.
|
425 |
-
If the condition is of any other type, set its value to None.
|
426 |
-
Works in-place.
|
427 |
-
"""
|
428 |
-
if condition_type not in ['text', 'wav', 'joint_embed']:
|
429 |
-
raise ValueError(
|
430 |
-
"dropout_condition got an unexpected condition type!"
|
431 |
-
f" expected 'text', 'wav' or 'joint_embed' but got '{condition_type}'"
|
432 |
-
)
|
433 |
-
|
434 |
-
if condition not in getattr(sample, condition_type):
|
435 |
-
raise ValueError(
|
436 |
-
"dropout_condition received an unexpected condition!"
|
437 |
-
f" expected wav={sample.wav.keys()} and text={sample.text.keys()}"
|
438 |
-
f" but got '{condition}' of type '{condition_type}'!"
|
439 |
-
)
|
440 |
-
|
441 |
-
if condition_type == 'wav':
|
442 |
-
wav_cond = sample.wav[condition]
|
443 |
-
sample.wav[condition] = nullify_wav(wav_cond)
|
444 |
-
elif condition_type == 'joint_embed':
|
445 |
-
embed = sample.joint_embed[condition]
|
446 |
-
sample.joint_embed[condition] = nullify_joint_embed(embed)
|
447 |
-
else:
|
448 |
-
sample.text[condition] = None
|
449 |
-
|
450 |
-
return sample
|
451 |
-
|
452 |
-
|
453 |
-
class DropoutModule(nn.Module):
|
454 |
-
"""Base module for all dropout modules."""
|
455 |
-
def __init__(self, seed: int = 1234):
|
456 |
-
super().__init__()
|
457 |
-
self.rng = torch.Generator()
|
458 |
-
self.rng.manual_seed(seed)
|
459 |
-
|
460 |
-
|
461 |
-
class AttributeDropout(DropoutModule):
|
462 |
-
"""Dropout with a given probability per attribute.
|
463 |
-
This is different from the behavior of ClassifierFreeGuidanceDropout as this allows for attributes
|
464 |
-
to be dropped out separately. For example, "artist" can be dropped while "genre" remains.
|
465 |
-
This is in contrast to ClassifierFreeGuidanceDropout where if "artist" is dropped "genre"
|
466 |
-
must also be dropped.
|
467 |
-
|
468 |
-
Args:
|
469 |
-
p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example:
|
470 |
-
...
|
471 |
-
"genre": 0.1,
|
472 |
-
"artist": 0.5,
|
473 |
-
"wav": 0.25,
|
474 |
-
...
|
475 |
-
active_on_eval (bool, optional): Whether the dropout is active at eval. Default to False.
|
476 |
-
seed (int, optional): Random seed.
|
477 |
-
"""
|
478 |
-
def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eval: bool = False, seed: int = 1234):
|
479 |
-
super().__init__(seed=seed)
|
480 |
-
self.active_on_eval = active_on_eval
|
481 |
-
# construct dict that return the values from p otherwise 0
|
482 |
-
self.p = {}
|
483 |
-
for condition_type, probs in p.items():
|
484 |
-
self.p[condition_type] = defaultdict(lambda: 0, probs)
|
485 |
-
|
486 |
-
def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
|
487 |
-
"""
|
488 |
-
Args:
|
489 |
-
samples (list[ConditioningAttributes]): List of conditions.
|
490 |
-
Returns:
|
491 |
-
list[ConditioningAttributes]: List of conditions after certain attributes were set to None.
|
492 |
-
"""
|
493 |
-
if not self.training and not self.active_on_eval:
|
494 |
-
return samples
|
495 |
-
|
496 |
-
samples = deepcopy(samples)
|
497 |
-
for condition_type, ps in self.p.items(): # for condition types [text, wav]
|
498 |
-
for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre])
|
499 |
-
if torch.rand(1, generator=self.rng).item() < p:
|
500 |
-
for sample in samples:
|
501 |
-
dropout_condition(sample, condition_type, condition)
|
502 |
-
return samples
|
503 |
-
|
504 |
-
def __repr__(self):
|
505 |
-
return f"AttributeDropout({dict(self.p)})"
|
506 |
-
|
507 |
-
|
508 |
-
class ClassifierFreeGuidanceDropout(DropoutModule):
|
509 |
-
"""Classifier Free Guidance dropout.
|
510 |
-
All attributes are dropped with the same probability.
|
511 |
|
512 |
-
Args:
|
513 |
-
p (float): Probability to apply condition dropout during training.
|
514 |
-
seed (int): Random seed.
|
515 |
-
"""
|
516 |
-
def __init__(self, p: float, seed: int = 1234):
|
517 |
-
super().__init__(seed=seed)
|
518 |
-
self.p = p
|
519 |
-
|
520 |
-
def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
|
521 |
-
"""
|
522 |
-
Args:
|
523 |
-
samples (list[ConditioningAttributes]): List of conditions.
|
524 |
-
Returns:
|
525 |
-
list[ConditioningAttributes]: List of conditions after all attributes were set to None.
|
526 |
-
"""
|
527 |
-
if not self.training:
|
528 |
-
return samples
|
529 |
-
|
530 |
-
# decide on which attributes to drop in a batched fashion
|
531 |
-
drop = torch.rand(1, generator=self.rng).item() < self.p
|
532 |
-
if not drop:
|
533 |
-
return samples
|
534 |
-
|
535 |
-
# nullify conditions of all attributes
|
536 |
-
samples = deepcopy(samples)
|
537 |
-
for condition_type in ["wav", "text"]:
|
538 |
-
for sample in samples:
|
539 |
-
for condition in sample.attributes[condition_type]:
|
540 |
-
dropout_condition(sample, condition_type, condition)
|
541 |
-
return samples
|
542 |
-
|
543 |
-
def __repr__(self):
|
544 |
-
return f"ClassifierFreeGuidanceDropout(p={self.p})"
|
545 |
|
546 |
|
547 |
class ConditioningProvider(nn.Module):
|
@@ -696,43 +508,17 @@ class ConditionFuser(StreamingModule):
|
|
696 |
"""
|
697 |
B, T, _ = input.shape
|
698 |
|
699 |
-
if 'offsets' in self._streaming_state:
|
700 |
-
first_step = False
|
701 |
-
offsets = self._streaming_state['offsets']
|
702 |
-
else:
|
703 |
-
first_step = True
|
704 |
-
offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device)
|
705 |
|
706 |
-
|
707 |
-
|
708 |
-
|
|
|
709 |
cross_attention_output = None
|
710 |
for cond_type, (cond, cond_mask) in conditions.items():
|
711 |
-
|
712 |
-
|
713 |
-
|
714 |
-
|
715 |
-
cond = einops.rearrange(cond, "b t d -> b d t")
|
716 |
-
cond = F.interpolate(cond, size=input.shape[1])
|
717 |
-
input += einops.rearrange(cond, "b d t -> b t d")
|
718 |
-
elif op == 'prepend':
|
719 |
-
if first_step:
|
720 |
-
input = torch.cat([cond, input], dim=1)
|
721 |
-
elif op == 'cross':
|
722 |
-
if cross_attention_output is not None:
|
723 |
-
cross_attention_output = torch.cat([cross_attention_output, cond], dim=1)
|
724 |
-
else:
|
725 |
-
cross_attention_output = cond
|
726 |
-
else:
|
727 |
-
raise ValueError(f"unknown op ({op})")
|
728 |
-
|
729 |
-
if self.cross_attention_pos_emb and cross_attention_output is not None:
|
730 |
-
positions = torch.arange(
|
731 |
-
cross_attention_output.shape[1],
|
732 |
-
device=cross_attention_output.device
|
733 |
-
).view(1, -1, 1)
|
734 |
-
pos_emb = create_sin_embedding(positions, cross_attention_output.shape[-1])
|
735 |
-
cross_attention_output = cross_attention_output + self.cross_attention_pos_emb_scale * pos_emb
|
736 |
|
737 |
if self._is_streaming:
|
738 |
self._streaming_state['offsets'] = offsets + T
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from collections import defaultdict
|
2 |
from copy import deepcopy
|
3 |
from dataclasses import dataclass, field
|
|
|
10 |
import typing as tp
|
11 |
import warnings
|
12 |
import soundfile
|
|
|
13 |
from num2words import num2words
|
14 |
import spacy
|
15 |
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
|
|
|
35 |
ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
|
36 |
|
37 |
|
38 |
+
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
|
41 |
class JointEmbedCondition(tp.NamedTuple):
|
|
|
50 |
@dataclass
|
51 |
class ConditioningAttributes:
|
52 |
text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
|
53 |
+
wav: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
|
54 |
joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
|
55 |
|
56 |
def __getitem__(self, item):
|
|
|
95 |
|
96 |
|
97 |
|
|
|
|
|
|
|
|
|
98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
|
|
|
|
|
|
|
102 |
|
103 |
+
|
104 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
|
107 |
class Tokenizer:
|
|
|
353 |
|
354 |
|
355 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
356 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
357 |
|
358 |
|
359 |
class ConditioningProvider(nn.Module):
|
|
|
508 |
"""
|
509 |
B, T, _ = input.shape
|
510 |
|
|
|
|
|
|
|
|
|
|
|
|
|
511 |
|
512 |
+
first_step = True
|
513 |
+
offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device)
|
514 |
+
|
515 |
+
|
516 |
cross_attention_output = None
|
517 |
for cond_type, (cond, cond_mask) in conditions.items():
|
518 |
+
# print(f'{self.cond2fuse=}') - self.cond2fuse={'description': 'cross'}
|
519 |
+
|
520 |
+
cross_attention_output = cond
|
521 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
522 |
|
523 |
if self._is_streaming:
|
524 |
self._streaming_state['offsets'] = offsets + T
|
demo.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from audiocraft.audiogen import AudioGen #, audio_write
|
2 |
+
import audiofile
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
print('\n\n\n\n___________________')
|
6 |
+
|
7 |
+
txt = 'australian music'
|
8 |
+
|
9 |
+
sound_generator = AudioGen.get_pretrained('facebook/audiogen-medium')
|
10 |
+
sound_generator.set_generation_params(duration=1) # why is generating so long at 14 seconds
|
11 |
+
|
12 |
+
x = sound_generator.generate([txt])[0].detach().cpu().numpy()[0, :]
|
13 |
+
x /= np.abs(x).max() + 1e-7
|
14 |
+
|
15 |
+
audiofile.write('_audio_.wav', x, 16000)
|