IAsistemofinteres commited on
Commit
001cd36
·
verified ·
1 Parent(s): 44b799e

Upload 2 files

Browse files
Files changed (2) hide show
  1. dataset.py +66 -0
  2. tain.py +220 -0
dataset.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ import tempfile
5
+
6
+ import librosa.display
7
+ import numpy as np
8
+
9
+ import os
10
+ import torch
11
+ import torchaudio
12
+ import traceback
13
+ from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list
14
+ from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt
15
+
16
+ from TTS.tts.configs.xtts_config import XttsConfig
17
+ from TTS.tts.models.xtts import Xtts
18
+
19
+
20
+ def clear_gpu_cache():
21
+ # clear the GPU cache
22
+ if torch.cuda.is_available():
23
+ torch.cuda.empty_cache()
24
+
25
+
26
+ def preprocess_dataset(audio_path, language, out_path):
27
+ """
28
+ Prepara los datos de audio para el entrenamiento del modelo.
29
+
30
+ Args:
31
+ audio_path (list): Lista de rutas de los archivos de audio.
32
+ language (str): Código del idioma del dataset.
33
+ out_path (str): Ruta de salida para el dataset procesado.
34
+
35
+ Returns:
36
+ tuple: Tupla con las rutas de los archivos CSV de entrenamiento y evaluación.
37
+ """
38
+ out_path = os.path.join(out_path, "dataset")
39
+ os.makedirs(out_path, exist_ok=True)
40
+ train_meta, eval_meta, _ = format_audio_list(audio_path, target_language=language, out_path=out_path)
41
+ train_csv = os.path.join(out_path, "train.csv")
42
+ eval_csv = os.path.join(out_path, "eval.csv")
43
+ return train_csv, eval_csv
44
+
45
+ def main(dataset_path, output_path, language):
46
+ # Obtener información del usuario
47
+ audio_path = dataset_path #input("Ingresa la ruta de los archivos de audio (separados por espacio): ")
48
+ language = language #input("Ingresa el idioma del dataset: ")
49
+ out_path = output_path #input("Ingresa la ruta de salida para el dataset procesado: ")
50
+
51
+ # Prepara los datos
52
+ train_csv, eval_csv = preprocess_dataset(audio_path.split(), language, out_path)
53
+
54
+ print(f"Los archivos CSV se han creado en: {out_path}")
55
+ print(f"train.csv: {train_csv}")
56
+ print(f"eval.csv: {eval_csv}")
57
+
58
+
59
+ if __name__ == "__main__":
60
+ parser = argparse.ArgumentParser()
61
+ parser.add_argument("--dataset_path", type=str, required=True, help="Ruta del dataset de audio")
62
+ parser.add_argument("--output_path", type=str, required=True, help="Ruta de salida para el dataset procesado")
63
+ parser.add_argument("--language", type=str, required=True, help="Idioma del dataset")
64
+ args = parser.parse_args()
65
+
66
+ main(args.dataset_path, args.output_path, args.language)
tain.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ import tempfile
5
+
6
+ import librosa.display
7
+ import numpy as np
8
+
9
+ import os
10
+ import torch
11
+ import torchaudio
12
+ import traceback
13
+ from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list
14
+ from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt
15
+
16
+ from TTS.tts.configs.xtts_config import XttsConfig
17
+ from TTS.tts.models.xtts import Xtts
18
+
19
+
20
+ def clear_gpu_cache():
21
+ # clear the GPU cache
22
+ if torch.cuda.is_available():
23
+ torch.cuda.empty_cache()
24
+
25
+ XTTS_MODEL = None
26
+ def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
27
+ global XTTS_MODEL
28
+ clear_gpu_cache()
29
+ if not xtts_checkpoint or not xtts_config or not xtts_vocab:
30
+ return "You need to run the previous steps or manually set the `XTTS checkpoint path`, `XTTS config path`, and `XTTS vocab path` fields !!"
31
+ config = XttsConfig()
32
+ config.load_json(xtts_config)
33
+ XTTS_MODEL = Xtts.init_from_config(config)
34
+ print("Loading XTTS model! ")
35
+ XTTS_MODEL.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=xtts_vocab, use_deepspeed=False)
36
+ if torch.cuda.is_available():
37
+ XTTS_MODEL.cuda()
38
+
39
+ print("Model Loaded!")
40
+ return "Model Loaded!"
41
+
42
+ def run_tts(lang, tts_text, speaker_audio_file):
43
+ if XTTS_MODEL is None or not speaker_audio_file:
44
+ return "You need to run the previous step to load the model !!", None, None
45
+
46
+ gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs)
47
+ out = XTTS_MODEL.inference(
48
+ text=tts_text,
49
+ language=lang,
50
+ gpt_cond_latent=gpt_cond_latent,
51
+ speaker_embedding=speaker_embedding,
52
+ temperature=XTTS_MODEL.config.temperature, # Add custom parameters here
53
+ length_penalty=XTTS_MODEL.config.length_penalty,
54
+ repetition_penalty=XTTS_MODEL.config.repetition_penalty,
55
+ top_k=XTTS_MODEL.config.top_k,
56
+ top_p=XTTS_MODEL.config.top_p,
57
+ )
58
+
59
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
60
+ out["wav"] = torch.tensor(out["wav"]).unsqueeze(0)
61
+ out_path = fp.name
62
+ torchaudio.save(out_path, out["wav"], 24000)
63
+
64
+ return "Speech generated !", out_path, speaker_audio_file
65
+
66
+
67
+
68
+
69
+ # define a logger to redirect
70
+ class Logger:
71
+ def __init__(self, filename="log.out"):
72
+ self.log_file = filename
73
+ self.terminal = sys.stdout
74
+ self.log = open(self.log_file, "w")
75
+
76
+ def write(self, message):
77
+ self.terminal.write(message)
78
+ self.log.write(message)
79
+
80
+ def flush(self):
81
+ self.terminal.flush()
82
+ self.log.flush()
83
+
84
+ def isatty(self):
85
+ return False
86
+
87
+ # redirect stdout and stderr to a file
88
+ sys.stdout = Logger()
89
+ sys.stderr = sys.stdout
90
+
91
+
92
+ # logging.basicConfig(stream=sys.stdout, level=logging.INFO)
93
+ import logging
94
+ logging.basicConfig(
95
+ level=logging.INFO,
96
+ format="%(asctime)s [%(levelname)s] %(message)s",
97
+ handlers=[
98
+ logging.StreamHandler(sys.stdout)
99
+ ]
100
+ )
101
+
102
+ def read_logs():
103
+ sys.stdout.flush()
104
+ with open(sys.stdout.log_file, "r") as f:
105
+ return f.read()
106
+
107
+
108
+ if __name__ == "__main__":
109
+
110
+ parser = argparse.ArgumentParser(
111
+ description="""XTTS fine-tuning demo\n\n"""
112
+ """
113
+ Example runs:
114
+ python3 TTS/demos/xtts_ft_demo/xtts_demo.py --port
115
+ """,
116
+ formatter_class=argparse.RawTextHelpFormatter,
117
+ )
118
+ parser.add_argument(
119
+ "--port",
120
+ type=int,
121
+ help="Port to run the gradio demo. Default: 5003",
122
+ default=5003,
123
+ )
124
+ parser.add_argument(
125
+ "--out_path",
126
+ type=str,
127
+ help="Output path (where data and checkpoints will be saved) Default: /tmp/xtts_ft/",
128
+ default="/tmp/xtts_ft/",
129
+ )
130
+
131
+ parser.add_argument(
132
+ "--num_epochs",
133
+ type=int,
134
+ help="Number of epochs to train. Default: 10",
135
+ default=10,
136
+ )
137
+ parser.add_argument(
138
+ "--batch_size",
139
+ type=int,
140
+ help="Batch size. Default: 4",
141
+ default=4,
142
+ )
143
+ parser.add_argument(
144
+ "--grad_acumm",
145
+ type=int,
146
+ help="Grad accumulation steps. Default: 1",
147
+ default=1,
148
+ )
149
+ parser.add_argument(
150
+ "--max_audio_length",
151
+ type=int,
152
+ help="Max permitted audio size in seconds. Default: 11",
153
+ default=11,
154
+ )
155
+
156
+ # Add the new arguments
157
+ parser.add_argument(
158
+ "--lang",
159
+ type=str,
160
+ help="Dataset Language",
161
+ default="en",
162
+ )
163
+ parser.add_argument(
164
+ "--train_csv",
165
+ type=str,
166
+ help="Path to the train CSV file",
167
+ required=True,
168
+ )
169
+ parser.add_argument(
170
+ "--eval_csv",
171
+ type=str,
172
+ help="Path to the eval CSV file",
173
+ required=True,
174
+ )
175
+
176
+ args = parser.parse_args()
177
+
178
+ # ... (rest of your code)
179
+
180
+ def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length):
181
+ clear_gpu_cache()
182
+ if not train_csv or not eval_csv:
183
+ return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", ""
184
+ try:
185
+ # convert seconds to waveform frames
186
+ max_audio_length = int(max_audio_length * 22050)
187
+ config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path, max_audio_length=max_audio_length)
188
+ except:
189
+ traceback.print_exc()
190
+ error = traceback.format_exc()
191
+ return f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}", "", "", "", ""
192
+
193
+ # copy original files to avoid parameters changes issues
194
+ os.system(f"cp {config_path} {exp_path}")
195
+ os.system(f"cp {vocab_file} {exp_path}")
196
+
197
+ ft_xtts_checkpoint = os.path.join(exp_path, "best_model.pth")
198
+ print("Model training done!")
199
+ clear_gpu_cache()
200
+ return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint, speaker_wav
201
+
202
+ # ... (rest of your code)
203
+
204
+ # The following section is the only part to be changed:
205
+ # It now directly calls the train_model function instead of using Gradio
206
+
207
+ if __name__ == "__main__":
208
+ # ... (argparse setup)
209
+
210
+ # Call the function directly
211
+ train_model(
212
+ args.lang,
213
+ args.train_csv,
214
+ args.eval_csv,
215
+ args.num_epochs,
216
+ args.batch_size,
217
+ args.grad_acumm,
218
+ args.out_path,
219
+ args.max_audio_length,
220
+ )