jhj0517 commited on
Commit
8bcf1fb
·
unverified ·
2 Parent(s): 4904f1c d8c2d87

Merge pull request #255 from jhj0517/feature/remember-settings

Browse files
.gitignore CHANGED
@@ -3,8 +3,9 @@
3
  *.mp4
4
  *.mp3
5
  venv/
6
- ui/__pycache__/
7
  outputs/
8
  modules/__pycache__/
9
  models/
10
  modules/yt_tmp.wav
 
 
3
  *.mp4
4
  *.mp3
5
  venv/
6
+ modules/ui/__pycache__/
7
  outputs/
8
  modules/__pycache__/
9
  models/
10
  modules/yt_tmp.wav
11
+ configs/default_parameters.yaml
app.py CHANGED
@@ -1,12 +1,16 @@
1
  import os
2
  import argparse
3
  import gradio as gr
 
4
 
 
 
 
5
  from modules.whisper.whisper_factory import WhisperFactory
6
  from modules.whisper.faster_whisper_inference import FasterWhisperInference
7
  from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
8
  from modules.translation.nllb_inference import NLLBInference
9
- from ui.htmls import *
10
  from modules.utils.youtube_manager import get_ytmetas
11
  from modules.translation.deepl_api import DeepLAPI
12
  from modules.whisper.whisper_parameter import *
@@ -32,103 +36,117 @@ class App:
32
  self.deepl_api = DeepLAPI(
33
  output_dir=os.path.join(self.args.output_dir, "translations")
34
  )
 
35
 
36
  def create_whisper_parameters(self):
 
 
 
 
37
  with gr.Row():
38
- dd_model = gr.Dropdown(choices=self.whisper_inf.available_models, value="large-v2",
39
  label="Model")
40
  dd_lang = gr.Dropdown(choices=["Automatic Detection"] + self.whisper_inf.available_langs,
41
- value="Automatic Detection", label="Language")
42
  dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt"], value="SRT", label="File Format")
43
  with gr.Row():
44
- cb_translate = gr.Checkbox(value=False, label="Translate to English?", interactive=True)
 
45
  with gr.Row():
46
- cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename",
47
  interactive=True)
48
  with gr.Accordion("Advanced Parameters", open=False):
49
- nb_beam_size = gr.Number(label="Beam Size", value=5, precision=0, interactive=True,
50
  info="Beam size to use for decoding.")
51
- nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True,
52
  info="If the average log probability over sampled tokens is below this value, treat as failed.")
53
- nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True,
54
  info="If the no speech probability is higher than this value AND the average log probability over sampled tokens is below 'Log Prob Threshold', consider the segment as silent.")
55
  dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types,
56
  value=self.whisper_inf.current_compute_type, interactive=True,
57
  info="Select the type of computation to perform.")
58
- nb_best_of = gr.Number(label="Best Of", value=5, interactive=True,
59
  info="Number of candidates when sampling with non-zero temperature.")
60
- nb_patience = gr.Number(label="Patience", value=1, interactive=True,
61
  info="Beam search patience factor.")
62
- cb_condition_on_previous_text = gr.Checkbox(label="Condition On Previous Text", value=True,
63
  interactive=True,
64
  info="Condition on previous text during decoding.")
65
- sld_prompt_reset_on_temperature = gr.Slider(label="Prompt Reset On Temperature", value=0.5,
66
  minimum=0, maximum=1, step=0.01, interactive=True,
67
  info="Resets prompt if temperature is above this value."
68
  " Arg has effect only if 'Condition On Previous Text' is True.")
69
  tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True,
70
  info="Initial prompt to use for decoding.")
71
- sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True,
 
72
  info="Temperature for sampling. It can be a tuple of temperatures, which will be successively used upon failures according to either `Compression Ratio Threshold` or `Log Prob Threshold`.")
73
- nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", value=2.4, interactive=True,
 
74
  info="If the gzip compression ratio is above this value, treat as failed.")
75
  with gr.Group(visible=isinstance(self.whisper_inf, FasterWhisperInference)):
76
- nb_length_penalty = gr.Number(label="Length Penalty", value=1,
77
  info="Exponential length penalty constant.")
78
- nb_repetition_penalty = gr.Number(label="Repetition Penalty", value=1,
79
  info="Penalty applied to the score of previously generated tokens (set > 1 to penalize).")
80
- nb_no_repeat_ngram_size = gr.Number(label="No Repeat N-gram Size", value=0, precision=0,
 
81
  info="Prevent repetitions of n-grams with this size (set 0 to disable).")
82
- tb_prefix = gr.Textbox(label="Prefix", value=lambda: None,
83
  info="Optional text to provide as a prefix for the first window.")
84
- cb_suppress_blank = gr.Checkbox(label="Suppress Blank", value=True,
85
  info="Suppress blank outputs at the beginning of the sampling.")
86
- tb_suppress_tokens = gr.Textbox(label="Suppress Tokens", value="[-1]",
87
  info="List of token IDs to suppress. -1 will suppress a default set of symbols as defined in the model config.json file.")
88
- nb_max_initial_timestamp = gr.Number(label="Max Initial Timestamp", value=1.0,
89
  info="The initial timestamp cannot be later than this.")
90
- cb_word_timestamps = gr.Checkbox(label="Word Timestamps", value=False,
91
  info="Extract word-level timestamps using the cross-attention pattern and dynamic time warping, and include the timestamps for each word in each segment.")
92
- tb_prepend_punctuations = gr.Textbox(label="Prepend Punctuations", value="\"'“¿([{-",
93
  info="If 'Word Timestamps' is True, merge these punctuation symbols with the next word.")
94
- tb_append_punctuations = gr.Textbox(label="Append Punctuations", value="\"'.。,,!!??::”)]}、",
95
  info="If 'Word Timestamps' is True, merge these punctuation symbols with the previous word.")
96
- nb_max_new_tokens = gr.Number(label="Max New Tokens", value=lambda: None, precision=0,
 
97
  info="Maximum number of new tokens to generate per-chunk. If not set, the maximum will be set by the default max_length.")
98
- nb_chunk_length = gr.Number(label="Chunk Length", value=lambda: None, precision=0,
 
99
  info="The length of audio segments. If it is not None, it will overwrite the default chunk_length of the FeatureExtractor.")
100
  nb_hallucination_silence_threshold = gr.Number(label="Hallucination Silence Threshold (sec)",
101
- value=lambda: None,
102
  info="When 'Word Timestamps' is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected.")
103
- tb_hotwords = gr.Textbox(label="Hotwords", value=None,
104
  info="Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.")
105
- nb_language_detection_threshold = gr.Number(label="Language Detection Threshold", value=None,
106
  info="If the maximum probability of the language tokens is higher than this value, the language is detected.")
107
- nb_language_detection_segments = gr.Number(label="Language Detection Segments", value=1, precision=0,
 
108
  info="Number of segments to consider for the language detection.")
109
  with gr.Group(visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
110
- nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
111
- nb_batch_size = gr.Number(label="Batch Size", value=24, precision=0)
 
112
 
113
  with gr.Accordion("VAD", open=False):
114
- cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
115
- sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5,
 
116
  info="Lower it to be more sensitive to small sounds.")
117
- nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250,
118
  info="Final speech chunks shorter than this time are thrown out")
119
- nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", value=9999,
120
  info="Maximum duration of speech chunks in \"seconds\". Chunks longer"
121
  " than this time will be split at the timestamp of the last silence that"
122
  " lasts more than 100ms (if any), to prevent aggressive cutting.")
123
- nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000,
124
  info="In the end of each speech chunk wait for this time"
125
  " before separating it")
126
- nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400,
127
  info="Final speech chunks are padded by this time each side")
128
 
129
  with gr.Accordion("Diarization", open=False):
130
- cb_diarize = gr.Checkbox(label="Enable Diarization")
131
- tb_hf_token = gr.Text(label="HuggingFace Token", value="",
132
  info="This is only needed the first time you download the model. If you already have models, you don't need to enter. To download the model, you must manually go to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and agree to their requirement.")
133
  dd_diarization_device = gr.Dropdown(label="Device",
134
  choices=self.whisper_inf.diarizer.get_available_device(),
@@ -162,6 +180,10 @@ class App:
162
  )
163
 
164
  def launch(self):
 
 
 
 
165
  with self.app:
166
  with gr.Row():
167
  with gr.Column():
@@ -246,19 +268,17 @@ class App:
246
 
247
  with gr.TabItem("DeepL API"): # sub tab1
248
  with gr.Row():
249
- tb_authkey = gr.Textbox(label="Your Auth Key (API KEY)",
250
- value="")
251
  with gr.Row():
252
- dd_deepl_sourcelang = gr.Dropdown(label="Source Language", value="Automatic Detection",
253
- choices=list(
254
  self.deepl_api.available_source_langs.keys()))
255
- dd_deepl_targetlang = gr.Dropdown(label="Target Language", value="English",
256
- choices=list(
257
- self.deepl_api.available_target_langs.keys()))
258
  with gr.Row():
259
- cb_deepl_ispro = gr.Checkbox(label="Pro User?", value=False)
260
  with gr.Row():
261
- cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename",
262
  interactive=True)
263
  with gr.Row():
264
  btn_run = gr.Button("TRANSLATE SUBTITLE FILE", variant="primary")
@@ -268,26 +288,27 @@ class App:
268
  btn_openfolder = gr.Button('📂', scale=1)
269
 
270
  btn_run.click(fn=self.deepl_api.translate_deepl,
271
- inputs=[tb_authkey, file_subs, dd_deepl_sourcelang, dd_deepl_targetlang,
272
- cb_deepl_ispro, cb_timestamp],
273
  outputs=[tb_indicator, files_subtitles])
274
 
275
- btn_openfolder.click(fn=lambda: self.open_folder(os.path.join("outputs", "translations")),
276
  inputs=None,
277
  outputs=None)
278
 
279
  with gr.TabItem("NLLB"): # sub tab2
280
  with gr.Row():
281
- dd_nllb_model = gr.Dropdown(label="Model", value="facebook/nllb-200-1.3B",
282
  choices=self.nllb_inf.available_models)
283
- dd_nllb_sourcelang = gr.Dropdown(label="Source Language",
284
- choices=self.nllb_inf.available_source_langs)
285
- dd_nllb_targetlang = gr.Dropdown(label="Target Language",
286
- choices=self.nllb_inf.available_target_langs)
287
  with gr.Row():
288
- nb_max_length = gr.Number(label="Max Length Per Line", value=200, precision=0)
 
289
  with gr.Row():
290
- cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename",
291
  interactive=True)
292
  with gr.Row():
293
  btn_run = gr.Button("TRANSLATE SUBTITLE FILE", variant="primary")
@@ -299,11 +320,11 @@ class App:
299
  md_vram_table = gr.HTML(NLLB_VRAM_TABLE, elem_id="md_nllb_vram_table")
300
 
301
  btn_run.click(fn=self.nllb_inf.translate_file,
302
- inputs=[file_subs, dd_nllb_model, dd_nllb_sourcelang, dd_nllb_targetlang,
303
  nb_max_length, cb_timestamp],
304
  outputs=[tb_indicator, files_subtitles])
305
 
306
- btn_openfolder.click(fn=lambda: self.open_folder(os.path.join("outputs", "translations")),
307
  inputs=None,
308
  outputs=None)
309
 
@@ -351,18 +372,18 @@ parser.add_argument('--theme', type=str, default=None, help='Gradio Blocks theme
351
  parser.add_argument('--colab', type=bool, default=False, nargs='?', const=True, help='Is colab user or not')
352
  parser.add_argument('--api_open', type=bool, default=False, nargs='?', const=True, help='Enable api or not in Gradio')
353
  parser.add_argument('--inbrowser', type=bool, default=True, nargs='?', const=True, help='Whether to automatically start Gradio app or not')
354
- parser.add_argument('--whisper_model_dir', type=str, default=os.path.join("models", "Whisper"),
355
  help='Directory path of the whisper model')
356
- parser.add_argument('--faster_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "faster-whisper"),
357
  help='Directory path of the faster-whisper model')
358
  parser.add_argument('--insanely_fast_whisper_model_dir', type=str,
359
- default=os.path.join("models", "Whisper", "insanely-fast-whisper"),
360
  help='Directory path of the insanely-fast-whisper model')
361
- parser.add_argument('--diarization_model_dir', type=str, default=os.path.join("models", "Diarization"),
362
  help='Directory path of the diarization model')
363
- parser.add_argument('--nllb_model_dir', type=str, default=os.path.join("models", "NLLB"),
364
  help='Directory path of the Facebook NLLB model')
365
- parser.add_argument('--output_dir', type=str, default=os.path.join("outputs"), help='Directory path of the outputs')
366
  _args = parser.parse_args()
367
 
368
  if __name__ == "__main__":
 
1
  import os
2
  import argparse
3
  import gradio as gr
4
+ import yaml
5
 
6
+ from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, WHISPER_MODELS_DIR,
7
+ INSANELY_FAST_WHISPER_MODELS_DIR, NLLB_MODELS_DIR, DEFAULT_PARAMETERS_CONFIG_PATH)
8
+ from modules.utils.files_manager import load_yaml
9
  from modules.whisper.whisper_factory import WhisperFactory
10
  from modules.whisper.faster_whisper_inference import FasterWhisperInference
11
  from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
12
  from modules.translation.nllb_inference import NLLBInference
13
+ from modules.ui.htmls import *
14
  from modules.utils.youtube_manager import get_ytmetas
15
  from modules.translation.deepl_api import DeepLAPI
16
  from modules.whisper.whisper_parameter import *
 
36
  self.deepl_api = DeepLAPI(
37
  output_dir=os.path.join(self.args.output_dir, "translations")
38
  )
39
+ self.default_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
40
 
41
  def create_whisper_parameters(self):
42
+ whisper_params = self.default_params["whisper"]
43
+ vad_params = self.default_params["vad"]
44
+ diarization_params = self.default_params["diarization"]
45
+
46
  with gr.Row():
47
+ dd_model = gr.Dropdown(choices=self.whisper_inf.available_models, value=whisper_params["model_size"],
48
  label="Model")
49
  dd_lang = gr.Dropdown(choices=["Automatic Detection"] + self.whisper_inf.available_langs,
50
+ value=whisper_params["lang"], label="Language")
51
  dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt"], value="SRT", label="File Format")
52
  with gr.Row():
53
+ cb_translate = gr.Checkbox(value=whisper_params["is_translate"], label="Translate to English?",
54
+ interactive=True)
55
  with gr.Row():
56
+ cb_timestamp = gr.Checkbox(value=whisper_params["add_timestamp"], label="Add a timestamp to the end of the filename",
57
  interactive=True)
58
  with gr.Accordion("Advanced Parameters", open=False):
59
+ nb_beam_size = gr.Number(label="Beam Size", value=whisper_params["beam_size"], precision=0, interactive=True,
60
  info="Beam size to use for decoding.")
61
+ nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=whisper_params["log_prob_threshold"], interactive=True,
62
  info="If the average log probability over sampled tokens is below this value, treat as failed.")
63
+ nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=whisper_params["no_speech_threshold"], interactive=True,
64
  info="If the no speech probability is higher than this value AND the average log probability over sampled tokens is below 'Log Prob Threshold', consider the segment as silent.")
65
  dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types,
66
  value=self.whisper_inf.current_compute_type, interactive=True,
67
  info="Select the type of computation to perform.")
68
+ nb_best_of = gr.Number(label="Best Of", value=whisper_params["best_of"], interactive=True,
69
  info="Number of candidates when sampling with non-zero temperature.")
70
+ nb_patience = gr.Number(label="Patience", value=whisper_params["patience"], interactive=True,
71
  info="Beam search patience factor.")
72
+ cb_condition_on_previous_text = gr.Checkbox(label="Condition On Previous Text", value=whisper_params["condition_on_previous_text"],
73
  interactive=True,
74
  info="Condition on previous text during decoding.")
75
+ sld_prompt_reset_on_temperature = gr.Slider(label="Prompt Reset On Temperature", value=whisper_params["prompt_reset_on_temperature"],
76
  minimum=0, maximum=1, step=0.01, interactive=True,
77
  info="Resets prompt if temperature is above this value."
78
  " Arg has effect only if 'Condition On Previous Text' is True.")
79
  tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True,
80
  info="Initial prompt to use for decoding.")
81
+ sd_temperature = gr.Slider(label="Temperature", value=whisper_params["temperature"], minimum=0.0,
82
+ step=0.01, maximum=1.0, interactive=True,
83
  info="Temperature for sampling. It can be a tuple of temperatures, which will be successively used upon failures according to either `Compression Ratio Threshold` or `Log Prob Threshold`.")
84
+ nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", value=whisper_params["compression_ratio_threshold"],
85
+ interactive=True,
86
  info="If the gzip compression ratio is above this value, treat as failed.")
87
  with gr.Group(visible=isinstance(self.whisper_inf, FasterWhisperInference)):
88
+ nb_length_penalty = gr.Number(label="Length Penalty", value=whisper_params["length_penalty"],
89
  info="Exponential length penalty constant.")
90
+ nb_repetition_penalty = gr.Number(label="Repetition Penalty", value=whisper_params["repetition_penalty"],
91
  info="Penalty applied to the score of previously generated tokens (set > 1 to penalize).")
92
+ nb_no_repeat_ngram_size = gr.Number(label="No Repeat N-gram Size", value=whisper_params["no_repeat_ngram_size"],
93
+ precision=0,
94
  info="Prevent repetitions of n-grams with this size (set 0 to disable).")
95
+ tb_prefix = gr.Textbox(label="Prefix", value=lambda: whisper_params["prefix"],
96
  info="Optional text to provide as a prefix for the first window.")
97
+ cb_suppress_blank = gr.Checkbox(label="Suppress Blank", value=whisper_params["suppress_blank"],
98
  info="Suppress blank outputs at the beginning of the sampling.")
99
+ tb_suppress_tokens = gr.Textbox(label="Suppress Tokens", value=whisper_params["suppress_tokens"],
100
  info="List of token IDs to suppress. -1 will suppress a default set of symbols as defined in the model config.json file.")
101
+ nb_max_initial_timestamp = gr.Number(label="Max Initial Timestamp", value=whisper_params["max_initial_timestamp"],
102
  info="The initial timestamp cannot be later than this.")
103
+ cb_word_timestamps = gr.Checkbox(label="Word Timestamps", value=whisper_params["word_timestamps"],
104
  info="Extract word-level timestamps using the cross-attention pattern and dynamic time warping, and include the timestamps for each word in each segment.")
105
+ tb_prepend_punctuations = gr.Textbox(label="Prepend Punctuations", value=whisper_params["prepend_punctuations"],
106
  info="If 'Word Timestamps' is True, merge these punctuation symbols with the next word.")
107
+ tb_append_punctuations = gr.Textbox(label="Append Punctuations", value=whisper_params["append_punctuations"],
108
  info="If 'Word Timestamps' is True, merge these punctuation symbols with the previous word.")
109
+ nb_max_new_tokens = gr.Number(label="Max New Tokens", value=lambda: whisper_params["max_new_tokens"],
110
+ precision=0,
111
  info="Maximum number of new tokens to generate per-chunk. If not set, the maximum will be set by the default max_length.")
112
+ nb_chunk_length = gr.Number(label="Chunk Length", value=lambda: whisper_params["chunk_length"],
113
+ precision=0,
114
  info="The length of audio segments. If it is not None, it will overwrite the default chunk_length of the FeatureExtractor.")
115
  nb_hallucination_silence_threshold = gr.Number(label="Hallucination Silence Threshold (sec)",
116
+ value=lambda: whisper_params["hallucination_silence_threshold"],
117
  info="When 'Word Timestamps' is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected.")
118
+ tb_hotwords = gr.Textbox(label="Hotwords", value=lambda: whisper_params["hotwords"],
119
  info="Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.")
120
+ nb_language_detection_threshold = gr.Number(label="Language Detection Threshold", value=lambda: whisper_params["language_detection_threshold"],
121
  info="If the maximum probability of the language tokens is higher than this value, the language is detected.")
122
+ nb_language_detection_segments = gr.Number(label="Language Detection Segments", value=lambda: whisper_params["language_detection_segments"],
123
+ precision=0,
124
  info="Number of segments to consider for the language detection.")
125
  with gr.Group(visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
126
+ nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=whisper_params["chunk_length_s"],
127
+ precision=0)
128
+ nb_batch_size = gr.Number(label="Batch Size", value=whisper_params["batch_size"], precision=0)
129
 
130
  with gr.Accordion("VAD", open=False):
131
+ cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=vad_params["vad_filter"],
132
+ interactive=True)
133
+ sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=vad_params["threshold"],
134
  info="Lower it to be more sensitive to small sounds.")
135
+ nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=vad_params["min_speech_duration_ms"],
136
  info="Final speech chunks shorter than this time are thrown out")
137
+ nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", value=vad_params["max_speech_duration_s"],
138
  info="Maximum duration of speech chunks in \"seconds\". Chunks longer"
139
  " than this time will be split at the timestamp of the last silence that"
140
  " lasts more than 100ms (if any), to prevent aggressive cutting.")
141
+ nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=vad_params["min_silence_duration_ms"],
142
  info="In the end of each speech chunk wait for this time"
143
  " before separating it")
144
+ nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=vad_params["speech_pad_ms"],
145
  info="Final speech chunks are padded by this time each side")
146
 
147
  with gr.Accordion("Diarization", open=False):
148
+ cb_diarize = gr.Checkbox(label="Enable Diarization", value=diarization_params["is_diarize"])
149
+ tb_hf_token = gr.Text(label="HuggingFace Token", value=diarization_params["hf_token"],
150
  info="This is only needed the first time you download the model. If you already have models, you don't need to enter. To download the model, you must manually go to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and agree to their requirement.")
151
  dd_diarization_device = gr.Dropdown(label="Device",
152
  choices=self.whisper_inf.diarizer.get_available_device(),
 
180
  )
181
 
182
  def launch(self):
183
+ translation_params = self.default_params["translation"]
184
+ deepl_params = translation_params["deepl"]
185
+ nllb_params = translation_params["nllb"]
186
+
187
  with self.app:
188
  with gr.Row():
189
  with gr.Column():
 
268
 
269
  with gr.TabItem("DeepL API"): # sub tab1
270
  with gr.Row():
271
+ tb_api_key = gr.Textbox(label="Your Auth Key (API KEY)", value=deepl_params["api_key"])
 
272
  with gr.Row():
273
+ dd_source_lang = gr.Dropdown(label="Source Language", value=deepl_params["source_lang"],
274
+ choices=list(
275
  self.deepl_api.available_source_langs.keys()))
276
+ dd_target_lang = gr.Dropdown(label="Target Language", value=deepl_params["target_lang"],
277
+ choices=list(self.deepl_api.available_target_langs.keys()))
 
278
  with gr.Row():
279
+ cb_is_pro = gr.Checkbox(label="Pro User?", value=deepl_params["is_pro"])
280
  with gr.Row():
281
+ cb_timestamp = gr.Checkbox(value=translation_params["add_timestamp"], label="Add a timestamp to the end of the filename",
282
  interactive=True)
283
  with gr.Row():
284
  btn_run = gr.Button("TRANSLATE SUBTITLE FILE", variant="primary")
 
288
  btn_openfolder = gr.Button('📂', scale=1)
289
 
290
  btn_run.click(fn=self.deepl_api.translate_deepl,
291
+ inputs=[tb_api_key, file_subs, dd_source_lang, dd_target_lang,
292
+ cb_is_pro, cb_timestamp],
293
  outputs=[tb_indicator, files_subtitles])
294
 
295
+ btn_openfolder.click(fn=lambda: self.open_folder(os.path.join(self.args.output_dir, "translations")),
296
  inputs=None,
297
  outputs=None)
298
 
299
  with gr.TabItem("NLLB"): # sub tab2
300
  with gr.Row():
301
+ dd_model_size = gr.Dropdown(label="Model", value=nllb_params["model_size"],
302
  choices=self.nllb_inf.available_models)
303
+ dd_source_lang = gr.Dropdown(label="Source Language", value=nllb_params["source_lang"],
304
+ choices=self.nllb_inf.available_source_langs)
305
+ dd_target_lang = gr.Dropdown(label="Target Language", value=nllb_params["target_lang"],
306
+ choices=self.nllb_inf.available_target_langs)
307
  with gr.Row():
308
+ nb_max_length = gr.Number(label="Max Length Per Line", value=nllb_params["max_length"],
309
+ precision=0)
310
  with gr.Row():
311
+ cb_timestamp = gr.Checkbox(value=translation_params["add_timestamp"], label="Add a timestamp to the end of the filename",
312
  interactive=True)
313
  with gr.Row():
314
  btn_run = gr.Button("TRANSLATE SUBTITLE FILE", variant="primary")
 
320
  md_vram_table = gr.HTML(NLLB_VRAM_TABLE, elem_id="md_nllb_vram_table")
321
 
322
  btn_run.click(fn=self.nllb_inf.translate_file,
323
+ inputs=[file_subs, dd_model_size, dd_source_lang, dd_target_lang,
324
  nb_max_length, cb_timestamp],
325
  outputs=[tb_indicator, files_subtitles])
326
 
327
+ btn_openfolder.click(fn=lambda: self.open_folder(os.path.join(self.args.output_dir, "translations")),
328
  inputs=None,
329
  outputs=None)
330
 
 
372
  parser.add_argument('--colab', type=bool, default=False, nargs='?', const=True, help='Is colab user or not')
373
  parser.add_argument('--api_open', type=bool, default=False, nargs='?', const=True, help='Enable api or not in Gradio')
374
  parser.add_argument('--inbrowser', type=bool, default=True, nargs='?', const=True, help='Whether to automatically start Gradio app or not')
375
+ parser.add_argument('--whisper_model_dir', type=str, default=WHISPER_MODELS_DIR,
376
  help='Directory path of the whisper model')
377
+ parser.add_argument('--faster_whisper_model_dir', type=str, default=FASTER_WHISPER_MODELS_DIR,
378
  help='Directory path of the faster-whisper model')
379
  parser.add_argument('--insanely_fast_whisper_model_dir', type=str,
380
+ default=INSANELY_FAST_WHISPER_MODELS_DIR,
381
  help='Directory path of the insanely-fast-whisper model')
382
+ parser.add_argument('--diarization_model_dir', type=str, default=DIARIZATION_MODELS_DIR,
383
  help='Directory path of the diarization model')
384
+ parser.add_argument('--nllb_model_dir', type=str, default=NLLB_MODELS_DIR,
385
  help='Directory path of the Facebook NLLB model')
386
+ parser.add_argument('--output_dir', type=str, default=OUTPUT_DIR, help='Directory path of the outputs')
387
  _args = parser.parse_args()
388
 
389
  if __name__ == "__main__":
configs/default_parameters.yaml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ whisper:
2
+ model_size: "large-v2"
3
+ lang: "Automatic Detection"
4
+ is_translate: false
5
+ beam_size: 5
6
+ log_prob_threshold: -1
7
+ no_speech_threshold: 0.6
8
+ best_of: 5
9
+ patience: 1
10
+ condition_on_previous_text: true
11
+ prompt_reset_on_temperature: 0.5
12
+ initial_prompt: null
13
+ temperature: 0
14
+ compression_ratio_threshold: 2.4
15
+ chunk_length_s: 30
16
+ batch_size: 24
17
+ length_penalty: 1
18
+ repetition_penalty: 1
19
+ no_repeat_ngram_size: 0
20
+ prefix: null
21
+ suppress_blank: true
22
+ suppress_tokens: "[-1]"
23
+ max_initial_timestamp: 1
24
+ word_timestamps: false
25
+ prepend_punctuations: "\"'“¿([{-"
26
+ append_punctuations: "\"'.。,,!!??::”)]}、"
27
+ max_new_tokens: null
28
+ chunk_length: null
29
+ hallucination_silence_threshold: null
30
+ hotwords: null
31
+ language_detection_threshold: null
32
+ language_detection_segments: 1
33
+ add_timestamp: true
34
+
35
+ vad:
36
+ vad_filter: false
37
+ threshold: 0.5
38
+ min_speech_duration_ms: 250
39
+ max_speech_duration_s: 9999
40
+ min_silence_duration_ms: 2000
41
+ speech_pad_ms: 400
42
+
43
+ diarization:
44
+ is_diarize: false
45
+ hf_token: ""
46
+
47
+ translation:
48
+ deepl:
49
+ api_key: ""
50
+ is_pro: false
51
+ source_lang: "Automatic Detection"
52
+ target_lang: "English"
53
+ nllb:
54
+ model_size: "facebook/nllb-200-1.3B"
55
+ source_lang: null
56
+ target_lang: null
57
+ max_length: 200
58
+ add_timestamp: true
modules/diarize/diarize_pipeline.py CHANGED
@@ -7,6 +7,7 @@ from pyannote.audio import Pipeline
7
  from typing import Optional, Union
8
  import torch
9
 
 
10
  from modules.diarize.audio_loader import load_audio, SAMPLE_RATE
11
 
12
 
@@ -14,7 +15,7 @@ class DiarizationPipeline:
14
  def __init__(
15
  self,
16
  model_name="pyannote/speaker-diarization-3.1",
17
- cache_dir: str = os.path.join("models", "Diarization"),
18
  use_auth_token=None,
19
  device: Optional[Union[str, torch.device]] = "cpu",
20
  ):
 
7
  from typing import Optional, Union
8
  import torch
9
 
10
+ from modules.utils.paths import DIARIZATION_MODELS_DIR
11
  from modules.diarize.audio_loader import load_audio, SAMPLE_RATE
12
 
13
 
 
15
  def __init__(
16
  self,
17
  model_name="pyannote/speaker-diarization-3.1",
18
+ cache_dir: str = DIARIZATION_MODELS_DIR,
19
  use_auth_token=None,
20
  device: Optional[Union[str, torch.device]] = "cpu",
21
  ):
modules/diarize/diarizer.py CHANGED
@@ -5,13 +5,14 @@ import numpy as np
5
  import time
6
  import logging
7
 
 
8
  from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers
9
  from modules.diarize.audio_loader import load_audio
10
 
11
 
12
  class Diarizer:
13
  def __init__(self,
14
- model_dir: str = os.path.join("models", "Diarization")
15
  ):
16
  self.device = self.get_device()
17
  self.available_device = self.get_available_device()
 
5
  import time
6
  import logging
7
 
8
+ from modules.utils.paths import DIARIZATION_MODELS_DIR
9
  from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers
10
  from modules.diarize.audio_loader import load_audio
11
 
12
 
13
  class Diarizer:
14
  def __init__(self,
15
+ model_dir: str = DIARIZATION_MODELS_DIR
16
  ):
17
  self.device = self.get_device()
18
  self.available_device = self.get_available_device()
modules/translation/deepl_api.py CHANGED
@@ -4,7 +4,9 @@ import os
4
  from datetime import datetime
5
  import gradio as gr
6
 
 
7
  from modules.utils.subtitle_manager import *
 
8
 
9
  """
10
  This is written with reference to the DeepL API documentation.
@@ -83,7 +85,7 @@ DEEPL_AVAILABLE_SOURCE_LANGS = {
83
 
84
  class DeepLAPI:
85
  def __init__(self,
86
- output_dir: str = os.path.join("outputs", "translations")
87
  ):
88
  self.api_interval = 1
89
  self.max_text_batch_size = 50
@@ -124,6 +126,13 @@ class DeepLAPI:
124
  String to return to gr.Textbox()
125
  Files to return to gr.Files()
126
  """
 
 
 
 
 
 
 
127
 
128
  files_info = {}
129
  for fileobj in fileobjs:
@@ -198,4 +207,20 @@ class DeepLAPI:
198
  }
199
  response = requests.post(url, headers=headers, data=data).json()
200
  time.sleep(self.api_interval)
201
- return response["translations"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from datetime import datetime
5
  import gradio as gr
6
 
7
+ from modules.utils.paths import TRANSLATION_OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH
8
  from modules.utils.subtitle_manager import *
9
+ from modules.utils.files_manager import load_yaml, save_yaml
10
 
11
  """
12
  This is written with reference to the DeepL API documentation.
 
85
 
86
  class DeepLAPI:
87
  def __init__(self,
88
+ output_dir: str = TRANSLATION_OUTPUT_DIR
89
  ):
90
  self.api_interval = 1
91
  self.max_text_batch_size = 50
 
126
  String to return to gr.Textbox()
127
  Files to return to gr.Files()
128
  """
129
+ self.cache_parameters(
130
+ api_key=auth_key,
131
+ is_pro=is_pro,
132
+ source_lang=source_lang,
133
+ target_lang=target_lang,
134
+ add_timestamp=add_timestamp
135
+ )
136
 
137
  files_info = {}
138
  for fileobj in fileobjs:
 
207
  }
208
  response = requests.post(url, headers=headers, data=data).json()
209
  time.sleep(self.api_interval)
210
+ return response["translations"]
211
+
212
+ @staticmethod
213
+ def cache_parameters(api_key: str,
214
+ is_pro: bool,
215
+ source_lang: str,
216
+ target_lang: str,
217
+ add_timestamp: bool):
218
+ cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
219
+ cached_params["translation"]["deepl"] = {
220
+ "api_key": api_key,
221
+ "is_pro": is_pro,
222
+ "source_lang": source_lang,
223
+ "target_lang": target_lang
224
+ }
225
+ cached_params["translation"]["add_timestamp"] = add_timestamp
226
+ save_yaml(cached_params, DEFAULT_PARAMETERS_CONFIG_PATH)
modules/translation/nllb_inference.py CHANGED
@@ -2,13 +2,14 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
2
  import gradio as gr
3
  import os
4
 
 
5
  from modules.translation.translation_base import TranslationBase
6
 
7
 
8
  class NLLBInference(TranslationBase):
9
  def __init__(self,
10
- model_dir: str = os.path.join("models", "NLLB"),
11
- output_dir: str = os.path.join("outputs", "translations")
12
  ):
13
  super().__init__(
14
  model_dir=model_dir,
 
2
  import gradio as gr
3
  import os
4
 
5
+ from modules.utils.paths import TRANSLATION_OUTPUT_DIR, NLLB_MODELS_DIR
6
  from modules.translation.translation_base import TranslationBase
7
 
8
 
9
  class NLLBInference(TranslationBase):
10
  def __init__(self,
11
+ model_dir: str = NLLB_MODELS_DIR,
12
+ output_dir: str = TRANSLATION_OUTPUT_DIR
13
  ):
14
  super().__init__(
15
  model_dir=model_dir,
modules/translation/translation_base.py CHANGED
@@ -7,12 +7,14 @@ from datetime import datetime
7
 
8
  from modules.whisper.whisper_parameter import *
9
  from modules.utils.subtitle_manager import *
 
 
10
 
11
 
12
  class TranslationBase(ABC):
13
  def __init__(self,
14
- model_dir: str = os.path.join("models", "NLLB"),
15
- output_dir: str = os.path.join("outputs", "translations")
16
  ):
17
  super().__init__()
18
  self.model = None
@@ -75,6 +77,12 @@ class TranslationBase(ABC):
75
  Files to return to gr.Files()
76
  """
77
  try:
 
 
 
 
 
 
78
  self.update_model(model_size=model_size,
79
  src_lang=src_lang,
80
  tgt_lang=tgt_lang,
@@ -149,3 +157,19 @@ class TranslationBase(ABC):
149
  for file_path in file_paths:
150
  if file_path and os.path.exists(file_path):
151
  os.remove(file_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  from modules.whisper.whisper_parameter import *
9
  from modules.utils.subtitle_manager import *
10
+ from modules.utils.files_manager import load_yaml, save_yaml
11
+ from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH, NLLB_MODELS_DIR, TRANSLATION_OUTPUT_DIR
12
 
13
 
14
  class TranslationBase(ABC):
15
  def __init__(self,
16
+ model_dir: str = NLLB_MODELS_DIR,
17
+ output_dir: str = TRANSLATION_OUTPUT_DIR
18
  ):
19
  super().__init__()
20
  self.model = None
 
77
  Files to return to gr.Files()
78
  """
79
  try:
80
+ self.cache_parameters(model_size=model_size,
81
+ src_lang=src_lang,
82
+ tgt_lang=tgt_lang,
83
+ max_length=max_length,
84
+ add_timestamp=add_timestamp)
85
+
86
  self.update_model(model_size=model_size,
87
  src_lang=src_lang,
88
  tgt_lang=tgt_lang,
 
157
  for file_path in file_paths:
158
  if file_path and os.path.exists(file_path):
159
  os.remove(file_path)
160
+
161
+ @staticmethod
162
+ def cache_parameters(model_size: str,
163
+ src_lang: str,
164
+ tgt_lang: str,
165
+ max_length: int,
166
+ add_timestamp: bool):
167
+ cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
168
+ cached_params["translation"]["nllb"] = {
169
+ "model_size": model_size,
170
+ "source_lang": src_lang,
171
+ "target_lang": tgt_lang,
172
+ "max_length": max_length,
173
+ }
174
+ cached_params["translation"]["add_timestamp"] = add_timestamp
175
+ save_yaml(cached_params, DEFAULT_PARAMETERS_CONFIG_PATH)
{ui → modules/ui}/__init__.py RENAMED
File without changes
{ui → modules/ui}/htmls.py RENAMED
File without changes
modules/utils/files_manager.py CHANGED
@@ -1,8 +1,32 @@
1
  import os
2
  import fnmatch
3
-
4
  from gradio.utils import NamedString
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  def get_media_files(folder_path, include_sub_directory=False):
8
  video_extensions = ['*.mp4', '*.mkv', '*.flv', '*.avi', '*.mov', '*.wmv']
 
1
  import os
2
  import fnmatch
3
+ from ruamel.yaml import YAML
4
  from gradio.utils import NamedString
5
 
6
+ from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH
7
+
8
+
9
+ def load_yaml(path: str = DEFAULT_PARAMETERS_CONFIG_PATH):
10
+ yaml = YAML(typ="safe")
11
+ yaml.preserve_quotes = True
12
+ with open(path, 'r', encoding='utf-8') as file:
13
+ config = yaml.load(file)
14
+ return config
15
+
16
+
17
+ def save_yaml(data: dict, path: str = DEFAULT_PARAMETERS_CONFIG_PATH):
18
+ yaml = YAML(typ="safe")
19
+ yaml.map_indent = 2
20
+ yaml.sequence_indent = 4
21
+ yaml.sequence_dash_offset = 2
22
+ yaml.preserve_quotes = True
23
+ yaml.default_flow_style = False
24
+ yaml.sort_base_mapping_type_on_output = False
25
+
26
+ with open(path, 'w', encoding='utf-8') as file:
27
+ yaml.dump(data, file)
28
+ return path
29
+
30
 
31
  def get_media_files(folder_path, include_sub_directory=False):
32
  video_extensions = ['*.mp4', '*.mkv', '*.flv', '*.avi', '*.mov', '*.wmv']
modules/utils/paths.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ WEBUI_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
4
+ MODELS_DIR = os.path.join(WEBUI_DIR, "models")
5
+ WHISPER_MODELS_DIR = os.path.join(MODELS_DIR, "Whisper")
6
+ FASTER_WHISPER_MODELS_DIR = os.path.join(WHISPER_MODELS_DIR, "faster-whisper")
7
+ INSANELY_FAST_WHISPER_MODELS_DIR = os.path.join(WHISPER_MODELS_DIR, "insanely-fast-whisper")
8
+ NLLB_MODELS_DIR = os.path.join(MODELS_DIR, "NLLB")
9
+ DIARIZATION_MODELS_DIR = os.path.join(MODELS_DIR, "Diarization")
10
+ CONFIGS_DIR = os.path.join(WEBUI_DIR, "configs")
11
+ DEFAULT_PARAMETERS_CONFIG_PATH = os.path.join(CONFIGS_DIR, "default_parameters.yaml")
12
+ OUTPUT_DIR = os.path.join(WEBUI_DIR, "outputs")
13
+ TRANSLATION_OUTPUT_DIR = os.path.join(OUTPUT_DIR, "translations")
14
+
15
+ for dir_path in [MODELS_DIR,
16
+ WHISPER_MODELS_DIR,
17
+ FASTER_WHISPER_MODELS_DIR,
18
+ INSANELY_FAST_WHISPER_MODELS_DIR,
19
+ NLLB_MODELS_DIR,
20
+ DIARIZATION_MODELS_DIR,
21
+ CONFIGS_DIR,
22
+ OUTPUT_DIR,
23
+ TRANSLATION_OUTPUT_DIR]:
24
+ os.makedirs(dir_path, exist_ok=True)
modules/whisper/faster_whisper_inference.py CHANGED
@@ -11,15 +11,16 @@ import whisper
11
  import gradio as gr
12
  from argparse import Namespace
13
 
 
14
  from modules.whisper.whisper_parameter import *
15
  from modules.whisper.whisper_base import WhisperBase
16
 
17
 
18
  class FasterWhisperInference(WhisperBase):
19
  def __init__(self,
20
- model_dir: str = os.path.join("models", "Whisper", "faster-whisper"),
21
- diarization_model_dir: str = os.path.join("models", "Diarization"),
22
- output_dir: str = os.path.join("outputs"),
23
  ):
24
  super().__init__(
25
  model_dir=model_dir,
@@ -163,14 +164,12 @@ class FasterWhisperInference(WhisperBase):
163
  wrong_dirs = [".locks"]
164
  existing_models = list(set(existing_models) - set(wrong_dirs))
165
 
166
- webui_dir = os.getcwd()
167
-
168
  for model_name in existing_models:
169
  if faster_whisper_prefix in model_name:
170
  model_name = model_name[len(faster_whisper_prefix):]
171
 
172
  if model_name not in whisper.available_models():
173
- model_paths[model_name] = os.path.join(webui_dir, self.model_dir, model_name)
174
  return model_paths
175
 
176
  @staticmethod
 
11
  import gradio as gr
12
  from argparse import Namespace
13
 
14
+ from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR)
15
  from modules.whisper.whisper_parameter import *
16
  from modules.whisper.whisper_base import WhisperBase
17
 
18
 
19
  class FasterWhisperInference(WhisperBase):
20
  def __init__(self,
21
+ model_dir: str = FASTER_WHISPER_MODELS_DIR,
22
+ diarization_model_dir: str = DIARIZATION_MODELS_DIR,
23
+ output_dir: str = OUTPUT_DIR,
24
  ):
25
  super().__init__(
26
  model_dir=model_dir,
 
164
  wrong_dirs = [".locks"]
165
  existing_models = list(set(existing_models) - set(wrong_dirs))
166
 
 
 
167
  for model_name in existing_models:
168
  if faster_whisper_prefix in model_name:
169
  model_name = model_name[len(faster_whisper_prefix):]
170
 
171
  if model_name not in whisper.available_models():
172
+ model_paths[model_name] = os.path.join(self.model_dir, model_name)
173
  return model_paths
174
 
175
  @staticmethod
modules/whisper/insanely_fast_whisper_inference.py CHANGED
@@ -11,15 +11,16 @@ import whisper
11
  from rich.progress import Progress, TimeElapsedColumn, BarColumn, TextColumn
12
  from argparse import Namespace
13
 
 
14
  from modules.whisper.whisper_parameter import *
15
  from modules.whisper.whisper_base import WhisperBase
16
 
17
 
18
  class InsanelyFastWhisperInference(WhisperBase):
19
  def __init__(self,
20
- model_dir: str = os.path.join("models", "Whisper", "insanely-fast-whisper"),
21
- diarization_model_dir: str = os.path.join("models", "Diarization"),
22
- output_dir: str = os.path.join("outputs"),
23
  ):
24
  super().__init__(
25
  model_dir=model_dir,
 
11
  from rich.progress import Progress, TimeElapsedColumn, BarColumn, TextColumn
12
  from argparse import Namespace
13
 
14
+ from modules.utils.paths import (INSANELY_FAST_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR)
15
  from modules.whisper.whisper_parameter import *
16
  from modules.whisper.whisper_base import WhisperBase
17
 
18
 
19
  class InsanelyFastWhisperInference(WhisperBase):
20
  def __init__(self,
21
+ model_dir: str = INSANELY_FAST_WHISPER_MODELS_DIR,
22
+ diarization_model_dir: str = DIARIZATION_MODELS_DIR,
23
+ output_dir: str = OUTPUT_DIR,
24
  ):
25
  super().__init__(
26
  model_dir=model_dir,
modules/whisper/whisper_Inference.py CHANGED
@@ -7,15 +7,16 @@ import torch
7
  import os
8
  from argparse import Namespace
9
 
 
10
  from modules.whisper.whisper_base import WhisperBase
11
  from modules.whisper.whisper_parameter import *
12
 
13
 
14
  class WhisperInference(WhisperBase):
15
  def __init__(self,
16
- model_dir: str = os.path.join("models", "Whisper"),
17
- diarization_model_dir: str = os.path.join("models", "Diarization"),
18
- output_dir: str = os.path.join("outputs"),
19
  ):
20
  super().__init__(
21
  model_dir=model_dir,
 
7
  import os
8
  from argparse import Namespace
9
 
10
+ from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR)
11
  from modules.whisper.whisper_base import WhisperBase
12
  from modules.whisper.whisper_parameter import *
13
 
14
 
15
  class WhisperInference(WhisperBase):
16
  def __init__(self,
17
+ model_dir: str = WHISPER_MODELS_DIR,
18
+ diarization_model_dir: str = DIARIZATION_MODELS_DIR,
19
+ output_dir: str = OUTPUT_DIR,
20
  ):
21
  super().__init__(
22
  model_dir=model_dir,
modules/whisper/whisper_base.py CHANGED
@@ -9,9 +9,10 @@ from datetime import datetime
9
  from faster_whisper.vad import VadOptions
10
  from dataclasses import astuple
11
 
 
12
  from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
13
  from modules.utils.youtube_manager import get_ytdata, get_ytaudio
14
- from modules.utils.files_manager import get_media_files, format_gradio_files
15
  from modules.whisper.whisper_parameter import *
16
  from modules.diarize.diarizer import Diarizer
17
  from modules.vad.silero_vad import SileroVAD
@@ -19,9 +20,9 @@ from modules.vad.silero_vad import SileroVAD
19
 
20
  class WhisperBase(ABC):
21
  def __init__(self,
22
- model_dir: str = os.path.join("models", "Whisper"),
23
- diarization_model_dir: str = os.path.join("models", "Diarization"),
24
- output_dir: str = os.path.join("outputs"),
25
  ):
26
  self.model_dir = model_dir
27
  self.output_dir = output_dir
@@ -61,7 +62,8 @@ class WhisperBase(ABC):
61
 
62
  def run(self,
63
  audio: Union[str, BinaryIO, np.ndarray],
64
- progress: gr.Progress,
 
65
  *whisper_params,
66
  ) -> Tuple[List[dict], float]:
67
  """
@@ -75,6 +77,8 @@ class WhisperBase(ABC):
75
  Audio input. This can be file path or binary type.
76
  progress: gr.Progress
77
  Indicator to show progress directly in gradio.
 
 
78
  *whisper_params: tuple
79
  Parameters related with whisper. This will be dealt with "WhisperParameters" data class
80
 
@@ -87,6 +91,11 @@ class WhisperBase(ABC):
87
  """
88
  params = WhisperParameters.as_value(*whisper_params)
89
 
 
 
 
 
 
90
  if params.lang == "Automatic Detection":
91
  params.lang = None
92
  else:
@@ -178,6 +187,7 @@ class WhisperBase(ABC):
178
  transcribed_segments, time_for_task = self.run(
179
  file.name,
180
  progress,
 
181
  *whisper_params,
182
  )
183
 
@@ -301,6 +311,7 @@ class WhisperBase(ABC):
301
  transcribed_segments, time_for_task = self.run(
302
  audio,
303
  progress,
 
304
  *whisper_params,
305
  )
306
 
@@ -434,3 +445,15 @@ class WhisperBase(ABC):
434
  for file_path in file_paths:
435
  if file_path and os.path.exists(file_path):
436
  os.remove(file_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from faster_whisper.vad import VadOptions
10
  from dataclasses import astuple
11
 
12
+ from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH)
13
  from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
14
  from modules.utils.youtube_manager import get_ytdata, get_ytaudio
15
+ from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml
16
  from modules.whisper.whisper_parameter import *
17
  from modules.diarize.diarizer import Diarizer
18
  from modules.vad.silero_vad import SileroVAD
 
20
 
21
  class WhisperBase(ABC):
22
  def __init__(self,
23
+ model_dir: str = WHISPER_MODELS_DIR,
24
+ diarization_model_dir: str = DIARIZATION_MODELS_DIR,
25
+ output_dir: str = OUTPUT_DIR,
26
  ):
27
  self.model_dir = model_dir
28
  self.output_dir = output_dir
 
62
 
63
  def run(self,
64
  audio: Union[str, BinaryIO, np.ndarray],
65
+ progress: gr.Progress = gr.Progress(),
66
+ add_timestamp: bool = True,
67
  *whisper_params,
68
  ) -> Tuple[List[dict], float]:
69
  """
 
77
  Audio input. This can be file path or binary type.
78
  progress: gr.Progress
79
  Indicator to show progress directly in gradio.
80
+ add_timestamp: bool
81
+ Whether to add a timestamp at the end of the filename.
82
  *whisper_params: tuple
83
  Parameters related with whisper. This will be dealt with "WhisperParameters" data class
84
 
 
91
  """
92
  params = WhisperParameters.as_value(*whisper_params)
93
 
94
+ self.cache_parameters(
95
+ whisper_params=params,
96
+ add_timestamp=add_timestamp
97
+ )
98
+
99
  if params.lang == "Automatic Detection":
100
  params.lang = None
101
  else:
 
187
  transcribed_segments, time_for_task = self.run(
188
  file.name,
189
  progress,
190
+ add_timestamp,
191
  *whisper_params,
192
  )
193
 
 
311
  transcribed_segments, time_for_task = self.run(
312
  audio,
313
  progress,
314
+ add_timestamp,
315
  *whisper_params,
316
  )
317
 
 
445
  for file_path in file_paths:
446
  if file_path and os.path.exists(file_path):
447
  os.remove(file_path)
448
+
449
+ @staticmethod
450
+ def cache_parameters(
451
+ whisper_params: WhisperValues,
452
+ add_timestamp: bool
453
+ ):
454
+ cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
455
+ cached_whisper_param = whisper_params.to_yaml()
456
+ cached_yaml = {**cached_params, **cached_whisper_param}
457
+ cached_yaml["whisper"]["add_timestamp"] = add_timestamp
458
+
459
+ save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
modules/whisper/whisper_factory.py CHANGED
@@ -1,6 +1,8 @@
1
  from typing import Optional
2
  import os
3
 
 
 
4
  from modules.whisper.faster_whisper_inference import FasterWhisperInference
5
  from modules.whisper.whisper_Inference import WhisperInference
6
  from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
@@ -11,11 +13,11 @@ class WhisperFactory:
11
  @staticmethod
12
  def create_whisper_inference(
13
  whisper_type: str,
14
- whisper_model_dir: str = os.path.join("models", "Whisper"),
15
- faster_whisper_model_dir: str = os.path.join("models", "Whisper", "faster-whisper"),
16
- insanely_fast_whisper_model_dir: str = os.path.join("models", "Whisper", "insanely-fast-whisper"),
17
- diarization_model_dir: str = os.path.join("models", "Diarization"),
18
- output_dir: str = os.path.join("outputs"),
19
  ) -> "WhisperBase":
20
  """
21
  Create a whisper inference class based on the provided whisper_type.
 
1
  from typing import Optional
2
  import os
3
 
4
+ from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR,
5
+ INSANELY_FAST_WHISPER_MODELS_DIR, WHISPER_MODELS_DIR)
6
  from modules.whisper.faster_whisper_inference import FasterWhisperInference
7
  from modules.whisper.whisper_Inference import WhisperInference
8
  from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
 
13
  @staticmethod
14
  def create_whisper_inference(
15
  whisper_type: str,
16
+ whisper_model_dir: str = WHISPER_MODELS_DIR,
17
+ faster_whisper_model_dir: str = FASTER_WHISPER_MODELS_DIR,
18
+ insanely_fast_whisper_model_dir: str = INSANELY_FAST_WHISPER_MODELS_DIR,
19
+ diarization_model_dir: str = DIARIZATION_MODELS_DIR,
20
+ output_dir: str = OUTPUT_DIR,
21
  ) -> "WhisperBase":
22
  """
23
  Create a whisper inference class based on the provided whisper_type.
modules/whisper/whisper_parameter.py CHANGED
@@ -1,6 +1,7 @@
1
  from dataclasses import dataclass, fields
2
  import gradio as gr
3
- from typing import Optional
 
4
 
5
 
6
  @dataclass
@@ -274,4 +275,54 @@ class WhisperValues:
274
  language_detection_segments: int
275
  """
276
  A data class to use Whisper parameters.
277
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from dataclasses import dataclass, fields
2
  import gradio as gr
3
+ from typing import Optional, Dict
4
+ import yaml
5
 
6
 
7
  @dataclass
 
275
  language_detection_segments: int
276
  """
277
  A data class to use Whisper parameters.
278
+ """
279
+
280
+ def to_yaml(self) -> Dict:
281
+ data = {
282
+ "whisper": {
283
+ "model_size": self.model_size,
284
+ "lang": "Automatic Detection" if self.lang is None else self.lang,
285
+ "is_translate": self.is_translate,
286
+ "beam_size": self.beam_size,
287
+ "log_prob_threshold": self.log_prob_threshold,
288
+ "no_speech_threshold": self.no_speech_threshold,
289
+ "best_of": self.best_of,
290
+ "patience": self.patience,
291
+ "condition_on_previous_text": self.condition_on_previous_text,
292
+ "prompt_reset_on_temperature": self.prompt_reset_on_temperature,
293
+ "initial_prompt": None if not self.initial_prompt else self.initial_prompt,
294
+ "temperature": self.temperature,
295
+ "compression_ratio_threshold": self.compression_ratio_threshold,
296
+ "chunk_length_s": None if self.chunk_length_s is None else self.chunk_length_s,
297
+ "batch_size": self.batch_size,
298
+ "length_penalty": self.length_penalty,
299
+ "repetition_penalty": self.repetition_penalty,
300
+ "no_repeat_ngram_size": self.no_repeat_ngram_size,
301
+ "prefix": None if not self.prefix else self.prefix,
302
+ "suppress_blank": self.suppress_blank,
303
+ "suppress_tokens": self.suppress_tokens,
304
+ "max_initial_timestamp": self.max_initial_timestamp,
305
+ "word_timestamps": self.word_timestamps,
306
+ "prepend_punctuations": self.prepend_punctuations,
307
+ "append_punctuations": self.append_punctuations,
308
+ "max_new_tokens": self.max_new_tokens,
309
+ "chunk_length": self.chunk_length,
310
+ "hallucination_silence_threshold": self.hallucination_silence_threshold,
311
+ "hotwords": None if not self.hotwords else self.hotwords,
312
+ "language_detection_threshold": self.language_detection_threshold,
313
+ "language_detection_segments": self.language_detection_segments,
314
+ },
315
+ "vad": {
316
+ "vad_filter": self.vad_filter,
317
+ "threshold": self.threshold,
318
+ "min_speech_duration_ms": self.min_speech_duration_ms,
319
+ "max_speech_duration_s": self.max_speech_duration_s,
320
+ "min_silence_duration_ms": self.min_silence_duration_ms,
321
+ "speech_pad_ms": self.speech_pad_ms,
322
+ },
323
+ "diarization": {
324
+ "is_diarize": self.is_diarize,
325
+ "hf_token": self.hf_token
326
+ }
327
+ }
328
+ return data
requirements.txt CHANGED
@@ -11,4 +11,5 @@ faster-whisper==1.0.3
11
  transformers==4.42.3
12
  gradio==4.29.0
13
  pytubefix
 
14
  pyannote.audio==3.3.1
 
11
  transformers==4.42.3
12
  gradio==4.29.0
13
  pytubefix
14
+ ruamel.yaml==0.18.6
15
  pyannote.audio==3.3.1