Create ultravox_processing.py

#9
by reach-vb HF staff - opened
Files changed (1) hide show
  1. ultravox_processing.py +215 -0
ultravox_processing.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import transformers
6
+
7
+ from .ultravox_config import UltravoxConfig
8
+
9
+
10
+ class UltravoxProcessor(transformers.ProcessorMixin):
11
+ """
12
+ Constructs an Ultravox processor which wraps an audio processor and a tokenizer into a single processor.
13
+
14
+ Args:
15
+ audio_processor: The audio processor for the audio encoder.
16
+ tokenizer: The tokenizer for the language model.
17
+ """
18
+
19
+ attributes = ["audio_processor", "tokenizer"]
20
+ audio_processor_class = (
21
+ "Wav2Vec2Processor",
22
+ "SeamlessM4TFeatureExtractor",
23
+ "WhisperProcessor",
24
+ )
25
+ tokenizer_class = (
26
+ "PreTrainedTokenizer",
27
+ "PreTrainedTokenizerFast",
28
+ )
29
+
30
+ tokenizer: transformers.PreTrainedTokenizerBase
31
+ audio_processor: transformers.ProcessorMixin
32
+
33
+ def __init__(
34
+ self,
35
+ audio_processor=None,
36
+ tokenizer=None,
37
+ audio_padding: str = "longest",
38
+ encoder_ds_factor: int = 320,
39
+ stack_factor: int = 8,
40
+ audio_placeholder: str = "<|audio|>",
41
+ ):
42
+ """
43
+ Args:
44
+ audio_processor: The audio processor for the audio encoder.
45
+ tokenizer: The tokenizer for the language model.
46
+ audio_padding: The padding strategy for the audio encoder.
47
+ encoder_ds_factor: The downsample factor of the audio encoder.
48
+ stack_factor: The factor by which the audio encoder output is stacked in the multimodal projector.
49
+ audio_placeholder: The placeholder for the audio in the text.
50
+ """
51
+ self.audio_padding = audio_padding
52
+ self.encoder_ds_factor = encoder_ds_factor
53
+ self.stack_factor = stack_factor
54
+ self.audio_placeholder = audio_placeholder
55
+ self.audio_token_replacement = tokenizer.eos_token
56
+ assert (
57
+ self.audio_token_replacement is not None
58
+ ), "The tokenizer has no EOS token. Cannot recover."
59
+ if tokenizer.pad_token_id is None:
60
+ tokenizer.pad_token_id = tokenizer.eos_token_id
61
+
62
+ super().__init__(audio_processor=audio_processor, tokenizer=tokenizer)
63
+
64
+ @classmethod
65
+ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
66
+ config: UltravoxConfig = transformers.AutoConfig.from_pretrained(
67
+ pretrained_model_name_or_path, **kwargs
68
+ )
69
+ audio_processor = transformers.AutoProcessor.from_pretrained(
70
+ config.audio_model_id
71
+ or config.audio_config._name_or_path
72
+ or "facebook/wav2vec2-base-960h"
73
+ )
74
+
75
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
76
+ pretrained_model_name_or_path, **kwargs
77
+ )
78
+ tokenizer.padding_side = "left"
79
+ tokenizer.pad_token = tokenizer.eos_token
80
+
81
+ return cls(
82
+ audio_processor=audio_processor,
83
+ tokenizer=tokenizer,
84
+ stack_factor=config.stack_factor,
85
+ )
86
+
87
+ def __call__(
88
+ self,
89
+ text: Optional[str] = None,
90
+ audio: Optional[Union[np.ndarray, torch.Tensor]] = None,
91
+ sampling_rate: Optional[int] = None,
92
+ return_tensors: Optional[
93
+ Union[str, transformers.TensorType]
94
+ ] = transformers.TensorType.PYTORCH,
95
+ **kwargs,
96
+ ) -> transformers.BatchFeature:
97
+ """
98
+ Main method to prepare for the model one text sequence and audio. This method forwards the `text`
99
+ and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode
100
+ the text. To prepare the audio(s), this method forwards the `audio`, `sampling_rate` and `kwargs` arguments to
101
+ audio processor's [`~Wav2Vec2Processor.__call__`] if `audio` is not `None`. Please refer to the docstring
102
+ of the above two methods for more information.
103
+
104
+ Args:
105
+ text (`str`, `List[str]`):
106
+ The sequence to be encoded. Sequence can be a string or (pretokenized string).
107
+ audio (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
108
+ The audio to be prepared. Audio can be NumPy array or PyTorch tensor. In case of a
109
+ NumPy array/PyTorch tensor, each audio should be of shape (C, T), where C is a number of channels, and T the
110
+ sample length of the audio.
111
+ sampling_rate (`int`, *optional*, defaults to 16000):
112
+ Sampling rate of the input audio. We expect 16kHz audio. Don't change this value unless you know what
113
+ you are doing.
114
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
115
+ If set, will return tensors of a particular framework. Acceptable values are:
116
+
117
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
118
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
119
+ - `'np'`: Return NumPy `np.ndarray` objects.
120
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
121
+
122
+ Returns:
123
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
124
+
125
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
126
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
127
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
128
+ `None`).
129
+ - **audio_values** -- Processed audio values to be fed to a model. Returned when `audio` is not `None`.
130
+ - **audio_token_len** -- Predicted number of audio frames: this value is guaranteed to be a close upper bound.
131
+ Returned when `audio` is not `None`.
132
+ - **audio_token_start_idx** -- The index in the tokenized text where the audio starts. Returned when `audio` is not `None`.
133
+ """
134
+ # TODO: Add support for multiple audio and text inputs.
135
+ data = {}
136
+ audio_embed_frames = 0
137
+ if audio is not None and len(audio) > 0:
138
+ if self.audio_padding == "max_length":
139
+ # 30 seconds is the expected length for Whisper
140
+ assert sampling_rate is not None, "Sampling rate must be provided."
141
+ audio_len = 30 * sampling_rate
142
+ else:
143
+ audio_len = audio.shape[-1]
144
+ # It's guaranteed that the number of frames is less than or equal to this amount.
145
+ # For Whisper this is exact AFAICT, but for Wav2Vec2 it's an upper bound.
146
+ # Currently, StackAudioFrames makes sure an over-estimation won't cause issues by padding the audio embeddings.
147
+ nb_encoder_frames = int(round(audio_len / self.encoder_ds_factor + 1e-4))
148
+ audio_embed_frames = int(np.ceil(nb_encoder_frames / self.stack_factor))
149
+ data["audio_token_len"] = [audio_embed_frames]
150
+
151
+ # Main audio processing. The processor is model-specific.
152
+ x = self.audio_processor(
153
+ audio,
154
+ sampling_rate=sampling_rate,
155
+ padding="longest",
156
+ max_length=audio_len,
157
+ return_attention_mask=True,
158
+ **kwargs,
159
+ )
160
+ if "input_features" in x:
161
+ data["audio_values"] = x.input_features
162
+ else:
163
+ data["audio_values"] = x.input_values
164
+ if self.audio_padding == "max_length":
165
+ data["audio_len"] = x.attention_mask.sum(-1) - 1
166
+ else:
167
+ data["audio_len"] = [data["audio_values"].shape[-1]]
168
+
169
+ if text is not None:
170
+ assert isinstance(
171
+ text, str
172
+ ), "Text must be a string. Batch mode not supported yet."
173
+ if self.audio_placeholder in text:
174
+ if "audio_token_len" not in data:
175
+ raise ValueError(
176
+ f"audio must be provided when using audio placeholder ({self.audio_placeholder}) in text."
177
+ )
178
+
179
+ start_idx = len(
180
+ self.tokenizer.encode(
181
+ text[: text.index(self.audio_placeholder)],
182
+ add_special_tokens=False,
183
+ )
184
+ )
185
+ data["audio_token_start_idx"] = [start_idx]
186
+
187
+ # Replace the audio placeholder with the audio token.
188
+ # e.g. "Transcribe\n<|audio|>" -> "Transcribe </s></s></s></s></s></s></s></s>"
189
+ # where the number of </s> is the number of audio frames.
190
+ text = text.replace(
191
+ self.audio_placeholder,
192
+ self.audio_token_replacement * audio_embed_frames,
193
+ )
194
+
195
+ # Special tokens like BOS should already have been added by the caller.
196
+ data.update(self.tokenizer([text], add_special_tokens=False, **kwargs))
197
+
198
+ return transformers.BatchFeature(data=data, tensor_type=return_tensors)
199
+
200
+ def batch_decode(self, *args, **kwargs):
201
+ return self.tokenizer.batch_decode(*args, **kwargs)
202
+
203
+ def decode(self, *args, **kwargs):
204
+ return self.tokenizer.decode(*args, **kwargs)
205
+
206
+ @property
207
+ def model_input_names(self):
208
+ tokenizer_input_names = self.tokenizer.model_input_names
209
+ audio_processor_input_names = self.audio_processor.model_input_names
210
+ return list(set(tokenizer_input_names + audio_processor_input_names))
211
+
212
+
213
+ UltravoxProcessor.register_for_auto_class()
214
+
215
+ transformers.AutoProcessor.register(UltravoxConfig, UltravoxProcessor)