import argparse import os import sys import tempfile import librosa.display import numpy as np import os import torch import torchaudio import traceback from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.models.xtts import Xtts def clear_gpu_cache(): # clear the GPU cache if torch.cuda.is_available(): torch.cuda.empty_cache() XTTS_MODEL = None def load_model(xtts_checkpoint, xtts_config, xtts_vocab): global XTTS_MODEL clear_gpu_cache() if not xtts_checkpoint or not xtts_config or not xtts_vocab: return "You need to run the previous steps or manually set the `XTTS checkpoint path`, `XTTS config path`, and `XTTS vocab path` fields !!" config = XttsConfig() config.load_json(xtts_config) XTTS_MODEL = Xtts.init_from_config(config) print("Loading XTTS model! ") XTTS_MODEL.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=xtts_vocab, use_deepspeed=False) if torch.cuda.is_available(): XTTS_MODEL.cuda() print("Model Loaded!") return "Model Loaded!" def run_tts(lang, tts_text, speaker_audio_file): if XTTS_MODEL is None or not speaker_audio_file: return "You need to run the previous step to load the model !!", None, None gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs) out = XTTS_MODEL.inference( text=tts_text, language=lang, gpt_cond_latent=gpt_cond_latent, speaker_embedding=speaker_embedding, temperature=XTTS_MODEL.config.temperature, # Add custom parameters here length_penalty=XTTS_MODEL.config.length_penalty, repetition_penalty=XTTS_MODEL.config.repetition_penalty, top_k=XTTS_MODEL.config.top_k, top_p=XTTS_MODEL.config.top_p, ) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp: out["wav"] = torch.tensor(out["wav"]).unsqueeze(0) out_path = fp.name torchaudio.save(out_path, out["wav"], 24000) return "Speech generated !", out_path, speaker_audio_file # define a logger to redirect class Logger: def __init__(self, filename="log.out"): self.log_file = filename self.terminal = sys.stdout self.log = open(self.log_file, "w") def write(self, message): self.terminal.write(message) self.log.write(message) def flush(self): self.terminal.flush() self.log.flush() def isatty(self): return False # redirect stdout and stderr to a file sys.stdout = Logger() sys.stderr = sys.stdout # logging.basicConfig(stream=sys.stdout, level=logging.INFO) import logging logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[ logging.StreamHandler(sys.stdout) ] ) def read_logs(): sys.stdout.flush() with open(sys.stdout.log_file, "r") as f: return f.read() if __name__ == "__main__": parser = argparse.ArgumentParser( description="""XTTS fine-tuning demo\n\n""" """ Example runs: python3 TTS/demos/xtts_ft_demo/xtts_demo.py --port """, formatter_class=argparse.RawTextHelpFormatter, ) parser.add_argument( "--port", type=int, help="Port to run the gradio demo. Default: 5003", default=5003, ) parser.add_argument( "--out_path", type=str, help="Output path (where data and checkpoints will be saved) Default: /tmp/xtts_ft/", default="/tmp/xtts_ft/", ) parser.add_argument( "--num_epochs", type=int, help="Number of epochs to train. Default: 10", default=10, ) parser.add_argument( "--batch_size", type=int, help="Batch size. Default: 4", default=4, ) parser.add_argument( "--grad_acumm", type=int, help="Grad accumulation steps. Default: 1", default=1, ) parser.add_argument( "--max_audio_length", type=int, help="Max permitted audio size in seconds. Default: 11", default=11, ) # Add the new arguments parser.add_argument( "--lang", type=str, help="Dataset Language", default="en", ) parser.add_argument( "--train_csv", type=str, help="Path to the train CSV file", required=True, ) parser.add_argument( "--eval_csv", type=str, help="Path to the eval CSV file", required=True, ) args = parser.parse_args() # ... (rest of your code) def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length): clear_gpu_cache() if not train_csv or not eval_csv: return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", "" try: # convert seconds to waveform frames max_audio_length = int(max_audio_length * 22050) config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path, max_audio_length=max_audio_length) except: traceback.print_exc() error = traceback.format_exc() return f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}", "", "", "", "" # copy original files to avoid parameters changes issues os.system(f"cp {config_path} {exp_path}") os.system(f"cp {vocab_file} {exp_path}") ft_xtts_checkpoint = os.path.join(exp_path, "best_model.pth") print("Model training done!") clear_gpu_cache() return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint, speaker_wav # ... (rest of your code) # The following section is the only part to be changed: # It now directly calls the train_model function instead of using Gradio if __name__ == "__main__": # ... (argparse setup) # Call the function directly train_model( args.lang, args.train_csv, args.eval_csv, args.num_epochs, args.batch_size, args.grad_acumm, args.out_path, args.max_audio_length, )