Spaces:
Running
Running
Merge pull request #255 from jhj0517/feature/remember-settings
Browse files- .gitignore +2 -1
- app.py +88 -67
- configs/default_parameters.yaml +58 -0
- modules/diarize/diarize_pipeline.py +2 -1
- modules/diarize/diarizer.py +2 -1
- modules/translation/deepl_api.py +27 -2
- modules/translation/nllb_inference.py +3 -2
- modules/translation/translation_base.py +26 -2
- {ui → modules/ui}/__init__.py +0 -0
- {ui → modules/ui}/htmls.py +0 -0
- modules/utils/files_manager.py +25 -1
- modules/utils/paths.py +24 -0
- modules/whisper/faster_whisper_inference.py +5 -6
- modules/whisper/insanely_fast_whisper_inference.py +4 -3
- modules/whisper/whisper_Inference.py +4 -3
- modules/whisper/whisper_base.py +28 -5
- modules/whisper/whisper_factory.py +7 -5
- modules/whisper/whisper_parameter.py +53 -2
- requirements.txt +1 -0
.gitignore
CHANGED
@@ -3,8 +3,9 @@
|
|
3 |
*.mp4
|
4 |
*.mp3
|
5 |
venv/
|
6 |
-
ui/__pycache__/
|
7 |
outputs/
|
8 |
modules/__pycache__/
|
9 |
models/
|
10 |
modules/yt_tmp.wav
|
|
|
|
3 |
*.mp4
|
4 |
*.mp3
|
5 |
venv/
|
6 |
+
modules/ui/__pycache__/
|
7 |
outputs/
|
8 |
modules/__pycache__/
|
9 |
models/
|
10 |
modules/yt_tmp.wav
|
11 |
+
configs/default_parameters.yaml
|
app.py
CHANGED
@@ -1,12 +1,16 @@
|
|
1 |
import os
|
2 |
import argparse
|
3 |
import gradio as gr
|
|
|
4 |
|
|
|
|
|
|
|
5 |
from modules.whisper.whisper_factory import WhisperFactory
|
6 |
from modules.whisper.faster_whisper_inference import FasterWhisperInference
|
7 |
from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
|
8 |
from modules.translation.nllb_inference import NLLBInference
|
9 |
-
from ui.htmls import *
|
10 |
from modules.utils.youtube_manager import get_ytmetas
|
11 |
from modules.translation.deepl_api import DeepLAPI
|
12 |
from modules.whisper.whisper_parameter import *
|
@@ -32,103 +36,117 @@ class App:
|
|
32 |
self.deepl_api = DeepLAPI(
|
33 |
output_dir=os.path.join(self.args.output_dir, "translations")
|
34 |
)
|
|
|
35 |
|
36 |
def create_whisper_parameters(self):
|
|
|
|
|
|
|
|
|
37 |
with gr.Row():
|
38 |
-
dd_model = gr.Dropdown(choices=self.whisper_inf.available_models, value="
|
39 |
label="Model")
|
40 |
dd_lang = gr.Dropdown(choices=["Automatic Detection"] + self.whisper_inf.available_langs,
|
41 |
-
value="
|
42 |
dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt"], value="SRT", label="File Format")
|
43 |
with gr.Row():
|
44 |
-
cb_translate = gr.Checkbox(value=
|
|
|
45 |
with gr.Row():
|
46 |
-
cb_timestamp = gr.Checkbox(value=
|
47 |
interactive=True)
|
48 |
with gr.Accordion("Advanced Parameters", open=False):
|
49 |
-
nb_beam_size = gr.Number(label="Beam Size", value=
|
50 |
info="Beam size to use for decoding.")
|
51 |
-
nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value
|
52 |
info="If the average log probability over sampled tokens is below this value, treat as failed.")
|
53 |
-
nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=
|
54 |
info="If the no speech probability is higher than this value AND the average log probability over sampled tokens is below 'Log Prob Threshold', consider the segment as silent.")
|
55 |
dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types,
|
56 |
value=self.whisper_inf.current_compute_type, interactive=True,
|
57 |
info="Select the type of computation to perform.")
|
58 |
-
nb_best_of = gr.Number(label="Best Of", value=
|
59 |
info="Number of candidates when sampling with non-zero temperature.")
|
60 |
-
nb_patience = gr.Number(label="Patience", value=
|
61 |
info="Beam search patience factor.")
|
62 |
-
cb_condition_on_previous_text = gr.Checkbox(label="Condition On Previous Text", value=
|
63 |
interactive=True,
|
64 |
info="Condition on previous text during decoding.")
|
65 |
-
sld_prompt_reset_on_temperature = gr.Slider(label="Prompt Reset On Temperature", value=
|
66 |
minimum=0, maximum=1, step=0.01, interactive=True,
|
67 |
info="Resets prompt if temperature is above this value."
|
68 |
" Arg has effect only if 'Condition On Previous Text' is True.")
|
69 |
tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True,
|
70 |
info="Initial prompt to use for decoding.")
|
71 |
-
sd_temperature = gr.Slider(label="Temperature", value=
|
|
|
72 |
info="Temperature for sampling. It can be a tuple of temperatures, which will be successively used upon failures according to either `Compression Ratio Threshold` or `Log Prob Threshold`.")
|
73 |
-
nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", value=
|
|
|
74 |
info="If the gzip compression ratio is above this value, treat as failed.")
|
75 |
with gr.Group(visible=isinstance(self.whisper_inf, FasterWhisperInference)):
|
76 |
-
nb_length_penalty = gr.Number(label="Length Penalty", value=
|
77 |
info="Exponential length penalty constant.")
|
78 |
-
nb_repetition_penalty = gr.Number(label="Repetition Penalty", value=
|
79 |
info="Penalty applied to the score of previously generated tokens (set > 1 to penalize).")
|
80 |
-
nb_no_repeat_ngram_size = gr.Number(label="No Repeat N-gram Size", value=
|
|
|
81 |
info="Prevent repetitions of n-grams with this size (set 0 to disable).")
|
82 |
-
tb_prefix = gr.Textbox(label="Prefix", value=lambda:
|
83 |
info="Optional text to provide as a prefix for the first window.")
|
84 |
-
cb_suppress_blank = gr.Checkbox(label="Suppress Blank", value=
|
85 |
info="Suppress blank outputs at the beginning of the sampling.")
|
86 |
-
tb_suppress_tokens = gr.Textbox(label="Suppress Tokens", value="
|
87 |
info="List of token IDs to suppress. -1 will suppress a default set of symbols as defined in the model config.json file.")
|
88 |
-
nb_max_initial_timestamp = gr.Number(label="Max Initial Timestamp", value=
|
89 |
info="The initial timestamp cannot be later than this.")
|
90 |
-
cb_word_timestamps = gr.Checkbox(label="Word Timestamps", value=
|
91 |
info="Extract word-level timestamps using the cross-attention pattern and dynamic time warping, and include the timestamps for each word in each segment.")
|
92 |
-
tb_prepend_punctuations = gr.Textbox(label="Prepend Punctuations", value="
|
93 |
info="If 'Word Timestamps' is True, merge these punctuation symbols with the next word.")
|
94 |
-
tb_append_punctuations = gr.Textbox(label="Append Punctuations", value="
|
95 |
info="If 'Word Timestamps' is True, merge these punctuation symbols with the previous word.")
|
96 |
-
nb_max_new_tokens = gr.Number(label="Max New Tokens", value=lambda:
|
|
|
97 |
info="Maximum number of new tokens to generate per-chunk. If not set, the maximum will be set by the default max_length.")
|
98 |
-
nb_chunk_length = gr.Number(label="Chunk Length", value=lambda:
|
|
|
99 |
info="The length of audio segments. If it is not None, it will overwrite the default chunk_length of the FeatureExtractor.")
|
100 |
nb_hallucination_silence_threshold = gr.Number(label="Hallucination Silence Threshold (sec)",
|
101 |
-
value=lambda:
|
102 |
info="When 'Word Timestamps' is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected.")
|
103 |
-
tb_hotwords = gr.Textbox(label="Hotwords", value=
|
104 |
info="Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.")
|
105 |
-
nb_language_detection_threshold = gr.Number(label="Language Detection Threshold", value=
|
106 |
info="If the maximum probability of the language tokens is higher than this value, the language is detected.")
|
107 |
-
nb_language_detection_segments = gr.Number(label="Language Detection Segments", value=
|
|
|
108 |
info="Number of segments to consider for the language detection.")
|
109 |
with gr.Group(visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
|
110 |
-
nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=
|
111 |
-
|
|
|
112 |
|
113 |
with gr.Accordion("VAD", open=False):
|
114 |
-
cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=
|
115 |
-
|
|
|
116 |
info="Lower it to be more sensitive to small sounds.")
|
117 |
-
nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=
|
118 |
info="Final speech chunks shorter than this time are thrown out")
|
119 |
-
nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", value=
|
120 |
info="Maximum duration of speech chunks in \"seconds\". Chunks longer"
|
121 |
" than this time will be split at the timestamp of the last silence that"
|
122 |
" lasts more than 100ms (if any), to prevent aggressive cutting.")
|
123 |
-
nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=
|
124 |
info="In the end of each speech chunk wait for this time"
|
125 |
" before separating it")
|
126 |
-
nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=
|
127 |
info="Final speech chunks are padded by this time each side")
|
128 |
|
129 |
with gr.Accordion("Diarization", open=False):
|
130 |
-
cb_diarize = gr.Checkbox(label="Enable Diarization")
|
131 |
-
tb_hf_token = gr.Text(label="HuggingFace Token", value="",
|
132 |
info="This is only needed the first time you download the model. If you already have models, you don't need to enter. To download the model, you must manually go to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and agree to their requirement.")
|
133 |
dd_diarization_device = gr.Dropdown(label="Device",
|
134 |
choices=self.whisper_inf.diarizer.get_available_device(),
|
@@ -162,6 +180,10 @@ class App:
|
|
162 |
)
|
163 |
|
164 |
def launch(self):
|
|
|
|
|
|
|
|
|
165 |
with self.app:
|
166 |
with gr.Row():
|
167 |
with gr.Column():
|
@@ -246,19 +268,17 @@ class App:
|
|
246 |
|
247 |
with gr.TabItem("DeepL API"): # sub tab1
|
248 |
with gr.Row():
|
249 |
-
|
250 |
-
value="")
|
251 |
with gr.Row():
|
252 |
-
|
253 |
-
|
254 |
self.deepl_api.available_source_langs.keys()))
|
255 |
-
|
256 |
-
|
257 |
-
self.deepl_api.available_target_langs.keys()))
|
258 |
with gr.Row():
|
259 |
-
|
260 |
with gr.Row():
|
261 |
-
cb_timestamp = gr.Checkbox(value=
|
262 |
interactive=True)
|
263 |
with gr.Row():
|
264 |
btn_run = gr.Button("TRANSLATE SUBTITLE FILE", variant="primary")
|
@@ -268,26 +288,27 @@ class App:
|
|
268 |
btn_openfolder = gr.Button('📂', scale=1)
|
269 |
|
270 |
btn_run.click(fn=self.deepl_api.translate_deepl,
|
271 |
-
inputs=[
|
272 |
-
|
273 |
outputs=[tb_indicator, files_subtitles])
|
274 |
|
275 |
-
btn_openfolder.click(fn=lambda: self.open_folder(os.path.join(
|
276 |
inputs=None,
|
277 |
outputs=None)
|
278 |
|
279 |
with gr.TabItem("NLLB"): # sub tab2
|
280 |
with gr.Row():
|
281 |
-
|
282 |
choices=self.nllb_inf.available_models)
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
with gr.Row():
|
288 |
-
nb_max_length = gr.Number(label="Max Length Per Line", value=
|
|
|
289 |
with gr.Row():
|
290 |
-
cb_timestamp = gr.Checkbox(value=
|
291 |
interactive=True)
|
292 |
with gr.Row():
|
293 |
btn_run = gr.Button("TRANSLATE SUBTITLE FILE", variant="primary")
|
@@ -299,11 +320,11 @@ class App:
|
|
299 |
md_vram_table = gr.HTML(NLLB_VRAM_TABLE, elem_id="md_nllb_vram_table")
|
300 |
|
301 |
btn_run.click(fn=self.nllb_inf.translate_file,
|
302 |
-
inputs=[file_subs,
|
303 |
nb_max_length, cb_timestamp],
|
304 |
outputs=[tb_indicator, files_subtitles])
|
305 |
|
306 |
-
btn_openfolder.click(fn=lambda: self.open_folder(os.path.join(
|
307 |
inputs=None,
|
308 |
outputs=None)
|
309 |
|
@@ -351,18 +372,18 @@ parser.add_argument('--theme', type=str, default=None, help='Gradio Blocks theme
|
|
351 |
parser.add_argument('--colab', type=bool, default=False, nargs='?', const=True, help='Is colab user or not')
|
352 |
parser.add_argument('--api_open', type=bool, default=False, nargs='?', const=True, help='Enable api or not in Gradio')
|
353 |
parser.add_argument('--inbrowser', type=bool, default=True, nargs='?', const=True, help='Whether to automatically start Gradio app or not')
|
354 |
-
parser.add_argument('--whisper_model_dir', type=str, default=
|
355 |
help='Directory path of the whisper model')
|
356 |
-
parser.add_argument('--faster_whisper_model_dir', type=str, default=
|
357 |
help='Directory path of the faster-whisper model')
|
358 |
parser.add_argument('--insanely_fast_whisper_model_dir', type=str,
|
359 |
-
default=
|
360 |
help='Directory path of the insanely-fast-whisper model')
|
361 |
-
parser.add_argument('--diarization_model_dir', type=str, default=
|
362 |
help='Directory path of the diarization model')
|
363 |
-
parser.add_argument('--nllb_model_dir', type=str, default=
|
364 |
help='Directory path of the Facebook NLLB model')
|
365 |
-
parser.add_argument('--output_dir', type=str, default=
|
366 |
_args = parser.parse_args()
|
367 |
|
368 |
if __name__ == "__main__":
|
|
|
1 |
import os
|
2 |
import argparse
|
3 |
import gradio as gr
|
4 |
+
import yaml
|
5 |
|
6 |
+
from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, WHISPER_MODELS_DIR,
|
7 |
+
INSANELY_FAST_WHISPER_MODELS_DIR, NLLB_MODELS_DIR, DEFAULT_PARAMETERS_CONFIG_PATH)
|
8 |
+
from modules.utils.files_manager import load_yaml
|
9 |
from modules.whisper.whisper_factory import WhisperFactory
|
10 |
from modules.whisper.faster_whisper_inference import FasterWhisperInference
|
11 |
from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
|
12 |
from modules.translation.nllb_inference import NLLBInference
|
13 |
+
from modules.ui.htmls import *
|
14 |
from modules.utils.youtube_manager import get_ytmetas
|
15 |
from modules.translation.deepl_api import DeepLAPI
|
16 |
from modules.whisper.whisper_parameter import *
|
|
|
36 |
self.deepl_api = DeepLAPI(
|
37 |
output_dir=os.path.join(self.args.output_dir, "translations")
|
38 |
)
|
39 |
+
self.default_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
|
40 |
|
41 |
def create_whisper_parameters(self):
|
42 |
+
whisper_params = self.default_params["whisper"]
|
43 |
+
vad_params = self.default_params["vad"]
|
44 |
+
diarization_params = self.default_params["diarization"]
|
45 |
+
|
46 |
with gr.Row():
|
47 |
+
dd_model = gr.Dropdown(choices=self.whisper_inf.available_models, value=whisper_params["model_size"],
|
48 |
label="Model")
|
49 |
dd_lang = gr.Dropdown(choices=["Automatic Detection"] + self.whisper_inf.available_langs,
|
50 |
+
value=whisper_params["lang"], label="Language")
|
51 |
dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt"], value="SRT", label="File Format")
|
52 |
with gr.Row():
|
53 |
+
cb_translate = gr.Checkbox(value=whisper_params["is_translate"], label="Translate to English?",
|
54 |
+
interactive=True)
|
55 |
with gr.Row():
|
56 |
+
cb_timestamp = gr.Checkbox(value=whisper_params["add_timestamp"], label="Add a timestamp to the end of the filename",
|
57 |
interactive=True)
|
58 |
with gr.Accordion("Advanced Parameters", open=False):
|
59 |
+
nb_beam_size = gr.Number(label="Beam Size", value=whisper_params["beam_size"], precision=0, interactive=True,
|
60 |
info="Beam size to use for decoding.")
|
61 |
+
nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=whisper_params["log_prob_threshold"], interactive=True,
|
62 |
info="If the average log probability over sampled tokens is below this value, treat as failed.")
|
63 |
+
nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=whisper_params["no_speech_threshold"], interactive=True,
|
64 |
info="If the no speech probability is higher than this value AND the average log probability over sampled tokens is below 'Log Prob Threshold', consider the segment as silent.")
|
65 |
dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types,
|
66 |
value=self.whisper_inf.current_compute_type, interactive=True,
|
67 |
info="Select the type of computation to perform.")
|
68 |
+
nb_best_of = gr.Number(label="Best Of", value=whisper_params["best_of"], interactive=True,
|
69 |
info="Number of candidates when sampling with non-zero temperature.")
|
70 |
+
nb_patience = gr.Number(label="Patience", value=whisper_params["patience"], interactive=True,
|
71 |
info="Beam search patience factor.")
|
72 |
+
cb_condition_on_previous_text = gr.Checkbox(label="Condition On Previous Text", value=whisper_params["condition_on_previous_text"],
|
73 |
interactive=True,
|
74 |
info="Condition on previous text during decoding.")
|
75 |
+
sld_prompt_reset_on_temperature = gr.Slider(label="Prompt Reset On Temperature", value=whisper_params["prompt_reset_on_temperature"],
|
76 |
minimum=0, maximum=1, step=0.01, interactive=True,
|
77 |
info="Resets prompt if temperature is above this value."
|
78 |
" Arg has effect only if 'Condition On Previous Text' is True.")
|
79 |
tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True,
|
80 |
info="Initial prompt to use for decoding.")
|
81 |
+
sd_temperature = gr.Slider(label="Temperature", value=whisper_params["temperature"], minimum=0.0,
|
82 |
+
step=0.01, maximum=1.0, interactive=True,
|
83 |
info="Temperature for sampling. It can be a tuple of temperatures, which will be successively used upon failures according to either `Compression Ratio Threshold` or `Log Prob Threshold`.")
|
84 |
+
nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", value=whisper_params["compression_ratio_threshold"],
|
85 |
+
interactive=True,
|
86 |
info="If the gzip compression ratio is above this value, treat as failed.")
|
87 |
with gr.Group(visible=isinstance(self.whisper_inf, FasterWhisperInference)):
|
88 |
+
nb_length_penalty = gr.Number(label="Length Penalty", value=whisper_params["length_penalty"],
|
89 |
info="Exponential length penalty constant.")
|
90 |
+
nb_repetition_penalty = gr.Number(label="Repetition Penalty", value=whisper_params["repetition_penalty"],
|
91 |
info="Penalty applied to the score of previously generated tokens (set > 1 to penalize).")
|
92 |
+
nb_no_repeat_ngram_size = gr.Number(label="No Repeat N-gram Size", value=whisper_params["no_repeat_ngram_size"],
|
93 |
+
precision=0,
|
94 |
info="Prevent repetitions of n-grams with this size (set 0 to disable).")
|
95 |
+
tb_prefix = gr.Textbox(label="Prefix", value=lambda: whisper_params["prefix"],
|
96 |
info="Optional text to provide as a prefix for the first window.")
|
97 |
+
cb_suppress_blank = gr.Checkbox(label="Suppress Blank", value=whisper_params["suppress_blank"],
|
98 |
info="Suppress blank outputs at the beginning of the sampling.")
|
99 |
+
tb_suppress_tokens = gr.Textbox(label="Suppress Tokens", value=whisper_params["suppress_tokens"],
|
100 |
info="List of token IDs to suppress. -1 will suppress a default set of symbols as defined in the model config.json file.")
|
101 |
+
nb_max_initial_timestamp = gr.Number(label="Max Initial Timestamp", value=whisper_params["max_initial_timestamp"],
|
102 |
info="The initial timestamp cannot be later than this.")
|
103 |
+
cb_word_timestamps = gr.Checkbox(label="Word Timestamps", value=whisper_params["word_timestamps"],
|
104 |
info="Extract word-level timestamps using the cross-attention pattern and dynamic time warping, and include the timestamps for each word in each segment.")
|
105 |
+
tb_prepend_punctuations = gr.Textbox(label="Prepend Punctuations", value=whisper_params["prepend_punctuations"],
|
106 |
info="If 'Word Timestamps' is True, merge these punctuation symbols with the next word.")
|
107 |
+
tb_append_punctuations = gr.Textbox(label="Append Punctuations", value=whisper_params["append_punctuations"],
|
108 |
info="If 'Word Timestamps' is True, merge these punctuation symbols with the previous word.")
|
109 |
+
nb_max_new_tokens = gr.Number(label="Max New Tokens", value=lambda: whisper_params["max_new_tokens"],
|
110 |
+
precision=0,
|
111 |
info="Maximum number of new tokens to generate per-chunk. If not set, the maximum will be set by the default max_length.")
|
112 |
+
nb_chunk_length = gr.Number(label="Chunk Length", value=lambda: whisper_params["chunk_length"],
|
113 |
+
precision=0,
|
114 |
info="The length of audio segments. If it is not None, it will overwrite the default chunk_length of the FeatureExtractor.")
|
115 |
nb_hallucination_silence_threshold = gr.Number(label="Hallucination Silence Threshold (sec)",
|
116 |
+
value=lambda: whisper_params["hallucination_silence_threshold"],
|
117 |
info="When 'Word Timestamps' is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected.")
|
118 |
+
tb_hotwords = gr.Textbox(label="Hotwords", value=lambda: whisper_params["hotwords"],
|
119 |
info="Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.")
|
120 |
+
nb_language_detection_threshold = gr.Number(label="Language Detection Threshold", value=lambda: whisper_params["language_detection_threshold"],
|
121 |
info="If the maximum probability of the language tokens is higher than this value, the language is detected.")
|
122 |
+
nb_language_detection_segments = gr.Number(label="Language Detection Segments", value=lambda: whisper_params["language_detection_segments"],
|
123 |
+
precision=0,
|
124 |
info="Number of segments to consider for the language detection.")
|
125 |
with gr.Group(visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
|
126 |
+
nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=whisper_params["chunk_length_s"],
|
127 |
+
precision=0)
|
128 |
+
nb_batch_size = gr.Number(label="Batch Size", value=whisper_params["batch_size"], precision=0)
|
129 |
|
130 |
with gr.Accordion("VAD", open=False):
|
131 |
+
cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=vad_params["vad_filter"],
|
132 |
+
interactive=True)
|
133 |
+
sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=vad_params["threshold"],
|
134 |
info="Lower it to be more sensitive to small sounds.")
|
135 |
+
nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=vad_params["min_speech_duration_ms"],
|
136 |
info="Final speech chunks shorter than this time are thrown out")
|
137 |
+
nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", value=vad_params["max_speech_duration_s"],
|
138 |
info="Maximum duration of speech chunks in \"seconds\". Chunks longer"
|
139 |
" than this time will be split at the timestamp of the last silence that"
|
140 |
" lasts more than 100ms (if any), to prevent aggressive cutting.")
|
141 |
+
nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=vad_params["min_silence_duration_ms"],
|
142 |
info="In the end of each speech chunk wait for this time"
|
143 |
" before separating it")
|
144 |
+
nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=vad_params["speech_pad_ms"],
|
145 |
info="Final speech chunks are padded by this time each side")
|
146 |
|
147 |
with gr.Accordion("Diarization", open=False):
|
148 |
+
cb_diarize = gr.Checkbox(label="Enable Diarization", value=diarization_params["is_diarize"])
|
149 |
+
tb_hf_token = gr.Text(label="HuggingFace Token", value=diarization_params["hf_token"],
|
150 |
info="This is only needed the first time you download the model. If you already have models, you don't need to enter. To download the model, you must manually go to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and agree to their requirement.")
|
151 |
dd_diarization_device = gr.Dropdown(label="Device",
|
152 |
choices=self.whisper_inf.diarizer.get_available_device(),
|
|
|
180 |
)
|
181 |
|
182 |
def launch(self):
|
183 |
+
translation_params = self.default_params["translation"]
|
184 |
+
deepl_params = translation_params["deepl"]
|
185 |
+
nllb_params = translation_params["nllb"]
|
186 |
+
|
187 |
with self.app:
|
188 |
with gr.Row():
|
189 |
with gr.Column():
|
|
|
268 |
|
269 |
with gr.TabItem("DeepL API"): # sub tab1
|
270 |
with gr.Row():
|
271 |
+
tb_api_key = gr.Textbox(label="Your Auth Key (API KEY)", value=deepl_params["api_key"])
|
|
|
272 |
with gr.Row():
|
273 |
+
dd_source_lang = gr.Dropdown(label="Source Language", value=deepl_params["source_lang"],
|
274 |
+
choices=list(
|
275 |
self.deepl_api.available_source_langs.keys()))
|
276 |
+
dd_target_lang = gr.Dropdown(label="Target Language", value=deepl_params["target_lang"],
|
277 |
+
choices=list(self.deepl_api.available_target_langs.keys()))
|
|
|
278 |
with gr.Row():
|
279 |
+
cb_is_pro = gr.Checkbox(label="Pro User?", value=deepl_params["is_pro"])
|
280 |
with gr.Row():
|
281 |
+
cb_timestamp = gr.Checkbox(value=translation_params["add_timestamp"], label="Add a timestamp to the end of the filename",
|
282 |
interactive=True)
|
283 |
with gr.Row():
|
284 |
btn_run = gr.Button("TRANSLATE SUBTITLE FILE", variant="primary")
|
|
|
288 |
btn_openfolder = gr.Button('📂', scale=1)
|
289 |
|
290 |
btn_run.click(fn=self.deepl_api.translate_deepl,
|
291 |
+
inputs=[tb_api_key, file_subs, dd_source_lang, dd_target_lang,
|
292 |
+
cb_is_pro, cb_timestamp],
|
293 |
outputs=[tb_indicator, files_subtitles])
|
294 |
|
295 |
+
btn_openfolder.click(fn=lambda: self.open_folder(os.path.join(self.args.output_dir, "translations")),
|
296 |
inputs=None,
|
297 |
outputs=None)
|
298 |
|
299 |
with gr.TabItem("NLLB"): # sub tab2
|
300 |
with gr.Row():
|
301 |
+
dd_model_size = gr.Dropdown(label="Model", value=nllb_params["model_size"],
|
302 |
choices=self.nllb_inf.available_models)
|
303 |
+
dd_source_lang = gr.Dropdown(label="Source Language", value=nllb_params["source_lang"],
|
304 |
+
choices=self.nllb_inf.available_source_langs)
|
305 |
+
dd_target_lang = gr.Dropdown(label="Target Language", value=nllb_params["target_lang"],
|
306 |
+
choices=self.nllb_inf.available_target_langs)
|
307 |
with gr.Row():
|
308 |
+
nb_max_length = gr.Number(label="Max Length Per Line", value=nllb_params["max_length"],
|
309 |
+
precision=0)
|
310 |
with gr.Row():
|
311 |
+
cb_timestamp = gr.Checkbox(value=translation_params["add_timestamp"], label="Add a timestamp to the end of the filename",
|
312 |
interactive=True)
|
313 |
with gr.Row():
|
314 |
btn_run = gr.Button("TRANSLATE SUBTITLE FILE", variant="primary")
|
|
|
320 |
md_vram_table = gr.HTML(NLLB_VRAM_TABLE, elem_id="md_nllb_vram_table")
|
321 |
|
322 |
btn_run.click(fn=self.nllb_inf.translate_file,
|
323 |
+
inputs=[file_subs, dd_model_size, dd_source_lang, dd_target_lang,
|
324 |
nb_max_length, cb_timestamp],
|
325 |
outputs=[tb_indicator, files_subtitles])
|
326 |
|
327 |
+
btn_openfolder.click(fn=lambda: self.open_folder(os.path.join(self.args.output_dir, "translations")),
|
328 |
inputs=None,
|
329 |
outputs=None)
|
330 |
|
|
|
372 |
parser.add_argument('--colab', type=bool, default=False, nargs='?', const=True, help='Is colab user or not')
|
373 |
parser.add_argument('--api_open', type=bool, default=False, nargs='?', const=True, help='Enable api or not in Gradio')
|
374 |
parser.add_argument('--inbrowser', type=bool, default=True, nargs='?', const=True, help='Whether to automatically start Gradio app or not')
|
375 |
+
parser.add_argument('--whisper_model_dir', type=str, default=WHISPER_MODELS_DIR,
|
376 |
help='Directory path of the whisper model')
|
377 |
+
parser.add_argument('--faster_whisper_model_dir', type=str, default=FASTER_WHISPER_MODELS_DIR,
|
378 |
help='Directory path of the faster-whisper model')
|
379 |
parser.add_argument('--insanely_fast_whisper_model_dir', type=str,
|
380 |
+
default=INSANELY_FAST_WHISPER_MODELS_DIR,
|
381 |
help='Directory path of the insanely-fast-whisper model')
|
382 |
+
parser.add_argument('--diarization_model_dir', type=str, default=DIARIZATION_MODELS_DIR,
|
383 |
help='Directory path of the diarization model')
|
384 |
+
parser.add_argument('--nllb_model_dir', type=str, default=NLLB_MODELS_DIR,
|
385 |
help='Directory path of the Facebook NLLB model')
|
386 |
+
parser.add_argument('--output_dir', type=str, default=OUTPUT_DIR, help='Directory path of the outputs')
|
387 |
_args = parser.parse_args()
|
388 |
|
389 |
if __name__ == "__main__":
|
configs/default_parameters.yaml
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
whisper:
|
2 |
+
model_size: "large-v2"
|
3 |
+
lang: "Automatic Detection"
|
4 |
+
is_translate: false
|
5 |
+
beam_size: 5
|
6 |
+
log_prob_threshold: -1
|
7 |
+
no_speech_threshold: 0.6
|
8 |
+
best_of: 5
|
9 |
+
patience: 1
|
10 |
+
condition_on_previous_text: true
|
11 |
+
prompt_reset_on_temperature: 0.5
|
12 |
+
initial_prompt: null
|
13 |
+
temperature: 0
|
14 |
+
compression_ratio_threshold: 2.4
|
15 |
+
chunk_length_s: 30
|
16 |
+
batch_size: 24
|
17 |
+
length_penalty: 1
|
18 |
+
repetition_penalty: 1
|
19 |
+
no_repeat_ngram_size: 0
|
20 |
+
prefix: null
|
21 |
+
suppress_blank: true
|
22 |
+
suppress_tokens: "[-1]"
|
23 |
+
max_initial_timestamp: 1
|
24 |
+
word_timestamps: false
|
25 |
+
prepend_punctuations: "\"'“¿([{-"
|
26 |
+
append_punctuations: "\"'.。,,!!??::”)]}、"
|
27 |
+
max_new_tokens: null
|
28 |
+
chunk_length: null
|
29 |
+
hallucination_silence_threshold: null
|
30 |
+
hotwords: null
|
31 |
+
language_detection_threshold: null
|
32 |
+
language_detection_segments: 1
|
33 |
+
add_timestamp: true
|
34 |
+
|
35 |
+
vad:
|
36 |
+
vad_filter: false
|
37 |
+
threshold: 0.5
|
38 |
+
min_speech_duration_ms: 250
|
39 |
+
max_speech_duration_s: 9999
|
40 |
+
min_silence_duration_ms: 2000
|
41 |
+
speech_pad_ms: 400
|
42 |
+
|
43 |
+
diarization:
|
44 |
+
is_diarize: false
|
45 |
+
hf_token: ""
|
46 |
+
|
47 |
+
translation:
|
48 |
+
deepl:
|
49 |
+
api_key: ""
|
50 |
+
is_pro: false
|
51 |
+
source_lang: "Automatic Detection"
|
52 |
+
target_lang: "English"
|
53 |
+
nllb:
|
54 |
+
model_size: "facebook/nllb-200-1.3B"
|
55 |
+
source_lang: null
|
56 |
+
target_lang: null
|
57 |
+
max_length: 200
|
58 |
+
add_timestamp: true
|
modules/diarize/diarize_pipeline.py
CHANGED
@@ -7,6 +7,7 @@ from pyannote.audio import Pipeline
|
|
7 |
from typing import Optional, Union
|
8 |
import torch
|
9 |
|
|
|
10 |
from modules.diarize.audio_loader import load_audio, SAMPLE_RATE
|
11 |
|
12 |
|
@@ -14,7 +15,7 @@ class DiarizationPipeline:
|
|
14 |
def __init__(
|
15 |
self,
|
16 |
model_name="pyannote/speaker-diarization-3.1",
|
17 |
-
cache_dir: str =
|
18 |
use_auth_token=None,
|
19 |
device: Optional[Union[str, torch.device]] = "cpu",
|
20 |
):
|
|
|
7 |
from typing import Optional, Union
|
8 |
import torch
|
9 |
|
10 |
+
from modules.utils.paths import DIARIZATION_MODELS_DIR
|
11 |
from modules.diarize.audio_loader import load_audio, SAMPLE_RATE
|
12 |
|
13 |
|
|
|
15 |
def __init__(
|
16 |
self,
|
17 |
model_name="pyannote/speaker-diarization-3.1",
|
18 |
+
cache_dir: str = DIARIZATION_MODELS_DIR,
|
19 |
use_auth_token=None,
|
20 |
device: Optional[Union[str, torch.device]] = "cpu",
|
21 |
):
|
modules/diarize/diarizer.py
CHANGED
@@ -5,13 +5,14 @@ import numpy as np
|
|
5 |
import time
|
6 |
import logging
|
7 |
|
|
|
8 |
from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers
|
9 |
from modules.diarize.audio_loader import load_audio
|
10 |
|
11 |
|
12 |
class Diarizer:
|
13 |
def __init__(self,
|
14 |
-
model_dir: str =
|
15 |
):
|
16 |
self.device = self.get_device()
|
17 |
self.available_device = self.get_available_device()
|
|
|
5 |
import time
|
6 |
import logging
|
7 |
|
8 |
+
from modules.utils.paths import DIARIZATION_MODELS_DIR
|
9 |
from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers
|
10 |
from modules.diarize.audio_loader import load_audio
|
11 |
|
12 |
|
13 |
class Diarizer:
|
14 |
def __init__(self,
|
15 |
+
model_dir: str = DIARIZATION_MODELS_DIR
|
16 |
):
|
17 |
self.device = self.get_device()
|
18 |
self.available_device = self.get_available_device()
|
modules/translation/deepl_api.py
CHANGED
@@ -4,7 +4,9 @@ import os
|
|
4 |
from datetime import datetime
|
5 |
import gradio as gr
|
6 |
|
|
|
7 |
from modules.utils.subtitle_manager import *
|
|
|
8 |
|
9 |
"""
|
10 |
This is written with reference to the DeepL API documentation.
|
@@ -83,7 +85,7 @@ DEEPL_AVAILABLE_SOURCE_LANGS = {
|
|
83 |
|
84 |
class DeepLAPI:
|
85 |
def __init__(self,
|
86 |
-
output_dir: str =
|
87 |
):
|
88 |
self.api_interval = 1
|
89 |
self.max_text_batch_size = 50
|
@@ -124,6 +126,13 @@ class DeepLAPI:
|
|
124 |
String to return to gr.Textbox()
|
125 |
Files to return to gr.Files()
|
126 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
files_info = {}
|
129 |
for fileobj in fileobjs:
|
@@ -198,4 +207,20 @@ class DeepLAPI:
|
|
198 |
}
|
199 |
response = requests.post(url, headers=headers, data=data).json()
|
200 |
time.sleep(self.api_interval)
|
201 |
-
return response["translations"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
from datetime import datetime
|
5 |
import gradio as gr
|
6 |
|
7 |
+
from modules.utils.paths import TRANSLATION_OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH
|
8 |
from modules.utils.subtitle_manager import *
|
9 |
+
from modules.utils.files_manager import load_yaml, save_yaml
|
10 |
|
11 |
"""
|
12 |
This is written with reference to the DeepL API documentation.
|
|
|
85 |
|
86 |
class DeepLAPI:
|
87 |
def __init__(self,
|
88 |
+
output_dir: str = TRANSLATION_OUTPUT_DIR
|
89 |
):
|
90 |
self.api_interval = 1
|
91 |
self.max_text_batch_size = 50
|
|
|
126 |
String to return to gr.Textbox()
|
127 |
Files to return to gr.Files()
|
128 |
"""
|
129 |
+
self.cache_parameters(
|
130 |
+
api_key=auth_key,
|
131 |
+
is_pro=is_pro,
|
132 |
+
source_lang=source_lang,
|
133 |
+
target_lang=target_lang,
|
134 |
+
add_timestamp=add_timestamp
|
135 |
+
)
|
136 |
|
137 |
files_info = {}
|
138 |
for fileobj in fileobjs:
|
|
|
207 |
}
|
208 |
response = requests.post(url, headers=headers, data=data).json()
|
209 |
time.sleep(self.api_interval)
|
210 |
+
return response["translations"]
|
211 |
+
|
212 |
+
@staticmethod
|
213 |
+
def cache_parameters(api_key: str,
|
214 |
+
is_pro: bool,
|
215 |
+
source_lang: str,
|
216 |
+
target_lang: str,
|
217 |
+
add_timestamp: bool):
|
218 |
+
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
|
219 |
+
cached_params["translation"]["deepl"] = {
|
220 |
+
"api_key": api_key,
|
221 |
+
"is_pro": is_pro,
|
222 |
+
"source_lang": source_lang,
|
223 |
+
"target_lang": target_lang
|
224 |
+
}
|
225 |
+
cached_params["translation"]["add_timestamp"] = add_timestamp
|
226 |
+
save_yaml(cached_params, DEFAULT_PARAMETERS_CONFIG_PATH)
|
modules/translation/nllb_inference.py
CHANGED
@@ -2,13 +2,14 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
|
2 |
import gradio as gr
|
3 |
import os
|
4 |
|
|
|
5 |
from modules.translation.translation_base import TranslationBase
|
6 |
|
7 |
|
8 |
class NLLBInference(TranslationBase):
|
9 |
def __init__(self,
|
10 |
-
model_dir: str =
|
11 |
-
output_dir: str =
|
12 |
):
|
13 |
super().__init__(
|
14 |
model_dir=model_dir,
|
|
|
2 |
import gradio as gr
|
3 |
import os
|
4 |
|
5 |
+
from modules.utils.paths import TRANSLATION_OUTPUT_DIR, NLLB_MODELS_DIR
|
6 |
from modules.translation.translation_base import TranslationBase
|
7 |
|
8 |
|
9 |
class NLLBInference(TranslationBase):
|
10 |
def __init__(self,
|
11 |
+
model_dir: str = NLLB_MODELS_DIR,
|
12 |
+
output_dir: str = TRANSLATION_OUTPUT_DIR
|
13 |
):
|
14 |
super().__init__(
|
15 |
model_dir=model_dir,
|
modules/translation/translation_base.py
CHANGED
@@ -7,12 +7,14 @@ from datetime import datetime
|
|
7 |
|
8 |
from modules.whisper.whisper_parameter import *
|
9 |
from modules.utils.subtitle_manager import *
|
|
|
|
|
10 |
|
11 |
|
12 |
class TranslationBase(ABC):
|
13 |
def __init__(self,
|
14 |
-
model_dir: str =
|
15 |
-
output_dir: str =
|
16 |
):
|
17 |
super().__init__()
|
18 |
self.model = None
|
@@ -75,6 +77,12 @@ class TranslationBase(ABC):
|
|
75 |
Files to return to gr.Files()
|
76 |
"""
|
77 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
self.update_model(model_size=model_size,
|
79 |
src_lang=src_lang,
|
80 |
tgt_lang=tgt_lang,
|
@@ -149,3 +157,19 @@ class TranslationBase(ABC):
|
|
149 |
for file_path in file_paths:
|
150 |
if file_path and os.path.exists(file_path):
|
151 |
os.remove(file_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
from modules.whisper.whisper_parameter import *
|
9 |
from modules.utils.subtitle_manager import *
|
10 |
+
from modules.utils.files_manager import load_yaml, save_yaml
|
11 |
+
from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH, NLLB_MODELS_DIR, TRANSLATION_OUTPUT_DIR
|
12 |
|
13 |
|
14 |
class TranslationBase(ABC):
|
15 |
def __init__(self,
|
16 |
+
model_dir: str = NLLB_MODELS_DIR,
|
17 |
+
output_dir: str = TRANSLATION_OUTPUT_DIR
|
18 |
):
|
19 |
super().__init__()
|
20 |
self.model = None
|
|
|
77 |
Files to return to gr.Files()
|
78 |
"""
|
79 |
try:
|
80 |
+
self.cache_parameters(model_size=model_size,
|
81 |
+
src_lang=src_lang,
|
82 |
+
tgt_lang=tgt_lang,
|
83 |
+
max_length=max_length,
|
84 |
+
add_timestamp=add_timestamp)
|
85 |
+
|
86 |
self.update_model(model_size=model_size,
|
87 |
src_lang=src_lang,
|
88 |
tgt_lang=tgt_lang,
|
|
|
157 |
for file_path in file_paths:
|
158 |
if file_path and os.path.exists(file_path):
|
159 |
os.remove(file_path)
|
160 |
+
|
161 |
+
@staticmethod
|
162 |
+
def cache_parameters(model_size: str,
|
163 |
+
src_lang: str,
|
164 |
+
tgt_lang: str,
|
165 |
+
max_length: int,
|
166 |
+
add_timestamp: bool):
|
167 |
+
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
|
168 |
+
cached_params["translation"]["nllb"] = {
|
169 |
+
"model_size": model_size,
|
170 |
+
"source_lang": src_lang,
|
171 |
+
"target_lang": tgt_lang,
|
172 |
+
"max_length": max_length,
|
173 |
+
}
|
174 |
+
cached_params["translation"]["add_timestamp"] = add_timestamp
|
175 |
+
save_yaml(cached_params, DEFAULT_PARAMETERS_CONFIG_PATH)
|
{ui → modules/ui}/__init__.py
RENAMED
File without changes
|
{ui → modules/ui}/htmls.py
RENAMED
File without changes
|
modules/utils/files_manager.py
CHANGED
@@ -1,8 +1,32 @@
|
|
1 |
import os
|
2 |
import fnmatch
|
3 |
-
|
4 |
from gradio.utils import NamedString
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
def get_media_files(folder_path, include_sub_directory=False):
|
8 |
video_extensions = ['*.mp4', '*.mkv', '*.flv', '*.avi', '*.mov', '*.wmv']
|
|
|
1 |
import os
|
2 |
import fnmatch
|
3 |
+
from ruamel.yaml import YAML
|
4 |
from gradio.utils import NamedString
|
5 |
|
6 |
+
from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH
|
7 |
+
|
8 |
+
|
9 |
+
def load_yaml(path: str = DEFAULT_PARAMETERS_CONFIG_PATH):
|
10 |
+
yaml = YAML(typ="safe")
|
11 |
+
yaml.preserve_quotes = True
|
12 |
+
with open(path, 'r', encoding='utf-8') as file:
|
13 |
+
config = yaml.load(file)
|
14 |
+
return config
|
15 |
+
|
16 |
+
|
17 |
+
def save_yaml(data: dict, path: str = DEFAULT_PARAMETERS_CONFIG_PATH):
|
18 |
+
yaml = YAML(typ="safe")
|
19 |
+
yaml.map_indent = 2
|
20 |
+
yaml.sequence_indent = 4
|
21 |
+
yaml.sequence_dash_offset = 2
|
22 |
+
yaml.preserve_quotes = True
|
23 |
+
yaml.default_flow_style = False
|
24 |
+
yaml.sort_base_mapping_type_on_output = False
|
25 |
+
|
26 |
+
with open(path, 'w', encoding='utf-8') as file:
|
27 |
+
yaml.dump(data, file)
|
28 |
+
return path
|
29 |
+
|
30 |
|
31 |
def get_media_files(folder_path, include_sub_directory=False):
|
32 |
video_extensions = ['*.mp4', '*.mkv', '*.flv', '*.avi', '*.mov', '*.wmv']
|
modules/utils/paths.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
WEBUI_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
4 |
+
MODELS_DIR = os.path.join(WEBUI_DIR, "models")
|
5 |
+
WHISPER_MODELS_DIR = os.path.join(MODELS_DIR, "Whisper")
|
6 |
+
FASTER_WHISPER_MODELS_DIR = os.path.join(WHISPER_MODELS_DIR, "faster-whisper")
|
7 |
+
INSANELY_FAST_WHISPER_MODELS_DIR = os.path.join(WHISPER_MODELS_DIR, "insanely-fast-whisper")
|
8 |
+
NLLB_MODELS_DIR = os.path.join(MODELS_DIR, "NLLB")
|
9 |
+
DIARIZATION_MODELS_DIR = os.path.join(MODELS_DIR, "Diarization")
|
10 |
+
CONFIGS_DIR = os.path.join(WEBUI_DIR, "configs")
|
11 |
+
DEFAULT_PARAMETERS_CONFIG_PATH = os.path.join(CONFIGS_DIR, "default_parameters.yaml")
|
12 |
+
OUTPUT_DIR = os.path.join(WEBUI_DIR, "outputs")
|
13 |
+
TRANSLATION_OUTPUT_DIR = os.path.join(OUTPUT_DIR, "translations")
|
14 |
+
|
15 |
+
for dir_path in [MODELS_DIR,
|
16 |
+
WHISPER_MODELS_DIR,
|
17 |
+
FASTER_WHISPER_MODELS_DIR,
|
18 |
+
INSANELY_FAST_WHISPER_MODELS_DIR,
|
19 |
+
NLLB_MODELS_DIR,
|
20 |
+
DIARIZATION_MODELS_DIR,
|
21 |
+
CONFIGS_DIR,
|
22 |
+
OUTPUT_DIR,
|
23 |
+
TRANSLATION_OUTPUT_DIR]:
|
24 |
+
os.makedirs(dir_path, exist_ok=True)
|
modules/whisper/faster_whisper_inference.py
CHANGED
@@ -11,15 +11,16 @@ import whisper
|
|
11 |
import gradio as gr
|
12 |
from argparse import Namespace
|
13 |
|
|
|
14 |
from modules.whisper.whisper_parameter import *
|
15 |
from modules.whisper.whisper_base import WhisperBase
|
16 |
|
17 |
|
18 |
class FasterWhisperInference(WhisperBase):
|
19 |
def __init__(self,
|
20 |
-
model_dir: str =
|
21 |
-
diarization_model_dir: str =
|
22 |
-
output_dir: str =
|
23 |
):
|
24 |
super().__init__(
|
25 |
model_dir=model_dir,
|
@@ -163,14 +164,12 @@ class FasterWhisperInference(WhisperBase):
|
|
163 |
wrong_dirs = [".locks"]
|
164 |
existing_models = list(set(existing_models) - set(wrong_dirs))
|
165 |
|
166 |
-
webui_dir = os.getcwd()
|
167 |
-
|
168 |
for model_name in existing_models:
|
169 |
if faster_whisper_prefix in model_name:
|
170 |
model_name = model_name[len(faster_whisper_prefix):]
|
171 |
|
172 |
if model_name not in whisper.available_models():
|
173 |
-
model_paths[model_name] = os.path.join(
|
174 |
return model_paths
|
175 |
|
176 |
@staticmethod
|
|
|
11 |
import gradio as gr
|
12 |
from argparse import Namespace
|
13 |
|
14 |
+
from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR)
|
15 |
from modules.whisper.whisper_parameter import *
|
16 |
from modules.whisper.whisper_base import WhisperBase
|
17 |
|
18 |
|
19 |
class FasterWhisperInference(WhisperBase):
|
20 |
def __init__(self,
|
21 |
+
model_dir: str = FASTER_WHISPER_MODELS_DIR,
|
22 |
+
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
23 |
+
output_dir: str = OUTPUT_DIR,
|
24 |
):
|
25 |
super().__init__(
|
26 |
model_dir=model_dir,
|
|
|
164 |
wrong_dirs = [".locks"]
|
165 |
existing_models = list(set(existing_models) - set(wrong_dirs))
|
166 |
|
|
|
|
|
167 |
for model_name in existing_models:
|
168 |
if faster_whisper_prefix in model_name:
|
169 |
model_name = model_name[len(faster_whisper_prefix):]
|
170 |
|
171 |
if model_name not in whisper.available_models():
|
172 |
+
model_paths[model_name] = os.path.join(self.model_dir, model_name)
|
173 |
return model_paths
|
174 |
|
175 |
@staticmethod
|
modules/whisper/insanely_fast_whisper_inference.py
CHANGED
@@ -11,15 +11,16 @@ import whisper
|
|
11 |
from rich.progress import Progress, TimeElapsedColumn, BarColumn, TextColumn
|
12 |
from argparse import Namespace
|
13 |
|
|
|
14 |
from modules.whisper.whisper_parameter import *
|
15 |
from modules.whisper.whisper_base import WhisperBase
|
16 |
|
17 |
|
18 |
class InsanelyFastWhisperInference(WhisperBase):
|
19 |
def __init__(self,
|
20 |
-
model_dir: str =
|
21 |
-
diarization_model_dir: str =
|
22 |
-
output_dir: str =
|
23 |
):
|
24 |
super().__init__(
|
25 |
model_dir=model_dir,
|
|
|
11 |
from rich.progress import Progress, TimeElapsedColumn, BarColumn, TextColumn
|
12 |
from argparse import Namespace
|
13 |
|
14 |
+
from modules.utils.paths import (INSANELY_FAST_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR)
|
15 |
from modules.whisper.whisper_parameter import *
|
16 |
from modules.whisper.whisper_base import WhisperBase
|
17 |
|
18 |
|
19 |
class InsanelyFastWhisperInference(WhisperBase):
|
20 |
def __init__(self,
|
21 |
+
model_dir: str = INSANELY_FAST_WHISPER_MODELS_DIR,
|
22 |
+
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
23 |
+
output_dir: str = OUTPUT_DIR,
|
24 |
):
|
25 |
super().__init__(
|
26 |
model_dir=model_dir,
|
modules/whisper/whisper_Inference.py
CHANGED
@@ -7,15 +7,16 @@ import torch
|
|
7 |
import os
|
8 |
from argparse import Namespace
|
9 |
|
|
|
10 |
from modules.whisper.whisper_base import WhisperBase
|
11 |
from modules.whisper.whisper_parameter import *
|
12 |
|
13 |
|
14 |
class WhisperInference(WhisperBase):
|
15 |
def __init__(self,
|
16 |
-
model_dir: str =
|
17 |
-
diarization_model_dir: str =
|
18 |
-
output_dir: str =
|
19 |
):
|
20 |
super().__init__(
|
21 |
model_dir=model_dir,
|
|
|
7 |
import os
|
8 |
from argparse import Namespace
|
9 |
|
10 |
+
from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR)
|
11 |
from modules.whisper.whisper_base import WhisperBase
|
12 |
from modules.whisper.whisper_parameter import *
|
13 |
|
14 |
|
15 |
class WhisperInference(WhisperBase):
|
16 |
def __init__(self,
|
17 |
+
model_dir: str = WHISPER_MODELS_DIR,
|
18 |
+
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
19 |
+
output_dir: str = OUTPUT_DIR,
|
20 |
):
|
21 |
super().__init__(
|
22 |
model_dir=model_dir,
|
modules/whisper/whisper_base.py
CHANGED
@@ -9,9 +9,10 @@ from datetime import datetime
|
|
9 |
from faster_whisper.vad import VadOptions
|
10 |
from dataclasses import astuple
|
11 |
|
|
|
12 |
from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
|
13 |
from modules.utils.youtube_manager import get_ytdata, get_ytaudio
|
14 |
-
from modules.utils.files_manager import get_media_files, format_gradio_files
|
15 |
from modules.whisper.whisper_parameter import *
|
16 |
from modules.diarize.diarizer import Diarizer
|
17 |
from modules.vad.silero_vad import SileroVAD
|
@@ -19,9 +20,9 @@ from modules.vad.silero_vad import SileroVAD
|
|
19 |
|
20 |
class WhisperBase(ABC):
|
21 |
def __init__(self,
|
22 |
-
model_dir: str =
|
23 |
-
diarization_model_dir: str =
|
24 |
-
output_dir: str =
|
25 |
):
|
26 |
self.model_dir = model_dir
|
27 |
self.output_dir = output_dir
|
@@ -61,7 +62,8 @@ class WhisperBase(ABC):
|
|
61 |
|
62 |
def run(self,
|
63 |
audio: Union[str, BinaryIO, np.ndarray],
|
64 |
-
progress: gr.Progress,
|
|
|
65 |
*whisper_params,
|
66 |
) -> Tuple[List[dict], float]:
|
67 |
"""
|
@@ -75,6 +77,8 @@ class WhisperBase(ABC):
|
|
75 |
Audio input. This can be file path or binary type.
|
76 |
progress: gr.Progress
|
77 |
Indicator to show progress directly in gradio.
|
|
|
|
|
78 |
*whisper_params: tuple
|
79 |
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
|
80 |
|
@@ -87,6 +91,11 @@ class WhisperBase(ABC):
|
|
87 |
"""
|
88 |
params = WhisperParameters.as_value(*whisper_params)
|
89 |
|
|
|
|
|
|
|
|
|
|
|
90 |
if params.lang == "Automatic Detection":
|
91 |
params.lang = None
|
92 |
else:
|
@@ -178,6 +187,7 @@ class WhisperBase(ABC):
|
|
178 |
transcribed_segments, time_for_task = self.run(
|
179 |
file.name,
|
180 |
progress,
|
|
|
181 |
*whisper_params,
|
182 |
)
|
183 |
|
@@ -301,6 +311,7 @@ class WhisperBase(ABC):
|
|
301 |
transcribed_segments, time_for_task = self.run(
|
302 |
audio,
|
303 |
progress,
|
|
|
304 |
*whisper_params,
|
305 |
)
|
306 |
|
@@ -434,3 +445,15 @@ class WhisperBase(ABC):
|
|
434 |
for file_path in file_paths:
|
435 |
if file_path and os.path.exists(file_path):
|
436 |
os.remove(file_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
from faster_whisper.vad import VadOptions
|
10 |
from dataclasses import astuple
|
11 |
|
12 |
+
from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH)
|
13 |
from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
|
14 |
from modules.utils.youtube_manager import get_ytdata, get_ytaudio
|
15 |
+
from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml
|
16 |
from modules.whisper.whisper_parameter import *
|
17 |
from modules.diarize.diarizer import Diarizer
|
18 |
from modules.vad.silero_vad import SileroVAD
|
|
|
20 |
|
21 |
class WhisperBase(ABC):
|
22 |
def __init__(self,
|
23 |
+
model_dir: str = WHISPER_MODELS_DIR,
|
24 |
+
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
25 |
+
output_dir: str = OUTPUT_DIR,
|
26 |
):
|
27 |
self.model_dir = model_dir
|
28 |
self.output_dir = output_dir
|
|
|
62 |
|
63 |
def run(self,
|
64 |
audio: Union[str, BinaryIO, np.ndarray],
|
65 |
+
progress: gr.Progress = gr.Progress(),
|
66 |
+
add_timestamp: bool = True,
|
67 |
*whisper_params,
|
68 |
) -> Tuple[List[dict], float]:
|
69 |
"""
|
|
|
77 |
Audio input. This can be file path or binary type.
|
78 |
progress: gr.Progress
|
79 |
Indicator to show progress directly in gradio.
|
80 |
+
add_timestamp: bool
|
81 |
+
Whether to add a timestamp at the end of the filename.
|
82 |
*whisper_params: tuple
|
83 |
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
|
84 |
|
|
|
91 |
"""
|
92 |
params = WhisperParameters.as_value(*whisper_params)
|
93 |
|
94 |
+
self.cache_parameters(
|
95 |
+
whisper_params=params,
|
96 |
+
add_timestamp=add_timestamp
|
97 |
+
)
|
98 |
+
|
99 |
if params.lang == "Automatic Detection":
|
100 |
params.lang = None
|
101 |
else:
|
|
|
187 |
transcribed_segments, time_for_task = self.run(
|
188 |
file.name,
|
189 |
progress,
|
190 |
+
add_timestamp,
|
191 |
*whisper_params,
|
192 |
)
|
193 |
|
|
|
311 |
transcribed_segments, time_for_task = self.run(
|
312 |
audio,
|
313 |
progress,
|
314 |
+
add_timestamp,
|
315 |
*whisper_params,
|
316 |
)
|
317 |
|
|
|
445 |
for file_path in file_paths:
|
446 |
if file_path and os.path.exists(file_path):
|
447 |
os.remove(file_path)
|
448 |
+
|
449 |
+
@staticmethod
|
450 |
+
def cache_parameters(
|
451 |
+
whisper_params: WhisperValues,
|
452 |
+
add_timestamp: bool
|
453 |
+
):
|
454 |
+
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
|
455 |
+
cached_whisper_param = whisper_params.to_yaml()
|
456 |
+
cached_yaml = {**cached_params, **cached_whisper_param}
|
457 |
+
cached_yaml["whisper"]["add_timestamp"] = add_timestamp
|
458 |
+
|
459 |
+
save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
|
modules/whisper/whisper_factory.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
from typing import Optional
|
2 |
import os
|
3 |
|
|
|
|
|
4 |
from modules.whisper.faster_whisper_inference import FasterWhisperInference
|
5 |
from modules.whisper.whisper_Inference import WhisperInference
|
6 |
from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
|
@@ -11,11 +13,11 @@ class WhisperFactory:
|
|
11 |
@staticmethod
|
12 |
def create_whisper_inference(
|
13 |
whisper_type: str,
|
14 |
-
whisper_model_dir: str =
|
15 |
-
faster_whisper_model_dir: str =
|
16 |
-
insanely_fast_whisper_model_dir: str =
|
17 |
-
diarization_model_dir: str =
|
18 |
-
output_dir: str =
|
19 |
) -> "WhisperBase":
|
20 |
"""
|
21 |
Create a whisper inference class based on the provided whisper_type.
|
|
|
1 |
from typing import Optional
|
2 |
import os
|
3 |
|
4 |
+
from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR,
|
5 |
+
INSANELY_FAST_WHISPER_MODELS_DIR, WHISPER_MODELS_DIR)
|
6 |
from modules.whisper.faster_whisper_inference import FasterWhisperInference
|
7 |
from modules.whisper.whisper_Inference import WhisperInference
|
8 |
from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
|
|
|
13 |
@staticmethod
|
14 |
def create_whisper_inference(
|
15 |
whisper_type: str,
|
16 |
+
whisper_model_dir: str = WHISPER_MODELS_DIR,
|
17 |
+
faster_whisper_model_dir: str = FASTER_WHISPER_MODELS_DIR,
|
18 |
+
insanely_fast_whisper_model_dir: str = INSANELY_FAST_WHISPER_MODELS_DIR,
|
19 |
+
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
20 |
+
output_dir: str = OUTPUT_DIR,
|
21 |
) -> "WhisperBase":
|
22 |
"""
|
23 |
Create a whisper inference class based on the provided whisper_type.
|
modules/whisper/whisper_parameter.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
from dataclasses import dataclass, fields
|
2 |
import gradio as gr
|
3 |
-
from typing import Optional
|
|
|
4 |
|
5 |
|
6 |
@dataclass
|
@@ -274,4 +275,54 @@ class WhisperValues:
|
|
274 |
language_detection_segments: int
|
275 |
"""
|
276 |
A data class to use Whisper parameters.
|
277 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from dataclasses import dataclass, fields
|
2 |
import gradio as gr
|
3 |
+
from typing import Optional, Dict
|
4 |
+
import yaml
|
5 |
|
6 |
|
7 |
@dataclass
|
|
|
275 |
language_detection_segments: int
|
276 |
"""
|
277 |
A data class to use Whisper parameters.
|
278 |
+
"""
|
279 |
+
|
280 |
+
def to_yaml(self) -> Dict:
|
281 |
+
data = {
|
282 |
+
"whisper": {
|
283 |
+
"model_size": self.model_size,
|
284 |
+
"lang": "Automatic Detection" if self.lang is None else self.lang,
|
285 |
+
"is_translate": self.is_translate,
|
286 |
+
"beam_size": self.beam_size,
|
287 |
+
"log_prob_threshold": self.log_prob_threshold,
|
288 |
+
"no_speech_threshold": self.no_speech_threshold,
|
289 |
+
"best_of": self.best_of,
|
290 |
+
"patience": self.patience,
|
291 |
+
"condition_on_previous_text": self.condition_on_previous_text,
|
292 |
+
"prompt_reset_on_temperature": self.prompt_reset_on_temperature,
|
293 |
+
"initial_prompt": None if not self.initial_prompt else self.initial_prompt,
|
294 |
+
"temperature": self.temperature,
|
295 |
+
"compression_ratio_threshold": self.compression_ratio_threshold,
|
296 |
+
"chunk_length_s": None if self.chunk_length_s is None else self.chunk_length_s,
|
297 |
+
"batch_size": self.batch_size,
|
298 |
+
"length_penalty": self.length_penalty,
|
299 |
+
"repetition_penalty": self.repetition_penalty,
|
300 |
+
"no_repeat_ngram_size": self.no_repeat_ngram_size,
|
301 |
+
"prefix": None if not self.prefix else self.prefix,
|
302 |
+
"suppress_blank": self.suppress_blank,
|
303 |
+
"suppress_tokens": self.suppress_tokens,
|
304 |
+
"max_initial_timestamp": self.max_initial_timestamp,
|
305 |
+
"word_timestamps": self.word_timestamps,
|
306 |
+
"prepend_punctuations": self.prepend_punctuations,
|
307 |
+
"append_punctuations": self.append_punctuations,
|
308 |
+
"max_new_tokens": self.max_new_tokens,
|
309 |
+
"chunk_length": self.chunk_length,
|
310 |
+
"hallucination_silence_threshold": self.hallucination_silence_threshold,
|
311 |
+
"hotwords": None if not self.hotwords else self.hotwords,
|
312 |
+
"language_detection_threshold": self.language_detection_threshold,
|
313 |
+
"language_detection_segments": self.language_detection_segments,
|
314 |
+
},
|
315 |
+
"vad": {
|
316 |
+
"vad_filter": self.vad_filter,
|
317 |
+
"threshold": self.threshold,
|
318 |
+
"min_speech_duration_ms": self.min_speech_duration_ms,
|
319 |
+
"max_speech_duration_s": self.max_speech_duration_s,
|
320 |
+
"min_silence_duration_ms": self.min_silence_duration_ms,
|
321 |
+
"speech_pad_ms": self.speech_pad_ms,
|
322 |
+
},
|
323 |
+
"diarization": {
|
324 |
+
"is_diarize": self.is_diarize,
|
325 |
+
"hf_token": self.hf_token
|
326 |
+
}
|
327 |
+
}
|
328 |
+
return data
|
requirements.txt
CHANGED
@@ -11,4 +11,5 @@ faster-whisper==1.0.3
|
|
11 |
transformers==4.42.3
|
12 |
gradio==4.29.0
|
13 |
pytubefix
|
|
|
14 |
pyannote.audio==3.3.1
|
|
|
11 |
transformers==4.42.3
|
12 |
gradio==4.29.0
|
13 |
pytubefix
|
14 |
+
ruamel.yaml==0.18.6
|
15 |
pyannote.audio==3.3.1
|