# import warnings # warnings.simplefilter(action='ignore', category=FutureWarning) # import itertools # import os # import time # import argparse # import json # import torch # import torch.nn.functional as F # from torch.utils.tensorboard import SummaryWriter # from torch.utils.data import DistributedSampler, DataLoader # import torch.multiprocessing as mp # from torch.distributed import init_process_group # from torch.nn.parallel import DistributedDataParallel # from env import AttrDict, build_env # from meldataset import MelDataset, mel_spectrogram, get_dataset_filelist # from models import Generator, MultiPeriodDiscriminator, feature_loss, generator_loss,\ # discriminator_loss, discriminator_TPRLS_loss, generator_TPRLS_loss, MultiScaleSubbandCQTDiscriminator # from utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint # from stft import TorchSTFT # from Utils.JDC.model import JDCNet # torch.backends.cudnn.benchmark = True # def train(rank, a, h): # if h.num_gpus > 1: # init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'], # world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank) # torch.cuda.manual_seed(h.seed) # device = torch.device('cuda:{:d}'.format(rank)) # F0_model = JDCNet(num_class=1, seq_len=192) # params = torch.load(h.F0_path)['net'] # F0_model.load_state_dict(params) # generator = Generator(h, F0_model, device=device).to(device) # mpd = MultiPeriodDiscriminator().to(device) # msd = MultiScaleSubbandCQTDiscriminator().to(device) # stft = TorchSTFT(filter_length=h.gen_istft_n_fft, hop_length=h.gen_istft_hop_size, win_length=h.gen_istft_n_fft).to(device) # if rank == 0: # print(generator) # os.makedirs(a.checkpoint_path, exist_ok=True) # print("checkpoints directory : ", a.checkpoint_path) # if os.path.isdir(a.checkpoint_path): # cp_g = scan_checkpoint(a.checkpoint_path, 'g_') # cp_do = scan_checkpoint(a.checkpoint_path, 'do_') # steps = 0 # if cp_g is None or cp_do is None: # state_dict_do = None # last_epoch = -1 # else: # state_dict_g = load_checkpoint(cp_g, device) # state_dict_do = load_checkpoint(cp_do, device) # generator.load_state_dict(state_dict_g['generator']) # mpd.load_state_dict(state_dict_do['mpd']) # msd.load_state_dict(state_dict_do['msd']) # steps = state_dict_do['steps'] + 1 # last_epoch = state_dict_do['epoch'] # if h.num_gpus > 1: # generator = DistributedDataParallel(generator, device_ids=[rank], find_unused_parameters=True).to(device) # mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device) # msd = DistributedDataParallel(msd, device_ids=[rank]).to(device) # optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2]) # optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()), # h.learning_rate, betas=[h.adam_b1, h.adam_b2]) # if state_dict_do is not None: # optim_g.load_state_dict(state_dict_do['optim_g']) # optim_d.load_state_dict(state_dict_do['optim_d']) # scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch) # scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch) # training_filelist, validation_filelist = get_dataset_filelist(a) # trainset = MelDataset(training_filelist, h.segment_size, h.n_fft, h.num_mels, # h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0, # shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device, # fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir) # train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None # train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False, # sampler=train_sampler, # batch_size=h.batch_size, # pin_memory=True, # drop_last=True) # if rank == 0: # validset = MelDataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels, # h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0, # fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning, # base_mels_path=a.input_mels_dir) # validation_loader = DataLoader(validset, num_workers=1, shuffle=False, # sampler=None, # batch_size=1, # pin_memory=True, # drop_last=True) # sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs')) # generator.train() # mpd.train() # msd.train() # for epoch in range(max(0, last_epoch), a.training_epochs): # if rank == 0: # start = time.time() # print("Epoch: {}".format(epoch+1)) # if h.num_gpus > 1: # train_sampler.set_epoch(epoch) # for i, batch in enumerate(train_loader): # if rank == 0: # start_b = time.time() # x, y, _, y_mel = batch # x = torch.autograd.Variable(x.to(device, non_blocking=True)) # y = torch.autograd.Variable(y.to(device, non_blocking=True)) # y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True)) # y = y.unsqueeze(1) # # y_g_hat = generator(x) # spec, phase = generator(x) # y_g_hat = stft.inverse(spec, phase) # y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, # h.fmin, h.fmax_for_loss) # optim_d.zero_grad() # # MPD # y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach()) # loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g) # loss_disc_f += discriminator_TPRLS_loss(y_df_hat_r, y_df_hat_g) # # MSD # y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach()) # loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g) # loss_disc_s += discriminator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g) # loss_disc_all = loss_disc_s + loss_disc_f # loss_disc_all.backward() # optim_d.step() # # Generator # optim_g.zero_grad() # # L1 Mel-Spectrogram Loss # loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45 # y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat) # y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat) # loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) # loss_fm_s = feature_loss(fmap_s_r, fmap_s_g) # loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g) # loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g) # loss_gen_f += generator_TPRLS_loss(y_df_hat_r, y_df_hat_g) # loss_gen_s += generator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g) # loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel # loss_gen_all.backward() # optim_g.step() # if rank == 0: # # STDOUT logging # if steps % a.stdout_interval == 0: # with torch.no_grad(): # mel_error = F.l1_loss(y_mel, y_g_hat_mel).item() # print('Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'. # format(steps, loss_gen_all, mel_error, time.time() - start_b)) # # checkpointing # if steps % a.checkpoint_interval == 0 and steps != 0: # checkpoint_path = "{}/g_{:08d}".format(a.checkpoint_path, steps) # save_checkpoint(checkpoint_path, # {'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()}) # checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps) # save_checkpoint(checkpoint_path, # {'mpd': (mpd.module if h.num_gpus > 1 # else mpd).state_dict(), # 'msd': (msd.module if h.num_gpus > 1 # else msd).state_dict(), # 'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps, # 'epoch': epoch}) # # Tensorboard summary logging # if steps % a.summary_interval == 0: # sw.add_scalar("training/gen_loss_total", loss_gen_all, steps) # sw.add_scalar("training/mel_spec_error", mel_error, steps) # # Validation # if steps % a.validation_interval == 0: # and steps != 0: # generator.eval() # torch.cuda.empty_cache() # val_err_tot = 0 # with torch.no_grad(): # for j, batch in enumerate(validation_loader): # x, y, _, y_mel = batch # # y_g_hat = generator(x.to(device)) # spec, phase = generator(x.to(device)) # y_g_hat = stft.inverse(spec, phase) # y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True)) # y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, # h.hop_size, h.win_size, # h.fmin, h.fmax_for_loss) # val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item() # if j <= 4: # if steps == 0: # sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate) # sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps) # sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate) # y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, # h.sampling_rate, h.hop_size, h.win_size, # h.fmin, h.fmax) # sw.add_figure('generated/y_hat_spec_{}'.format(j), # plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps) # val_err = val_err_tot / (j+1) # sw.add_scalar("validation/mel_spec_error", val_err, steps) # generator.train() # steps += 1 # scheduler_g.step() # scheduler_d.step() # if rank == 0: # print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start))) # def main(): # print('Initializing Training Process..') # parser = argparse.ArgumentParser() # parser.add_argument('--group_name', default=None) # parser.add_argument('--input_wavs_dir', default='') # parser.add_argument('--input_mels_dir', default='ft_dataset') # # parser.add_argument('--input_training_file', default='/home/ubuntu/RINGFORMER/LJSpeech-1.1/training.txt') # parser.add_argument('--input_training_file', default='/home/ubuntu/RINGFORMER/LJSpeech-1.1/validation.txt') # parser.add_argument('--input_validation_file', default='/home/ubuntu/RINGFORMER/LJSpeech-1.1/validation.txt') # parser.add_argument('--checkpoint_path', default='cp_ringformer_24') # parser.add_argument('--config', default='config_v1.json') # parser.add_argument('--training_epochs', default=3100, type=int) # parser.add_argument('--stdout_interval', default=5, type=int) # parser.add_argument('--checkpoint_interval', default=5000, type=int) # parser.add_argument('--summary_interval', default=100, type=int) # parser.add_argument('--validation_interval', default=1000, type=int) # parser.add_argument('--fine_tuning', default=False, type=bool) # a = parser.parse_args() # with open(a.config) as f: # data = f.read() # json_config = json.loads(data) # h = AttrDict(json_config) # build_env(a.config, 'config.json', a.checkpoint_path) # torch.manual_seed(h.seed) # if torch.cuda.is_available(): # torch.cuda.manual_seed(h.seed) # h.num_gpus = torch.cuda.device_count() # h.batch_size = int(h.batch_size / h.num_gpus) # print('Batch size per GPU :', h.batch_size) # else: # pass # if h.num_gpus > 1: # mp.spawn(train, nprocs=h.num_gpus, args=(a, h,)) # else: # train(0, a, h) # if __name__ == '__main__': # main() import warnings warnings.simplefilter(action='ignore', category=FutureWarning) import itertools import os import time import argparse import json import torch import torch.nn.functional as F from torch.utils.tensorboard import SummaryWriter from torch.utils.data import DistributedSampler, DataLoader import torch.multiprocessing as mp from torch.distributed import init_process_group from torch.nn.parallel import DistributedDataParallel from env import AttrDict, build_env from meldataset import MelDataset, mel_spectrogram, get_dataset_filelist from models import Generator, MultiPeriodDiscriminator, feature_loss, generator_loss,\ discriminator_loss, discriminator_TPRLS_loss, generator_TPRLS_loss, MultiScaleSubbandCQTDiscriminator from utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint from stft import TorchSTFT from Utils.JDC.model import JDCNet from accelerate import Accelerator from accelerate.utils import LoggerType from accelerate import DistributedDataParallelKwargs torch.backends.cudnn.benchmark = True def train(accelerator, a, h): # if h.num_gpus > 1: # init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'], # world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank) torch.cuda.manual_seed(h.seed) device = accelerator.device F0_model = JDCNet(num_class=1, seq_len=192) params = torch.load(h.F0_path)['model'] F0_model.load_state_dict(params) generator = Generator(h, F0_model).to(device) mpd = MultiPeriodDiscriminator().to(device) msd = MultiScaleSubbandCQTDiscriminator().to(device) stft = TorchSTFT(filter_length=h.gen_istft_n_fft, hop_length=h.gen_istft_hop_size, win_length=h.gen_istft_n_fft).to(device) with accelerator.main_process_first(): accelerator.print(generator) os.makedirs(a.checkpoint_path, exist_ok=True) accelerator.print("checkpoints directory : ", a.checkpoint_path) if os.path.isdir(a.checkpoint_path): cp_g = scan_checkpoint(a.checkpoint_path, 'g_') cp_do = scan_checkpoint(a.checkpoint_path, 'do_') steps = 0 if cp_g is None or cp_do is None: state_dict_do = None last_epoch = -1 else: state_dict_g = load_checkpoint(cp_g, device) state_dict_do = load_checkpoint(cp_do, device) generator.load_state_dict(state_dict_g['generator']) mpd.load_state_dict(state_dict_do['mpd']) msd.load_state_dict(state_dict_do['msd']) steps = state_dict_do['steps'] + 1 last_epoch = state_dict_do['epoch'] # if h.num_gpus > 1: generator = accelerator.prepare(generator).to(device) mpd = accelerator.prepare(mpd).to(device) msd = accelerator.prepare(msd).to(device) optim_g = accelerator.prepare(torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])) optim_d = accelerator.prepare(torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()), h.learning_rate, betas=[h.adam_b1, h.adam_b2])) if state_dict_do is not None: optim_g.load_state_dict(state_dict_do['optim_g']) optim_d.load_state_dict(state_dict_do['optim_d']) scheduler_g = accelerator.prepare(torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)) scheduler_d = accelerator.prepare(torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)) training_filelist, validation_filelist = get_dataset_filelist(a) trainset = MelDataset(training_filelist, h.segment_size, h.n_fft, h.num_mels, h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0, shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir) # train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None accelerator.print("bs", h.batch_size) train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False, # sampler=train_sampler, batch_size=h.batch_size, pin_memory=True, drop_last=True) train_loader = accelerator.prepare(train_loader) validset = MelDataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels, h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0, fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir) validation_loader = DataLoader(validset, num_workers=1, shuffle=False, sampler=None, batch_size=1, pin_memory=True, drop_last=True) sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs')) # validation_loader = accelerator.prepare(validation_loader) generator.train() mpd.train() msd.train() for epoch in range(max(0, last_epoch), a.training_epochs): with accelerator.main_process_first(): start = time.time() accelerator.print("Epoch: {}".format(epoch+1)) # if h.num_gpus > 1: # train_sampler.set_epoch(epoch) for i, batch in enumerate(train_loader): # with accelerator.main_process_first(): start_b = time.time() x, y, _, y_mel = batch x = torch.autograd.Variable(x.to(device, non_blocking=True)) y = torch.autograd.Variable(y.to(device, non_blocking=True)) y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True)) y = y.unsqueeze(1) # y_g_hat = generator(x) spec, phase = generator(x) y_g_hat = stft.inverse(spec, phase) y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax_for_loss) optim_d.zero_grad() # MPD y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach()) loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g) loss_disc_f += discriminator_TPRLS_loss(y_df_hat_r, y_df_hat_g) # MSD y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach()) loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g) loss_disc_s += discriminator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g) loss_disc_all = loss_disc_s + loss_disc_f accelerator.backward(loss_disc_all) # accelerator.clip_grad_norm_(mpd.parameters(), max_norm=10.0) # accelerator.clip_grad_norm_(msd.parameters(), max_norm=10.0) # loss_disc_all.backward() optim_d.step() # Generator optim_g.zero_grad() # L1 Mel-Spectrogram Loss loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45 y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat) y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat) loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) loss_fm_s = feature_loss(fmap_s_r, fmap_s_g) loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g) loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g) loss_gen_f += generator_TPRLS_loss(y_df_hat_r, y_df_hat_g) loss_gen_s += generator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g) loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel accelerator.backward(loss_gen_all) # accelerator.clip_grad_norm_(generator.parameters(), max_norm=10.0) # loss_gen_all.backward() optim_g.step() # accelerator.print('done +',i) # STDOUT logging if steps % a.stdout_interval == 0: with accelerator.main_process_first(): with torch.no_grad(): mel_error = F.l1_loss(y_mel, y_g_hat_mel).item() accelerator.print('Steps : {:d}, Gen Loss Total : {:4.3f}, Disc Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'. format(steps, loss_gen_all, loss_disc_all, mel_error, time.time() - start_b)) # checkpointing # In the main code: # if steps % a.checkpoint_interval == 0 and steps != 0: # # with accelerator.main_process_first(): # # Unwrap the generator # unwrapped_generator = accelerator.unwrap_model(generator) # checkpoint_path = "{}/g_{:08d}".format(a.checkpoint_path, steps) # save_checkpoint(checkpoint_path, # {'generator': unwrapped_generator.state_dict()}) # # Unwrap discriminators # unwrapped_mpd = accelerator.unwrap_model(mpd) # unwrapped_msd = accelerator.unwrap_model(msd) # checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps) # save_checkpoint(checkpoint_path, # {'mpd': unwrapped_mpd.state_dict(), # 'msd': unwrapped_msd.state_dict(), # 'optim_g': optim_g.state_dict(), # 'optim_d': optim_d.state_dict(), # 'steps': steps, # 'epoch': epoch}) if steps % a.checkpoint_interval == 0 and steps != 0: with accelerator.main_process_first(): checkpoint_path = "{}/g_{:08d}".format(a.checkpoint_path, steps) save_checkpoint(checkpoint_path, {'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()}) checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps) save_checkpoint(checkpoint_path, {'mpd': (mpd.module if h.num_gpus > 1 else mpd).state_dict(), 'msd': (msd.module if h.num_gpus > 1 else msd).state_dict(), 'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps, 'epoch': epoch}) # Tensorboard summary logging if steps % a.summary_interval == 0: sw.add_scalar("training/gen_loss_total", loss_gen_all, steps) sw.add_scalar("training/mel_spec_error", mel_error, steps) # Validation if steps % a.validation_interval == 0: # and steps != 0: # with accelerator.main_process_first(): accelerator.print('evalution...') generator.eval() torch.cuda.empty_cache() val_err_tot = 0 with torch.no_grad(): for j, batch in enumerate(validation_loader): x, y, _, y_mel = batch # y_g_hat = generator(x.to(device)) spec, phase = generator(x.to('cuda')) y_g_hat = stft.inverse(spec, phase) y_mel = torch.autograd.Variable(y_mel.to('cuda', non_blocking=True)) y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax_for_loss) accelerator.print('done: ',j) val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item() if j <= 4: if steps == 0: sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate) sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps) sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate) y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax) sw.add_figure('generated/y_hat_spec_{}'.format(j), plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps) val_err = val_err_tot / (j+1) sw.add_scalar("validation/mel_spec_error", val_err, steps) generator.train() steps += 1 scheduler_g.step() scheduler_d.step() # with accelerator.main_process_first(): accelerator.print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start))) def main(): ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) parser = argparse.ArgumentParser() parser.add_argument('--group_name', default=None) parser.add_argument('--input_wavs_dir', default='') parser.add_argument('--input_mels_dir', default='ft_dataset') # parser.add_argument('--input_training_file', default='/home/ubuntu/RINGFORMER/LJSpeech-1.1/training.txt') parser.add_argument('--input_training_file', default='/home/ubuntu/respair/audio_files_over_2sec_SUB.txt') parser.add_argument('--input_validation_file', default='/home/ubuntu/RINGFORMER/LJSpeech-1.1/validation.txt') parser.add_argument('--checkpoint_path', default='cp_ringformer_44.1khz_NUM2') parser.add_argument('--config', default='config_v1.json') parser.add_argument('--training_epochs', default=3100, type=int) parser.add_argument('--stdout_interval', default=10, type=int) parser.add_argument('--checkpoint_interval', default=1000, type=int) parser.add_argument('--summary_interval', default=100, type=int) parser.add_argument('--validation_interval', default=1000, type=int) parser.add_argument('--fine_tuning', default=False, type=bool) a = parser.parse_args() accelerator = Accelerator(project_dir=a.checkpoint_path, split_batches=False, kwargs_handlers=[ddp_kwargs]) accelerator.print('Initializing Training Process..') with open(a.config) as f: data = f.read() json_config = json.loads(data) h = AttrDict(json_config) build_env(a.config, 'config.json', a.checkpoint_path) torch.manual_seed(h.seed) if torch.cuda.is_available(): torch.cuda.manual_seed(h.seed) h.num_gpus = torch.cuda.device_count() # h.batch_size = int(h.batch_size / h.num_gpus) # print('Batch size per GPU :', h.batch_size) else: pass # Pass accelerator to train function instead of rank train(accelerator, a, h) if __name__ == '__main__': main()