|
|
|
|
|
|
|
import torch |
|
import numpy |
|
import random |
|
import pdb |
|
import os |
|
import threading |
|
import time |
|
import math |
|
import glob |
|
|
|
from scipy import signal |
|
import soundfile |
|
from torch.utils.data import Dataset, DataLoader |
|
import torch.distributed as dist |
|
|
|
def round_down(num, divisor): |
|
return num - (num%divisor) |
|
|
|
def worker_init_fn(worker_id): |
|
numpy.random.seed(numpy.random.get_state()[1][0] + worker_id) |
|
|
|
|
|
def loadWAV(filename, max_frames, evalmode=True, num_eval=5): |
|
|
|
|
|
max_audio = max_frames * 160 + 240 |
|
|
|
|
|
audio, sample_rate = soundfile.read(filename) |
|
|
|
|
|
audiosize = audio.shape[0] |
|
|
|
if audiosize <= max_audio: |
|
shortage = max_audio - audiosize + 1 |
|
audio = numpy.pad(audio, (0, shortage), 'wrap') |
|
audiosize = audio.shape[0] |
|
|
|
if evalmode: |
|
startframe = numpy.linspace(0,audiosize-max_audio,num=num_eval) |
|
else: |
|
startframe = numpy.array([numpy.int64(random.random()*(audiosize-max_audio))]) |
|
|
|
feats = [] |
|
if evalmode and max_frames == 0: |
|
feats.append(audio) |
|
else: |
|
for asf in startframe: |
|
feats.append(audio[int(asf):int(asf)+max_audio]) |
|
|
|
feat = numpy.stack(feats,axis=0).astype(float) |
|
|
|
return feat; |
|
|
|
class AugmentWAV(object): |
|
|
|
def __init__(self, musan_path, rir_path, max_frames): |
|
|
|
self.max_frames = max_frames |
|
self.max_audio = max_audio = max_frames * 160 + 240 |
|
|
|
self.noisetypes = ['noise','speech','music'] |
|
|
|
self.noisesnr = {'noise':[0,15],'speech':[13,20],'music':[5,15]} |
|
self.numnoise = {'noise':[1,1], 'speech':[3,8], 'music':[1,1] } |
|
self.noiselist = {} |
|
|
|
augment_files = glob.glob(os.path.join(musan_path,'*/*/*.wav')); |
|
|
|
for file in augment_files: |
|
if not file.split('/')[-3] in self.noiselist: |
|
self.noiselist[file.split('/')[-3]] = [] |
|
self.noiselist[file.split('/')[-3]].append(file) |
|
|
|
self.rir_files = glob.glob(os.path.join(rir_path,'*/*/*.wav')); |
|
|
|
def additive_noise(self, noisecat, audio): |
|
|
|
clean_db = 10 * numpy.log10(numpy.mean(audio ** 2)+1e-4) |
|
|
|
numnoise = self.numnoise[noisecat] |
|
noiselist = random.sample(self.noiselist[noisecat], random.randint(numnoise[0],numnoise[1])) |
|
|
|
noises = [] |
|
|
|
for noise in noiselist: |
|
|
|
noiseaudio = loadWAV(noise, self.max_frames, evalmode=False) |
|
noise_snr = random.uniform(self.noisesnr[noisecat][0],self.noisesnr[noisecat][1]) |
|
noise_db = 10 * numpy.log10(numpy.mean(noiseaudio[0] ** 2)+1e-4) |
|
noises.append(numpy.sqrt(10 ** ((clean_db - noise_db - noise_snr) / 10)) * noiseaudio) |
|
|
|
return numpy.sum(numpy.concatenate(noises,axis=0),axis=0,keepdims=True) + audio |
|
|
|
def reverberate(self, audio): |
|
|
|
rir_file = random.choice(self.rir_files) |
|
|
|
rir, fs = soundfile.read(rir_file) |
|
rir = numpy.expand_dims(rir.astype(float),0) |
|
rir = rir / numpy.sqrt(numpy.sum(rir**2)) |
|
|
|
return signal.convolve(audio, rir, mode='full')[:,:self.max_audio] |
|
|
|
|
|
class train_dataset_loader(Dataset): |
|
def __init__(self, train_list, augment, musan_path, rir_path, max_frames, train_path, **kwargs): |
|
|
|
self.augment_wav = AugmentWAV(musan_path=musan_path, rir_path=rir_path, max_frames = max_frames) |
|
|
|
self.train_list = train_list |
|
self.max_frames = max_frames; |
|
self.musan_path = musan_path |
|
self.rir_path = rir_path |
|
self.augment = augment |
|
|
|
|
|
with open(train_list) as dataset_file: |
|
lines = dataset_file.readlines(); |
|
|
|
|
|
dictkeys = list(set([x.split()[0] for x in lines])) |
|
dictkeys.sort() |
|
dictkeys = { key : ii for ii, key in enumerate(dictkeys) } |
|
|
|
|
|
self.data_list = [] |
|
self.data_label = [] |
|
|
|
for lidx, line in enumerate(lines): |
|
data = line.strip().split(); |
|
|
|
speaker_label = dictkeys[data[0]]; |
|
filename = os.path.join(train_path,data[1]); |
|
|
|
self.data_label.append(speaker_label) |
|
self.data_list.append(filename) |
|
|
|
|
|
def __getitem__(self, indices): |
|
|
|
feat_clean = [] |
|
feat = [] |
|
|
|
for index in indices: |
|
try: |
|
audio_clean = loadWAV(self.data_list[index], self.max_frames, evalmode=False) |
|
except: |
|
print(self.data_list[index]) |
|
|
|
if len(audio_clean.shape) == 3: |
|
print(self.data_list[index]) |
|
|
|
if self.augment: |
|
augtype = random.randint(0,5) |
|
if augtype == 0: |
|
audio = audio_clean |
|
elif augtype == 1: |
|
audio = self.augment_wav.reverberate(audio_clean) |
|
elif augtype == 2: |
|
audio = self.augment_wav.additive_noise('music',audio_clean) |
|
elif augtype == 3: |
|
audio = self.augment_wav.additive_noise('speech',audio_clean) |
|
elif augtype == 4: |
|
audio = self.augment_wav.additive_noise('noise',audio_clean) |
|
elif augtype == 5: |
|
audio = self.augment_wav.additive_noise('speech',audio_clean) |
|
audio = self.augment_wav.additive_noise('music',audio_clean) |
|
|
|
feat_clean.append(audio_clean) |
|
feat.append(audio) |
|
|
|
feat_clean = numpy.concatenate(feat_clean, axis=0) |
|
feat = numpy.concatenate(feat, axis=0) |
|
|
|
return torch.FloatTensor(feat_clean), torch.FloatTensor(feat), self.data_label[index], self.data_list[index] |
|
|
|
def __len__(self): |
|
return len(self.data_list) |
|
|
|
|
|
|
|
class test_dataset_loader(Dataset): |
|
def __init__(self, test_list, test_path, eval_frames, num_eval, **kwargs): |
|
self.max_frames = eval_frames; |
|
self.num_eval = num_eval |
|
self.test_path = test_path |
|
self.test_list = test_list |
|
|
|
def __getitem__(self, index): |
|
|
|
audio = loadWAV(os.path.join(self.test_path,self.test_list[index]), self.max_frames, evalmode=True, num_eval=self.num_eval) |
|
|
|
audio2 = loadWAV(os.path.join(self.test_path,self.test_list[index]), 0, evalmode=True, num_eval=self.num_eval) |
|
|
|
return torch.FloatTensor(audio), torch.FloatTensor(audio2), self.test_list[index] |
|
|
|
|
|
def __len__(self): |
|
return len(self.test_list) |
|
|
|
|
|
class train_dataset_sampler(torch.utils.data.Sampler): |
|
def __init__(self, data_source, nPerSpeaker, max_seg_per_spk, batch_size, distributed, seed, **kwargs): |
|
|
|
self.data_label = data_source.data_label; |
|
self.nPerSpeaker = nPerSpeaker; |
|
self.max_seg_per_spk = max_seg_per_spk; |
|
self.batch_size = batch_size; |
|
self.epoch = 0; |
|
self.seed = seed; |
|
self.distributed = distributed; |
|
|
|
def __iter__(self): |
|
|
|
g = torch.Generator() |
|
g.manual_seed(self.seed + self.epoch) |
|
indices = torch.randperm(len(self.data_label), generator=g).tolist() |
|
|
|
data_dict = {} |
|
|
|
|
|
for index in indices: |
|
speaker_label = self.data_label[index] |
|
if not (speaker_label in data_dict): |
|
data_dict[speaker_label] = []; |
|
data_dict[speaker_label].append(index); |
|
|
|
|
|
|
|
dictkeys = list(data_dict.keys()); |
|
dictkeys.sort() |
|
|
|
lol = lambda lst, sz: [lst[i:i+sz] for i in range(0, len(lst), sz)] |
|
|
|
flattened_list = [] |
|
flattened_label = [] |
|
|
|
for findex, key in enumerate(dictkeys): |
|
data = data_dict[key] |
|
numSeg = round_down(min(len(data),self.max_seg_per_spk),self.nPerSpeaker) |
|
|
|
rp = lol(numpy.arange(numSeg),self.nPerSpeaker) |
|
flattened_label.extend([findex] * (len(rp))) |
|
for indices in rp: |
|
flattened_list.append([data[i] for i in indices]) |
|
|
|
|
|
mixid = torch.randperm(len(flattened_label), generator=g).tolist() |
|
mixlabel = [] |
|
mixmap = [] |
|
|
|
|
|
for ii in mixid: |
|
startbatch = round_down(len(mixlabel), self.batch_size) |
|
if flattened_label[ii] not in mixlabel[startbatch:]: |
|
mixlabel.append(flattened_label[ii]) |
|
mixmap.append(ii) |
|
|
|
mixed_list = [flattened_list[i] for i in mixmap] |
|
|
|
|
|
if self.distributed: |
|
total_size = round_down(len(mixed_list), self.batch_size * dist.get_world_size()) |
|
start_index = int ( ( dist.get_rank() ) / dist.get_world_size() * total_size ) |
|
end_index = int ( ( dist.get_rank() + 1 ) / dist.get_world_size() * total_size ) |
|
self.num_samples = end_index - start_index |
|
return iter(mixed_list[start_index:end_index]) |
|
else: |
|
total_size = round_down(len(mixed_list), self.batch_size) |
|
self.num_samples = total_size |
|
return iter(mixed_list[:total_size]) |
|
|
|
|
|
def __len__(self) -> int: |
|
return self.num_samples |
|
|
|
def set_epoch(self, epoch: int) -> None: |
|
self.epoch = epoch |
|
|
|
|
|
if __name__ == '__main__': |
|
train_dataset = train_dataset_loader(train_list='/mnt/proj3/open-24-5/pengjy_new/WavLM_Adapter/CNCeleb_lst/CNCeleb_trainlist_200spk.txt', |
|
augment=False, |
|
musan_path='/mnt/proj3/open-24-5/pengjy_new/musan_split/', |
|
rir_path='/mnt/proj3/open-24-5/plchot/data_augment/16kHz/simulated_rirs/', |
|
max_frames=300, |
|
train_path='/mnt/proj3/open-24-5/pengjy_new/Data/CN-Celeb_flac/data', |
|
) |
|
|
|
train_sampler = train_dataset_sampler(train_dataset, nPerSpeaker=1, max_seg_per_spk=500, batch_size=100, distributed=False,seed=120) |
|
|
|
|
|
train_loader = torch.utils.data.DataLoader( |
|
train_dataset, |
|
batch_size=100, |
|
num_workers=10, |
|
sampler=train_sampler, |
|
pin_memory=True, |
|
drop_last=True, |
|
) |
|
for data, data_label in train_loader: |
|
print(data.shape) |
|
data = data.transpose(1,0) |
|
print(data.shape) |
|
quit() |