jcmc's picture
Upload speechbrain IC model
9b333dd
raw
history blame
12.1 kB
#!/usr/bin/env/python3
"""
Recipe for "direct" (speech -> semantics) SLU with ASR-based transfer learning.
We encode input waveforms into features using a model trained on LibriSpeech,
then feed the features into a seq2seq model to map them to semantics.
(Adapted from the LibriSpeech seq2seq ASR recipe written by Ju-Chieh Chou, Mirco Ravanelli, Abdel Heba, and Peter Plantinga.)
Run using:
> python train.py hparams/train.yaml
Authors
* Loren Lugosch 2020
* Mirco Ravanelli 2020
"""
import sys
import torch
import speechbrain as sb
import logging
from hyperpyyaml import load_hyperpyyaml
from speechbrain.utils.distributed import run_on_main
logger = logging.getLogger(__name__)
# Define training procedure
class SLU(sb.Brain):
def compute_forward(self, batch, stage):
"""Forward computations from the waveform batches to the output probabilities."""
batch = batch.to(self.device)
wavs, wav_lens = batch.sig
tokens_bos, tokens_bos_lens = batch.tokens_bos
# Add augmentation if specified
if stage == sb.Stage.TRAIN:
# Applying the augmentation pipeline
wavs_aug_tot = []
wavs_aug_tot.append(wavs)
for count, augment in enumerate(self.hparams.augment_pipeline):
# Apply augment
wavs_aug = augment(wavs, wav_lens)
# Managing speed change
if wavs_aug.shape[1] > wavs.shape[1]:
wavs_aug = wavs_aug[:, 0 : wavs.shape[1]]
else:
zero_sig = torch.zeros_like(wavs)
zero_sig[:, 0 : wavs_aug.shape[1]] = wavs_aug
wavs_aug = zero_sig
wavs_aug_tot.append(wavs_aug)
wavs = torch.cat(wavs_aug_tot, dim=0)
self.n_augment = len(wavs_aug_tot)
wav_lens = torch.cat([wav_lens] * self.n_augment)
tokens_bos = torch.cat([tokens_bos] * self.n_augment)
# ASR encoder forward pass
with torch.no_grad():
ASR_encoder_out = self.hparams.asr_model.encode_batch(
wavs.detach(), wav_lens
)
# SLU forward pass
encoder_out = self.hparams.slu_enc(ASR_encoder_out)
e_in = self.hparams.output_emb(tokens_bos)
h, _ = self.hparams.dec(e_in, encoder_out, wav_lens)
# Output layer for seq2seq log-probabilities
logits = self.hparams.seq_lin(h)
p_seq = self.hparams.log_softmax(logits)
# Compute outputs
if (
stage == sb.Stage.TRAIN
and self.batch_count % show_results_every != 0
):
return p_seq, wav_lens
else:
p_tokens, scores = self.hparams.beam_searcher(encoder_out, wav_lens)
return p_seq, wav_lens, p_tokens
def compute_objectives(self, predictions, batch, stage):
"""Computes the loss (NLL) given predictions and targets."""
if (
stage == sb.Stage.TRAIN
and self.batch_count % show_results_every != 0
):
p_seq, wav_lens = predictions
else:
p_seq, wav_lens, predicted_tokens = predictions
ids = batch.id
tokens_eos, tokens_eos_lens = batch.tokens_eos
tokens, tokens_lens = batch.tokens
if hasattr(self.hparams, "env_corrupt") and stage == sb.Stage.TRAIN:
tokens_eos = torch.cat([tokens_eos, tokens_eos], dim=0)
tokens_eos_lens = torch.cat(
[tokens_eos_lens, tokens_eos_lens], dim=0
)
if stage == sb.Stage.TRAIN:
tokens_eos = torch.cat([tokens_eos] * self.n_augment, dim=0)
tokens_eos_lens = torch.cat(
[tokens_eos_lens] * self.n_augment, dim=0
)
loss_seq = self.hparams.seq_cost(
p_seq, tokens_eos, length=tokens_eos_lens
)
# (No ctc loss)
loss = loss_seq
if (stage != sb.Stage.TRAIN) or (
self.batch_count % show_results_every == 0
):
# Decode token terms to words
predicted_semantics = [
tokenizer.decode_ids(utt_seq).split(" ")
for utt_seq in predicted_tokens
]
target_semantics = [wrd.split(" ") for wrd in batch.semantics]
for i in range(len(target_semantics)):
print(" ".join(predicted_semantics[i]).replace("|", ","))
print(" ".join(target_semantics[i]).replace("|", ","))
print("")
if stage != sb.Stage.TRAIN:
self.wer_metric.append(
ids, predicted_semantics, target_semantics
)
self.cer_metric.append(
ids, predicted_semantics, target_semantics
)
return loss
def fit_batch(self, batch):
"""Train the parameters given a single batch in input"""
predictions = self.compute_forward(batch, sb.Stage.TRAIN)
loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
loss.backward()
if self.check_gradients(loss):
self.optimizer.step()
self.optimizer.zero_grad()
self.batch_count += 1
return loss.detach()
def evaluate_batch(self, batch, stage):
"""Computations needed for validation/test batches"""
predictions = self.compute_forward(batch, stage=stage)
loss = self.compute_objectives(predictions, batch, stage=stage)
return loss.detach()
def on_stage_start(self, stage, epoch):
"""Gets called at the beginning of each epoch"""
self.batch_count = 0
if stage != sb.Stage.TRAIN:
self.cer_metric = self.hparams.cer_computer()
self.wer_metric = self.hparams.error_rate_computer()
def on_stage_end(self, stage, stage_loss, epoch):
"""Gets called at the end of a epoch."""
# Compute/store important stats
stage_stats = {"loss": stage_loss}
if stage == sb.Stage.TRAIN:
self.train_stats = stage_stats
else:
stage_stats["CER"] = self.cer_metric.summarize("error_rate")
stage_stats["WER"] = self.wer_metric.summarize("error_rate")
# Perform end-of-iteration things, like annealing, logging, etc.
if stage == sb.Stage.VALID:
old_lr, new_lr = self.hparams.lr_annealing(stage_stats["WER"])
sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr)
self.hparams.train_logger.log_stats(
stats_meta={"epoch": epoch, "lr": old_lr},
train_stats=self.train_stats,
valid_stats=stage_stats,
)
self.checkpointer.save_and_keep_only(
meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
)
elif stage == sb.Stage.TEST:
self.hparams.train_logger.log_stats(
stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
test_stats=stage_stats,
)
with open(self.hparams.wer_file, "w") as w:
self.wer_metric.write_stats(w)
def dataio_prepare(hparams):
"""This function prepares the datasets to be used in the brain class.
It also defines the data processing pipeline through user-defined functions."""
data_folder = hparams["data_folder"]
train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
csv_path=hparams["csv_train"], replacements={"data_root": data_folder},
)
if hparams["sorting"] == "ascending":
# we sort training data to speed up training and get better results.
train_data = train_data.filtered_sorted(sort_key="duration")
# when sorting do not shuffle in dataloader ! otherwise is pointless
hparams["dataloader_opts"]["shuffle"] = False
elif hparams["sorting"] == "descending":
train_data = train_data.filtered_sorted(
sort_key="duration", reverse=True
)
# when sorting do not shuffle in dataloader ! otherwise is pointless
hparams["dataloader_opts"]["shuffle"] = False
elif hparams["sorting"] == "random":
pass
else:
raise NotImplementedError(
"sorting must be random, ascending or descending"
)
valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
csv_path=hparams["csv_valid"], replacements={"data_root": data_folder},
)
valid_data = valid_data.filtered_sorted(sort_key="duration")
test_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
csv_path=hparams["csv_test"], replacements={"data_root": data_folder},
)
test_data = test_data.filtered_sorted(sort_key="duration")
datasets = [train_data, valid_data, test_data]
tokenizer = hparams["tokenizer"]
# 2. Define audio pipeline:
@sb.utils.data_pipeline.takes("wav")
@sb.utils.data_pipeline.provides("sig")
def audio_pipeline(wav):
sig = sb.dataio.dataio.read_audio(wav)
return sig
sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
# 3. Define text pipeline:
@sb.utils.data_pipeline.takes("semantics")
@sb.utils.data_pipeline.provides(
"semantics", "token_list", "tokens_bos", "tokens_eos", "tokens"
)
def text_pipeline(semantics):
yield semantics
tokens_list = tokenizer.encode_as_ids(semantics)
yield tokens_list
tokens_bos = torch.LongTensor([hparams["bos_index"]] + (tokens_list))
yield tokens_bos
tokens_eos = torch.LongTensor(tokens_list + [hparams["eos_index"]])
yield tokens_eos
tokens = torch.LongTensor(tokens_list)
yield tokens
sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
# 4. Set output:
sb.dataio.dataset.set_output_keys(
datasets,
["id", "sig", "semantics", "tokens_bos", "tokens_eos", "tokens"],
)
return train_data, valid_data, test_data, tokenizer
if __name__ == "__main__":
# Load hyperparameters file with command-line overrides
hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
with open(hparams_file) as fin:
hparams = load_hyperpyyaml(fin, overrides)
show_results_every = 100 # plots results every N iterations
# If distributed_launch=True then
# create ddp_group with the right communication protocol
sb.utils.distributed.ddp_init_group(run_opts)
# Create experiment directory
sb.create_experiment_directory(
experiment_directory=hparams["output_folder"],
hyperparams_to_save=hparams_file,
overrides=overrides,
)
# Dataset prep
from prepare import prepare_FSC # noqa
# multi-gpu (ddp) save data preparation
run_on_main(
prepare_FSC,
kwargs={
"data_folder": hparams["data_folder"],
"save_folder": hparams["output_folder"],
"skip_prep": hparams["skip_prep"],
},
)
# here we create the datasets objects as well as tokenization and encoding
(train_set, valid_set, test_set, tokenizer,) = dataio_prepare(hparams)
# We download and pretrain the tokenizer
run_on_main(hparams["pretrainer"].collect_files)
hparams["pretrainer"].load_collected(device=run_opts["device"])
# Brain class initialization
slu_brain = SLU(
modules=hparams["modules"],
opt_class=hparams["opt_class"],
hparams=hparams,
run_opts=run_opts,
checkpointer=hparams["checkpointer"],
)
# adding objects to trainer:
slu_brain.tokenizer = tokenizer
# Training
slu_brain.fit(
slu_brain.hparams.epoch_counter,
train_set,
valid_set,
train_loader_kwargs=hparams["dataloader_opts"],
valid_loader_kwargs=hparams["dataloader_opts"],
)
# Test
slu_brain.hparams.wer_file = hparams["output_folder"] + "/wer_test.txt"
slu_brain.evaluate(test_set, test_loader_kwargs=hparams["dataloader_opts"])