Spaces:
Runtime error
Runtime error
# ---------------------------------------------------------------------------- | |
# SpeechLM: Enhanced Speech Pre-Training with Unpaired Textual Data (https://arxiv.org/abs/2209.15329) | |
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechLM | |
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4 | |
# | |
# Copyright (c) 2022 Microsoft | |
# Licensed under The MIT License [see LICENSE for details] | |
# ---------------------------------------------------------------------------- | |
import torch | |
import numpy as np | |
import logging | |
from pathlib import Path | |
from argparse import Namespace | |
from fairseq.tasks import LegacyFairseqTask, register_task | |
from fairseq.data import Dictionary, encoders | |
from fairseq.data.audio.speech_to_text_joint_dataset import S2TJointDataConfig | |
from speechlm.unit_generator import NonAutoregressiveUnitGenerator | |
from speechlm.data.text_to_unit_dataset import Text2UnitDatasetCreator | |
logging.basicConfig( | |
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", | |
datefmt="%Y-%m-%d %H:%M:%S", | |
level=logging.INFO, | |
) | |
logger = logging.getLogger(__name__) | |
class FastTextToUnitTask(LegacyFairseqTask): | |
def add_args(parser): | |
parser.add_argument("data", help="manifest root path") | |
parser.add_argument( | |
"--config-yaml", | |
type=str, | |
default="config.yaml", | |
help="Configuration YAML filename (under manifest root)", | |
) | |
parser.add_argument( | |
"--max-source-positions", | |
default=2048, | |
type=int, | |
metavar="N", | |
help="max number of tokens in the source sequence", | |
) | |
parser.add_argument( | |
"--max-target-positions", | |
default=1024, | |
type=int, | |
metavar="N", | |
help="max number of tokens in the target sequence", | |
) | |
parser.add_argument("--n-frames-per-step", type=int, default=1) | |
parser.add_argument("--eos-prob-threshold", type=float, default=0.5) | |
parser.add_argument("--eval-inference", action="store_true") | |
parser.add_argument("--eval-tb-nsample", type=int, default=8) | |
parser.add_argument("--vocoder", type=str, default="griffin_lim") | |
parser.add_argument("--spec-bwd-max-iter", type=int, default=8) | |
def __init__(self, args, src_dict, tgt_dict): | |
super().__init__(args) | |
self.src_dict = src_dict | |
self.tgt_dict = tgt_dict | |
self.data_cfg = S2TJointDataConfig(Path(args.data) / args.config_yaml) | |
self.speaker_to_id = self._get_speaker_to_id() | |
def setup_task(cls, args, **kwargs): | |
data_cfg = S2TJointDataConfig(Path(args.data) / args.config_yaml) | |
src_dict_path = Path(args.data) / data_cfg.src_vocab_filename | |
if not src_dict_path.is_file(): | |
raise FileNotFoundError(f"Dict not found: {src_dict_path.as_posix()}") | |
src_dict = Dictionary.load(src_dict_path.as_posix()) | |
logger.info( | |
f"Source dictionary size ({data_cfg.src_vocab_filename}): " f"{len(src_dict):,}" | |
) | |
tgt_dict_path = Path(args.data) / data_cfg.vocab_filename | |
if not tgt_dict_path.is_file(): | |
raise FileNotFoundError(f"Dict not found: {tgt_dict_path.as_posix()}") | |
tgt_dict = Dictionary.load(tgt_dict_path.as_posix()) | |
logger.info( | |
f"Target dictionary size ({data_cfg.vocab_filename}): " f"{len(tgt_dict):,}" | |
) | |
if getattr(args, "train_subset", None) is not None: | |
if not all(s.startswith("train") for s in args.train_subset.split(",")): | |
raise ValueError('Train splits should be named like "train*".') | |
return cls(args, src_dict, tgt_dict) | |
def load_dataset(self, split, epoch=1, combine=False, **kwargs): | |
is_train_split = split.startswith("train") | |
pre_tokenizer = self.build_tokenizer(self.args) | |
bpe_tokenizer = self.build_bpe(self.args) | |
self.datasets[split] = Text2UnitDatasetCreator.from_tsv( | |
self.args.data, | |
self.data_cfg, | |
split, | |
self.src_dict, | |
pre_tokenizer, | |
bpe_tokenizer, | |
is_train_split=is_train_split, | |
epoch=epoch, | |
seed=self.args.seed, | |
n_frames_per_step=self.args.n_frames_per_step, | |
speaker_to_id=self.speaker_to_id, | |
) | |
def target_dictionary(self): | |
return self.tgt_dict | |
def source_dictionary(self): | |
return self.src_dict | |
def max_positions(self): | |
return self.args.max_source_positions, self.args.max_target_positions | |
def _get_speaker_to_id(self): | |
speaker_to_id = None | |
speaker_set_filename = self.data_cfg.config.get("speaker_set_filename") | |
if speaker_set_filename is not None: | |
speaker_set_path = Path(self.args.data) / speaker_set_filename | |
with open(speaker_set_path) as f: | |
speaker_to_id = {r.strip(): i for i, r in enumerate(f)} | |
return speaker_to_id | |
def get_speaker_embeddings(cls, args): | |
# It Will be used in FastText2UnitModel model, insdead of nn.Embedding on speaker-id, we default to use x-vectors extracted ahead. | |
# This is for varying the speaker information when generating units from text. | |
if args.speaker_to_id is not None: | |
embed_speaker = torch.nn.Embedding( | |
len(args.speaker_to_id), args.speaker_embed_dim | |
) | |
elif args.speaker_embedding_type == "x-vector": | |
# return LayerNorm(args.speaker_embed_dim) | |
return lambda x: x.unsqueeze(1) | |
elif args.speaker_embedding_type == "i-vector": | |
# return LayerNorm(args.speaker_embed_dim) | |
return lambda x: x | |
else: | |
embed_speaker = None | |
return embed_speaker | |
def build_model(self, cfg): | |
cfg.pitch_min = self.data_cfg.config["features"].get("pitch_min", None) | |
cfg.pitch_max = self.data_cfg.config["features"].get("pitch_max", None) | |
cfg.energy_min = self.data_cfg.config["features"].get("energy_min", None) | |
cfg.energy_max = self.data_cfg.config["features"].get("energy_max", None) | |
cfg.speaker_to_id = self.speaker_to_id | |
cfg.speaker_embedding_type = self.data_cfg.config.get("speaker_embedding_type", None) | |
model = super().build_model(cfg) | |
self.generator = None | |
if getattr(cfg, "eval_inference", False): | |
self.generator = self.build_generator([model], cfg) | |
return model | |
def build_generator(self, models, cfg, vocoder=None, **unused): | |
model = models[0] | |
assert getattr(model, "NON_AUTOREGRESSIVE") is True | |
return NonAutoregressiveUnitGenerator(model, vocoder, self.data_cfg) | |
def build_tokenizer(self, args): | |
logger.info(f"pre-tokenizer: {self.data_cfg.pre_tokenizer}") | |
return encoders.build_tokenizer(Namespace(**self.data_cfg.pre_tokenizer)) | |
def build_bpe(self, args): | |
logger.info(f"tokenizer: {self.data_cfg.bpe_tokenizer}") | |
return encoders.build_bpe(Namespace(**self.data_cfg.bpe_tokenizer)) | |