tezuesh's picture
Upload folder using huggingface_hub
22d5f88 verified
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from functools import partial
import logging
import typing as tp
import torch
from torch import nn
from moshi.utils.sampling import sample_token
from moshi.utils.compile import CUDAGraphed
from moshi.modules.streaming import StreamingContainer, StreamingModule
from moshi.modules.transformer import (
StreamingTransformer,
create_norm_fn,
)
logger = logging.getLogger(__name__)
class ScaledEmbedding(nn.Embedding):
"""Boost learning rate for embeddings (with `scale`).
Args:
norm (bool): if True, uses a layer norm after the embedding.
zero_idx (int): special value indicating that the output should be exactly 0.
"""
def __init__(self, *args, norm: bool = False, zero_idx: int = -1, **kwargs):
super().__init__(*args, **kwargs)
self.norm = None
if norm:
self.norm = create_norm_fn("layer_norm", self.embedding_dim)
assert zero_idx < 0, "Please use negative values for the zero_idx."
self.zero_idx = zero_idx
def forward(self, input, *args, **kwargs):
is_zero = input == self.zero_idx
zero = torch.zeros(1, dtype=input.dtype, device=input.device)
input = input.clamp(min=0)
y = super().forward(input, *args, **kwargs)
if self.norm is not None:
y = self.norm(y)
y = torch.where(is_zero[..., None], zero, y)
return y
class LMModel(StreamingContainer):
"""Transformer-based language model on multiple streams of codes.
Args:
n_q (int): Number of parallel streams to model as input.
dep_q (int): Number of parallel streams to model in the depformer.
card (int): Cardinality, vocabulary size.
text_card (int): Cardinality of the text vocabulary.
dim (int): Dimension of the transformer encoder.
num_heads (int): Number of heads for the transformer encoder.
hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder.
norm (str): Normalization method.
norm_emb (bool): Whether to normalize embeddings.
bias_proj (bool): Use bias for output projections.
depformer_*: params used for the Depformer Transformer, all the other will be shared.
depformer_multi_linear (bool): if True, uses one linear layer per codebook to project the
output of the main transformer to the Depformer latent space.
depformer_dim_feedforward (int| list[int]| None): If None, defaults to hidden_scale * depformer_dim.
existing_text_padding_id (bool): if True, will use a different token for the initial text token, and
the text padding token.
same_initial (bool): if True, uses the same initial tokens for both text and audio mode.
**kwargs: Additional parameters for the transformer encoder.
"""
def __init__(
self,
delays: tp.List[int] = [0],
n_q: int = 8,
dep_q: int = 8,
card: int = 1024,
text_card: int = 32000,
dim: int = 128,
num_heads: int = 8,
hidden_scale: int = 4,
norm: str = "layer_norm",
norm_emb: bool = False,
bias_proj: bool = False,
depformer_dim: int = 256,
depformer_dim_feedforward: int | list[int] | None = None,
depformer_multi_linear: bool = False,
depformer_weights_per_step: bool = False,
depformer_pos_emb: str = "sin",
existing_text_padding_id: tp.Optional[int] = None,
context: tp.Optional[int] = None,
device=None,
dtype=None,
**kwargs,
):
super().__init__()
self.n_q = n_q
self.dep_q = dep_q
self.card = card
self.text_card = text_card
assert len(delays) == self.num_codebooks, "unexpected number of delays"
self.delays = delays
self.dim = dim
self.existing_text_padding_id = existing_text_padding_id
self.context = context
kwargs["context"] = context
EmbeddingFactory = partial(
ScaledEmbedding,
norm=norm_emb,
device=device,
dtype=dtype,
zero_idx=self.zero_token_id,
)
self.emb = nn.ModuleList(
[EmbeddingFactory(self.card + 1, dim) for _ in range(n_q)]
)
# Text card + padding token (if not in the original tokenizer)
extra_text = self.existing_text_padding_id is None
# Unlike for audio, here we authorize the model to output the special token.
self.text_emb = EmbeddingFactory(text_card + 1, dim)
self.text_linear = nn.Linear(dim, text_card + extra_text, bias=bias_proj)
depformer_prefix = "depformer_"
main_kwargs = {
k: v for k, v in kwargs.items() if not k.startswith(depformer_prefix)
}
self.transformer = StreamingTransformer(
d_model=dim,
num_heads=num_heads,
dim_feedforward=int(hidden_scale * dim),
norm=norm,
device=device,
dtype=dtype,
**main_kwargs,
)
self.out_norm = create_norm_fn(norm, dim)
self.depformer_multi_linear = depformer_multi_linear
kwargs_dep = main_kwargs.copy()
kwargs_dep.update(
{
k.removeprefix(depformer_prefix): v
for k, v in kwargs.items()
if k.startswith(depformer_prefix)
}
)
kwargs_dep["positional_embedding"] = depformer_pos_emb
kwargs_dep["context"] = None
if depformer_weights_per_step:
kwargs_dep["weights_per_step"] = dep_q
if depformer_multi_linear:
# One linear layer per codebook to project different informations from the main model.
self.depformer_in = nn.ModuleList(
[nn.Linear(dim, depformer_dim, bias=False) for _ in range(dep_q)]
)
else:
self.depformer_in = nn.ModuleList(
[nn.Linear(dim, depformer_dim, bias=False)]
)
# Only using up to dep_q - 1 because the last codebook is never an input to Depformer.
self.depformer_emb = nn.ModuleList(
[EmbeddingFactory(self.card + 1, depformer_dim) for _ in range(dep_q - 1)]
)
self.depformer_text_emb = EmbeddingFactory(text_card + 1, depformer_dim)
if depformer_dim_feedforward is None:
depformer_dim_feedforward = int(hidden_scale * depformer_dim)
self.depformer = StreamingTransformer(
d_model=depformer_dim,
dim_feedforward=depformer_dim_feedforward,
norm=norm,
device=device,
dtype=dtype,
**kwargs_dep,
)
self.depformer.set_streaming_propagate(False)
dim = depformer_dim # we will directly apply the next linears to the output of the Depformer.
self.linears = nn.ModuleList(
[nn.Linear(dim, self.card, bias=bias_proj) for _ in range(dep_q)]
)
@property
def initial_token_id(self) -> int:
"""Token id for the start of sequence (audio)."""
return self.card
@property
def text_initial_token_id(self) -> int:
"""Token id for the start of sequence (text)."""
return self.text_card
@property
def text_padding_token_id(self) -> int:
"""Token id for text padding."""
if self.existing_text_padding_id is None:
return self.text_card
else:
return self.existing_text_padding_id
@property
def end_of_text_padding_id(self) -> int:
"""Token id for optionally marking the last padding step for a word."""
return 0
@property
def zero_token_id(self) -> int:
"""Special value in the input tokens, indicating that no sampling should
happen for that value, and no input should be given to the model."""
return -1
@property
def ungenerated_token_id(self) -> int:
"""Special value that can be provided in the prompt to indicate that this specific
value should be predicted and sampled. This allows for partial teacher forcing, by generating
one modality, with the other one fixed.
"""
return -2
@property
def device(self):
first_param = next(iter(self.parameters()))
return first_param.device
@property
def num_codebooks(self) -> int:
return self.n_q + 1
@property
def num_audio_codebooks(self) -> int:
return self.n_q
@property
def audio_offset(self) -> int:
return 1
def _get_initial_token(self) -> torch.Tensor:
# Returns the initial token that will be fed to the model to predict the very first timestep.
# The output shape will be [B, K, 1].
device = next(iter(self.parameters())).device
zero = torch.full(
[1, 1, 1], self.zero_token_id, device=device, dtype=torch.long
)
special = torch.full_like(zero, self.initial_token_id)
text_special = torch.full_like(zero, self.text_initial_token_id)
audio_token = special
text_token = text_special
audio_token = audio_token.expand(-1, self.num_audio_codebooks, -1)
token = torch.cat([text_token, audio_token], dim=1)
return token
def forward_text(
self,
sequence: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
B, K, S = sequence.shape
assert (
K == self.num_codebooks
), f"Sequence shape {sequence.shape} must match the number of codebooks."
input_sequence = sequence
input_ = None
for cb_index in range(self.num_audio_codebooks):
audio_emb = self.emb[cb_index](
input_sequence[:, cb_index + self.audio_offset]
)
input_ = audio_emb if input_ is None else input_ + audio_emb
text_emb = self.text_emb(input_sequence[:, 0])
input_ = text_emb if input_ is None else input_ + text_emb
transformer_out = self.transformer(input_)
if self.out_norm:
transformer_out = self.out_norm(transformer_out)
assert isinstance(transformer_out, torch.Tensor)
text_logits = self.text_linear(transformer_out)
text_logits = text_logits[:, None]
return transformer_out, text_logits
def forward_depformer(
self,
depformer_cb_index: int,
sequence: torch.Tensor,
transformer_out: torch.Tensor,
) -> torch.Tensor:
B, K, S = sequence.shape
assert (
K == 1
), f"Codebooks for Depformer streaming should be passed 1 by 1, got {K}."
assert (
S == 1
), f"Steps for Depformer streaming should be passed 1 by 1, got {S}."
assert (
transformer_out.shape[1] == 1
), "Transformer out should be a for a single step."
last_token_input: tp.Optional[torch.Tensor] = None
depformer_input = transformer_out
if self.depformer_multi_linear:
depformer_input = self.depformer_in[depformer_cb_index](depformer_input)
else:
depformer_input = self.depformer_in[0](depformer_input)
if depformer_cb_index == 0:
last_token_input = self.depformer_text_emb(sequence[:, 0])
else:
last_token_input = self.depformer_emb[depformer_cb_index - 1](
sequence[:, 0]
)
depformer_input = depformer_input + last_token_input
assert depformer_input.shape[1] == 1
# depformer_input is [B, 1, depformer_dim].
# The streaming state of the depformer ensures that the proper layer is run.
dep_output = self.depformer(depformer_input)
logits = self.linears[depformer_cb_index](dep_output)
logits = logits[:, None]
assert logits.dim() == 4, logits.shape # [B, Ka, S, card]
return logits
@dataclass
class _LMGenState:
cache: torch.Tensor
initial: torch.Tensor
graphed_main: CUDAGraphed
graphed_depth: CUDAGraphed
offset: int = 0
def reset(self):
self.offset = 0
class LMGen(StreamingModule[_LMGenState]):
def __init__(
self,
lm_model: LMModel,
use_sampling: bool = True,
temp: float = 0.8,
temp_text: float = 0.7,
top_k: int = 250,
top_k_text: int = 25,
check: bool = False,
):
assert not lm_model.training, "generation shouldn't be used in training mode."
super().__init__()
self.lm_model = lm_model
self.use_sampling = use_sampling
self.temp = temp
self.temp_text = temp_text
self.top_k = top_k
self.top_k_text = top_k_text
self.check = check
self.max_delay = max(
lm_model.delays
) # with delays, we need to generate a few more time steps.
self.delays_cuda = torch.tensor(
lm_model.delays, device=lm_model.device, dtype=torch.long
)
def _init_streaming_state(self, batch_size: int) -> _LMGenState:
lm_model = self.lm_model
initial = lm_model._get_initial_token()
cache = torch.full(
(batch_size, self.lm_model.num_codebooks, self.max_delay + 2),
lm_model.ungenerated_token_id,
device=lm_model.device,
dtype=torch.long,
)
disable = lm_model.device.type != 'cuda'
graphed_main = CUDAGraphed(lm_model.forward_text, disable=disable)
graphed_depth = CUDAGraphed(self.depformer_step, disable=disable)
return _LMGenState(cache, initial, graphed_main, graphed_depth)
@torch.no_grad()
def step(self, input_tokens: torch.Tensor) -> torch.Tensor | None:
state = self._streaming_state
if state is None:
raise RuntimeError(
"You should wrap those calls with a `with lm_gen.streaming(): ...`."
)
lm_model = self.lm_model
assert input_tokens.dim() == 3, "Shape should be [B, K, T]."
B, Ki, S = input_tokens.shape
assert S == 1, "Only support being given steps one by one."
needed_tokens = lm_model.num_codebooks - lm_model.dep_q - 1
assert (
Ki == needed_tokens
), f"We expect {needed_tokens} tokens from the user stream, got {Ki}."
CT = state.cache.shape[2]
for q_other in range(input_tokens.shape[1]):
k = lm_model.dep_q + 1 + q_other
delay = lm_model.delays[k]
write_position = (state.offset + delay) % CT
state.cache[:, k, write_position : write_position + 1] = input_tokens[
:, q_other
]
position = state.offset % CT
for k, delay in enumerate(lm_model.delays):
# Only for the very beginning, we extend the initial token for the acoustic
# token that are delayed, and thus have no good value to take.
if state.offset <= delay:
state.cache[:, k, position] = state.initial[:, k, 0]
input_ = state.cache[:, :, position : position + 1]
if self.check:
# Check that we are not feeding in any value that is not generated yet.
assert not (input_ == lm_model.ungenerated_token_id).any(), (
state.offset,
input_,
)
assert (input_[:, lm_model.audio_offset :] <= lm_model.card).all(), input_
assert (input_[:, :1] <= lm_model.text_card).all()
transformer_out, text_logits = state.graphed_main(input_)
# Shape of text_logits should be [B, K_text=1, T=1, Card_text]
text_token = sample_token(
text_logits.float(),
self.use_sampling,
self.temp_text,
self.top_k_text,
)
assert text_token.dim() == 3, text_token.shape
assert text_token.shape[2] == 1
assert text_token.shape[1] == 1, "Only one text stream supported."
text_token = text_token[:, 0, 0] # shape is [B]
audio_tokens = state.graphed_depth(text_token, transformer_out)
# ensure we don't overwrite prompt tokens, we only write over ungenerated tokens
state.offset += 1
position = state.offset % CT
state.cache[:, 0, position] = text_token
state.cache[:, 1 : lm_model.dep_q + 1, position] = audio_tokens
if state.offset <= self.max_delay:
return None
B = state.cache.shape[0]
gen_delays_cuda = self.delays_cuda[: lm_model.dep_q + 1]
index = (
((state.offset - self.max_delay + gen_delays_cuda) % CT)
.view(1, -1, 1)
.expand(B, -1, 1)
)
out = state.cache.gather(dim=2, index=index)
return out
def depformer_step(
self,
text_token: torch.Tensor,
transformer_out: torch.Tensor,
) -> torch.Tensor:
(B,) = text_token.shape
prev_token = text_token
lm_model = self.lm_model
depformer_tokens: list[torch.Tensor] = []
assert not lm_model.depformer.is_streaming
with lm_model.depformer.streaming(B):
for cb_index in range(lm_model.dep_q):
input_ = prev_token[:, None, None]
logits = lm_model.forward_depformer(cb_index, input_, transformer_out)
next_token = sample_token(
logits.float(),
self.use_sampling,
self.temp,
self.top_k,
)
assert next_token.shape == (B, 1, 1)
next_token = next_token[:, 0, 0] # shape is B
depformer_tokens.append(next_token)
prev_token = next_token
assert len(depformer_tokens) == lm_model.dep_q, (
len(depformer_tokens),
lm_model.dep_q,
)
out = torch.stack(depformer_tokens, dim=1)
assert out.shape == (B, lm_model.dep_q), out.shape
return out