|
|
|
""" |
|
Recipe for "direct" (speech -> semantics) SLU with ASR-based transfer learning. |
|
|
|
We encode input waveforms into features using a model trained on LibriSpeech, |
|
then feed the features into a seq2seq model to map them to semantics. |
|
|
|
(Adapted from the LibriSpeech seq2seq ASR recipe written by Ju-Chieh Chou, Mirco Ravanelli, Abdel Heba, and Peter Plantinga.) |
|
|
|
Run using: |
|
> python train.py hparams/train.yaml |
|
|
|
Authors |
|
* Loren Lugosch 2020 |
|
* Mirco Ravanelli 2020 |
|
""" |
|
|
|
import sys |
|
import torch |
|
import speechbrain as sb |
|
import logging |
|
from hyperpyyaml import load_hyperpyyaml |
|
from speechbrain.utils.distributed import run_on_main |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
class SLU(sb.Brain): |
|
def compute_forward(self, batch, stage): |
|
"""Forward computations from the waveform batches to the output probabilities.""" |
|
batch = batch.to(self.device) |
|
wavs, wav_lens = batch.sig |
|
tokens_bos, tokens_bos_lens = batch.tokens_bos |
|
|
|
|
|
if stage == sb.Stage.TRAIN: |
|
|
|
wavs_aug_tot = [] |
|
wavs_aug_tot.append(wavs) |
|
for count, augment in enumerate(self.hparams.augment_pipeline): |
|
|
|
|
|
wavs_aug = augment(wavs, wav_lens) |
|
|
|
|
|
if wavs_aug.shape[1] > wavs.shape[1]: |
|
wavs_aug = wavs_aug[:, 0 : wavs.shape[1]] |
|
else: |
|
zero_sig = torch.zeros_like(wavs) |
|
zero_sig[:, 0 : wavs_aug.shape[1]] = wavs_aug |
|
wavs_aug = zero_sig |
|
|
|
wavs_aug_tot.append(wavs_aug) |
|
|
|
wavs = torch.cat(wavs_aug_tot, dim=0) |
|
self.n_augment = len(wavs_aug_tot) |
|
wav_lens = torch.cat([wav_lens] * self.n_augment) |
|
tokens_bos = torch.cat([tokens_bos] * self.n_augment) |
|
|
|
|
|
with torch.no_grad(): |
|
ASR_encoder_out = self.hparams.asr_model.encode_batch( |
|
wavs.detach(), wav_lens |
|
) |
|
|
|
|
|
encoder_out = self.hparams.slu_enc(ASR_encoder_out) |
|
e_in = self.hparams.output_emb(tokens_bos) |
|
h, _ = self.hparams.dec(e_in, encoder_out, wav_lens) |
|
|
|
|
|
logits = self.hparams.seq_lin(h) |
|
p_seq = self.hparams.log_softmax(logits) |
|
|
|
|
|
if ( |
|
stage == sb.Stage.TRAIN |
|
and self.batch_count % show_results_every != 0 |
|
): |
|
return p_seq, wav_lens |
|
else: |
|
p_tokens, scores = self.hparams.beam_searcher(encoder_out, wav_lens) |
|
return p_seq, wav_lens, p_tokens |
|
|
|
def compute_objectives(self, predictions, batch, stage): |
|
"""Computes the loss (NLL) given predictions and targets.""" |
|
|
|
if ( |
|
stage == sb.Stage.TRAIN |
|
and self.batch_count % show_results_every != 0 |
|
): |
|
p_seq, wav_lens = predictions |
|
else: |
|
p_seq, wav_lens, predicted_tokens = predictions |
|
|
|
ids = batch.id |
|
tokens_eos, tokens_eos_lens = batch.tokens_eos |
|
tokens, tokens_lens = batch.tokens |
|
|
|
if hasattr(self.hparams, "env_corrupt") and stage == sb.Stage.TRAIN: |
|
tokens_eos = torch.cat([tokens_eos, tokens_eos], dim=0) |
|
tokens_eos_lens = torch.cat( |
|
[tokens_eos_lens, tokens_eos_lens], dim=0 |
|
) |
|
|
|
if stage == sb.Stage.TRAIN: |
|
tokens_eos = torch.cat([tokens_eos] * self.n_augment, dim=0) |
|
tokens_eos_lens = torch.cat( |
|
[tokens_eos_lens] * self.n_augment, dim=0 |
|
) |
|
|
|
loss_seq = self.hparams.seq_cost( |
|
p_seq, tokens_eos, length=tokens_eos_lens |
|
) |
|
|
|
|
|
loss = loss_seq |
|
|
|
if (stage != sb.Stage.TRAIN) or ( |
|
self.batch_count % show_results_every == 0 |
|
): |
|
|
|
predicted_semantics = [ |
|
tokenizer.decode_ids(utt_seq).split(" ") |
|
for utt_seq in predicted_tokens |
|
] |
|
|
|
target_semantics = [wrd.split(" ") for wrd in batch.semantics] |
|
|
|
for i in range(len(target_semantics)): |
|
print(" ".join(predicted_semantics[i]).replace("|", ",")) |
|
print(" ".join(target_semantics[i]).replace("|", ",")) |
|
print("") |
|
|
|
if stage != sb.Stage.TRAIN: |
|
self.wer_metric.append( |
|
ids, predicted_semantics, target_semantics |
|
) |
|
self.cer_metric.append( |
|
ids, predicted_semantics, target_semantics |
|
) |
|
|
|
return loss |
|
|
|
def fit_batch(self, batch): |
|
"""Train the parameters given a single batch in input""" |
|
predictions = self.compute_forward(batch, sb.Stage.TRAIN) |
|
loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN) |
|
loss.backward() |
|
if self.check_gradients(loss): |
|
self.optimizer.step() |
|
self.optimizer.zero_grad() |
|
self.batch_count += 1 |
|
return loss.detach() |
|
|
|
def evaluate_batch(self, batch, stage): |
|
"""Computations needed for validation/test batches""" |
|
predictions = self.compute_forward(batch, stage=stage) |
|
loss = self.compute_objectives(predictions, batch, stage=stage) |
|
return loss.detach() |
|
|
|
def on_stage_start(self, stage, epoch): |
|
"""Gets called at the beginning of each epoch""" |
|
self.batch_count = 0 |
|
|
|
if stage != sb.Stage.TRAIN: |
|
|
|
self.cer_metric = self.hparams.cer_computer() |
|
self.wer_metric = self.hparams.error_rate_computer() |
|
|
|
def on_stage_end(self, stage, stage_loss, epoch): |
|
"""Gets called at the end of a epoch.""" |
|
|
|
stage_stats = {"loss": stage_loss} |
|
if stage == sb.Stage.TRAIN: |
|
self.train_stats = stage_stats |
|
else: |
|
stage_stats["CER"] = self.cer_metric.summarize("error_rate") |
|
stage_stats["WER"] = self.wer_metric.summarize("error_rate") |
|
|
|
|
|
if stage == sb.Stage.VALID: |
|
old_lr, new_lr = self.hparams.lr_annealing(stage_stats["WER"]) |
|
sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr) |
|
self.hparams.train_logger.log_stats( |
|
stats_meta={"epoch": epoch, "lr": old_lr}, |
|
train_stats=self.train_stats, |
|
valid_stats=stage_stats, |
|
) |
|
self.checkpointer.save_and_keep_only( |
|
meta={"WER": stage_stats["WER"]}, min_keys=["WER"], |
|
) |
|
elif stage == sb.Stage.TEST: |
|
self.hparams.train_logger.log_stats( |
|
stats_meta={"Epoch loaded": self.hparams.epoch_counter.current}, |
|
test_stats=stage_stats, |
|
) |
|
with open(self.hparams.wer_file, "w") as w: |
|
self.wer_metric.write_stats(w) |
|
|
|
|
|
def dataio_prepare(hparams): |
|
"""This function prepares the datasets to be used in the brain class. |
|
It also defines the data processing pipeline through user-defined functions.""" |
|
|
|
data_folder = hparams["data_folder"] |
|
|
|
train_data = sb.dataio.dataset.DynamicItemDataset.from_csv( |
|
csv_path=hparams["csv_train"], replacements={"data_root": data_folder}, |
|
) |
|
|
|
if hparams["sorting"] == "ascending": |
|
|
|
train_data = train_data.filtered_sorted(sort_key="duration") |
|
|
|
hparams["dataloader_opts"]["shuffle"] = False |
|
|
|
elif hparams["sorting"] == "descending": |
|
train_data = train_data.filtered_sorted( |
|
sort_key="duration", reverse=True |
|
) |
|
|
|
hparams["dataloader_opts"]["shuffle"] = False |
|
|
|
elif hparams["sorting"] == "random": |
|
pass |
|
|
|
else: |
|
raise NotImplementedError( |
|
"sorting must be random, ascending or descending" |
|
) |
|
|
|
valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv( |
|
csv_path=hparams["csv_valid"], replacements={"data_root": data_folder}, |
|
) |
|
valid_data = valid_data.filtered_sorted(sort_key="duration") |
|
|
|
test_data = sb.dataio.dataset.DynamicItemDataset.from_csv( |
|
csv_path=hparams["csv_test"], replacements={"data_root": data_folder}, |
|
) |
|
test_data = test_data.filtered_sorted(sort_key="duration") |
|
|
|
datasets = [train_data, valid_data, test_data] |
|
|
|
tokenizer = hparams["tokenizer"] |
|
|
|
|
|
@sb.utils.data_pipeline.takes("wav") |
|
@sb.utils.data_pipeline.provides("sig") |
|
def audio_pipeline(wav): |
|
sig = sb.dataio.dataio.read_audio(wav) |
|
return sig |
|
|
|
sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline) |
|
|
|
|
|
@sb.utils.data_pipeline.takes("semantics") |
|
@sb.utils.data_pipeline.provides( |
|
"semantics", "token_list", "tokens_bos", "tokens_eos", "tokens" |
|
) |
|
def text_pipeline(semantics): |
|
yield semantics |
|
tokens_list = tokenizer.encode_as_ids(semantics) |
|
yield tokens_list |
|
tokens_bos = torch.LongTensor([hparams["bos_index"]] + (tokens_list)) |
|
yield tokens_bos |
|
tokens_eos = torch.LongTensor(tokens_list + [hparams["eos_index"]]) |
|
yield tokens_eos |
|
tokens = torch.LongTensor(tokens_list) |
|
yield tokens |
|
|
|
sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline) |
|
|
|
|
|
sb.dataio.dataset.set_output_keys( |
|
datasets, |
|
["id", "sig", "semantics", "tokens_bos", "tokens_eos", "tokens"], |
|
) |
|
return train_data, valid_data, test_data, tokenizer |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) |
|
with open(hparams_file) as fin: |
|
hparams = load_hyperpyyaml(fin, overrides) |
|
|
|
show_results_every = 100 |
|
|
|
|
|
|
|
sb.utils.distributed.ddp_init_group(run_opts) |
|
|
|
|
|
sb.create_experiment_directory( |
|
experiment_directory=hparams["output_folder"], |
|
hyperparams_to_save=hparams_file, |
|
overrides=overrides, |
|
) |
|
|
|
|
|
from prepare import prepare_FSC |
|
|
|
|
|
run_on_main( |
|
prepare_FSC, |
|
kwargs={ |
|
"data_folder": hparams["data_folder"], |
|
"save_folder": hparams["output_folder"], |
|
"skip_prep": hparams["skip_prep"], |
|
}, |
|
) |
|
|
|
|
|
(train_set, valid_set, test_set, tokenizer,) = dataio_prepare(hparams) |
|
|
|
|
|
run_on_main(hparams["pretrainer"].collect_files) |
|
hparams["pretrainer"].load_collected(device=run_opts["device"]) |
|
|
|
|
|
slu_brain = SLU( |
|
modules=hparams["modules"], |
|
opt_class=hparams["opt_class"], |
|
hparams=hparams, |
|
run_opts=run_opts, |
|
checkpointer=hparams["checkpointer"], |
|
) |
|
|
|
|
|
slu_brain.tokenizer = tokenizer |
|
|
|
|
|
slu_brain.fit( |
|
slu_brain.hparams.epoch_counter, |
|
train_set, |
|
valid_set, |
|
train_loader_kwargs=hparams["dataloader_opts"], |
|
valid_loader_kwargs=hparams["dataloader_opts"], |
|
) |
|
|
|
|
|
slu_brain.hparams.wer_file = hparams["output_folder"] + "/wer_test.txt" |
|
slu_brain.evaluate(test_set, test_loader_kwargs=hparams["dataloader_opts"]) |
|
|