xtts / tain.py
IAsistemofinteres's picture
Upload 2 files
001cd36 verified
import argparse
import os
import sys
import tempfile
import librosa.display
import numpy as np
import os
import torch
import torchaudio
import traceback
from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list
from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
def clear_gpu_cache():
# clear the GPU cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
XTTS_MODEL = None
def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
global XTTS_MODEL
clear_gpu_cache()
if not xtts_checkpoint or not xtts_config or not xtts_vocab:
return "You need to run the previous steps or manually set the `XTTS checkpoint path`, `XTTS config path`, and `XTTS vocab path` fields !!"
config = XttsConfig()
config.load_json(xtts_config)
XTTS_MODEL = Xtts.init_from_config(config)
print("Loading XTTS model! ")
XTTS_MODEL.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=xtts_vocab, use_deepspeed=False)
if torch.cuda.is_available():
XTTS_MODEL.cuda()
print("Model Loaded!")
return "Model Loaded!"
def run_tts(lang, tts_text, speaker_audio_file):
if XTTS_MODEL is None or not speaker_audio_file:
return "You need to run the previous step to load the model !!", None, None
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs)
out = XTTS_MODEL.inference(
text=tts_text,
language=lang,
gpt_cond_latent=gpt_cond_latent,
speaker_embedding=speaker_embedding,
temperature=XTTS_MODEL.config.temperature, # Add custom parameters here
length_penalty=XTTS_MODEL.config.length_penalty,
repetition_penalty=XTTS_MODEL.config.repetition_penalty,
top_k=XTTS_MODEL.config.top_k,
top_p=XTTS_MODEL.config.top_p,
)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
out["wav"] = torch.tensor(out["wav"]).unsqueeze(0)
out_path = fp.name
torchaudio.save(out_path, out["wav"], 24000)
return "Speech generated !", out_path, speaker_audio_file
# define a logger to redirect
class Logger:
def __init__(self, filename="log.out"):
self.log_file = filename
self.terminal = sys.stdout
self.log = open(self.log_file, "w")
def write(self, message):
self.terminal.write(message)
self.log.write(message)
def flush(self):
self.terminal.flush()
self.log.flush()
def isatty(self):
return False
# redirect stdout and stderr to a file
sys.stdout = Logger()
sys.stderr = sys.stdout
# logging.basicConfig(stream=sys.stdout, level=logging.INFO)
import logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[
logging.StreamHandler(sys.stdout)
]
)
def read_logs():
sys.stdout.flush()
with open(sys.stdout.log_file, "r") as f:
return f.read()
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="""XTTS fine-tuning demo\n\n"""
"""
Example runs:
python3 TTS/demos/xtts_ft_demo/xtts_demo.py --port
""",
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
"--port",
type=int,
help="Port to run the gradio demo. Default: 5003",
default=5003,
)
parser.add_argument(
"--out_path",
type=str,
help="Output path (where data and checkpoints will be saved) Default: /tmp/xtts_ft/",
default="/tmp/xtts_ft/",
)
parser.add_argument(
"--num_epochs",
type=int,
help="Number of epochs to train. Default: 10",
default=10,
)
parser.add_argument(
"--batch_size",
type=int,
help="Batch size. Default: 4",
default=4,
)
parser.add_argument(
"--grad_acumm",
type=int,
help="Grad accumulation steps. Default: 1",
default=1,
)
parser.add_argument(
"--max_audio_length",
type=int,
help="Max permitted audio size in seconds. Default: 11",
default=11,
)
# Add the new arguments
parser.add_argument(
"--lang",
type=str,
help="Dataset Language",
default="en",
)
parser.add_argument(
"--train_csv",
type=str,
help="Path to the train CSV file",
required=True,
)
parser.add_argument(
"--eval_csv",
type=str,
help="Path to the eval CSV file",
required=True,
)
args = parser.parse_args()
# ... (rest of your code)
def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length):
clear_gpu_cache()
if not train_csv or not eval_csv:
return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", ""
try:
# convert seconds to waveform frames
max_audio_length = int(max_audio_length * 22050)
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path, max_audio_length=max_audio_length)
except:
traceback.print_exc()
error = traceback.format_exc()
return f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}", "", "", "", ""
# copy original files to avoid parameters changes issues
os.system(f"cp {config_path} {exp_path}")
os.system(f"cp {vocab_file} {exp_path}")
ft_xtts_checkpoint = os.path.join(exp_path, "best_model.pth")
print("Model training done!")
clear_gpu_cache()
return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint, speaker_wav
# ... (rest of your code)
# The following section is the only part to be changed:
# It now directly calls the train_model function instead of using Gradio
if __name__ == "__main__":
# ... (argparse setup)
# Call the function directly
train_model(
args.lang,
args.train_csv,
args.eval_csv,
args.num_epochs,
args.batch_size,
args.grad_acumm,
args.out_path,
args.max_audio_length,
)