|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
from functools import partial |
|
import logging |
|
import typing as tp |
|
|
|
import torch |
|
from torch import nn |
|
|
|
from moshi.utils.sampling import sample_token |
|
from moshi.utils.compile import CUDAGraphed |
|
from moshi.modules.streaming import StreamingContainer, StreamingModule |
|
from moshi.modules.transformer import ( |
|
StreamingTransformer, |
|
create_norm_fn, |
|
) |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class ScaledEmbedding(nn.Embedding): |
|
"""Boost learning rate for embeddings (with `scale`). |
|
|
|
Args: |
|
norm (bool): if True, uses a layer norm after the embedding. |
|
zero_idx (int): special value indicating that the output should be exactly 0. |
|
""" |
|
|
|
def __init__(self, *args, norm: bool = False, zero_idx: int = -1, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.norm = None |
|
if norm: |
|
self.norm = create_norm_fn("layer_norm", self.embedding_dim) |
|
assert zero_idx < 0, "Please use negative values for the zero_idx." |
|
self.zero_idx = zero_idx |
|
|
|
def forward(self, input, *args, **kwargs): |
|
is_zero = input == self.zero_idx |
|
zero = torch.zeros(1, dtype=input.dtype, device=input.device) |
|
input = input.clamp(min=0) |
|
y = super().forward(input, *args, **kwargs) |
|
if self.norm is not None: |
|
y = self.norm(y) |
|
y = torch.where(is_zero[..., None], zero, y) |
|
return y |
|
|
|
|
|
class LMModel(StreamingContainer): |
|
"""Transformer-based language model on multiple streams of codes. |
|
|
|
Args: |
|
n_q (int): Number of parallel streams to model as input. |
|
dep_q (int): Number of parallel streams to model in the depformer. |
|
card (int): Cardinality, vocabulary size. |
|
text_card (int): Cardinality of the text vocabulary. |
|
dim (int): Dimension of the transformer encoder. |
|
num_heads (int): Number of heads for the transformer encoder. |
|
hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder. |
|
norm (str): Normalization method. |
|
norm_emb (bool): Whether to normalize embeddings. |
|
bias_proj (bool): Use bias for output projections. |
|
depformer_*: params used for the Depformer Transformer, all the other will be shared. |
|
depformer_multi_linear (bool): if True, uses one linear layer per codebook to project the |
|
output of the main transformer to the Depformer latent space. |
|
depformer_dim_feedforward (int| list[int]| None): If None, defaults to hidden_scale * depformer_dim. |
|
existing_text_padding_id (bool): if True, will use a different token for the initial text token, and |
|
the text padding token. |
|
same_initial (bool): if True, uses the same initial tokens for both text and audio mode. |
|
**kwargs: Additional parameters for the transformer encoder. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
delays: tp.List[int] = [0], |
|
n_q: int = 8, |
|
dep_q: int = 8, |
|
card: int = 1024, |
|
text_card: int = 32000, |
|
dim: int = 128, |
|
num_heads: int = 8, |
|
hidden_scale: int = 4, |
|
norm: str = "layer_norm", |
|
norm_emb: bool = False, |
|
bias_proj: bool = False, |
|
depformer_dim: int = 256, |
|
depformer_dim_feedforward: int | list[int] | None = None, |
|
depformer_multi_linear: bool = False, |
|
depformer_weights_per_step: bool = False, |
|
depformer_pos_emb: str = "sin", |
|
existing_text_padding_id: tp.Optional[int] = None, |
|
context: tp.Optional[int] = None, |
|
device=None, |
|
dtype=None, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
self.n_q = n_q |
|
self.dep_q = dep_q |
|
self.card = card |
|
self.text_card = text_card |
|
assert len(delays) == self.num_codebooks, "unexpected number of delays" |
|
self.delays = delays |
|
self.dim = dim |
|
self.existing_text_padding_id = existing_text_padding_id |
|
self.context = context |
|
kwargs["context"] = context |
|
EmbeddingFactory = partial( |
|
ScaledEmbedding, |
|
norm=norm_emb, |
|
device=device, |
|
dtype=dtype, |
|
zero_idx=self.zero_token_id, |
|
) |
|
self.emb = nn.ModuleList( |
|
[EmbeddingFactory(self.card + 1, dim) for _ in range(n_q)] |
|
) |
|
|
|
extra_text = self.existing_text_padding_id is None |
|
|
|
self.text_emb = EmbeddingFactory(text_card + 1, dim) |
|
self.text_linear = nn.Linear(dim, text_card + extra_text, bias=bias_proj) |
|
depformer_prefix = "depformer_" |
|
main_kwargs = { |
|
k: v for k, v in kwargs.items() if not k.startswith(depformer_prefix) |
|
} |
|
self.transformer = StreamingTransformer( |
|
d_model=dim, |
|
num_heads=num_heads, |
|
dim_feedforward=int(hidden_scale * dim), |
|
norm=norm, |
|
device=device, |
|
dtype=dtype, |
|
**main_kwargs, |
|
) |
|
self.out_norm = create_norm_fn(norm, dim) |
|
self.depformer_multi_linear = depformer_multi_linear |
|
kwargs_dep = main_kwargs.copy() |
|
kwargs_dep.update( |
|
{ |
|
k.removeprefix(depformer_prefix): v |
|
for k, v in kwargs.items() |
|
if k.startswith(depformer_prefix) |
|
} |
|
) |
|
kwargs_dep["positional_embedding"] = depformer_pos_emb |
|
kwargs_dep["context"] = None |
|
if depformer_weights_per_step: |
|
kwargs_dep["weights_per_step"] = dep_q |
|
if depformer_multi_linear: |
|
|
|
self.depformer_in = nn.ModuleList( |
|
[nn.Linear(dim, depformer_dim, bias=False) for _ in range(dep_q)] |
|
) |
|
else: |
|
self.depformer_in = nn.ModuleList( |
|
[nn.Linear(dim, depformer_dim, bias=False)] |
|
) |
|
|
|
self.depformer_emb = nn.ModuleList( |
|
[EmbeddingFactory(self.card + 1, depformer_dim) for _ in range(dep_q - 1)] |
|
) |
|
self.depformer_text_emb = EmbeddingFactory(text_card + 1, depformer_dim) |
|
if depformer_dim_feedforward is None: |
|
depformer_dim_feedforward = int(hidden_scale * depformer_dim) |
|
self.depformer = StreamingTransformer( |
|
d_model=depformer_dim, |
|
dim_feedforward=depformer_dim_feedforward, |
|
norm=norm, |
|
device=device, |
|
dtype=dtype, |
|
**kwargs_dep, |
|
) |
|
self.depformer.set_streaming_propagate(False) |
|
dim = depformer_dim |
|
|
|
self.linears = nn.ModuleList( |
|
[nn.Linear(dim, self.card, bias=bias_proj) for _ in range(dep_q)] |
|
) |
|
|
|
@property |
|
def initial_token_id(self) -> int: |
|
"""Token id for the start of sequence (audio).""" |
|
return self.card |
|
|
|
@property |
|
def text_initial_token_id(self) -> int: |
|
"""Token id for the start of sequence (text).""" |
|
return self.text_card |
|
|
|
@property |
|
def text_padding_token_id(self) -> int: |
|
"""Token id for text padding.""" |
|
if self.existing_text_padding_id is None: |
|
return self.text_card |
|
else: |
|
return self.existing_text_padding_id |
|
|
|
@property |
|
def end_of_text_padding_id(self) -> int: |
|
"""Token id for optionally marking the last padding step for a word.""" |
|
return 0 |
|
|
|
@property |
|
def zero_token_id(self) -> int: |
|
"""Special value in the input tokens, indicating that no sampling should |
|
happen for that value, and no input should be given to the model.""" |
|
return -1 |
|
|
|
@property |
|
def ungenerated_token_id(self) -> int: |
|
"""Special value that can be provided in the prompt to indicate that this specific |
|
value should be predicted and sampled. This allows for partial teacher forcing, by generating |
|
one modality, with the other one fixed. |
|
""" |
|
return -2 |
|
|
|
@property |
|
def device(self): |
|
first_param = next(iter(self.parameters())) |
|
return first_param.device |
|
|
|
@property |
|
def num_codebooks(self) -> int: |
|
return self.n_q + 1 |
|
|
|
@property |
|
def num_audio_codebooks(self) -> int: |
|
return self.n_q |
|
|
|
@property |
|
def audio_offset(self) -> int: |
|
return 1 |
|
|
|
def _get_initial_token(self) -> torch.Tensor: |
|
|
|
|
|
device = next(iter(self.parameters())).device |
|
zero = torch.full( |
|
[1, 1, 1], self.zero_token_id, device=device, dtype=torch.long |
|
) |
|
special = torch.full_like(zero, self.initial_token_id) |
|
|
|
text_special = torch.full_like(zero, self.text_initial_token_id) |
|
audio_token = special |
|
text_token = text_special |
|
audio_token = audio_token.expand(-1, self.num_audio_codebooks, -1) |
|
token = torch.cat([text_token, audio_token], dim=1) |
|
return token |
|
|
|
def forward_text( |
|
self, |
|
sequence: torch.Tensor, |
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
B, K, S = sequence.shape |
|
assert ( |
|
K == self.num_codebooks |
|
), f"Sequence shape {sequence.shape} must match the number of codebooks." |
|
input_sequence = sequence |
|
input_ = None |
|
for cb_index in range(self.num_audio_codebooks): |
|
audio_emb = self.emb[cb_index]( |
|
input_sequence[:, cb_index + self.audio_offset] |
|
) |
|
input_ = audio_emb if input_ is None else input_ + audio_emb |
|
text_emb = self.text_emb(input_sequence[:, 0]) |
|
input_ = text_emb if input_ is None else input_ + text_emb |
|
transformer_out = self.transformer(input_) |
|
|
|
if self.out_norm: |
|
transformer_out = self.out_norm(transformer_out) |
|
assert isinstance(transformer_out, torch.Tensor) |
|
text_logits = self.text_linear(transformer_out) |
|
text_logits = text_logits[:, None] |
|
return transformer_out, text_logits |
|
|
|
def forward_depformer( |
|
self, |
|
depformer_cb_index: int, |
|
sequence: torch.Tensor, |
|
transformer_out: torch.Tensor, |
|
) -> torch.Tensor: |
|
B, K, S = sequence.shape |
|
assert ( |
|
K == 1 |
|
), f"Codebooks for Depformer streaming should be passed 1 by 1, got {K}." |
|
assert ( |
|
S == 1 |
|
), f"Steps for Depformer streaming should be passed 1 by 1, got {S}." |
|
assert ( |
|
transformer_out.shape[1] == 1 |
|
), "Transformer out should be a for a single step." |
|
last_token_input: tp.Optional[torch.Tensor] = None |
|
depformer_input = transformer_out |
|
if self.depformer_multi_linear: |
|
depformer_input = self.depformer_in[depformer_cb_index](depformer_input) |
|
else: |
|
depformer_input = self.depformer_in[0](depformer_input) |
|
if depformer_cb_index == 0: |
|
last_token_input = self.depformer_text_emb(sequence[:, 0]) |
|
else: |
|
last_token_input = self.depformer_emb[depformer_cb_index - 1]( |
|
sequence[:, 0] |
|
) |
|
depformer_input = depformer_input + last_token_input |
|
assert depformer_input.shape[1] == 1 |
|
|
|
|
|
dep_output = self.depformer(depformer_input) |
|
logits = self.linears[depformer_cb_index](dep_output) |
|
logits = logits[:, None] |
|
assert logits.dim() == 4, logits.shape |
|
return logits |
|
|
|
|
|
@dataclass |
|
class _LMGenState: |
|
cache: torch.Tensor |
|
initial: torch.Tensor |
|
graphed_main: CUDAGraphed |
|
graphed_depth: CUDAGraphed |
|
offset: int = 0 |
|
|
|
def reset(self): |
|
self.offset = 0 |
|
|
|
|
|
class LMGen(StreamingModule[_LMGenState]): |
|
def __init__( |
|
self, |
|
lm_model: LMModel, |
|
use_sampling: bool = True, |
|
temp: float = 0.8, |
|
temp_text: float = 0.7, |
|
top_k: int = 250, |
|
top_k_text: int = 25, |
|
check: bool = False, |
|
): |
|
assert not lm_model.training, "generation shouldn't be used in training mode." |
|
super().__init__() |
|
|
|
self.lm_model = lm_model |
|
self.use_sampling = use_sampling |
|
self.temp = temp |
|
self.temp_text = temp_text |
|
self.top_k = top_k |
|
self.top_k_text = top_k_text |
|
self.check = check |
|
self.max_delay = max( |
|
lm_model.delays |
|
) |
|
self.delays_cuda = torch.tensor( |
|
lm_model.delays, device=lm_model.device, dtype=torch.long |
|
) |
|
|
|
def _init_streaming_state(self, batch_size: int) -> _LMGenState: |
|
lm_model = self.lm_model |
|
initial = lm_model._get_initial_token() |
|
cache = torch.full( |
|
(batch_size, self.lm_model.num_codebooks, self.max_delay + 2), |
|
lm_model.ungenerated_token_id, |
|
device=lm_model.device, |
|
dtype=torch.long, |
|
) |
|
|
|
disable = lm_model.device.type != 'cuda' |
|
graphed_main = CUDAGraphed(lm_model.forward_text, disable=disable) |
|
graphed_depth = CUDAGraphed(self.depformer_step, disable=disable) |
|
|
|
return _LMGenState(cache, initial, graphed_main, graphed_depth) |
|
|
|
@torch.no_grad() |
|
def step(self, input_tokens: torch.Tensor) -> torch.Tensor | None: |
|
state = self._streaming_state |
|
if state is None: |
|
raise RuntimeError( |
|
"You should wrap those calls with a `with lm_gen.streaming(): ...`." |
|
) |
|
lm_model = self.lm_model |
|
|
|
assert input_tokens.dim() == 3, "Shape should be [B, K, T]." |
|
B, Ki, S = input_tokens.shape |
|
assert S == 1, "Only support being given steps one by one." |
|
needed_tokens = lm_model.num_codebooks - lm_model.dep_q - 1 |
|
assert ( |
|
Ki == needed_tokens |
|
), f"We expect {needed_tokens} tokens from the user stream, got {Ki}." |
|
|
|
CT = state.cache.shape[2] |
|
for q_other in range(input_tokens.shape[1]): |
|
k = lm_model.dep_q + 1 + q_other |
|
delay = lm_model.delays[k] |
|
write_position = (state.offset + delay) % CT |
|
state.cache[:, k, write_position : write_position + 1] = input_tokens[ |
|
:, q_other |
|
] |
|
|
|
position = state.offset % CT |
|
for k, delay in enumerate(lm_model.delays): |
|
|
|
|
|
if state.offset <= delay: |
|
state.cache[:, k, position] = state.initial[:, k, 0] |
|
input_ = state.cache[:, :, position : position + 1] |
|
|
|
if self.check: |
|
|
|
assert not (input_ == lm_model.ungenerated_token_id).any(), ( |
|
state.offset, |
|
input_, |
|
) |
|
assert (input_[:, lm_model.audio_offset :] <= lm_model.card).all(), input_ |
|
assert (input_[:, :1] <= lm_model.text_card).all() |
|
|
|
transformer_out, text_logits = state.graphed_main(input_) |
|
|
|
text_token = sample_token( |
|
text_logits.float(), |
|
self.use_sampling, |
|
self.temp_text, |
|
self.top_k_text, |
|
) |
|
assert text_token.dim() == 3, text_token.shape |
|
assert text_token.shape[2] == 1 |
|
assert text_token.shape[1] == 1, "Only one text stream supported." |
|
text_token = text_token[:, 0, 0] |
|
audio_tokens = state.graphed_depth(text_token, transformer_out) |
|
|
|
|
|
state.offset += 1 |
|
position = state.offset % CT |
|
state.cache[:, 0, position] = text_token |
|
state.cache[:, 1 : lm_model.dep_q + 1, position] = audio_tokens |
|
|
|
if state.offset <= self.max_delay: |
|
return None |
|
B = state.cache.shape[0] |
|
gen_delays_cuda = self.delays_cuda[: lm_model.dep_q + 1] |
|
index = ( |
|
((state.offset - self.max_delay + gen_delays_cuda) % CT) |
|
.view(1, -1, 1) |
|
.expand(B, -1, 1) |
|
) |
|
out = state.cache.gather(dim=2, index=index) |
|
return out |
|
|
|
def depformer_step( |
|
self, |
|
text_token: torch.Tensor, |
|
transformer_out: torch.Tensor, |
|
) -> torch.Tensor: |
|
(B,) = text_token.shape |
|
prev_token = text_token |
|
lm_model = self.lm_model |
|
depformer_tokens: list[torch.Tensor] = [] |
|
assert not lm_model.depformer.is_streaming |
|
with lm_model.depformer.streaming(B): |
|
for cb_index in range(lm_model.dep_q): |
|
input_ = prev_token[:, None, None] |
|
logits = lm_model.forward_depformer(cb_index, input_, transformer_out) |
|
next_token = sample_token( |
|
logits.float(), |
|
self.use_sampling, |
|
self.temp, |
|
self.top_k, |
|
) |
|
assert next_token.shape == (B, 1, 1) |
|
next_token = next_token[:, 0, 0] |
|
depformer_tokens.append(next_token) |
|
prev_token = next_token |
|
|
|
assert len(depformer_tokens) == lm_model.dep_q, ( |
|
len(depformer_tokens), |
|
lm_model.dep_q, |
|
) |
|
out = torch.stack(depformer_tokens, dim=1) |
|
assert out.shape == (B, lm_model.dep_q), out.shape |
|
return out |
|
|