# -------------------------------------------------------- # Copyright (c) 2022 Microsoft # Licensed under The MIT License [see LICENSE for details] # Based on fairseq code bases # https://github.com/facebookresearch/fairseq # -------------------------------------------------------- import itertools import logging import io import os import sys import time from pathlib import Path from typing import Any, List, Optional, Union, Tuple import numpy as np import torch import torch.nn.functional as F from fairseq.data import data_utils, Dictionary from fairseq.data.fairseq_dataset import FairseqDataset from fairseq.data.audio.audio_utils import ( read_from_stored_zip, is_sf_audio_data, ) FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS = {".npy", ".wav", ".flac", ".ogg"} logger = logging.getLogger(__name__) def parse_path(path: str) -> Tuple[str, List[int]]: """Parse data path which is either a path to 1. a .npy/.wav/.flac/.ogg file 2. a stored ZIP file with slicing info: "[zip_path]:[offset]:[length]" Args: path (str): the data path to parse Returns: file_path (str): the file path slice_ptr (list of int): empty in case 1; byte offset and length for the slice in case 2 """ if Path(path).suffix in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS: _path, slice_ptr = path, [] else: _path, *slice_ptr = path.split(":") if not Path(_path).is_file(): raise FileNotFoundError(f"File not found: {_path}") assert len(slice_ptr) in {0, 1, 2}, f"Invalid path: {path}" slice_ptr = [int(i) for i in slice_ptr] return _path, slice_ptr def load_audio(manifest_path, max_keep, min_keep, retry_times=5): n_long, n_short = 0, 0 names, inds, sizes, chunk_names, chunk_indices = [], [], [], [], [] for i in range(retry_times): with open(manifest_path) as f: root = f.readline().strip() for ind, line in enumerate(f): items = line.strip().split("\t") assert len(items) == 2, line sz = int(items[1]) if min_keep is not None and sz < min_keep: n_short += 1 elif max_keep is not None and sz > max_keep: n_long += 1 else: fname = items[0].split(":") if len(fname) > 2: if len(chunk_names) == 0 or fname[0] != chunk_names[-1]: chunk_names.append(fname[0]) chunk_indices.append(len(names)) names.append(items[0]) inds.append(ind) sizes.append(sz) if len(names) == 0: logger.warn(f"Fail to load manifest for the {i} time") time.sleep(1) continue else: break tot = ind + 1 logger.info( ( f"max_keep={max_keep}, min_keep={min_keep}, " f"loaded {len(names)}, skipped {n_short} short and {n_long} long, " f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}" ) ) return root, names, inds, tot, sizes, chunk_names, chunk_indices def load_label(label_path, inds, tot, retry_times=5): for i in range(retry_times): with open(label_path) as f: labels = [line.rstrip() for line in f] if len(labels) == 0: logger.warn(f"Fail to load label for the {i} time") time.sleep(1) continue else: break assert ( len(labels) == tot ), f"number of labels does not match ({len(labels)} != {tot})" labels = [labels[i] for i in inds] return labels def load_label_offset(label_path, inds, tot, retry_times=5): for i in range(retry_times): with open(label_path) as f: code_lengths = [len(line.encode("utf-8")) for line in f] if len(code_lengths) == 0: logger.warn(f"Fail to load label for the {i} time") time.sleep(1) continue else: break assert ( len(code_lengths) == tot ), f"number of labels does not match ({len(code_lengths)} != {tot})" offsets = list(itertools.accumulate([0] + code_lengths)) offsets = [(offsets[i], offsets[i + 1]) for i in inds] return offsets def verify_label_lengths( audio_sizes, audio_rate, label_path, label_rate, inds, tot, tol=0.1, # tolerance in seconds ): if label_rate < 0: logger.info(f"{label_path} is sequence label. skipped") return with open(label_path) as f: lengths = [len(line.rstrip().split()) for line in f] assert len(lengths) == tot lengths = [lengths[i] for i in inds] num_invalid = 0 for i, ind in enumerate(inds): dur_from_audio = audio_sizes[i] / audio_rate dur_from_label = lengths[i] / label_rate if abs(dur_from_audio - dur_from_label) > tol: logger.warning( ( f"audio and label duration differ too much " f"(|{dur_from_audio} - {dur_from_label}| > {tol}) " f"in line {ind+1} of {label_path}. Check if `label_rate` " f"is correctly set (currently {label_rate}). " f"num. of samples = {audio_sizes[i]}; " f"label length = {lengths[i]}" ) ) num_invalid += 1 if num_invalid > 0: logger.warning( f"total {num_invalid} (audio, label) pairs with mismatched lengths" ) class HubertDataset(FairseqDataset): def __init__( self, manifest_path: str, sample_rate: float, label_paths: List[str], label_rates: Union[List[float], float], # -1 for sequence labels pad_list: List[str], eos_list: List[str], label_processors: Optional[List[Any]] = None, max_keep_sample_size: Optional[int] = None, min_keep_sample_size: Optional[int] = None, max_sample_size: Optional[int] = None, shuffle: bool = True, pad_audio: bool = False, normalize: bool = False, store_labels: bool = True, random_crop: bool = False, single_target: bool = False, tgt_dict: Optional[Dictionary] = None, add_decoder_target: bool = False, fine_tuning: bool = False, tgt_lang_idx: int = None, tokenizer = None, mbart_style_lang_id: bool = False, retry_times: int = 5, reduce_label_for_dec: bool = True, ): self.audio_root, self.audio_names, inds, tot, self.wav_sizes, self.chunk_names, self.chunk_indices = load_audio( manifest_path, max_keep_sample_size, min_keep_sample_size, retry_times ) self.sample_rate = sample_rate self.shuffle = shuffle self.random_crop = random_crop self.tgt_dict = tgt_dict self.add_decoder_target = add_decoder_target self.fine_tuning = fine_tuning self.num_labels = len(label_paths) self.pad_list = pad_list self.eos_list = eos_list self.label_processors = label_processors self.single_target = single_target self.epoch = 0 self.label_rates = ( [label_rates for _ in range(len(label_paths))] if isinstance(label_rates, int) else label_rates ) self.store_labels = store_labels if store_labels: self.label_list = [load_label(p, inds, tot, retry_times) for p in label_paths] else: self.label_paths = label_paths self.label_offsets_list = [ load_label_offset(p, inds, tot, retry_times) for p in label_paths ] assert label_processors is None or len(label_processors) == self.num_labels for label_path, label_rate in zip(label_paths, self.label_rates): verify_label_lengths( self.wav_sizes, sample_rate, label_path, label_rate, inds, tot ) self.max_sample_size = ( max_sample_size if max_sample_size is not None else sys.maxsize ) self.pad_audio = pad_audio self.normalize = normalize self.tgt_lang_idx = tgt_lang_idx self.tokenizer = tokenizer self.mbart_style_lang_id = mbart_style_lang_id self.retry_times = retry_times self.reduce_label_for_dec = reduce_label_for_dec logger.info( f"pad_audio={pad_audio}, random_crop={random_crop}, tgt_lang_idx={self.tgt_lang_idx}, reduce_label_for_dec={reduce_label_for_dec}, " f"mbart_style_lang_id={mbart_style_lang_id}, normalize={normalize}, max_sample_size={self.max_sample_size}" ) def set_epoch(self, epoch): self.epoch = epoch def batch_by_size(self, indices, max_tokens=None, max_sentences=None, required_batch_size_multiple=1): self.max_tokens = max_tokens self.max_sentences = max_sentences self.required_batch_size_multiple = required_batch_size_multiple if isinstance(indices[0], np.ndarray): batch_list = [] for indice in indices: batch = super(HubertDataset, self).batch_by_size(indice, max_tokens, max_sentences, required_batch_size_multiple) batch_list.append(batch) return batch_list else: return super(HubertDataset, self).batch_by_size(indices, max_tokens, max_sentences, required_batch_size_multiple) def shuffle_batches(self, batches, seed): if isinstance(batches[0], list): new_batches = [] with data_utils.numpy_seed(seed): np.random.shuffle(batches) for batch in batches: np.random.shuffle(batch) new_batches.extend(batch) return new_batches else: with data_utils.numpy_seed(seed): np.random.shuffle(batches) return batches def get_audio(self, index): import soundfile as sf wav_path = os.path.join(self.audio_root, self.audio_names[index]) _path, slice_ptr = parse_path(wav_path) if len(slice_ptr) == 1: import kaldiio feat = kaldiio.load_mat(wav_path) feat = torch.from_numpy(feat).float() if self.normalize: with torch.no_grad(): feat = F.layer_norm(feat, feat.shape[-1]) return feat else: if len(slice_ptr) == 2: byte_data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1]) assert is_sf_audio_data(byte_data) wav_path = io.BytesIO(byte_data) for i in range(self.retry_times): if i < self.retry_times - 1: try: wav, cur_sample_rate = sf.read(wav_path) break except Exception as e: logger.warn(f"Fail to load wav for the {i} time") logger.warn(e) time.sleep(1) continue else: wav, cur_sample_rate = sf.read(wav_path) wav = torch.from_numpy(wav).float() wav = self.postprocess(wav, cur_sample_rate) return wav def get_label(self, index, label_idx): if self.store_labels: label = self.label_list[label_idx][index] else: with open(self.label_paths[label_idx]) as f: offset_s, offset_e = self.label_offsets_list[label_idx][index] f.seek(offset_s) label = f.read(offset_e - offset_s) if self.tokenizer is not None and self.fine_tuning: label = self.tokenizer.encode(label) if self.label_processors is not None: label = self.label_processors[label_idx](label) return label def get_labels(self, index): return [self.get_label(index, i) for i in range(self.num_labels)] def __getitem__(self, index): wav = self.get_audio(index) labels = self.get_labels(index) return {"id": index, "source": wav, "label_list": labels} def __len__(self): return len(self.wav_sizes) def crop_to_max_size(self, wav, target_size): size = len(wav) diff = size - target_size if diff <= 0: return wav, 0 start, end = 0, target_size if self.random_crop: start = np.random.randint(0, diff + 1) end = size - diff + start return wav[start:end], start def collater(self, samples): # target = max(sizes) -> random_crop not used # target = max_sample_size -> random_crop used for long samples = [s for s in samples if s["source"] is not None] if len(samples) == 0: return {} audios = [s["source"] for s in samples] audio_sizes = [len(s) for s in audios] if self.pad_audio: audio_size = min(max(audio_sizes), self.max_sample_size) else: audio_size = min(min(audio_sizes), self.max_sample_size) feat_dim = audios[0].size(-1) if audios[0].dim() > 1 else 1 collated_audios, padding_mask, audio_starts = self.collater_audio( audios, audio_size, feat_dim, ) targets_by_label = [ [s["label_list"][i] for s in samples] for i in range(self.num_labels) ] targets_list, lengths_list, ntokens_list = self.collater_label( targets_by_label, audio_size, audio_starts ) if self.add_decoder_target: if self.fine_tuning: decoder_label = [ torch.cat((targets_list[0][i, :lengths_list[0][i]], torch.tensor([self.tgt_dict.eos()])), 0).long() for i in range(targets_list[0].size(0)) ] else: if self.tokenizer is not None: decoder_label = [ # Set 48 for translate int to char and avoid \n torch.cat( ( torch.tensor( self.tokenizer.sp.Encode( "".join( [chr(j + 48) for j in ( targets_list[0][i, :lengths_list[0][i]].unique_consecutive() if self.reduce_label_for_dec else targets_list[0][i, :lengths_list[0][i]] ).tolist()] ), out_type=int ) ), torch.tensor([self.tgt_dict.eos()]) ), dim=0 ).long() for i in range(targets_list[0].size(0)) ] else: decoder_label = [ torch.cat((targets_list[0][i, :lengths_list[0][i]].unique_consecutive() if self.reduce_label_for_dec else targets_list[0][i, :lengths_list[0][i]], torch.tensor([self.tgt_dict.eos()])), 0).long() for i in range(targets_list[0].size(0)) ] if self.mbart_style_lang_id: decoder_label = [ torch.cat((decoder_label[i], torch.tensor([self.tgt_lang_idx])), 0).long() for i in range(targets_list[0].size(0)) ] dec_ntokens = sum(x.size(0) for x in decoder_label) decoder_target = data_utils.collate_tokens( decoder_label, self.tgt_dict.pad(), self.tgt_dict.eos() if not self.mbart_style_lang_id else self.tgt_lang_idx, left_pad=False, move_eos_to_beginning=False, ) decoder_target_lengths = torch.tensor( [x.size(0) for x in decoder_label], dtype=torch.long ) prev_output_tokens = data_utils.collate_tokens( decoder_label, self.tgt_dict.pad(), self.tgt_dict.eos() if not self.mbart_style_lang_id else self.tgt_lang_idx, left_pad=False, move_eos_to_beginning=True, ) if self.tgt_lang_idx is not None and not self.mbart_style_lang_id: assert (prev_output_tokens[:, 0] != self.tgt_dict.eos()).sum() == 0 prev_output_tokens[:, 0] = self.tgt_lang_idx net_input = { "source": collated_audios, "padding_mask": padding_mask, "prev_output_tokens": prev_output_tokens, } batch = { "id": torch.LongTensor([s["id"] for s in samples]), "net_input": net_input, "decoder_target": decoder_target, "decoder_target_lengths": decoder_target_lengths, "dec_ntokens": dec_ntokens, "lang_idx": self.tgt_lang_idx, } else: net_input = {"source": collated_audios, "padding_mask": padding_mask} batch = { "id": torch.LongTensor([s["id"] for s in samples]), "net_input": net_input, } if self.single_target: batch["target_lengths"] = lengths_list[0] batch["ntokens"] = ntokens_list[0] batch["target"] = targets_list[0] else: batch["target_lengths_list"] = lengths_list batch["ntokens_list"] = ntokens_list batch["target_list"] = targets_list return batch def collater_audio(self, audios, audio_size, feat_dim=1): collated_audios = audios[0].new_zeros(len(audios), audio_size, feat_dim) padding_mask = ( torch.BoolTensor(collated_audios.shape[0:2]).fill_(False) # if self.pad_audio else None ) audio_starts = [0 for _ in audios] for i, audio in enumerate(audios): audio = audio.view(-1, feat_dim) diff = len(audio) - audio_size if diff == 0: collated_audios[i] = audio elif diff < 0: assert self.pad_audio collated_audios[i] = torch.cat([audio, audio.new_full((-diff, feat_dim), 0.0)]) padding_mask[i, diff:] = True else: collated_audios[i], audio_starts[i] = self.crop_to_max_size( audio, audio_size ) return collated_audios.squeeze(-1), padding_mask, audio_starts def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad): assert label_rate > 0 s2f = label_rate / self.sample_rate frm_starts = [int(round(s * s2f)) for s in audio_starts] frm_size = int(round(audio_size * s2f)) if not self.pad_audio: rem_size = [len(t) - s for t, s in zip(targets, frm_starts)] frm_size = min(frm_size, *rem_size) targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)] logger.debug(f"audio_starts={audio_starts}") logger.debug(f"frame_starts={frm_starts}") logger.debug(f"frame_size={frm_size}") lengths = torch.LongTensor([len(t) for t in targets]) ntokens = lengths.sum().item() targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False) return targets, lengths, ntokens def collater_seq_label(self, targets, pad): lengths = torch.LongTensor([len(t) for t in targets]) ntokens = lengths.sum().item() targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False) return targets, lengths, ntokens def collater_label(self, targets_by_label, audio_size, audio_starts): targets_list, lengths_list, ntokens_list = [], [], [] itr = zip(targets_by_label, self.label_rates, self.pad_list) for targets, label_rate, pad in itr: if label_rate == -1: targets, lengths, ntokens = self.collater_seq_label(targets, pad) else: targets, lengths, ntokens = self.collater_frm_label( targets, audio_size, audio_starts, label_rate, pad ) targets_list.append(targets) lengths_list.append(lengths) ntokens_list.append(ntokens) return targets_list, lengths_list, ntokens_list def num_tokens(self, index): return self.size(index) def size(self, index): if self.pad_audio: return self.wav_sizes[index] return min(self.wav_sizes[index], self.max_sample_size) @property def sizes(self): return np.array(self.wav_sizes) def ordered_indices(self): """Return an ordered list of indices. Batches will be constructed based on this order.""" if self.shuffle: if len(self.chunk_names) > 0: logger.info(f"ordered indices for epoch {self.epoch}") with data_utils.numpy_seed(self.epoch): self.chunk_order = np.random.permutation(len(self.chunk_names)) chunk_count = 0 tmp_sizes = [] tmp_indices = [] indice = [] for i in self.chunk_order: chunk_count += 1 start = self.chunk_indices[i] end = self.chunk_indices[i+1] if i < len(self.chunk_names) - 1 else len(self) size = list(self.sizes[start:end]) tmp_indices.extend(list(np.arange(start, end))) tmp_sizes.extend(size) if chunk_count % 10 == 0 or i == self.chunk_order[0]: order = [np.random.permutation(len(tmp_indices))] order.append( np.minimum( np.array(tmp_sizes), self.max_sample_size, ) ) sort_idx = np.lexsort(order)[::-1] indice.append(np.array([tmp_indices[k] for k in sort_idx])) tmp_indices = [] tmp_sizes =[] return indice else: order = [np.random.permutation(len(self))] order.append( np.minimum( np.array(self.sizes), self.max_sample_size, ) ) return np.lexsort(order)[::-1] else: return np.arange(len(self)) def postprocess(self, wav, cur_sample_rate): if wav.dim() == 2: wav = wav.mean(-1) assert wav.dim() == 1, wav.dim() if cur_sample_rate != self.sample_rate: raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}") if self.normalize: with torch.no_grad(): wav = F.layer_norm(wav, wav.shape) return wav