amupd's picture
SpeechT5 upload
62e9ca6
raw
history blame
15.9 kB
# --------------------------------------------------------
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/facebookresearch/fairseq
# --------------------------------------------------------
import math
from typing import Dict, List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.distributed import fsdp_wrap
from fairseq.models import FairseqEncoder
from fairseq.modules import (
FairseqDropout,
LayerDropModuleList,
LayerNorm,
SinusoidalPositionalEmbedding,
)
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
from torch import Tensor
from fairseq.models.transformer import (
TransformerConfig,
)
from speechut.modules import transformer_layer, LearnedPositionalEmbedding
from speechut.modules import RelativePositionalEncoding
# rewrite name for backward compatibility in `make_generation_fast_`
def module_name_fordropout(module_name: str) -> str:
if module_name == "TransformerEncoderBase":
return "TransformerEncoder"
else:
return module_name
class TransformerEncoderBase(FairseqEncoder):
"""
Transformer encoder consisting of *cfg.encoder.layers* layers. Each layer
is a :class:`TransformerEncoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): encoding dictionary
embed_tokens (torch.nn.Embedding): input embedding
"""
def __init__(self, cfg, dictionary, embed_tokens, use_rel_pos_enc=False, scaling_for_att=1.0):
self.cfg = cfg
super().__init__(dictionary)
self.register_buffer("version", torch.Tensor([3]))
self.dropout_module = FairseqDropout(
cfg.dropout, module_name=module_name_fordropout(self.__class__.__name__)
)
self.encoder_layerdrop = cfg.encoder.layerdrop
embed_dim = embed_tokens.embedding_dim
self.padding_idx = embed_tokens.padding_idx
self.max_source_positions = cfg.max_source_positions
self.embed_tokens = embed_tokens
self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(embed_dim)
self.embed_positions = (
PositionalEmbedding(
cfg.max_source_positions,
embed_dim,
self.padding_idx,
learned=cfg.encoder.learned_pos,
)
if not cfg.no_token_positional_embeddings
else None
)
if cfg.layernorm_embedding:
self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export)
else:
self.layernorm_embedding = None
if not cfg.adaptive_input and cfg.quant_noise.pq > 0:
self.quant_noise = apply_quant_noise_(
nn.Linear(embed_dim, embed_dim, bias=False),
cfg.quant_noise.pq,
cfg.quant_noise.pq_block_size,
)
else:
self.quant_noise = None
if self.encoder_layerdrop > 0.0:
self.layers = LayerDropModuleList(p=self.encoder_layerdrop)
else:
self.layers = nn.ModuleList([])
self.use_rel_pos_enc = use_rel_pos_enc
self.scaling_for_att = scaling_for_att
self.layers.extend(
[self.build_encoder_layer(cfg) for i in range(cfg.encoder.layers)]
)
self.num_layers = len(self.layers)
if cfg.encoder.normalize_before:
self.layer_norm = LayerNorm(embed_dim, export=cfg.export)
else:
self.layer_norm = None
if self.use_rel_pos_enc:
self.pos_emb = RelativePositionalEncoding(embed_dim // cfg.encoder.attention_heads, 160)
def build_encoder_layer(self, cfg):
layer = transformer_layer.TransformerEncoderLayerBase(cfg, has_relative_attention_bias=self.use_rel_pos_enc, scaling_for_att=self.scaling_for_att)
checkpoint = cfg.checkpoint_activations
if checkpoint:
offload_to_cpu = cfg.offload_activations
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
# if we are checkpointing, enforce that FSDP always wraps the
# checkpointed layer, regardless of layer size
min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0
layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
return layer
def forward_embedding(
self, src_tokens, token_embedding: Optional[torch.Tensor] = None
):
# embed tokens and positions
if token_embedding is None:
token_embedding = self.embed_tokens(src_tokens)
x = embed = self.embed_scale * token_embedding
if self.embed_positions is not None:
x = embed + self.embed_positions(src_tokens)
if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x)
x = self.dropout_module(x)
if self.quant_noise is not None:
x = self.quant_noise(x)
return x, embed
def forward(
self,
src_tokens,
src_lengths: Optional[torch.Tensor] = None,
return_all_hiddens: bool = False,
token_embeddings: Optional[torch.Tensor] = None,
uniformity_layers: Optional[List[int]] = None,
):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
token_embeddings (torch.Tensor, optional): precomputed embeddings
default `None` will recompute embeddings
Returns:
dict:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
of shape `(batch, src_len, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
"""
return self.forward_scriptable(
src_tokens, src_lengths, return_all_hiddens, token_embeddings, uniformity_layers
)
# TorchScript doesn't support super() method so that the scriptable Subclass
# can't access the base class model in Torchscript.
# Current workaround is to add a helper function with different name and
# call the helper function from scriptable Subclass.
def forward_scriptable(
self,
src_tokens,
src_lengths: Optional[torch.Tensor] = None,
return_all_hiddens: bool = False,
token_embeddings: Optional[torch.Tensor] = None,
uniformity_layers: Optional[List[int]] = None,
):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
token_embeddings (torch.Tensor, optional): precomputed embeddings
default `None` will recompute embeddings
Returns:
dict:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
of shape `(batch, src_len, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
"""
# compute padding mask
encoder_padding_mask = src_tokens.eq(self.padding_idx)
has_pads = src_tokens.device.type == "xla" or encoder_padding_mask.any()
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings)
# account for padding while computing the representation
if has_pads:
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
# B x T x C -> T x B x C
x = x.transpose(0, 1)
if self.use_rel_pos_enc:
x_len = x.shape[0]
pos_seq = torch.arange(0, x_len).long().to(x.device)
pos_seq = pos_seq[:, None] - pos_seq[None, :]
pos_k, pos_v = self.pos_emb(pos_seq)
else:
pos_k = None
encoder_states = []
uniformity_hiddens = []
if return_all_hiddens:
encoder_states.append(x)
if uniformity_layers is not None and 0 in uniformity_layers:
x = F.normalize(x.float(), dim=-1).type_as(x)
uniformity_hiddens.append(x)
# encoder layers
for i, layer in enumerate(self.layers):
x = layer(
x, encoder_padding_mask=encoder_padding_mask if has_pads else None,
pos_bias=pos_k,
)
if uniformity_layers is not None and i+1 in uniformity_layers:
x = F.normalize(x.float(), dim=-1).type_as(x)
uniformity_hiddens.append(x)
if return_all_hiddens:
assert encoder_states is not None
encoder_states.append(x)
if self.layer_norm is not None:
x = self.layer_norm(x)
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
# `forward` so we use a dictionary instead.
# TorchScript does not support mixed values so the values are all lists.
# The empty list is equivalent to None.
src_lengths = (
src_tokens.ne(self.padding_idx)
.sum(dim=1, dtype=torch.int32)
.reshape(-1, 1)
.contiguous()
)
return {
"encoder_out": [x], # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [encoder_embedding], # B x T x C
"encoder_states": encoder_states, # List[T x B x C]
"uniformity_hiddens": uniformity_hiddens, # List[T x B x C]
"src_tokens": [],
"src_lengths": [src_lengths],
}
@torch.jit.export
def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order):
"""
Reorder encoder output according to *new_order*.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
*encoder_out* rearranged according to *new_order*
"""
if len(encoder_out["encoder_out"]) == 0:
new_encoder_out = []
else:
new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)]
if len(encoder_out["encoder_padding_mask"]) == 0:
new_encoder_padding_mask = []
else:
new_encoder_padding_mask = [
encoder_out["encoder_padding_mask"][0].index_select(0, new_order)
]
if len(encoder_out["encoder_embedding"]) == 0:
new_encoder_embedding = []
else:
new_encoder_embedding = [
encoder_out["encoder_embedding"][0].index_select(0, new_order)
]
if len(encoder_out["src_tokens"]) == 0:
src_tokens = []
else:
src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)]
if len(encoder_out["src_lengths"]) == 0:
src_lengths = []
else:
src_lengths = [(encoder_out["src_lengths"][0]).index_select(0, new_order)]
encoder_states = encoder_out["encoder_states"]
if len(encoder_states) > 0:
for idx, state in enumerate(encoder_states):
encoder_states[idx] = state.index_select(1, new_order)
return {
"encoder_out": new_encoder_out, # T x B x C
"encoder_padding_mask": new_encoder_padding_mask, # B x T
"encoder_embedding": new_encoder_embedding, # B x T x C
"encoder_states": encoder_states, # List[T x B x C]
"src_tokens": src_tokens, # B x T
"src_lengths": src_lengths, # B x 1
}
def max_positions(self):
"""Maximum input length supported by the encoder."""
if self.embed_positions is None:
return self.max_source_positions
return min(self.max_source_positions, self.embed_positions.max_positions)
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
weights_key = "{}.embed_positions.weights".format(name)
if weights_key in state_dict:
print("deleting {0}".format(weights_key))
del state_dict[weights_key]
state_dict[
"{}.embed_positions._float_tensor".format(name)
] = torch.FloatTensor(1)
for i in range(self.num_layers):
# update layer norms
self.layers[i].upgrade_state_dict_named(
state_dict, "{}.layers.{}".format(name, i)
)
version_key = "{}.version".format(name)
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
# earlier checkpoints did not normalize after the stack of layers
self.layer_norm = None
self.normalize = False
state_dict[version_key] = torch.Tensor([1])
return state_dict
class TransformerEncoder(TransformerEncoderBase):
def __init__(self, args, dictionary, embed_tokens):
self.args = args
super().__init__(
TransformerConfig.from_namespace(args),
dictionary,
embed_tokens,
use_rel_pos_enc=getattr(args, "use_rel_pos_enc", False),
scaling_for_att=getattr(args, "scaling_for_att", 1.0),
)
def build_encoder_layer(self, args):
return super().build_encoder_layer(
TransformerConfig.from_namespace(args),
)
def PositionalEmbedding(
num_embeddings: int,
embedding_dim: int,
padding_idx: int,
learned: bool = False,
):
if learned:
# if padding_idx is specified then offset the embedding ids by
# this index and adjust num_embeddings appropriately
# TODO: The right place for this offset would be inside
# LearnedPositionalEmbedding. Move this there for a cleaner implementation.
if padding_idx is not None:
num_embeddings = num_embeddings + padding_idx + 1
m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5)
if padding_idx is not None:
nn.init.constant_(m.weight[padding_idx], 0)
else:
m = SinusoidalPositionalEmbedding(
embedding_dim,
padding_idx,
init_size=num_embeddings + padding_idx + 1,
)
return m