AkitoP commited on
Commit
0840da9
·
verified ·
1 Parent(s): 7b8f80b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +379 -365
app.py CHANGED
@@ -1,366 +1,380 @@
1
- import os
2
- import sys
3
- import spaces
4
- import tqdm
5
- cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
6
- bert_path = "GPT_SoVITS\pretrained_models\chinese-roberta-wwm-ext-large"
7
- os.environ["version"] = 'v2'
8
- now_dir = os.path.dirname(os.path.abspath(__file__))
9
- sys.path.insert(0, now_dir)
10
- sys.path.insert(0, os.path.join(now_dir, "GPT_SoVITS"))
11
- sys.path.insert(0, os.path.join(now_dir, "GPT_SoVITS",'text'))
12
- import site
13
- site_packages_roots = []
14
- for site_packages_root in site_packages_roots:
15
- if os.path.exists(site_packages_root):
16
- try:
17
- with open("%s/users.pth" % (site_packages_root), "w") as f:
18
- f.write(
19
- "%s\n%s/tools\n%s/tools/damo_asr\n%s/GPT_SoVITS\n%s/tools/uvr5"
20
- % (now_dir, now_dir, now_dir, now_dir, now_dir)
21
- )
22
- break
23
- except PermissionError:
24
- pass
25
- import re
26
- import gradio as gr
27
- from transformers import AutoModelForMaskedLM, AutoTokenizer
28
- import numpy as np
29
- import os,librosa,torch, audiosegment
30
- from GPT_SoVITS.feature_extractor import cnhubert
31
- cnhubert.cnhubert_base_path=cnhubert_base_path
32
- from GPT_SoVITS.module.models import SynthesizerTrn
33
- from GPT_SoVITS.AR.models.t2s_lightning_module import Text2SemanticLightningModule
34
- from GPT_SoVITS.text import cleaned_text_to_sequence
35
- from GPT_SoVITS.text.cleaner import clean_text
36
- from time import time as ttime
37
- from GPT_SoVITS.module.mel_processing import spectrogram_torch
38
- import tempfile
39
- from tools.my_utils import load_audio
40
- import os
41
- import json
42
- # import pyopenjtalk
43
- # cwd = os.getcwd()
44
- # if os.path.exists(os.path.join(cwd,'user.dic')):
45
- # pyopenjtalk.update_global_jtalk_with_user_dict(os.path.join(cwd, 'user.dic'))
46
-
47
-
48
- import logging
49
- logging.getLogger('httpx').setLevel(logging.WARNING)
50
- logging.getLogger('httpcore').setLevel(logging.WARNING)
51
- logging.getLogger('multipart').setLevel(logging.WARNING)
52
-
53
- device = "cuda" if torch.cuda.is_available() else "cpu"
54
- #device = "cpu"
55
- is_half = False
56
- # bert_model=bert_model.to(device)
57
-
58
- loaded_sovits_model = [] # [(path, dict, model)]
59
- loaded_gpt_model = []
60
- ssl_model = cnhubert.get_model()
61
- if (is_half == True):
62
- ssl_model = ssl_model.half().to(device)
63
- else:
64
- ssl_model = ssl_model.to(device)
65
-
66
-
67
- def load_model(sovits_path, gpt_path):
68
- global ssl_model
69
- global loaded_sovits_model
70
- global loaded_gpt_model
71
- vq_model = None
72
- t2s_model = None
73
- dict_s2 = None
74
- dict_s1 = None
75
- hps = None
76
- for path, dict_s2_, model in loaded_sovits_model:
77
- if path == sovits_path:
78
- vq_model = model
79
- dict_s2 = dict_s2_
80
- break
81
- for path, dict_s1_, model in loaded_gpt_model:
82
- if path == gpt_path:
83
- t2s_model = model
84
- dict_s1 = dict_s1_
85
- break
86
-
87
- if dict_s2 is None:
88
- dict_s2 = torch.load(sovits_path, map_location="cpu")
89
- hps = dict_s2["config"]
90
-
91
- if dict_s1 is None:
92
- dict_s1 = torch.load(gpt_path, map_location="cpu")
93
- config = dict_s1["config"]
94
- class DictToAttrRecursive:
95
- def __init__(self, input_dict):
96
- for key, value in input_dict.items():
97
- if isinstance(value, dict):
98
- # 如果值是字典,递归调用构造函数
99
- setattr(self, key, DictToAttrRecursive(value))
100
- else:
101
- setattr(self, key, value)
102
-
103
- hps = DictToAttrRecursive(hps)
104
- hps.model.semantic_frame_rate = "25hz"
105
-
106
-
107
- if not vq_model:
108
- vq_model = SynthesizerTrn(
109
- hps.data.filter_length // 2 + 1,
110
- hps.train.segment_size // hps.data.hop_length,
111
- n_speakers=hps.data.n_speakers,
112
- **hps.model)
113
- if (is_half == True):
114
- vq_model = vq_model.half().to(device)
115
- else:
116
- vq_model = vq_model.to(device)
117
- vq_model.eval()
118
- vq_model.load_state_dict(dict_s2["weight"], strict=False)
119
- loaded_sovits_model.append((sovits_path, dict_s2, vq_model))
120
- hz = 50
121
- max_sec = config['data']['max_sec']
122
- if not t2s_model:
123
- t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False)
124
- t2s_model.load_state_dict(dict_s1["weight"])
125
- if (is_half == True): t2s_model = t2s_model.half()
126
- t2s_model = t2s_model.to(device)
127
- t2s_model.eval()
128
- total = sum([param.nelement() for param in t2s_model.parameters()])
129
- print("Number of parameter: %.2fM" % (total / 1e6))
130
- loaded_gpt_model.append((gpt_path, dict_s1, t2s_model))
131
- return vq_model, ssl_model, t2s_model, hps, config, hz, max_sec
132
-
133
-
134
- def get_spepc(hps, filename):
135
- audio=load_audio(filename,int(hps.data.sampling_rate))
136
- audio = audio / np.max(np.abs(audio))
137
- audio=torch.FloatTensor(audio)
138
- audio_norm = audio
139
- # audio_norm = audio / torch.max(torch.abs(audio))
140
- audio_norm = audio_norm.unsqueeze(0)
141
- spec = spectrogram_torch(audio_norm, hps.data.filter_length,hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,center=False)
142
- return spec
143
-
144
- def create_tts_fn(vq_model, ssl_model, t2s_model, hps, config, hz, max_sec):
145
- @spaces.GPU
146
- def tts_fn(ref_wav_path, prompt_text, prompt_language, target_phone, text_language, target_text = None):
147
- t0 = ttime()
148
- prompt_text=prompt_text.strip()
149
- prompt_language=prompt_language
150
- with torch.no_grad():
151
- wav16k, sr = librosa.load(ref_wav_path, sr=16000) # 派蒙
152
- # maxx=0.95
153
- # tmp_max = np.abs(wav16k).max()
154
- # alpha=0.5
155
- # wav16k = (wav16k / tmp_max * (maxx * alpha*32768)) + ((1 - alpha)*32768) * wav16k
156
- #在这里归一化
157
- #print(max(np.abs(wav16k)))
158
- #wav16k = wav16k / np.max(np.abs(wav16k))
159
- #print(max(np.abs(wav16k)))
160
- # 添加0.3s的静音
161
- wav16k = np.concatenate([wav16k, np.zeros(int(hps.data.sampling_rate * 0.3)),])
162
- wav16k = torch.from_numpy(wav16k)
163
- wav16k = wav16k.float()
164
- if(is_half==True):wav16k=wav16k.half().to(device)
165
- else:wav16k=wav16k.to(device)
166
- ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2)#.float()
167
- codes = vq_model.extract_latent(ssl_content)
168
- prompt_semantic = codes[0, 0]
169
- t1 = ttime()
170
- phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
171
- phones1=cleaned_text_to_sequence(phones1)
172
- #texts=text.split("\n")
173
- audio_opt = []
174
- zero_wav=np.zeros(int(hps.data.sampling_rate*0.3),dtype=np.float16 if is_half==True else np.float32)
175
- phones = get_phone_from_str_list(target_phone, text_language)
176
- for phones2 in phones:
177
- if(len(phones2) == 0):
178
- continue
179
- if(len(phones2) == 1 and phones2[0] == ""):
180
- continue
181
- #phones2, word2ph2, norm_text2 = clean_text(text, text_language)
182
- phones2 = cleaned_text_to_sequence(phones2)
183
- #if(prompt_language=="zh"):bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
184
- bert1 = torch.zeros((1024, len(phones1)),dtype=torch.float16 if is_half==True else torch.float32).to(device)
185
- #if(text_language=="zh"):bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
186
- bert2 = torch.zeros((1024, len(phones2))).to(bert1)
187
- bert = torch.cat([bert1, bert2], 1)
188
-
189
- all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
190
- bert = bert.to(device).unsqueeze(0)
191
- all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
192
- prompt = prompt_semantic.unsqueeze(0).to(device)
193
- t2 = ttime()
194
- idx = 0
195
- cnt = 0
196
- while idx == 0 and cnt < 2:
197
- with torch.no_grad():
198
- # pred_semantic = t2s_model.model.infer
199
- pred_semantic,idx = t2s_model.model.infer_panel(
200
- all_phoneme_ids,
201
- all_phoneme_len,
202
- prompt,
203
- bert,
204
- # prompt_phone_len=ph_offset,
205
- top_k=config['inference']['top_k'],
206
- early_stop_num=hz * max_sec)
207
- t3 = ttime()
208
- cnt+=1
209
- if idx == 0:
210
- return "Error: Generation failure: bad zero prediction.", None
211
- pred_semantic = pred_semantic[:,-idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
212
- refer = get_spepc(hps, ref_wav_path)#.to(device)
213
- if(is_half==True):refer=refer.half().to(device)
214
- else:refer=refer.to(device)
215
- # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
216
- audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer).detach().cpu().numpy()[0, 0]###试试重建不带上prompt部分
217
- audio_opt.append(audio)
218
- audio_opt.append(zero_wav)
219
- t4 = ttime()
220
- print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
221
-
222
- audio = (hps.data.sampling_rate,(np.concatenate(audio_opt,0)*32768).astype(np.int16))
223
-
224
- filename = tempfile.mktemp(suffix=".wav",prefix=f"{prompt_text[:8]}_{target_text[:8]}_")
225
- audiosegment.from_numpy_array(audio[1], framerate=audio[0]).export(filename, format="WAV")
226
- return "Success", (hps.data.sampling_rate,(np.concatenate(audio_opt,0)*32768).astype(np.int16)), filename
227
- return tts_fn
228
-
229
-
230
- def get_str_list_from_phone(text, text_language):
231
- # raw文本过g2p得到音素列表,再转成字符串
232
- # 注意,这里的text是一个段落,可能包含多个句子
233
- # 段落间\n分割,音素间空格分割
234
- print(text)
235
- texts=text.split("\n")
236
- phone_list = []
237
- for text in texts:
238
- phones2, word2ph2, norm_text2 = clean_text(text, text_language)
239
- phone_list.append(" ".join(phones2))
240
- return "\n".join(phone_list)
241
-
242
- def get_phone_from_str_list(str_list:str, language:str = 'ja'):
243
- # 从音素字符串中得到音素列表
244
- # 注意,这里的text是一个段落,可能包含多个句子
245
- # 段落间\n分割,音素间空格分割
246
- sentences = str_list.split("\n")
247
- phones = []
248
- for sentence in sentences:
249
- phones.append(sentence.split(" "))
250
- return phones
251
-
252
- splits={",","。","?","!",",",".","?","!","~",":",":","—","…",}#不考虑省略号
253
- def split(todo_text):
254
- todo_text = todo_text.replace("……", "。").replace("——", ",")
255
- if (todo_text[-1] not in splits): todo_text += "。"
256
- i_split_head = i_split_tail = 0
257
- len_text = len(todo_text)
258
- todo_texts = []
259
- while (1):
260
- if (i_split_head >= len_text): break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
261
- if (todo_text[i_split_head] in splits):
262
- i_split_head += 1
263
- todo_texts.append(todo_text[i_split_tail:i_split_head])
264
- i_split_tail = i_split_head
265
- else:
266
- i_split_head += 1
267
- return todo_texts
268
-
269
-
270
- def change_reference_audio(prompt_text, transcripts):
271
- return transcripts[prompt_text]
272
-
273
-
274
- models = []
275
- models_info = json.load(open("./models/models_info.json", "r", encoding="utf-8"))
276
-
277
-
278
-
279
- for i, info in tqdm.tqdm(models_info.items()):
280
- title = info['title']
281
- cover = info['cover']
282
- gpt_weight = info['gpt_weight']
283
- sovits_weight = info['sovits_weight']
284
- example_reference = info['example_reference']
285
- transcripts = {}
286
- transcript_path = info["transcript_path"]
287
- path = os.path.dirname(transcript_path)
288
- with open(transcript_path, 'r', encoding='utf-8') as file:
289
- for line in file:
290
- line = line.strip().replace("\\", "/")
291
- wav,_,_, t = line.split("|")
292
- wav = os.path.basename(wav)
293
- transcripts[t] = os.path.join(os.path.join(path,"reference_audio"), wav)
294
-
295
- vq_model, ssl_model, t2s_model, hps, config, hz, max_sec = load_model(sovits_weight, gpt_weight)
296
-
297
-
298
- models.append(
299
- (
300
- i,
301
- title,
302
- cover,
303
- transcripts,
304
- example_reference,
305
- create_tts_fn(
306
- vq_model, ssl_model, t2s_model, hps, config, hz, max_sec
307
- )
308
- )
309
- )
310
- with gr.Blocks() as app:
311
- gr.Markdown(
312
- "# <center> GPT-SoVITS-V2-Gakuen Idolmaster\n"
313
- )
314
- with gr.Tabs():
315
- for (name, title, cover, transcripts, example_reference, tts_fn) in models:
316
- with gr.TabItem(name):
317
- with gr.Row():
318
- gr.Markdown(
319
- '<div align="center">'
320
- f'<a><strong>{title}</strong></a>'
321
- '</div>')
322
- with gr.Row():
323
- with gr.Column():
324
- prompt_text = gr.Dropdown(
325
- label="Transcript of the Reference Audio",
326
- value=example_reference if example_reference in transcripts else list(transcripts.keys())[0],
327
- choices=list(transcripts.keys())
328
- )
329
- inp_ref_audio = gr.Audio(
330
- label="Reference Audio",
331
- type="filepath",
332
- interactive=False,
333
- value=transcripts[example_reference] if example_reference in transcripts else list(transcripts.values())[0]
334
- )
335
- transcripts_state = gr.State(value=transcripts)
336
- prompt_text.change(
337
- fn=change_reference_audio,
338
- inputs=[prompt_text, transcripts_state],
339
- outputs=[inp_ref_audio]
340
- )
341
- prompt_language = gr.State(value="ja")
342
- with gr.Column():
343
- text = gr.Textbox(label="Input Text", value="学園アイドルマスター!")
344
- text_language = gr.Dropdown(
345
- label="Language",
346
- choices=["ja"],
347
- value="ja"
348
- )
349
- clean_button = gr.Button("Clean Text", variant="primary")
350
- inference_button = gr.Button("Generate", variant="primary")
351
- cleaned_text = gr.Textbox(label="Cleaned Text")
352
- output = gr.Audio(label="Output Audio")
353
- output_file = gr.File(label="Output Audio File")
354
- om = gr.Textbox(label="Output Message")
355
- clean_button.click(
356
- fn=get_str_list_from_phone,
357
- inputs=[text, text_language],
358
- outputs=[cleaned_text]
359
- )
360
- inference_button.click(
361
- fn=tts_fn,
362
- inputs=[inp_ref_audio, prompt_text, prompt_language, cleaned_text, text_language, text],
363
- outputs=[om, output, output_file]
364
- )
365
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  app.launch(share=True)
 
1
+ import os
2
+ import sys
3
+ import spaces
4
+ import tqdm
5
+ cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
6
+ bert_path = "GPT_SoVITS\pretrained_models\chinese-roberta-wwm-ext-large"
7
+ os.environ["version"] = 'v2'
8
+ now_dir = os.path.dirname(os.path.abspath(__file__))
9
+ sys.path.insert(0, now_dir)
10
+ sys.path.insert(0, os.path.join(now_dir, "GPT_SoVITS"))
11
+ sys.path.insert(0, os.path.join(now_dir, "GPT_SoVITS",'text'))
12
+ import site
13
+ site_packages_roots = []
14
+ for site_packages_root in site_packages_roots:
15
+ if os.path.exists(site_packages_root):
16
+ try:
17
+ with open("%s/users.pth" % (site_packages_root), "w") as f:
18
+ f.write(
19
+ "%s\n%s/tools\n%s/tools/damo_asr\n%s/GPT_SoVITS\n%s/tools/uvr5"
20
+ % (now_dir, now_dir, now_dir, now_dir, now_dir)
21
+ )
22
+ break
23
+ except PermissionError:
24
+ pass
25
+ import re
26
+ import gradio as gr
27
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
28
+ import numpy as np
29
+ import os,librosa,torch, audiosegment
30
+ from GPT_SoVITS.feature_extractor import cnhubert
31
+ cnhubert.cnhubert_base_path=cnhubert_base_path
32
+ from GPT_SoVITS.module.models import SynthesizerTrn
33
+ from GPT_SoVITS.AR.models.t2s_lightning_module import Text2SemanticLightningModule
34
+ from GPT_SoVITS.text import cleaned_text_to_sequence
35
+ from GPT_SoVITS.text.cleaner import clean_text
36
+ from time import time as ttime
37
+ from GPT_SoVITS.module.mel_processing import spectrogram_torch
38
+ import tempfile
39
+ from tools.my_utils import load_audio
40
+ import os
41
+ import json
42
+ # import pyopenjtalk
43
+ # cwd = os.getcwd()
44
+ # if os.path.exists(os.path.join(cwd,'user.dic')):
45
+ # pyopenjtalk.update_global_jtalk_with_user_dict(os.path.join(cwd, 'user.dic'))
46
+
47
+
48
+ import logging
49
+ logging.getLogger('httpx').setLevel(logging.WARNING)
50
+ logging.getLogger('httpcore').setLevel(logging.WARNING)
51
+ logging.getLogger('multipart').setLevel(logging.WARNING)
52
+
53
+ device = "cuda" if torch.cuda.is_available() else "cpu"
54
+ #device = "cpu"
55
+ is_half = False
56
+ # bert_model=bert_model.to(device)
57
+
58
+ loaded_sovits_model = [] # [(path, dict, model)]
59
+ loaded_gpt_model = []
60
+ ssl_model = cnhubert.get_model()
61
+ if (is_half == True):
62
+ ssl_model = ssl_model.half().to(device)
63
+ else:
64
+ ssl_model = ssl_model.to(device)
65
+
66
+
67
+ def load_model(sovits_path, gpt_path):
68
+ global ssl_model
69
+ global loaded_sovits_model
70
+ global loaded_gpt_model
71
+ vq_model = None
72
+ t2s_model = None
73
+ dict_s2 = None
74
+ dict_s1 = None
75
+ hps = None
76
+ for path, dict_s2_, model in loaded_sovits_model:
77
+ if path == sovits_path:
78
+ vq_model = model
79
+ dict_s2 = dict_s2_
80
+ break
81
+ for path, dict_s1_, model in loaded_gpt_model:
82
+ if path == gpt_path:
83
+ t2s_model = model
84
+ dict_s1 = dict_s1_
85
+ break
86
+
87
+ if dict_s2 is None:
88
+ dict_s2 = torch.load(sovits_path, map_location="cpu")
89
+ hps = dict_s2["config"]
90
+
91
+ if dict_s1 is None:
92
+ dict_s1 = torch.load(gpt_path, map_location="cpu")
93
+ config = dict_s1["config"]
94
+ class DictToAttrRecursive:
95
+ def __init__(self, input_dict):
96
+ for key, value in input_dict.items():
97
+ if isinstance(value, dict):
98
+ # 如果值是字典,递归调用构造函数
99
+ setattr(self, key, DictToAttrRecursive(value))
100
+ else:
101
+ setattr(self, key, value)
102
+
103
+ hps = DictToAttrRecursive(hps)
104
+ hps.model.semantic_frame_rate = "25hz"
105
+
106
+
107
+ if not vq_model:
108
+ vq_model = SynthesizerTrn(
109
+ hps.data.filter_length // 2 + 1,
110
+ hps.train.segment_size // hps.data.hop_length,
111
+ n_speakers=hps.data.n_speakers,
112
+ **hps.model)
113
+ if (is_half == True):
114
+ vq_model = vq_model.half().to(device)
115
+ else:
116
+ vq_model = vq_model.to(device)
117
+ vq_model.eval()
118
+ vq_model.load_state_dict(dict_s2["weight"], strict=False)
119
+ loaded_sovits_model.append((sovits_path, dict_s2, vq_model))
120
+ hz = 50
121
+ max_sec = config['data']['max_sec']
122
+ if not t2s_model:
123
+ t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False)
124
+ t2s_model.load_state_dict(dict_s1["weight"])
125
+ if (is_half == True): t2s_model = t2s_model.half()
126
+ t2s_model = t2s_model.to(device)
127
+ t2s_model.eval()
128
+ total = sum([param.nelement() for param in t2s_model.parameters()])
129
+ print("Number of parameter: %.2fM" % (total / 1e6))
130
+ loaded_gpt_model.append((gpt_path, dict_s1, t2s_model))
131
+ return vq_model, ssl_model, t2s_model, hps, config, hz, max_sec
132
+
133
+
134
+ def get_spepc(hps, filename):
135
+ audio=load_audio(filename,int(hps.data.sampling_rate))
136
+ audio = audio / np.max(np.abs(audio))
137
+ audio=torch.FloatTensor(audio)
138
+ audio_norm = audio
139
+ # audio_norm = audio / torch.max(torch.abs(audio))
140
+ audio_norm = audio_norm.unsqueeze(0)
141
+ spec = spectrogram_torch(audio_norm, hps.data.filter_length,hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,center=False)
142
+ return spec
143
+
144
+ def create_tts_fn(vq_model, ssl_model, t2s_model, hps, config, hz, max_sec):
145
+ @spaces.GPU
146
+ def tts_fn(ref_wav_path, prompt_text, prompt_language, target_phone, text_language, target_text = None):
147
+ t0 = ttime()
148
+ prompt_text=prompt_text.strip()
149
+ prompt_language=prompt_language
150
+ with torch.no_grad():
151
+ wav16k, sr = librosa.load(ref_wav_path, sr=16000) # 派蒙
152
+ # maxx=0.95
153
+ # tmp_max = np.abs(wav16k).max()
154
+ # alpha=0.5
155
+ # wav16k = (wav16k / tmp_max * (maxx * alpha*32768)) + ((1 - alpha)*32768) * wav16k
156
+ #在这里归一化
157
+ #print(max(np.abs(wav16k)))
158
+ #wav16k = wav16k / np.max(np.abs(wav16k))
159
+ #print(max(np.abs(wav16k)))
160
+ # 添加0.3s的静音
161
+ wav16k = np.concatenate([wav16k, np.zeros(int(hps.data.sampling_rate * 0.3)),])
162
+ wav16k = torch.from_numpy(wav16k)
163
+ wav16k = wav16k.float()
164
+ if(is_half==True):wav16k=wav16k.half().to(device)
165
+ else:wav16k=wav16k.to(device)
166
+ ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2)#.float()
167
+ codes = vq_model.extract_latent(ssl_content)
168
+ prompt_semantic = codes[0, 0]
169
+ t1 = ttime()
170
+ phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
171
+ phones1=cleaned_text_to_sequence(phones1)
172
+ #texts=text.split("\n")
173
+ audio_opt = []
174
+ zero_wav=np.zeros(int(hps.data.sampling_rate*0.3),dtype=np.float16 if is_half==True else np.float32)
175
+ phones = get_phone_from_str_list(target_phone, text_language)
176
+ for phones2 in phones:
177
+ if(len(phones2) == 0):
178
+ continue
179
+ if(len(phones2) == 1 and phones2[0] == ""):
180
+ continue
181
+ #phones2, word2ph2, norm_text2 = clean_text(text, text_language)
182
+ phones2 = cleaned_text_to_sequence(phones2)
183
+ #if(prompt_language=="zh"):bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
184
+ bert1 = torch.zeros((1024, len(phones1)),dtype=torch.float16 if is_half==True else torch.float32).to(device)
185
+ #if(text_language=="zh"):bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
186
+ bert2 = torch.zeros((1024, len(phones2))).to(bert1)
187
+ bert = torch.cat([bert1, bert2], 1)
188
+
189
+ all_phoneme_ids = torch.LongTensor(phones1+phones2).to(device).unsqueeze(0)
190
+ bert = bert.to(device).unsqueeze(0)
191
+ all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
192
+ prompt = prompt_semantic.unsqueeze(0).to(device)
193
+ t2 = ttime()
194
+ idx = 0
195
+ cnt = 0
196
+ while idx == 0 and cnt < 2:
197
+ with torch.no_grad():
198
+ # pred_semantic = t2s_model.model.infer
199
+ pred_semantic,idx = t2s_model.model.infer_panel(
200
+ all_phoneme_ids,
201
+ all_phoneme_len,
202
+ prompt,
203
+ bert,
204
+ # prompt_phone_len=ph_offset,
205
+ top_k=config['inference']['top_k'],
206
+ early_stop_num=hz * max_sec)
207
+ t3 = ttime()
208
+ cnt+=1
209
+ if idx == 0:
210
+ return "Error: Generation failure: bad zero prediction.", None
211
+ pred_semantic = pred_semantic[:,-idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
212
+ refer = get_spepc(hps, ref_wav_path)#.to(device)
213
+ if(is_half==True):refer=refer.half().to(device)
214
+ else:refer=refer.to(device)
215
+ # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
216
+ audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer).detach().cpu().numpy()[0, 0]###试试重建不带上prompt部分
217
+ audio_opt.append(audio)
218
+ audio_opt.append(zero_wav)
219
+ t4 = ttime()
220
+ print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
221
+
222
+ audio = (hps.data.sampling_rate,(np.concatenate(audio_opt,0)*32768).astype(np.int16))
223
+
224
+ filename = tempfile.mktemp(suffix=".wav",prefix=f"{prompt_text[:8]}_{target_text[:8]}_")
225
+ audiosegment.from_numpy_array(audio[1], framerate=audio[0]).export(filename, format="WAV")
226
+ return "Success", (hps.data.sampling_rate,(np.concatenate(audio_opt,0)*32768).astype(np.int16)), filename
227
+ return tts_fn
228
+
229
+
230
+ def get_str_list_from_phone(text, text_language):
231
+ # raw文本过g2p得到音素列表,再转成字符串
232
+ # 注意,这里的text是一个段落,可能包含多个句子
233
+ # 段落间\n分割,音素间空格分割
234
+ print(text)
235
+ texts=text.split("\n")
236
+ phone_list = []
237
+ for text in texts:
238
+ phones2, word2ph2, norm_text2 = clean_text(text, text_language)
239
+ phone_list.append(" ".join(phones2))
240
+ return "\n".join(phone_list)
241
+
242
+ def get_phone_from_str_list(str_list:str, language:str = 'ja'):
243
+ # 从音素字符串中得到音素列表
244
+ # 注意,这里的text是一个段落,可能包含多个句子
245
+ # 段落间\n分割,音素间空格分割
246
+ sentences = str_list.split("\n")
247
+ phones = []
248
+ for sentence in sentences:
249
+ phones.append(sentence.split(" "))
250
+ return phones
251
+
252
+ splits={",","。","?","!",",",".","?","!","~",":",":","—","…",}#不考虑省略号
253
+ def split(todo_text):
254
+ todo_text = todo_text.replace("……", "。").replace("——", ",")
255
+ if (todo_text[-1] not in splits): todo_text += "。"
256
+ i_split_head = i_split_tail = 0
257
+ len_text = len(todo_text)
258
+ todo_texts = []
259
+ while (1):
260
+ if (i_split_head >= len_text): break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
261
+ if (todo_text[i_split_head] in splits):
262
+ i_split_head += 1
263
+ todo_texts.append(todo_text[i_split_tail:i_split_head])
264
+ i_split_tail = i_split_head
265
+ else:
266
+ i_split_head += 1
267
+ return todo_texts
268
+
269
+
270
+ def change_reference_audio(prompt_text, transcripts):
271
+ return transcripts[prompt_text]
272
+
273
+
274
+ models = []
275
+ models_info = json.load(open("./models/models_info.json", "r", encoding="utf-8"))
276
+
277
+
278
+
279
+ for i, info in tqdm.tqdm(models_info.items()):
280
+ title = info['title']
281
+ cover = info['cover']
282
+ gpt_weight = info['gpt_weight']
283
+ sovits_weight = info['sovits_weight']
284
+ example_reference = info['example_reference']
285
+ transcripts = {}
286
+ transcript_path = info["transcript_path"]
287
+ path = os.path.dirname(transcript_path)
288
+ with open(transcript_path, 'r', encoding='utf-8') as file:
289
+ for line in file:
290
+ line = line.strip().replace("\\", "/")
291
+ wav,_,_, t = line.split("|")
292
+ wav = os.path.basename(wav)
293
+ transcripts[t] = os.path.join(os.path.join(path,"reference_audio"), wav)
294
+
295
+ vq_model, ssl_model, t2s_model, hps, config, hz, max_sec = load_model(sovits_weight, gpt_weight)
296
+
297
+
298
+ models.append(
299
+ (
300
+ i,
301
+ title,
302
+ cover,
303
+ transcripts,
304
+ example_reference,
305
+ create_tts_fn(
306
+ vq_model, ssl_model, t2s_model, hps, config, hz, max_sec
307
+ )
308
+ )
309
+ )
310
+ with gr.Blocks() as app:
311
+ gr.Markdown(
312
+ "# <center> GPT-SoVITS-V2-Gakuen Idolmaster\n"
313
+ "### 中文\n"
314
+ "1. 在左侧选择参考音频来调整合成语音的情感。\n"
315
+ "2. 在右侧输入要合成的文本(Shift+Enter换行,每行单独合成并拼接)。\n"
316
+ "3. 点击Clean Text将文本转为token。\n"
317
+ "4. (可选) 手动修改token中的错误。\n"
318
+ "5. 点击Generate生成语音。\n"
319
+ "注意:由于Zero显卡具有单次推理时长限制,每次推理的内容不应过长。\n"
320
+ "### 日本語\n"
321
+ "1. 左側でリファレンス音声を選択して、合成音声の感情を調整します。\n"
322
+ "2. 右側にテキストを入力します(Shift+Enterで改行、各行を個別に合成して連結)。\n"
323
+ "3. Clean Textをクリックしてテキストをトークンに変換します。\n"
324
+ "4. (オプション)トークンのエラーを手動で修正します。\n"
325
+ "5. Generateをクリックして音声を生成します。\n"
326
+ "注意:Zeroグラフィックカードには単一の推論時間制限があるため、推論内容を短くする必要があります。\n"
327
+ )
328
+ with gr.Tabs():
329
+ for (name, title, cover, transcripts, example_reference, tts_fn) in models:
330
+ with gr.TabItem(name):
331
+ with gr.Row():
332
+ gr.Markdown(
333
+ '<div align="center">'
334
+ f'<a><strong>{title}</strong></a>'
335
+ '</div>')
336
+ with gr.Row():
337
+ with gr.Column():
338
+ prompt_text = gr.Dropdown(
339
+ label="Transcript of the Reference Audio",
340
+ value=example_reference if example_reference in transcripts else list(transcripts.keys())[0],
341
+ choices=list(transcripts.keys())
342
+ )
343
+ inp_ref_audio = gr.Audio(
344
+ label="Reference Audio",
345
+ type="filepath",
346
+ interactive=False,
347
+ value=transcripts[example_reference] if example_reference in transcripts else list(transcripts.values())[0]
348
+ )
349
+ transcripts_state = gr.State(value=transcripts)
350
+ prompt_text.change(
351
+ fn=change_reference_audio,
352
+ inputs=[prompt_text, transcripts_state],
353
+ outputs=[inp_ref_audio]
354
+ )
355
+ prompt_language = gr.State(value="ja")
356
+ with gr.Column():
357
+ text = gr.Textbox(label="Input Text", value="学園アイドルマスター!")
358
+ text_language = gr.Dropdown(
359
+ label="Language",
360
+ choices=["ja"],
361
+ value="ja"
362
+ )
363
+ clean_button = gr.Button("Clean Text", variant="primary")
364
+ inference_button = gr.Button("Generate", variant="primary")
365
+ cleaned_text = gr.Textbox(label="Cleaned Text")
366
+ output = gr.Audio(label="Output Audio")
367
+ output_file = gr.File(label="Output Audio File")
368
+ om = gr.Textbox(label="Output Message")
369
+ clean_button.click(
370
+ fn=get_str_list_from_phone,
371
+ inputs=[text, text_language],
372
+ outputs=[cleaned_text]
373
+ )
374
+ inference_button.click(
375
+ fn=tts_fn,
376
+ inputs=[inp_ref_audio, prompt_text, prompt_language, cleaned_text, text_language, text],
377
+ outputs=[om, output, output_file]
378
+ )
379
+
380
  app.launch(share=True)