Dionyssos commited on
Commit
c5e1f80
·
1 Parent(s): 531e776
Files changed (3) hide show
  1. audiocraft/builders.py +1 -32
  2. audiocraft/conditioners.py +12 -226
  3. demo.py +15 -0
audiocraft/builders.py CHANGED
@@ -137,25 +137,9 @@ def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> Cond
137
  model_args = cond_cfg[model_type]
138
  if model_type == 't5':
139
  conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args)
140
- elif model_type == 'lut':
141
- conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args)
142
- # elif model_type == 'chroma_stem':
143
- # conditioners[str(cond)] = ChromaStemConditioner(
144
- # output_dim=output_dim,
145
- # duration=duration,
146
- # device=device,
147
- # **model_args
148
- # )
149
- elif model_type == 'clap':
150
- conditioners[str(cond)] = CLAPEmbeddingConditioner(
151
- output_dim=output_dim,
152
- device=device,
153
- **model_args
154
- )
155
  else:
156
  raise ValueError(f"Unrecognized conditioning model: {model_type}")
157
  conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args)
158
- print(' COND\n',conditioner)
159
  return conditioner
160
 
161
 
@@ -229,22 +213,7 @@ def get_processor(cfg, sample_rate: int = 24000):
229
  return sample_processor
230
 
231
 
232
- def get_debug_lm_model(device='cpu'):
233
- """Instantiate a debug LM to be used for unit tests."""
234
- pattern = DelayedPatternProvider(n_q=4)
235
- dim = 16
236
- providers = {
237
- 'description': LUTConditioner(n_bins=128, dim=dim, output_dim=dim, tokenizer="whitespace"),
238
- }
239
- condition_provider = ConditioningProvider(providers)
240
- fuser = ConditionFuser(
241
- {'cross': ['description'], 'prepend': [],
242
- 'sum': [], 'input_interpolate': []})
243
- lm = LMModel(
244
- pattern, condition_provider, fuser,
245
- n_q=4, card=400, dim=dim, num_heads=4, custom=True, num_layers=2,
246
- cross_attention=True, causal=True)
247
- return lm.to(device).eval()
248
 
249
 
250
  def get_wrapped_compression_model(
 
137
  model_args = cond_cfg[model_type]
138
  if model_type == 't5':
139
  conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  else:
141
  raise ValueError(f"Unrecognized conditioning model: {model_type}")
142
  conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args)
 
143
  return conditioner
144
 
145
 
 
213
  return sample_processor
214
 
215
 
216
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
 
219
  def get_wrapped_compression_model(
audiocraft/conditioners.py CHANGED
@@ -1,9 +1,3 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
  from collections import defaultdict
8
  from copy import deepcopy
9
  from dataclasses import dataclass, field
@@ -16,7 +10,6 @@ import re
16
  import typing as tp
17
  import warnings
18
  import soundfile
19
- import einops
20
  from num2words import num2words
21
  import spacy
22
  from transformers import T5EncoderModel, T5Tokenizer # type: ignore
@@ -42,12 +35,7 @@ TextCondition = tp.Optional[str] # a text condition can be a string or None (if
42
  ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
43
 
44
 
45
- class WavCondition(tp.NamedTuple):
46
- wav: torch.Tensor
47
- length: torch.Tensor
48
- sample_rate: tp.List[int]
49
- path: tp.List[tp.Optional[str]] = []
50
- seek_time: tp.List[tp.Optional[float]] = []
51
 
52
 
53
  class JointEmbedCondition(tp.NamedTuple):
@@ -62,7 +50,7 @@ class JointEmbedCondition(tp.NamedTuple):
62
  @dataclass
63
  class ConditioningAttributes:
64
  text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
65
- wav: tp.Dict[str, WavCondition] = field(default_factory=dict)
66
  joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
67
 
68
  def __getitem__(self, item):
@@ -107,67 +95,13 @@ class ConditioningAttributes:
107
 
108
 
109
 
110
- def nullify_condition(condition: ConditionType, dim: int = 1):
111
- """Transform an input condition to a null condition.
112
- The way it is done by converting it to a single zero vector similarly
113
- to how it is done inside WhiteSpaceTokenizer and NoopTokenizer.
114
 
115
- Args:
116
- condition (ConditionType): A tuple of condition and mask (tuple[torch.Tensor, torch.Tensor])
117
- dim (int): The dimension that will be truncated (should be the time dimension)
118
- WARNING!: dim should not be the batch dimension!
119
- Returns:
120
- ConditionType: A tuple of null condition and mask
121
- """
122
- assert dim != 0, "dim cannot be the batch dimension!"
123
- assert isinstance(condition, tuple) and \
124
- isinstance(condition[0], torch.Tensor) and \
125
- isinstance(condition[1], torch.Tensor), "'nullify_condition' got an unexpected input type!"
126
- cond, mask = condition
127
- B = cond.shape[0]
128
- last_dim = cond.dim() - 1
129
- out = cond.transpose(dim, last_dim)
130
- out = 0. * out[..., :1]
131
- out = out.transpose(dim, last_dim)
132
- mask = torch.zeros((B, 1), device=out.device).int()
133
- assert cond.dim() == out.dim()
134
- return out, mask
135
-
136
-
137
- def nullify_wav(cond: WavCondition) -> WavCondition:
138
- """Transform a WavCondition to a nullified WavCondition.
139
- It replaces the wav by a null tensor, forces its length to 0, and replaces metadata by dummy attributes.
140
 
141
- Args:
142
- cond (WavCondition): Wav condition with wav, tensor of shape [B, T].
143
- Returns:
144
- WavCondition: Nullified wav condition.
145
- """
146
- null_wav, _ = nullify_condition((cond.wav, torch.zeros_like(cond.wav)), dim=cond.wav.dim() - 1)
147
- return WavCondition(
148
- wav=null_wav,
149
- length=torch.tensor([0] * cond.wav.shape[0], device=cond.wav.device),
150
- sample_rate=cond.sample_rate,
151
- path=[None] * cond.wav.shape[0],
152
- seek_time=[None] * cond.wav.shape[0],
153
- )
154
 
155
 
156
- def nullify_joint_embed(embed: JointEmbedCondition) -> JointEmbedCondition:
157
- """Nullify the joint embedding condition by replacing it by a null tensor, forcing its length to 0,
158
- and replacing metadata by dummy attributes.
159
 
160
- Args:
161
- cond (JointEmbedCondition): Joint embedding condition with wav and text, wav tensor of shape [B, C, T].
162
- """
163
- null_wav, _ = nullify_condition((embed.wav, torch.zeros_like(embed.wav)), dim=embed.wav.dim() - 1)
164
- return JointEmbedCondition(
165
- wav=null_wav, text=[None] * len(embed.text),
166
- length=torch.LongTensor([0]).to(embed.wav.device),
167
- sample_rate=embed.sample_rate,
168
- path=[None] * embed.wav.shape[0],
169
- seek_time=[0] * embed.wav.shape[0],
170
- )
171
 
172
 
173
  class Tokenizer:
@@ -419,129 +353,7 @@ class T5Conditioner(TextConditioner):
419
 
420
 
421
 
422
- def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str) -> ConditioningAttributes:
423
- """Utility function for nullifying an attribute inside an ConditioningAttributes object.
424
- If the condition is of type "wav", then nullify it using `nullify_condition` function.
425
- If the condition is of any other type, set its value to None.
426
- Works in-place.
427
- """
428
- if condition_type not in ['text', 'wav', 'joint_embed']:
429
- raise ValueError(
430
- "dropout_condition got an unexpected condition type!"
431
- f" expected 'text', 'wav' or 'joint_embed' but got '{condition_type}'"
432
- )
433
-
434
- if condition not in getattr(sample, condition_type):
435
- raise ValueError(
436
- "dropout_condition received an unexpected condition!"
437
- f" expected wav={sample.wav.keys()} and text={sample.text.keys()}"
438
- f" but got '{condition}' of type '{condition_type}'!"
439
- )
440
-
441
- if condition_type == 'wav':
442
- wav_cond = sample.wav[condition]
443
- sample.wav[condition] = nullify_wav(wav_cond)
444
- elif condition_type == 'joint_embed':
445
- embed = sample.joint_embed[condition]
446
- sample.joint_embed[condition] = nullify_joint_embed(embed)
447
- else:
448
- sample.text[condition] = None
449
-
450
- return sample
451
-
452
-
453
- class DropoutModule(nn.Module):
454
- """Base module for all dropout modules."""
455
- def __init__(self, seed: int = 1234):
456
- super().__init__()
457
- self.rng = torch.Generator()
458
- self.rng.manual_seed(seed)
459
-
460
-
461
- class AttributeDropout(DropoutModule):
462
- """Dropout with a given probability per attribute.
463
- This is different from the behavior of ClassifierFreeGuidanceDropout as this allows for attributes
464
- to be dropped out separately. For example, "artist" can be dropped while "genre" remains.
465
- This is in contrast to ClassifierFreeGuidanceDropout where if "artist" is dropped "genre"
466
- must also be dropped.
467
-
468
- Args:
469
- p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example:
470
- ...
471
- "genre": 0.1,
472
- "artist": 0.5,
473
- "wav": 0.25,
474
- ...
475
- active_on_eval (bool, optional): Whether the dropout is active at eval. Default to False.
476
- seed (int, optional): Random seed.
477
- """
478
- def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eval: bool = False, seed: int = 1234):
479
- super().__init__(seed=seed)
480
- self.active_on_eval = active_on_eval
481
- # construct dict that return the values from p otherwise 0
482
- self.p = {}
483
- for condition_type, probs in p.items():
484
- self.p[condition_type] = defaultdict(lambda: 0, probs)
485
-
486
- def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
487
- """
488
- Args:
489
- samples (list[ConditioningAttributes]): List of conditions.
490
- Returns:
491
- list[ConditioningAttributes]: List of conditions after certain attributes were set to None.
492
- """
493
- if not self.training and not self.active_on_eval:
494
- return samples
495
-
496
- samples = deepcopy(samples)
497
- for condition_type, ps in self.p.items(): # for condition types [text, wav]
498
- for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre])
499
- if torch.rand(1, generator=self.rng).item() < p:
500
- for sample in samples:
501
- dropout_condition(sample, condition_type, condition)
502
- return samples
503
-
504
- def __repr__(self):
505
- return f"AttributeDropout({dict(self.p)})"
506
-
507
-
508
- class ClassifierFreeGuidanceDropout(DropoutModule):
509
- """Classifier Free Guidance dropout.
510
- All attributes are dropped with the same probability.
511
 
512
- Args:
513
- p (float): Probability to apply condition dropout during training.
514
- seed (int): Random seed.
515
- """
516
- def __init__(self, p: float, seed: int = 1234):
517
- super().__init__(seed=seed)
518
- self.p = p
519
-
520
- def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
521
- """
522
- Args:
523
- samples (list[ConditioningAttributes]): List of conditions.
524
- Returns:
525
- list[ConditioningAttributes]: List of conditions after all attributes were set to None.
526
- """
527
- if not self.training:
528
- return samples
529
-
530
- # decide on which attributes to drop in a batched fashion
531
- drop = torch.rand(1, generator=self.rng).item() < self.p
532
- if not drop:
533
- return samples
534
-
535
- # nullify conditions of all attributes
536
- samples = deepcopy(samples)
537
- for condition_type in ["wav", "text"]:
538
- for sample in samples:
539
- for condition in sample.attributes[condition_type]:
540
- dropout_condition(sample, condition_type, condition)
541
- return samples
542
-
543
- def __repr__(self):
544
- return f"ClassifierFreeGuidanceDropout(p={self.p})"
545
 
546
 
547
  class ConditioningProvider(nn.Module):
@@ -696,43 +508,17 @@ class ConditionFuser(StreamingModule):
696
  """
697
  B, T, _ = input.shape
698
 
699
- if 'offsets' in self._streaming_state:
700
- first_step = False
701
- offsets = self._streaming_state['offsets']
702
- else:
703
- first_step = True
704
- offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device)
705
 
706
- assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \
707
- f"given conditions contain unknown attributes for fuser, " \
708
- f"expected {self.cond2fuse.keys()}, got {conditions.keys()}"
 
709
  cross_attention_output = None
710
  for cond_type, (cond, cond_mask) in conditions.items():
711
- op = self.cond2fuse[cond_type]
712
- if op == 'sum':
713
- input += cond
714
- elif op == 'input_interpolate':
715
- cond = einops.rearrange(cond, "b t d -> b d t")
716
- cond = F.interpolate(cond, size=input.shape[1])
717
- input += einops.rearrange(cond, "b d t -> b t d")
718
- elif op == 'prepend':
719
- if first_step:
720
- input = torch.cat([cond, input], dim=1)
721
- elif op == 'cross':
722
- if cross_attention_output is not None:
723
- cross_attention_output = torch.cat([cross_attention_output, cond], dim=1)
724
- else:
725
- cross_attention_output = cond
726
- else:
727
- raise ValueError(f"unknown op ({op})")
728
-
729
- if self.cross_attention_pos_emb and cross_attention_output is not None:
730
- positions = torch.arange(
731
- cross_attention_output.shape[1],
732
- device=cross_attention_output.device
733
- ).view(1, -1, 1)
734
- pos_emb = create_sin_embedding(positions, cross_attention_output.shape[-1])
735
- cross_attention_output = cross_attention_output + self.cross_attention_pos_emb_scale * pos_emb
736
 
737
  if self._is_streaming:
738
  self._streaming_state['offsets'] = offsets + T
 
 
 
 
 
 
 
1
  from collections import defaultdict
2
  from copy import deepcopy
3
  from dataclasses import dataclass, field
 
10
  import typing as tp
11
  import warnings
12
  import soundfile
 
13
  from num2words import num2words
14
  import spacy
15
  from transformers import T5EncoderModel, T5Tokenizer # type: ignore
 
35
  ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
36
 
37
 
38
+
 
 
 
 
 
39
 
40
 
41
  class JointEmbedCondition(tp.NamedTuple):
 
50
  @dataclass
51
  class ConditioningAttributes:
52
  text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
53
+ wav: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
54
  joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
55
 
56
  def __getitem__(self, item):
 
95
 
96
 
97
 
 
 
 
 
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
 
 
 
 
102
 
103
+
104
+
 
 
 
 
 
 
 
 
 
105
 
106
 
107
  class Tokenizer:
 
353
 
354
 
355
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
 
358
 
359
  class ConditioningProvider(nn.Module):
 
508
  """
509
  B, T, _ = input.shape
510
 
 
 
 
 
 
 
511
 
512
+ first_step = True
513
+ offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device)
514
+
515
+
516
  cross_attention_output = None
517
  for cond_type, (cond, cond_mask) in conditions.items():
518
+ # print(f'{self.cond2fuse=}') - self.cond2fuse={'description': 'cross'}
519
+
520
+ cross_attention_output = cond
521
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522
 
523
  if self._is_streaming:
524
  self._streaming_state['offsets'] = offsets + T
demo.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from audiocraft.audiogen import AudioGen #, audio_write
2
+ import audiofile
3
+ import numpy as np
4
+
5
+ print('\n\n\n\n___________________')
6
+
7
+ txt = 'australian music'
8
+
9
+ sound_generator = AudioGen.get_pretrained('facebook/audiogen-medium')
10
+ sound_generator.set_generation_params(duration=1) # why is generating so long at 14 seconds
11
+
12
+ x = sound_generator.generate([txt])[0].detach().cpu().numpy()[0, :]
13
+ x /= np.abs(x).max() + 1e-7
14
+
15
+ audiofile.write('_audio_.wav', x, 16000)