|
|
|
|
|
|
|
import sys, time, os, argparse, socket |
|
import yaml |
|
import numpy |
|
import pdb |
|
import torch |
|
import glob |
|
import zipfile |
|
import warnings |
|
import datetime |
|
from tuneThreshold import * |
|
from SpeakerNet import * |
|
from DatasetLoader import * |
|
import torch.distributed as dist |
|
import torch.multiprocessing as mp |
|
from scipy.stats import norm |
|
from sklearn.mixture import GaussianMixture |
|
|
|
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description = "SpeakerNet"); |
|
|
|
parser.add_argument('--config', type=str, default=None, help='Config YAML file'); |
|
|
|
|
|
parser.add_argument('--max_frames', type=int, default=200, help='Input length to the network for training'); |
|
parser.add_argument('--eval_frames', type=int, default=300, help='Input length to the network for testing; 0 uses the whole files'); |
|
parser.add_argument('--batch_size', type=int, default=400, help='Batch size, number of speakers per batch'); |
|
parser.add_argument('--max_seg_per_spk', type=int, default=500, help='Maximum number of utterances per speaker per epoch'); |
|
parser.add_argument('--nDataLoaderThread', type=int, default=10, help='Number of loader threads'); |
|
parser.add_argument('--augment', type=bool, default=True, help='Augment input') |
|
parser.add_argument('--seed', type=int, default=20211202, help='Seed for the random number generator'); |
|
|
|
|
|
|
|
|
|
parser.add_argument('--test_interval', type=int, default=1, help='Test and save every [test_interval] epochs'); |
|
parser.add_argument('--max_epoch', type=int, default=50, help='Maximum number of epochs'); |
|
parser.add_argument('--trainfunc', type=str, default="aamsoftmax", help='Loss function'); |
|
|
|
|
|
parser.add_argument('--optimizer', type=str, default="adamw", help='sgd or adam'); |
|
parser.add_argument('--scheduler', type=str, default="steplr", help='Learning rate scheduler'); |
|
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate'); |
|
parser.add_argument("--lr_decay", type=float, default=0.9, help='Learning rate decay every [test_interval] epochs'); |
|
|
|
|
|
|
|
parser.add_argument('--pretrained_model_path', type=str, default="None", help='Absolute path to the pre-trained model'); |
|
parser.add_argument('--weight_finetuning_reg', type=float, default=0.001, help='L2 regularization towards the initial pre-trained model'); |
|
parser.add_argument('--LLRD_factor', type=float, default=1.0, help='Layer-wise Learning Rate Decay (LLRD) factor'); |
|
parser.add_argument('--LR_Transformer', type=float, default=2e-5, help='Learning rate of pre-trained model'); |
|
parser.add_argument('--LR_MHFA', type=float, default=5e-3, help='Learning rate of back-end attentive pooling model'); |
|
|
|
|
|
parser.add_argument("--hard_prob", type=float, default=0.5, help='Hard negative mining probability, otherwise random, only for some loss functions'); |
|
parser.add_argument("--hard_rank", type=int, default=10, help='Hard negative mining rank in the batch, only for some loss functions'); |
|
parser.add_argument('--margin', type=float, default=0.2, help='Loss margin, only for some loss functions'); |
|
parser.add_argument('--scale', type=float, default=30, help='Loss scale, only for some loss functions'); |
|
parser.add_argument('--nPerSpeaker', type=int, default=1, help='Number of utterances per speaker per batch, only for metric learning based losses'); |
|
parser.add_argument('--nClasses', type=int, default=5994, help='Number of speakers in the softmax layer, only for softmax-based losses'); |
|
|
|
|
|
parser.add_argument('--dcf_p_target', type=float, default=0.05, help='A priori probability of the specified target speaker'); |
|
parser.add_argument('--dcf_c_miss', type=float, default=1, help='Cost of a missed detection'); |
|
parser.add_argument('--dcf_c_fa', type=float, default=1, help='Cost of a spurious detection'); |
|
|
|
|
|
parser.add_argument('--initial_model', type=str, default="", help='Initial model weights'); |
|
parser.add_argument('--save_path', type=str, default="exps/exp1", help='Path for model and logs'); |
|
|
|
|
|
parser.add_argument('--train_list', type=str, default="data/train_list.txt", help='Train list'); |
|
parser.add_argument('--test_list', type=str, default="data/test_list.txt", help='Evaluation list'); |
|
parser.add_argument('--train_path', type=str, default="data/voxceleb2", help='Absolute path to the train set'); |
|
parser.add_argument('--test_path', type=str, default="data/voxceleb1", help='Absolute path to the test set'); |
|
parser.add_argument('--musan_path', type=str, default="data/musan_split", help='Absolute path to the test set'); |
|
parser.add_argument('--rir_path', type=str, default="data/simulated_rirs", help='Absolute path to the test set'); |
|
|
|
|
|
parser.add_argument('--n_mels', type=int, default=80, help='Number of mel filterbanks'); |
|
parser.add_argument('--log_input', type=bool, default=False, help='Log input features') |
|
parser.add_argument('--model', type=str, default="", help='Name of model definition'); |
|
parser.add_argument('--encoder_type', type=str, default="SAP", help='Type of encoder'); |
|
parser.add_argument('--nOut', type=int, default=192, help='Embedding size in the last FC layer'); |
|
|
|
|
|
parser.add_argument('--eval', dest='eval', action='store_true', help='Eval only') |
|
|
|
|
|
parser.add_argument('--port', type=str, default="7888", help='Port for distributed training, input as text'); |
|
parser.add_argument('--distributed', dest='distributed', action='store_true', help='Enable distributed training') |
|
parser.add_argument('--mixedprec', dest='mixedprec', action='store_true', help='Enable mixed precision training') |
|
|
|
args = parser.parse_args(); |
|
|
|
|
|
def find_option_type(key, parser): |
|
for opt in parser._get_optional_actions(): |
|
if ('--' + key) in opt.option_strings: |
|
return opt.type |
|
raise ValueError |
|
|
|
if args.config is not None: |
|
with open(args.config, "r") as f: |
|
yml_config = yaml.load(f, Loader=yaml.FullLoader) |
|
for k, v in yml_config.items(): |
|
if k in args.__dict__: |
|
typ = find_option_type(k, parser) |
|
args.__dict__[k] = typ(v) |
|
else: |
|
sys.stderr.write("Ignored unknown parameter {} in yaml.\n".format(k)) |
|
|
|
|
|
|
|
try: |
|
import nsml |
|
from nsml import HAS_DATASET, DATASET_PATH, PARALLEL_WORLD, PARALLEL_PORTS, MY_RANK |
|
from nsml import NSML_NFS_OUTPUT, SESSION_NAME |
|
except: |
|
pass; |
|
|
|
warnings.simplefilter("ignore") |
|
|
|
|
|
|
|
|
|
|
|
def LGL_threshold_update_gmm(loss_vals_path): |
|
with open(loss_vals_path, 'r') as f: |
|
lines = [line.strip().split() for line in f.readlines()] |
|
|
|
|
|
losses = [] |
|
errs = 0 |
|
for line in lines: |
|
try: |
|
losses.append(float(line[0])) |
|
except ValueError: |
|
errs += 1 |
|
pass |
|
if errs > 0: |
|
print('Could not read %d lines' % errs) |
|
|
|
log_losses = np.log(losses) |
|
|
|
gmm = GaussianMixture(n_components=2, random_state=0, covariance_type='full', tol=0.00001, max_iter=1000) |
|
gmm.fit(log_losses.reshape(-1, 1)) |
|
|
|
mean1 = gmm.means_[0, 0] |
|
covar1 = gmm.covariances_[0, 0] |
|
weight1 = gmm.weights_[0] |
|
x = np.linspace(min(log_losses), max(log_losses), 1000) |
|
g1 = weight1 * norm.pdf(x, mean1, np.sqrt(covar1)) |
|
|
|
mean2 = gmm.means_[1, 0] |
|
covar2 = gmm.covariances_[1, 0] |
|
weight2 = gmm.weights_[1] |
|
g2 = weight2 * norm.pdf(x, mean2, np.sqrt(covar2)) |
|
|
|
intersection = np.argwhere(np.diff(np.sign(g1 - g2))).flatten() |
|
|
|
max1 = x[np.argmax(g1)] |
|
max2 = x[np.argmax(g2)] |
|
good_intersection = x[intersection][(x[intersection] > min(max1, max2)) & (x[intersection] < max(max1, max2))] |
|
assert len(good_intersection) == 1, 'Wrong number of intersections' |
|
good_intersection = good_intersection[0] |
|
|
|
return good_intersection |
|
|
|
import idr_torch |
|
|
|
def main_worker(gpu, ngpus_per_node, args): |
|
|
|
args.gpu = gpu |
|
|
|
args.gpu = idr_torch.rank |
|
ngpus_per_node = idr_torch.size |
|
|
|
|
|
s = SpeakerNet(**vars(args)); |
|
|
|
if args.distributed: |
|
|
|
|
|
|
|
|
|
dist.init_process_group(backend='nccl', world_size=ngpus_per_node, rank=args.gpu) |
|
|
|
torch.cuda.set_device(args.gpu) |
|
s.cuda(args.gpu) |
|
|
|
s = torch.nn.parallel.DistributedDataParallel(s, device_ids=[args.gpu]) |
|
|
|
print('Loaded the model on GPU {:d}'.format(args.gpu)) |
|
|
|
else: |
|
s = WrappedModel(s).cuda(args.gpu) |
|
|
|
it = 1 |
|
eers = [100]; |
|
|
|
if args.gpu == 0: |
|
|
|
scorefile = open(args.result_save_path+"/scores.txt", "a+"); |
|
|
|
|
|
train_dataset = train_dataset_loader(**vars(args)) |
|
|
|
train_sampler = train_dataset_sampler(train_dataset, **vars(args)) |
|
|
|
train_loader = torch.utils.data.DataLoader( |
|
train_dataset, |
|
batch_size=args.batch_size, |
|
num_workers=args.nDataLoaderThread, |
|
sampler=train_sampler, |
|
pin_memory=True, |
|
worker_init_fn=worker_init_fn, |
|
drop_last=True, |
|
) |
|
|
|
|
|
trainer = ModelTrainer(s, **vars(args)) |
|
|
|
|
|
modelfiles = glob.glob('%s/model0*.model'%args.model_save_path) |
|
modelfiles.sort() |
|
|
|
if(args.initial_model != ""): |
|
trainer.loadParameters(args.initial_model); |
|
print("Model {} loaded!".format(args.initial_model)); |
|
elif len(modelfiles) >= 1: |
|
print("Model {} loaded from previous state!".format(modelfiles[-1])); |
|
trainer.loadParameters(modelfiles[-1]); |
|
it = int(os.path.splitext(os.path.basename(modelfiles[-1]))[0][5:]) + 1 |
|
|
|
for ii in range(1,it): |
|
trainer.__scheduler__.step() |
|
|
|
|
|
pytorch_total_params = sum(p.numel() for p in s.module.__S__.parameters()) |
|
|
|
print('Total parameters: ',pytorch_total_params) |
|
|
|
if args.eval == True: |
|
|
|
|
|
print('Test list',args.test_list) |
|
|
|
sc, lab, _, sc1,sc2 = trainer.evaluateFromList(**vars(args)) |
|
|
|
if args.gpu == 0: |
|
|
|
result = tuneThresholdfromScore(sc, lab, [1, 0.1]); |
|
result_s1 = tuneThresholdfromScore(sc1, lab, [1, 0.1]); |
|
result_s2 = tuneThresholdfromScore(sc2, lab, [1, 0.1]); |
|
|
|
|
|
|
|
fnrs, fprs, thresholds = ComputeErrorRates(sc, lab) |
|
mindcf, threshold = ComputeMinDcf(fnrs, fprs, thresholds, args.dcf_p_target, args.dcf_c_miss, args.dcf_c_fa) |
|
|
|
print('\n',time.strftime("%Y-%m-%d %H:%M:%S"), "VEER {:2.4f}".format(result[1]), "VEER_s1 {:2.4f}".format(result_s1[1]),"VEER_s2 {:2.4f}".format(result_s2[1]),"MinDCF {:2.5f}".format(mindcf)); |
|
|
|
if ("nsml" in sys.modules) and args.gpu == 0: |
|
training_report = {}; |
|
training_report["summary"] = True; |
|
training_report["epoch"] = it; |
|
training_report["step"] = it; |
|
training_report["val_eer"] = result[1]; |
|
training_report["val_dcf"] = mindcf; |
|
|
|
nsml.report(**training_report); |
|
|
|
return |
|
|
|
|
|
if args.gpu == 0: |
|
pyfiles = glob.glob('./*.py') |
|
strtime = datetime.datetime.now().strftime("%Y%m%d%H%M%S") |
|
|
|
zipf = zipfile.ZipFile(args.result_save_path+ '/run%s.zip'%strtime, 'w', zipfile.ZIP_DEFLATED) |
|
for file in pyfiles: |
|
zipf.write(file) |
|
zipf.close() |
|
|
|
with open(args.result_save_path + '/run%s.cmd'%strtime, 'w') as f: |
|
f.write('%s'%args) |
|
|
|
|
|
|
|
for it in range(it,args.max_epoch+1): |
|
|
|
train_sampler.set_epoch(it) |
|
|
|
clr = [x['lr'] for x in trainer.__optimizer__.param_groups] |
|
|
|
loss_vals_dir = 'exp/' + args.save_path.split('/')[-1] + '/loss_vals' |
|
os.makedirs(loss_vals_dir, exist_ok=True) |
|
loss_vals_path = os.path.join(loss_vals_dir, 'epoch_%d.txt' % it) |
|
|
|
if it >= 5: |
|
prev_loss_vals_path = os.path.join(loss_vals_dir, 'epoch_%d.txt' % (it - 1)) |
|
LGL_threshold = LGL_threshold_update_gmm(prev_loss_vals_path) |
|
|
|
|
|
if args.gpu == 0: |
|
if LGL_threshold is not None: |
|
print('Updated LGL threshold to %f' % LGL_threshold) |
|
else: |
|
print('Wrong number of intersections, keeping LGL threshold at %f' % LGL_threshold) |
|
|
|
trainer.update_lgl_threshold(LGL_threshold) |
|
|
|
|
|
loss, traineer = trainer.train_network(train_loader, loss_vals_path, it, verbose=(args.gpu == 0)) |
|
|
|
if args.distributed: |
|
dist.barrier() |
|
with open(loss_vals_path, 'w') as final_file: |
|
for r in range(dist.get_world_size()): |
|
part_file_path = f"{loss_vals_path.split('.')[0]}_rank{r}.txt" |
|
with open(part_file_path, 'r') as part_file: |
|
final_file.write(part_file.read()) |
|
|
|
if args.gpu == 0: |
|
print('\n',time.strftime("%Y-%m-%d %H:%M:%S"), "Epoch {:d}, TEER/TAcc {:2.2f}, TLOSS {:f}, LR {:f}".format(it, traineer.item(), loss.item(), max(clr))); |
|
scorefile.write("Epoch {:d}, TEER/TAcc {:2.2f}, TLOSS {:f}, LR {:f} \n".format(it, traineer.item(), loss.item(), max(clr))); |
|
|
|
if it % args.test_interval == 0: |
|
|
|
|
|
|
|
if args.gpu == 0: |
|
trainer.saveParameters(args.model_save_path+"/model%09d.model"%it); |
|
|
|
scorefile.flush() |
|
|
|
if ("nsml" in sys.modules) and args.gpu == 0: |
|
training_report = {}; |
|
training_report["summary"] = True; |
|
training_report["epoch"] = it; |
|
training_report["step"] = it; |
|
training_report["train_loss"] = loss; |
|
training_report["min_eer"] = min(eers); |
|
|
|
nsml.report(**training_report); |
|
|
|
if args.gpu == 0: |
|
scorefile.close(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
if ("nsml" in sys.modules) and not args.eval: |
|
args.save_path = os.path.join(args.save_path,SESSION_NAME.replace('/','_')) |
|
|
|
args.model_save_path = args.save_path+"/model" |
|
args.result_save_path = args.save_path+"/result" |
|
args.feat_save_path = "" |
|
|
|
os.makedirs(args.model_save_path, exist_ok=True) |
|
os.makedirs(args.result_save_path, exist_ok=True) |
|
|
|
n_gpus = torch.cuda.device_count() |
|
print(n_gpus) |
|
|
|
print('Python Version:', sys.version) |
|
print('PyTorch Version:', torch.__version__) |
|
print('Number of GPUs:', torch.cuda.device_count()) |
|
print('Save path:',args.save_path) |
|
|
|
if args.distributed: |
|
|
|
main_worker(None, None, args) |
|
else: |
|
main_worker(0, None, args) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|