jhj0517 commited on
Commit
d843d51
·
unverified ·
2 Parent(s): d868316 091209e

Merge pull request #173 from jhj0517/fix/refactor-scalability

Browse files
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
  import os
3
  import argparse
4
- import webbrowser
5
 
6
  from modules.whisper_Inference import WhisperInference
7
  from modules.faster_whisper_inference import FasterWhisperInference
@@ -16,17 +15,26 @@ class App:
16
  def __init__(self, args):
17
  self.args = args
18
  self.app = gr.Blocks(css=CSS, theme=self.args.theme)
19
- self.whisper_inf = WhisperInference() if self.args.disable_faster_whisper else FasterWhisperInference()
20
- if isinstance(self.whisper_inf, FasterWhisperInference):
21
- self.whisper_inf.model_dir = args.faster_whisper_model_dir
22
- print("Use Faster Whisper implementation")
23
- else:
24
- self.whisper_inf.model_dir = args.whisper_model_dir
25
- print("Use Open AI Whisper implementation")
26
  print(f"Device \"{self.whisper_inf.device}\" is detected")
27
  self.nllb_inf = NLLBInference()
28
  self.deepl_api = DeepLAPI()
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  @staticmethod
31
  def open_folder(folder_path: str):
32
  if os.path.exists(folder_path):
@@ -61,7 +69,7 @@ class App:
61
  cb_translate = gr.Checkbox(value=False, label="Translate to English?", interactive=True)
62
  with gr.Row():
63
  cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename", interactive=True)
64
- with gr.Accordion("VAD Options", open=False, visible=not self.args.disable_faster_whisper):
65
  cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
66
  sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5)
67
  nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
@@ -135,7 +143,7 @@ class App:
135
  with gr.Row():
136
  cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename",
137
  interactive=True)
138
- with gr.Accordion("VAD Options", open=False, visible=not self.args.disable_faster_whisper):
139
  cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
140
  sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5)
141
  nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
@@ -201,7 +209,7 @@ class App:
201
  dd_file_format = gr.Dropdown(["SRT", "WebVTT", "txt"], value="SRT", label="File Format")
202
  with gr.Row():
203
  cb_translate = gr.Checkbox(value=False, label="Translate to English?", interactive=True)
204
- with gr.Accordion("VAD Options", open=False, visible=not self.args.disable_faster_whisper):
205
  cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
206
  sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5)
207
  nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
@@ -289,7 +297,7 @@ class App:
289
 
290
  with gr.TabItem("NLLB"): # sub tab2
291
  with gr.Row():
292
- dd_nllb_model = gr.Dropdown(label="Model", value=self.nllb_inf.default_model_size,
293
  choices=self.nllb_inf.available_models)
294
  dd_nllb_sourcelang = gr.Dropdown(label="Source Language",
295
  choices=self.nllb_inf.available_source_langs)
@@ -332,7 +340,7 @@ class App:
332
 
333
  # Create the parser for command-line arguments
334
  parser = argparse.ArgumentParser()
335
- parser.add_argument('--disable_faster_whisper', type=bool, default=False, nargs='?', const=True, help='Disable the faster_whisper implementation. faster_whipser is implemented by https://github.com/guillaumekln/faster-whisper')
336
  parser.add_argument('--share', type=bool, default=False, nargs='?', const=True, help='Gradio share value')
337
  parser.add_argument('--server_name', type=str, default=None, help='Gradio server host')
338
  parser.add_argument('--server_port', type=int, default=None, help='Gradio server port')
 
1
  import gradio as gr
2
  import os
3
  import argparse
 
4
 
5
  from modules.whisper_Inference import WhisperInference
6
  from modules.faster_whisper_inference import FasterWhisperInference
 
15
  def __init__(self, args):
16
  self.args = args
17
  self.app = gr.Blocks(css=CSS, theme=self.args.theme)
18
+ self.whisper_inf = self.init_whisper()
19
+ print(f"Use \"{self.args.whisper_type}\" implementation")
 
 
 
 
 
20
  print(f"Device \"{self.whisper_inf.device}\" is detected")
21
  self.nllb_inf = NLLBInference()
22
  self.deepl_api = DeepLAPI()
23
 
24
+ def init_whisper(self):
25
+ whisper_type = self.args.whisper_type.lower().strip()
26
+
27
+ if whisper_type in ["faster_whisper", "faster-whisper"]:
28
+ whisper_inf = FasterWhisperInference()
29
+ whisper_inf.model_dir = self.args.faster_whisper_model_dir
30
+ if whisper_type in ["whisper"]:
31
+ whisper_inf = WhisperInference()
32
+ whisper_inf.model_dir = self.args.whisper_model_dir
33
+ else:
34
+ whisper_inf = FasterWhisperInference()
35
+ whisper_inf.model_dir = self.args.faster_whisper_model_dir
36
+ return whisper_inf
37
+
38
  @staticmethod
39
  def open_folder(folder_path: str):
40
  if os.path.exists(folder_path):
 
69
  cb_translate = gr.Checkbox(value=False, label="Translate to English?", interactive=True)
70
  with gr.Row():
71
  cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename", interactive=True)
72
+ with gr.Accordion("VAD Options", open=False, visible=isinstance(self.whisper_inf, FasterWhisperInference)):
73
  cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
74
  sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5)
75
  nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
 
143
  with gr.Row():
144
  cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename",
145
  interactive=True)
146
+ with gr.Accordion("VAD Options", open=False, visible=isinstance(self.whisper_inf, FasterWhisperInference)):
147
  cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
148
  sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5)
149
  nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
 
209
  dd_file_format = gr.Dropdown(["SRT", "WebVTT", "txt"], value="SRT", label="File Format")
210
  with gr.Row():
211
  cb_translate = gr.Checkbox(value=False, label="Translate to English?", interactive=True)
212
+ with gr.Accordion("VAD Options", open=False, visible=isinstance(self.whisper_inf, FasterWhisperInference)):
213
  cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
214
  sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5)
215
  nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
 
297
 
298
  with gr.TabItem("NLLB"): # sub tab2
299
  with gr.Row():
300
+ dd_nllb_model = gr.Dropdown(label="Model", value="facebook/nllb-200-1.3B",
301
  choices=self.nllb_inf.available_models)
302
  dd_nllb_sourcelang = gr.Dropdown(label="Source Language",
303
  choices=self.nllb_inf.available_source_langs)
 
340
 
341
  # Create the parser for command-line arguments
342
  parser = argparse.ArgumentParser()
343
+ parser.add_argument('--whisper_type', type=str, default="faster-whisper", help='A type of the whisper implementation between: ["whisper", "faster-whisper"]')
344
  parser.add_argument('--share', type=bool, default=False, nargs='?', const=True, help='Gradio share value')
345
  parser.add_argument('--server_name', type=str, default=None, help='Gradio server host')
346
  parser.add_argument('--server_port', type=int, default=None, help='Gradio server port')
modules/base_interface.py DELETED
@@ -1,23 +0,0 @@
1
- import os
2
- import torch
3
- from typing import List
4
-
5
-
6
- class BaseInterface:
7
- def __init__(self):
8
- pass
9
-
10
- @staticmethod
11
- def release_cuda_memory():
12
- if torch.cuda.is_available():
13
- torch.cuda.empty_cache()
14
- torch.cuda.reset_max_memory_allocated()
15
-
16
- @staticmethod
17
- def remove_input_files(file_paths: List[str]):
18
- if not file_paths:
19
- return
20
-
21
- for file_path in file_paths:
22
- if file_path and os.path.exists(file_path):
23
- os.remove(file_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/faster_whisper_inference.py CHANGED
@@ -2,233 +2,29 @@ import os
2
  import time
3
  import numpy as np
4
  from typing import BinaryIO, Union, Tuple, List
5
- from datetime import datetime
6
 
7
  import faster_whisper
8
  from faster_whisper.vad import VadOptions
9
  import ctranslate2
10
  import whisper
11
- import torch
12
  import gradio as gr
13
 
14
- from .base_interface import BaseInterface
15
- from modules.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
16
- from modules.youtube_manager import get_ytdata, get_ytaudio
17
  from modules.whisper_parameter import *
 
18
 
19
  # Temporal fix of the issue : https://github.com/jhj0517/Whisper-WebUI/issues/144
20
  os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
21
 
22
 
23
- class FasterWhisperInference(BaseInterface):
24
  def __init__(self):
25
- super().__init__()
26
- self.model_dir = os.path.join("models", "Whisper", "faster-whisper")
27
- os.makedirs(self.model_dir, exist_ok=True)
28
- self.current_model_size = None
29
- self.model = None
30
  self.model_paths = self.get_model_paths()
31
  self.available_models = self.model_paths.keys()
32
- self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
33
- self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
34
- if torch.cuda.is_available():
35
- self.device = "cuda"
36
- elif torch.backends.mps.is_available():
37
- self.device = "mps"
38
- else:
39
- self.device = "cpu"
40
  self.available_compute_types = ctranslate2.get_supported_compute_types(
41
  "cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
42
- self.current_compute_type = "float16" if self.device == "cuda" else "float32"
43
-
44
- def transcribe_file(self,
45
- files: list,
46
- file_format: str,
47
- add_timestamp: bool,
48
- progress=gr.Progress(),
49
- *whisper_params,
50
- ) -> list:
51
- """
52
- Write subtitle file from Files
53
-
54
- Parameters
55
- ----------
56
- files: list
57
- List of files to transcribe from gr.Files()
58
- file_format: str
59
- Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
60
- add_timestamp: bool
61
- Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
62
- progress: gr.Progress
63
- Indicator to show progress directly in gradio.
64
- *whisper_params: tuple
65
- Gradio components related to Whisper. see whisper_data_class.py for details.
66
-
67
- Returns
68
- ----------
69
- result_str:
70
- Result of transcription to return to gr.Textbox()
71
- result_file_path:
72
- Output file path to return to gr.Files()
73
- """
74
- try:
75
- files_info = {}
76
- for file in files:
77
- transcribed_segments, time_for_task = self.transcribe(
78
- file.name,
79
- progress,
80
- *whisper_params,
81
- )
82
-
83
- file_name, file_ext = os.path.splitext(os.path.basename(file.name))
84
- file_name = safe_filename(file_name)
85
- subtitle, file_path = self.generate_and_write_file(
86
- file_name=file_name,
87
- transcribed_segments=transcribed_segments,
88
- add_timestamp=add_timestamp,
89
- file_format=file_format
90
- )
91
- files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path}
92
-
93
- total_result = ''
94
- total_time = 0
95
- for file_name, info in files_info.items():
96
- total_result += '------------------------------------\n'
97
- total_result += f'{file_name}\n\n'
98
- total_result += f'{info["subtitle"]}'
99
- total_time += info["time_for_task"]
100
-
101
- result_str = f"Done in {self.format_time(total_time)}! Subtitle is in the outputs folder.\n\n{total_result}"
102
- result_file_path = [info['path'] for info in files_info.values()]
103
-
104
- return [result_str, result_file_path]
105
-
106
- except Exception as e:
107
- print(f"Error transcribing file: {e}")
108
- finally:
109
- self.release_cuda_memory()
110
- if not files:
111
- self.remove_input_files([file.name for file in files])
112
-
113
- def transcribe_youtube(self,
114
- youtube_link: str,
115
- file_format: str,
116
- add_timestamp: bool,
117
- progress=gr.Progress(),
118
- *whisper_params,
119
- ) -> list:
120
- """
121
- Write subtitle file from Youtube
122
-
123
- Parameters
124
- ----------
125
- youtube_link: str
126
- URL of the Youtube video to transcribe from gr.Textbox()
127
- file_format: str
128
- Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
129
- add_timestamp: bool
130
- Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
131
- progress: gr.Progress
132
- Indicator to show progress directly in gradio.
133
- *whisper_params: tuple
134
- Gradio components related to Whisper. see whisper_data_class.py for details.
135
-
136
- Returns
137
- ----------
138
- result_str:
139
- Result of transcription to return to gr.Textbox()
140
- result_file_path:
141
- Output file path to return to gr.Files()
142
- """
143
- try:
144
- progress(0, desc="Loading Audio from Youtube..")
145
- yt = get_ytdata(youtube_link)
146
- audio = get_ytaudio(yt)
147
-
148
- transcribed_segments, time_for_task = self.transcribe(
149
- audio,
150
- progress,
151
- *whisper_params,
152
- )
153
-
154
- progress(1, desc="Completed!")
155
-
156
- file_name = safe_filename(yt.title)
157
- subtitle, result_file_path = self.generate_and_write_file(
158
- file_name=file_name,
159
- transcribed_segments=transcribed_segments,
160
- add_timestamp=add_timestamp,
161
- file_format=file_format
162
- )
163
- result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
164
-
165
- return [result_str, result_file_path]
166
-
167
- except Exception as e:
168
- print(f"Error transcribing file: {e}")
169
- finally:
170
- try:
171
- if 'yt' not in locals():
172
- yt = get_ytdata(youtube_link)
173
- file_path = get_ytaudio(yt)
174
- else:
175
- file_path = get_ytaudio(yt)
176
-
177
- self.release_cuda_memory()
178
- self.remove_input_files([file_path])
179
- except Exception as cleanup_error:
180
- pass
181
-
182
- def transcribe_mic(self,
183
- mic_audio: str,
184
- file_format: str,
185
- progress=gr.Progress(),
186
- *whisper_params,
187
- ) -> list:
188
- """
189
- Write subtitle file from microphone
190
-
191
- Parameters
192
- ----------
193
- mic_audio: str
194
- Audio file path from gr.Microphone()
195
- file_format: str
196
- Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
197
- progress: gr.Progress
198
- Indicator to show progress directly in gradio.
199
- *whisper_params: tuple
200
- Gradio components related to Whisper. see whisper_data_class.py for details.
201
-
202
- Returns
203
- ----------
204
- result_str:
205
- Result of transcription to return to gr.Textbox()
206
- result_file_path:
207
- Output file path to return to gr.Files()
208
- """
209
- try:
210
- progress(0, desc="Loading Audio..")
211
- transcribed_segments, time_for_task = self.transcribe(
212
- mic_audio,
213
- progress,
214
- *whisper_params,
215
- )
216
- progress(1, desc="Completed!")
217
-
218
- subtitle, result_file_path = self.generate_and_write_file(
219
- file_name="Mic",
220
- transcribed_segments=transcribed_segments,
221
- add_timestamp=True,
222
- file_format=file_format
223
- )
224
-
225
- result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
226
- return [result_str, result_file_path]
227
- except Exception as e:
228
- print(f"Error transcribing file: {e}")
229
- finally:
230
- self.release_cuda_memory()
231
- self.remove_input_files([mic_audio])
232
 
233
  def transcribe(self,
234
  audio: Union[str, BinaryIO, np.ndarray],
@@ -356,79 +152,3 @@ class FasterWhisperInference(BaseInterface):
356
  if model_name not in whisper.available_models():
357
  model_paths[model_name] = os.path.join(webui_dir, self.model_dir, model_name)
358
  return model_paths
359
-
360
- @staticmethod
361
- def generate_and_write_file(file_name: str,
362
- transcribed_segments: list,
363
- add_timestamp: bool,
364
- file_format: str,
365
- ) -> str:
366
- """
367
- Writes subtitle file
368
-
369
- Parameters
370
- ----------
371
- file_name: str
372
- Output file name
373
- transcribed_segments: list
374
- Text segments transcribed from audio
375
- add_timestamp: bool
376
- Determines whether to add a timestamp to the end of the filename.
377
- file_format: str
378
- File format to write. Supported formats: [SRT, WebVTT, txt]
379
-
380
- Returns
381
- ----------
382
- content: str
383
- Result of the transcription
384
- output_path: str
385
- output file path
386
- """
387
- timestamp = datetime.now().strftime("%m%d%H%M%S")
388
- if add_timestamp:
389
- output_path = os.path.join("outputs", f"{file_name}-{timestamp}")
390
- else:
391
- output_path = os.path.join("outputs", f"{file_name}")
392
-
393
- if file_format == "SRT":
394
- content = get_srt(transcribed_segments)
395
- output_path += '.srt'
396
- write_file(content, output_path)
397
-
398
- elif file_format == "WebVTT":
399
- content = get_vtt(transcribed_segments)
400
- output_path += '.vtt'
401
- write_file(content, output_path)
402
-
403
- elif file_format == "txt":
404
- content = get_txt(transcribed_segments)
405
- output_path += '.txt'
406
- write_file(content, output_path)
407
- return content, output_path
408
-
409
- @staticmethod
410
- def format_time(elapsed_time: float) -> str:
411
- """
412
- Get {hours} {minutes} {seconds} time format string
413
-
414
- Parameters
415
- ----------
416
- elapsed_time: str
417
- Elapsed time for transcription
418
-
419
- Returns
420
- ----------
421
- Time format string
422
- """
423
- hours, rem = divmod(elapsed_time, 3600)
424
- minutes, seconds = divmod(rem, 60)
425
-
426
- time_str = ""
427
- if hours:
428
- time_str += f"{hours} hours "
429
- if minutes:
430
- time_str += f"{minutes} minutes "
431
- seconds = round(seconds)
432
- time_str += f"{seconds} seconds"
433
-
434
- return time_str.strip()
 
2
  import time
3
  import numpy as np
4
  from typing import BinaryIO, Union, Tuple, List
 
5
 
6
  import faster_whisper
7
  from faster_whisper.vad import VadOptions
8
  import ctranslate2
9
  import whisper
 
10
  import gradio as gr
11
 
 
 
 
12
  from modules.whisper_parameter import *
13
+ from modules.whisper_base import WhisperBase
14
 
15
  # Temporal fix of the issue : https://github.com/jhj0517/Whisper-WebUI/issues/144
16
  os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
17
 
18
 
19
+ class FasterWhisperInference(WhisperBase):
20
  def __init__(self):
21
+ super().__init__(
22
+ model_dir=os.path.join("models", "Whisper", "faster-whisper")
23
+ )
 
 
24
  self.model_paths = self.get_model_paths()
25
  self.available_models = self.model_paths.keys()
 
 
 
 
 
 
 
 
26
  self.available_compute_types = ctranslate2.get_supported_compute_types(
27
  "cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  def transcribe(self,
30
  audio: Union[str, BinaryIO, np.ndarray],
 
152
  if model_name not in whisper.available_models():
153
  model_paths[model_name] = os.path.join(webui_dir, self.model_dir, model_name)
154
  return model_paths
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/nllb_inference.py CHANGED
@@ -1,141 +1,49 @@
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
2
  import gradio as gr
3
- import torch
4
  import os
5
- from datetime import datetime
6
 
7
- from .base_interface import BaseInterface
8
- from modules.subtitle_manager import *
9
 
10
- DEFAULT_MODEL_SIZE = "facebook/nllb-200-1.3B"
11
- NLLB_MODELS = ["facebook/nllb-200-3.3B", "facebook/nllb-200-1.3B", "facebook/nllb-200-distilled-600M"]
12
 
13
-
14
- class NLLBInference(BaseInterface):
15
  def __init__(self):
16
- super().__init__()
17
- self.default_model_size = DEFAULT_MODEL_SIZE
18
- self.current_model_size = None
19
- self.model = None
20
  self.tokenizer = None
21
- self.available_models = NLLB_MODELS
22
  self.available_source_langs = list(NLLB_AVAILABLE_LANGS.keys())
23
  self.available_target_langs = list(NLLB_AVAILABLE_LANGS.keys())
24
- self.device = 0 if torch.cuda.is_available() else -1
25
  self.pipeline = None
26
 
27
- def translate_text(self, text):
 
 
28
  result = self.pipeline(text)
29
  return result[0]['translation_text']
30
 
31
- def translate_file(self,
32
- fileobjs: list,
33
- model_size: str,
34
- src_lang: str,
35
- tgt_lang: str,
36
- add_timestamp: bool,
37
- progress=gr.Progress()) -> list:
38
- """
39
- Translate subtitle file from source language to target language
40
-
41
- Parameters
42
- ----------
43
- fileobjs: list
44
- List of files to transcribe from gr.Files()
45
- model_size: str
46
- Whisper model size from gr.Dropdown()
47
- src_lang: str
48
- Source language of the file to translate from gr.Dropdown()
49
- tgt_lang: str
50
- Target language of the file to translate from gr.Dropdown()
51
- add_timestamp: bool
52
- Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
53
- progress: gr.Progress
54
- Indicator to show progress directly in gradio.
55
- I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
56
-
57
- Returns
58
- ----------
59
- A List of
60
- String to return to gr.Textbox()
61
- Files to return to gr.Files()
62
- """
63
- try:
64
- if model_size != self.current_model_size or self.model is None:
65
- print("\nInitializing NLLB Model..\n")
66
- progress(0, desc="Initializing NLLB Model..")
67
- self.current_model_size = model_size
68
- self.model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path=model_size,
69
- cache_dir=os.path.join("models", "NLLB"))
70
- self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size,
71
- cache_dir=os.path.join("models", "NLLB", "tokenizers"))
72
-
73
- src_lang = NLLB_AVAILABLE_LANGS[src_lang]
74
- tgt_lang = NLLB_AVAILABLE_LANGS[tgt_lang]
75
-
76
- self.pipeline = pipeline("translation",
77
- model=self.model,
78
- tokenizer=self.tokenizer,
79
- src_lang=src_lang,
80
- tgt_lang=tgt_lang,
81
- device=self.device)
82
-
83
- files_info = {}
84
- for fileobj in fileobjs:
85
- file_path = fileobj.name
86
- file_name, file_ext = os.path.splitext(os.path.basename(fileobj.name))
87
- if file_ext == ".srt":
88
- parsed_dicts = parse_srt(file_path=file_path)
89
- total_progress = len(parsed_dicts)
90
- for index, dic in enumerate(parsed_dicts):
91
- progress(index / total_progress, desc="Translating..")
92
- translated_text = self.translate_text(dic["sentence"])
93
- dic["sentence"] = translated_text
94
- subtitle = get_serialized_srt(parsed_dicts)
95
-
96
- timestamp = datetime.now().strftime("%m%d%H%M%S")
97
- if add_timestamp:
98
- output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}")
99
- else:
100
- output_path = os.path.join("outputs", "translations", f"{file_name}")
101
- output_path += '.srt'
102
-
103
- write_file(subtitle, output_path)
104
-
105
- elif file_ext == ".vtt":
106
- parsed_dicts = parse_vtt(file_path=file_path)
107
- total_progress = len(parsed_dicts)
108
- for index, dic in enumerate(parsed_dicts):
109
- progress(index / total_progress, desc="Translating..")
110
- translated_text = self.translate_text(dic["sentence"])
111
- dic["sentence"] = translated_text
112
- subtitle = get_serialized_vtt(parsed_dicts)
113
-
114
- timestamp = datetime.now().strftime("%m%d%H%M%S")
115
- if add_timestamp:
116
- output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}")
117
- else:
118
- output_path = os.path.join("outputs", "translations", f"{file_name}")
119
- output_path += '.vtt'
120
-
121
- write_file(subtitle, output_path)
122
-
123
- files_info[file_name] = subtitle
124
-
125
- total_result = ''
126
- for file_name, subtitle in files_info.items():
127
- total_result += '------------------------------------\n'
128
- total_result += f'{file_name}\n\n'
129
- total_result += f'{subtitle}'
130
-
131
- gr_str = f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}"
132
- return [gr_str, output_path]
133
- except Exception as e:
134
- print(f"Error: {str(e)}")
135
- finally:
136
- self.release_cuda_memory()
137
- self.remove_input_files([fileobj.name for fileobj in fileobjs])
138
-
139
 
140
  NLLB_AVAILABLE_LANGS = {
141
  "Acehnese (Arabic script)": "ace_Arab",
 
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
2
  import gradio as gr
 
3
  import os
 
4
 
5
+ from modules.translation_base import TranslationBase
 
6
 
 
 
7
 
8
+ class NLLBInference(TranslationBase):
 
9
  def __init__(self):
10
+ super().__init__(
11
+ model_dir=os.path.join("models", "NLLB")
12
+ )
 
13
  self.tokenizer = None
14
+ self.available_models = ["facebook/nllb-200-3.3B", "facebook/nllb-200-1.3B", "facebook/nllb-200-distilled-600M"]
15
  self.available_source_langs = list(NLLB_AVAILABLE_LANGS.keys())
16
  self.available_target_langs = list(NLLB_AVAILABLE_LANGS.keys())
 
17
  self.pipeline = None
18
 
19
+ def translate(self,
20
+ text: str
21
+ ):
22
  result = self.pipeline(text)
23
  return result[0]['translation_text']
24
 
25
+ def update_model(self,
26
+ model_size: str,
27
+ src_lang: str,
28
+ tgt_lang: str,
29
+ progress: gr.Progress
30
+ ):
31
+ if model_size != self.current_model_size or self.model is None:
32
+ print("\nInitializing NLLB Model..\n")
33
+ progress(0, desc="Initializing NLLB Model..")
34
+ self.current_model_size = model_size
35
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path=model_size,
36
+ cache_dir=self.model_dir)
37
+ self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size,
38
+ cache_dir=os.path.join(self.model_dir, "tokenizers"))
39
+ src_lang = NLLB_AVAILABLE_LANGS[src_lang]
40
+ tgt_lang = NLLB_AVAILABLE_LANGS[tgt_lang]
41
+ self.pipeline = pipeline("translation",
42
+ model=self.model,
43
+ tokenizer=self.tokenizer,
44
+ src_lang=src_lang,
45
+ tgt_lang=tgt_lang,
46
+ device=self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  NLLB_AVAILABLE_LANGS = {
49
  "Acehnese (Arabic script)": "ace_Arab",
modules/translation_base.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ from abc import ABC, abstractmethod
5
+ from typing import List
6
+ from datetime import datetime
7
+
8
+ from modules.whisper_parameter import *
9
+ from modules.subtitle_manager import *
10
+
11
+
12
+ class TranslationBase(ABC):
13
+ def __init__(self,
14
+ model_dir: str):
15
+ super().__init__()
16
+ self.model = None
17
+ self.model_dir = model_dir
18
+ os.makedirs(self.model_dir, exist_ok=True)
19
+ self.current_model_size = None
20
+ self.device = self.get_device()
21
+
22
+ @abstractmethod
23
+ def translate(self,
24
+ text: str
25
+ ):
26
+ pass
27
+
28
+ @abstractmethod
29
+ def update_model(self,
30
+ model_size: str,
31
+ src_lang: str,
32
+ tgt_lang: str,
33
+ progress: gr.Progress
34
+ ):
35
+ pass
36
+
37
+ def translate_file(self,
38
+ fileobjs: list,
39
+ model_size: str,
40
+ src_lang: str,
41
+ tgt_lang: str,
42
+ add_timestamp: bool,
43
+ progress=gr.Progress()) -> list:
44
+ """
45
+ Translate subtitle file from source language to target language
46
+
47
+ Parameters
48
+ ----------
49
+ fileobjs: list
50
+ List of files to transcribe from gr.Files()
51
+ model_size: str
52
+ Whisper model size from gr.Dropdown()
53
+ src_lang: str
54
+ Source language of the file to translate from gr.Dropdown()
55
+ tgt_lang: str
56
+ Target language of the file to translate from gr.Dropdown()
57
+ add_timestamp: bool
58
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
59
+ progress: gr.Progress
60
+ Indicator to show progress directly in gradio.
61
+ I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
62
+
63
+ Returns
64
+ ----------
65
+ A List of
66
+ String to return to gr.Textbox()
67
+ Files to return to gr.Files()
68
+ """
69
+ try:
70
+ self.update_model(model_size=model_size,
71
+ src_lang=src_lang,
72
+ tgt_lang=tgt_lang,
73
+ progress=progress)
74
+
75
+ files_info = {}
76
+ for fileobj in fileobjs:
77
+ file_path = fileobj.name
78
+ file_name, file_ext = os.path.splitext(os.path.basename(fileobj.name))
79
+ if file_ext == ".srt":
80
+ parsed_dicts = parse_srt(file_path=file_path)
81
+ total_progress = len(parsed_dicts)
82
+ for index, dic in enumerate(parsed_dicts):
83
+ progress(index / total_progress, desc="Translating..")
84
+ translated_text = self.translate(dic["sentence"])
85
+ dic["sentence"] = translated_text
86
+ subtitle = get_serialized_srt(parsed_dicts)
87
+
88
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
89
+ if add_timestamp:
90
+ output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}")
91
+ else:
92
+ output_path = os.path.join("outputs", "translations", f"{file_name}.srt")
93
+
94
+ elif file_ext == ".vtt":
95
+ parsed_dicts = parse_vtt(file_path=file_path)
96
+ total_progress = len(parsed_dicts)
97
+ for index, dic in enumerate(parsed_dicts):
98
+ progress(index / total_progress, desc="Translating..")
99
+ translated_text = self.translate(dic["sentence"])
100
+ dic["sentence"] = translated_text
101
+ subtitle = get_serialized_vtt(parsed_dicts)
102
+
103
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
104
+ if add_timestamp:
105
+ output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}")
106
+ else:
107
+ output_path = os.path.join("outputs", "translations", f"{file_name}.vtt")
108
+
109
+ write_file(subtitle, output_path)
110
+ files_info[file_name] = subtitle
111
+
112
+ total_result = ''
113
+ for file_name, subtitle in files_info.items():
114
+ total_result += '------------------------------------\n'
115
+ total_result += f'{file_name}\n\n'
116
+ total_result += f'{subtitle}'
117
+
118
+ gr_str = f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}"
119
+ return [gr_str, output_path]
120
+ except Exception as e:
121
+ print(f"Error: {str(e)}")
122
+ finally:
123
+ self.release_cuda_memory()
124
+ self.remove_input_files([fileobj.name for fileobj in fileobjs])
125
+
126
+ @staticmethod
127
+ def get_device():
128
+ if torch.cuda.is_available():
129
+ return "cuda"
130
+ elif torch.backends.mps.is_available():
131
+ return "mps"
132
+ else:
133
+ return "cpu"
134
+
135
+ @staticmethod
136
+ def release_cuda_memory():
137
+ if torch.cuda.is_available():
138
+ torch.cuda.empty_cache()
139
+ torch.cuda.reset_max_memory_allocated()
140
+
141
+ @staticmethod
142
+ def remove_input_files(file_paths: List[str]):
143
+ if not file_paths:
144
+ return
145
+
146
+ for file_path in file_paths:
147
+ if file_path and os.path.exists(file_path):
148
+ os.remove(file_path)
modules/whisper_Inference.py CHANGED
@@ -4,218 +4,17 @@ import time
4
  import os
5
  from typing import BinaryIO, Union, Tuple, List
6
  import numpy as np
7
- from datetime import datetime
8
  import torch
9
 
10
- from .base_interface import BaseInterface
11
- from modules.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
12
- from modules.youtube_manager import get_ytdata, get_ytaudio
13
  from modules.whisper_parameter import *
14
 
15
- DEFAULT_MODEL_SIZE = "large-v3"
16
 
17
-
18
- class WhisperInference(BaseInterface):
19
  def __init__(self):
20
- super().__init__()
21
- self.current_model_size = None
22
- self.model = None
23
- self.available_models = whisper.available_models()
24
- self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
25
- self.translatable_model = ["large", "large-v1", "large-v2", "large-v3"]
26
- if torch.cuda.is_available():
27
- self.device = "cuda"
28
- elif torch.backends.mps.is_available():
29
- self.device = "mps"
30
- else:
31
- self.device = "cpu"
32
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
33
- self.available_compute_types = ["float16", "float32"]
34
- self.current_compute_type = "float16" if self.device == "cuda" else "float32"
35
- self.model_dir = os.path.join("models", "Whisper")
36
-
37
- def transcribe_file(self,
38
- files: list,
39
- file_format: str,
40
- add_timestamp: bool,
41
- progress=gr.Progress(),
42
- *whisper_params
43
- ) -> list:
44
- """
45
- Write subtitle file from Files
46
-
47
- Parameters
48
- ----------
49
- files: list
50
- List of files to transcribe from gr.Files()
51
- file_format: str
52
- Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
53
- add_timestamp: bool
54
- Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
55
- progress: gr.Progress
56
- Indicator to show progress directly in gradio.
57
- *whisper_params: tuple
58
- Gradio components related to Whisper. see whisper_data_class.py for details.
59
-
60
- Returns
61
- ----------
62
- result_str:
63
- Result of transcription to return to gr.Textbox()
64
- result_file_path:
65
- Output file path to return to gr.Files()
66
- """
67
- try:
68
- files_info = {}
69
- for file in files:
70
- progress(0, desc="Loading Audio..")
71
- audio = whisper.load_audio(file.name)
72
-
73
- result, elapsed_time = self.transcribe(audio,
74
- progress,
75
- *whisper_params)
76
- progress(1, desc="Completed!")
77
-
78
- file_name, file_ext = os.path.splitext(os.path.basename(file.name))
79
- file_name = safe_filename(file_name)
80
- subtitle, file_path = self.generate_and_write_file(
81
- file_name=file_name,
82
- transcribed_segments=result,
83
- add_timestamp=add_timestamp,
84
- file_format=file_format
85
- )
86
- files_info[file_name] = {"subtitle": subtitle, "elapsed_time": elapsed_time, "path": file_path}
87
-
88
- total_result = ''
89
- total_time = 0
90
- for file_name, info in files_info.items():
91
- total_result += '------------------------------------\n'
92
- total_result += f'{file_name}\n\n'
93
- total_result += f"{info['subtitle']}"
94
- total_time += info["elapsed_time"]
95
-
96
- result_str = f"Done in {self.format_time(total_time)}! Subtitle is in the outputs folder.\n\n{total_result}"
97
- result_file_path = [info['path'] for info in files_info.values()]
98
-
99
- return [result_str, result_file_path]
100
- except Exception as e:
101
- print(f"Error transcribing file: {str(e)}")
102
- finally:
103
- self.release_cuda_memory()
104
- self.remove_input_files([file.name for file in files])
105
-
106
- def transcribe_youtube(self,
107
- youtube_link: str,
108
- file_format: str,
109
- add_timestamp: bool,
110
- progress=gr.Progress(),
111
- *whisper_params) -> list:
112
- """
113
- Write subtitle file from Youtube
114
-
115
- Parameters
116
- ----------
117
- youtube_link: str
118
- URL of the Youtube video to transcribe from gr.Textbox()
119
- file_format: str
120
- Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
121
- add_timestamp: bool
122
- Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
123
- progress: gr.Progress
124
- Indicator to show progress directly in gradio.
125
- *whisper_params: tuple
126
- Gradio components related to Whisper. see whisper_data_class.py for details.
127
-
128
- Returns
129
- ----------
130
- result_str:
131
- Result of transcription to return to gr.Textbox()
132
- result_file_path:
133
- Output file path to return to gr.Files()
134
- """
135
- try:
136
- progress(0, desc="Loading Audio from Youtube..")
137
- yt = get_ytdata(youtube_link)
138
- audio = whisper.load_audio(get_ytaudio(yt))
139
-
140
- result, elapsed_time = self.transcribe(audio,
141
- progress,
142
- *whisper_params)
143
- progress(1, desc="Completed!")
144
-
145
- file_name = safe_filename(yt.title)
146
- subtitle, result_file_path = self.generate_and_write_file(
147
- file_name=file_name,
148
- transcribed_segments=result,
149
- add_timestamp=add_timestamp,
150
- file_format=file_format
151
- )
152
-
153
- result_str = f"Done in {self.format_time(elapsed_time)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
154
- return [result_str, result_file_path]
155
- except Exception as e:
156
- print(f"Error transcribing youtube video: {str(e)}")
157
- finally:
158
- try:
159
- if 'yt' not in locals():
160
- yt = get_ytdata(youtube_link)
161
- file_path = get_ytaudio(yt)
162
- else:
163
- file_path = get_ytaudio(yt)
164
-
165
- self.release_cuda_memory()
166
- self.remove_input_files([file_path])
167
- except Exception as cleanup_error:
168
- pass
169
-
170
- def transcribe_mic(self,
171
- mic_audio: str,
172
- file_format: str,
173
- progress=gr.Progress(),
174
- *whisper_params) -> list:
175
- """
176
- Write subtitle file from microphone
177
-
178
- Parameters
179
- ----------
180
- mic_audio: str
181
- Audio file path from gr.Microphone()
182
- file_format: str
183
- Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
184
- progress: gr.Progress
185
- Indicator to show progress directly in gradio.
186
- *whisper_params: tuple
187
- Gradio components related to Whisper. see whisper_data_class.py for details.
188
-
189
- Returns
190
- ----------
191
- result_str:
192
- Result of transcription to return to gr.Textbox()
193
- result_file_path:
194
- Output file path to return to gr.Files()
195
- """
196
- try:
197
- progress(0, desc="Loading Audio..")
198
- result, elapsed_time = self.transcribe(
199
- mic_audio,
200
- progress,
201
- *whisper_params,
202
- )
203
- progress(1, desc="Completed!")
204
-
205
- subtitle, result_file_path = self.generate_and_write_file(
206
- file_name="Mic",
207
- transcribed_segments=result,
208
- add_timestamp=True,
209
- file_format=file_format
210
- )
211
-
212
- result_str = f"Done in {self.format_time(elapsed_time)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
213
- return [result_str, result_file_path]
214
- except Exception as e:
215
- print(f"Error transcribing mic: {str(e)}")
216
- finally:
217
- self.release_cuda_memory()
218
- self.remove_input_files([mic_audio])
219
 
220
  def transcribe(self,
221
  audio: Union[str, np.ndarray, torch.Tensor],
@@ -259,7 +58,7 @@ class WhisperInference(BaseInterface):
259
  beam_size=params.beam_size,
260
  logprob_threshold=params.log_prob_threshold,
261
  no_speech_threshold=params.no_speech_threshold,
262
- task="translate" if params.is_translate and self.current_model_size in self.translatable_model else "transcribe",
263
  fp16=True if params.compute_type == "float16" else False,
264
  best_of=params.best_of,
265
  patience=params.patience,
@@ -295,80 +94,4 @@ class WhisperInference(BaseInterface):
295
  name=model_size,
296
  device=self.device,
297
  download_root=self.model_dir
298
- )
299
-
300
- @staticmethod
301
- def generate_and_write_file(file_name: str,
302
- transcribed_segments: list,
303
- add_timestamp: bool,
304
- file_format: str,
305
- ) -> str:
306
- """
307
- Writes subtitle file
308
-
309
- Parameters
310
- ----------
311
- file_name: str
312
- Output file name
313
- transcribed_segments: list
314
- Text segments transcribed from audio
315
- add_timestamp: bool
316
- Determines whether to add a timestamp to the end of the filename.
317
- file_format: str
318
- File format to write. Supported formats: [SRT, WebVTT, txt]
319
-
320
- Returns
321
- ----------
322
- content: str
323
- Result of the transcription
324
- output_path: str
325
- output file path
326
- """
327
- timestamp = datetime.now().strftime("%m%d%H%M%S")
328
- if add_timestamp:
329
- output_path = os.path.join("outputs", f"{file_name}-{timestamp}")
330
- else:
331
- output_path = os.path.join("outputs", f"{file_name}")
332
-
333
- if file_format == "SRT":
334
- content = get_srt(transcribed_segments)
335
- output_path += '.srt'
336
- write_file(content, output_path)
337
-
338
- elif file_format == "WebVTT":
339
- content = get_vtt(transcribed_segments)
340
- output_path += '.vtt'
341
- write_file(content, output_path)
342
-
343
- elif file_format == "txt":
344
- content = get_txt(transcribed_segments)
345
- output_path += '.txt'
346
- write_file(content, output_path)
347
- return content, output_path
348
-
349
- @staticmethod
350
- def format_time(elapsed_time: float) -> str:
351
- """
352
- Get {hours} {minutes} {seconds} time format string
353
-
354
- Parameters
355
- ----------
356
- elapsed_time: str
357
- Elapsed time for transcription
358
-
359
- Returns
360
- ----------
361
- Time format string
362
- """
363
- hours, rem = divmod(elapsed_time, 3600)
364
- minutes, seconds = divmod(rem, 60)
365
-
366
- time_str = ""
367
- if hours:
368
- time_str += f"{hours} hours "
369
- if minutes:
370
- time_str += f"{minutes} minutes "
371
- seconds = round(seconds)
372
- time_str += f"{seconds} seconds"
373
-
374
- return time_str.strip()
 
4
  import os
5
  from typing import BinaryIO, Union, Tuple, List
6
  import numpy as np
 
7
  import torch
8
 
9
+ from modules.whisper_base import WhisperBase
 
 
10
  from modules.whisper_parameter import *
11
 
 
12
 
13
+ class WhisperInference(WhisperBase):
 
14
  def __init__(self):
15
+ super().__init__(
16
+ model_dir=os.path.join("models", "Whisper")
17
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def transcribe(self,
20
  audio: Union[str, np.ndarray, torch.Tensor],
 
58
  beam_size=params.beam_size,
59
  logprob_threshold=params.log_prob_threshold,
60
  no_speech_threshold=params.no_speech_threshold,
61
+ task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
62
  fp16=True if params.compute_type == "float16" else False,
63
  best_of=params.best_of,
64
  patience=params.patience,
 
94
  name=model_size,
95
  device=self.device,
96
  download_root=self.model_dir
97
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/whisper_base.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from typing import List
4
+ import whisper
5
+ import gradio as gr
6
+ from abc import ABC, abstractmethod
7
+ from typing import BinaryIO, Union, Tuple, List
8
+ import numpy as np
9
+ from datetime import datetime
10
+
11
+ from modules.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
12
+ from modules.youtube_manager import get_ytdata, get_ytaudio
13
+ from modules.whisper_parameter import *
14
+
15
+
16
+ class WhisperBase(ABC):
17
+ def __init__(self,
18
+ model_dir: str):
19
+ self.model = None
20
+ self.current_model_size = None
21
+ self.model_dir = model_dir
22
+ os.makedirs(self.model_dir, exist_ok=True)
23
+ self.available_models = whisper.available_models()
24
+ self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
25
+ self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
26
+ self.device = self.get_device()
27
+ self.available_compute_types = ["float16", "float32"]
28
+ self.current_compute_type = "float16" if self.device == "cuda" else "float32"
29
+
30
+ @abstractmethod
31
+ def transcribe(self,
32
+ audio: Union[str, BinaryIO, np.ndarray],
33
+ progress: gr.Progress,
34
+ *whisper_params,
35
+ ):
36
+ pass
37
+
38
+ @abstractmethod
39
+ def update_model(self,
40
+ model_size: str,
41
+ compute_type: str,
42
+ progress: gr.Progress
43
+ ):
44
+ pass
45
+
46
+ def transcribe_file(self,
47
+ files: list,
48
+ file_format: str,
49
+ add_timestamp: bool,
50
+ progress=gr.Progress(),
51
+ *whisper_params,
52
+ ) -> list:
53
+ """
54
+ Write subtitle file from Files
55
+
56
+ Parameters
57
+ ----------
58
+ files: list
59
+ List of files to transcribe from gr.Files()
60
+ file_format: str
61
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
62
+ add_timestamp: bool
63
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
64
+ progress: gr.Progress
65
+ Indicator to show progress directly in gradio.
66
+ *whisper_params: tuple
67
+ Gradio components related to Whisper. see whisper_data_class.py for details.
68
+
69
+ Returns
70
+ ----------
71
+ result_str:
72
+ Result of transcription to return to gr.Textbox()
73
+ result_file_path:
74
+ Output file path to return to gr.Files()
75
+ """
76
+ try:
77
+ files_info = {}
78
+ for file in files:
79
+ transcribed_segments, time_for_task = self.transcribe(
80
+ file.name,
81
+ progress,
82
+ *whisper_params,
83
+ )
84
+
85
+ file_name, file_ext = os.path.splitext(os.path.basename(file.name))
86
+ file_name = safe_filename(file_name)
87
+ subtitle, file_path = self.generate_and_write_file(
88
+ file_name=file_name,
89
+ transcribed_segments=transcribed_segments,
90
+ add_timestamp=add_timestamp,
91
+ file_format=file_format
92
+ )
93
+ files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path}
94
+
95
+ total_result = ''
96
+ total_time = 0
97
+ for file_name, info in files_info.items():
98
+ total_result += '------------------------------------\n'
99
+ total_result += f'{file_name}\n\n'
100
+ total_result += f'{info["subtitle"]}'
101
+ total_time += info["time_for_task"]
102
+
103
+ result_str = f"Done in {self.format_time(total_time)}! Subtitle is in the outputs folder.\n\n{total_result}"
104
+ result_file_path = [info['path'] for info in files_info.values()]
105
+
106
+ return [result_str, result_file_path]
107
+
108
+ except Exception as e:
109
+ print(f"Error transcribing file: {e}")
110
+ finally:
111
+ self.release_cuda_memory()
112
+ if not files:
113
+ self.remove_input_files([file.name for file in files])
114
+
115
+ def transcribe_mic(self,
116
+ mic_audio: str,
117
+ file_format: str,
118
+ progress=gr.Progress(),
119
+ *whisper_params,
120
+ ) -> list:
121
+ """
122
+ Write subtitle file from microphone
123
+
124
+ Parameters
125
+ ----------
126
+ mic_audio: str
127
+ Audio file path from gr.Microphone()
128
+ file_format: str
129
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
130
+ progress: gr.Progress
131
+ Indicator to show progress directly in gradio.
132
+ *whisper_params: tuple
133
+ Gradio components related to Whisper. see whisper_data_class.py for details.
134
+
135
+ Returns
136
+ ----------
137
+ result_str:
138
+ Result of transcription to return to gr.Textbox()
139
+ result_file_path:
140
+ Output file path to return to gr.Files()
141
+ """
142
+ try:
143
+ progress(0, desc="Loading Audio..")
144
+ transcribed_segments, time_for_task = self.transcribe(
145
+ mic_audio,
146
+ progress,
147
+ *whisper_params,
148
+ )
149
+ progress(1, desc="Completed!")
150
+
151
+ subtitle, result_file_path = self.generate_and_write_file(
152
+ file_name="Mic",
153
+ transcribed_segments=transcribed_segments,
154
+ add_timestamp=True,
155
+ file_format=file_format
156
+ )
157
+
158
+ result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
159
+ return [result_str, result_file_path]
160
+ except Exception as e:
161
+ print(f"Error transcribing file: {e}")
162
+ finally:
163
+ self.release_cuda_memory()
164
+ self.remove_input_files([mic_audio])
165
+
166
+ def transcribe_youtube(self,
167
+ youtube_link: str,
168
+ file_format: str,
169
+ add_timestamp: bool,
170
+ progress=gr.Progress(),
171
+ *whisper_params,
172
+ ) -> list:
173
+ """
174
+ Write subtitle file from Youtube
175
+
176
+ Parameters
177
+ ----------
178
+ youtube_link: str
179
+ URL of the Youtube video to transcribe from gr.Textbox()
180
+ file_format: str
181
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
182
+ add_timestamp: bool
183
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
184
+ progress: gr.Progress
185
+ Indicator to show progress directly in gradio.
186
+ *whisper_params: tuple
187
+ Gradio components related to Whisper. see whisper_data_class.py for details.
188
+
189
+ Returns
190
+ ----------
191
+ result_str:
192
+ Result of transcription to return to gr.Textbox()
193
+ result_file_path:
194
+ Output file path to return to gr.Files()
195
+ """
196
+ try:
197
+ progress(0, desc="Loading Audio from Youtube..")
198
+ yt = get_ytdata(youtube_link)
199
+ audio = get_ytaudio(yt)
200
+
201
+ transcribed_segments, time_for_task = self.transcribe(
202
+ audio,
203
+ progress,
204
+ *whisper_params,
205
+ )
206
+
207
+ progress(1, desc="Completed!")
208
+
209
+ file_name = safe_filename(yt.title)
210
+ subtitle, result_file_path = self.generate_and_write_file(
211
+ file_name=file_name,
212
+ transcribed_segments=transcribed_segments,
213
+ add_timestamp=add_timestamp,
214
+ file_format=file_format
215
+ )
216
+ result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
217
+
218
+ return [result_str, result_file_path]
219
+
220
+ except Exception as e:
221
+ print(f"Error transcribing file: {e}")
222
+ finally:
223
+ try:
224
+ if 'yt' not in locals():
225
+ yt = get_ytdata(youtube_link)
226
+ file_path = get_ytaudio(yt)
227
+ else:
228
+ file_path = get_ytaudio(yt)
229
+
230
+ self.release_cuda_memory()
231
+ self.remove_input_files([file_path])
232
+ except Exception as cleanup_error:
233
+ pass
234
+
235
+ @staticmethod
236
+ def generate_and_write_file(file_name: str,
237
+ transcribed_segments: list,
238
+ add_timestamp: bool,
239
+ file_format: str,
240
+ ) -> str:
241
+ """
242
+ Writes subtitle file
243
+
244
+ Parameters
245
+ ----------
246
+ file_name: str
247
+ Output file name
248
+ transcribed_segments: list
249
+ Text segments transcribed from audio
250
+ add_timestamp: bool
251
+ Determines whether to add a timestamp to the end of the filename.
252
+ file_format: str
253
+ File format to write. Supported formats: [SRT, WebVTT, txt]
254
+
255
+ Returns
256
+ ----------
257
+ content: str
258
+ Result of the transcription
259
+ output_path: str
260
+ output file path
261
+ """
262
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
263
+ if add_timestamp:
264
+ output_path = os.path.join("outputs", f"{file_name}-{timestamp}")
265
+ else:
266
+ output_path = os.path.join("outputs", f"{file_name}")
267
+
268
+ if file_format == "SRT":
269
+ content = get_srt(transcribed_segments)
270
+ output_path += '.srt'
271
+ write_file(content, output_path)
272
+
273
+ elif file_format == "WebVTT":
274
+ content = get_vtt(transcribed_segments)
275
+ output_path += '.vtt'
276
+ write_file(content, output_path)
277
+
278
+ elif file_format == "txt":
279
+ content = get_txt(transcribed_segments)
280
+ output_path += '.txt'
281
+ write_file(content, output_path)
282
+ return content, output_path
283
+
284
+ @staticmethod
285
+ def format_time(elapsed_time: float) -> str:
286
+ """
287
+ Get {hours} {minutes} {seconds} time format string
288
+
289
+ Parameters
290
+ ----------
291
+ elapsed_time: str
292
+ Elapsed time for transcription
293
+
294
+ Returns
295
+ ----------
296
+ Time format string
297
+ """
298
+ hours, rem = divmod(elapsed_time, 3600)
299
+ minutes, seconds = divmod(rem, 60)
300
+
301
+ time_str = ""
302
+ if hours:
303
+ time_str += f"{hours} hours "
304
+ if minutes:
305
+ time_str += f"{minutes} minutes "
306
+ seconds = round(seconds)
307
+ time_str += f"{seconds} seconds"
308
+
309
+ return time_str.strip()
310
+
311
+ @staticmethod
312
+ def get_device():
313
+ if torch.cuda.is_available():
314
+ return "cuda"
315
+ elif torch.backends.mps.is_available():
316
+ return "mps"
317
+ else:
318
+ return "cpu"
319
+
320
+ @staticmethod
321
+ def release_cuda_memory():
322
+ if torch.cuda.is_available():
323
+ torch.cuda.empty_cache()
324
+ torch.cuda.reset_max_memory_allocated()
325
+
326
+ @staticmethod
327
+ def remove_input_files(file_paths: List[str]):
328
+ if not file_paths:
329
+ return
330
+
331
+ for file_path in file_paths:
332
+ if file_path and os.path.exists(file_path):
333
+ os.remove(file_path)
user-start-webui.bat CHANGED
@@ -8,8 +8,8 @@ set USERNAME=
8
  set PASSWORD=
9
  set SHARE=
10
  set THEME=
11
- set DISABLE_FASTER_WHISPER=
12
  set API_OPEN=
 
13
  set WHISPER_MODEL_DIR=
14
  set FASTER_WHISPER_MODEL_DIR=
15
 
@@ -38,6 +38,9 @@ if /I "%DISABLE_FASTER_WHISPER%"=="true" (
38
  if /I "%API_OPEN%"=="true" (
39
  set API_OPEN=--api_open
40
  )
 
 
 
41
  if not "%WHISPER_MODEL_DIR%"=="" (
42
  set WHISPER_MODEL_DIR_ARG=--whisper_model_dir "%WHISPER_MODEL_DIR%"
43
  )
@@ -46,5 +49,5 @@ if not "%FASTER_WHISPER_MODEL_DIR%"=="" (
46
  )
47
 
48
  :: Call the original .bat script with optional arguments
49
- start-webui.bat %SERVER_NAME_ARG% %SERVER_PORT_ARG% %USERNAME_ARG% %PASSWORD_ARG% %SHARE_ARG% %THEME_ARG% %DISABLE_FASTER_WHISPER_ARG% %API_OPEN% %WHISPER_MODEL_DIR_ARG% %FASTER_WHISPER_MODEL_DIR_ARG%
50
  pause
 
8
  set PASSWORD=
9
  set SHARE=
10
  set THEME=
 
11
  set API_OPEN=
12
+ set WHISPER_TYPE=
13
  set WHISPER_MODEL_DIR=
14
  set FASTER_WHISPER_MODEL_DIR=
15
 
 
38
  if /I "%API_OPEN%"=="true" (
39
  set API_OPEN=--api_open
40
  )
41
+ if not "%WHISPER_TYPE%"=="" (
42
+ set WHISPER_TYPE_ARG=--whisper_type %WHISPER_TYPE%
43
+ )
44
  if not "%WHISPER_MODEL_DIR%"=="" (
45
  set WHISPER_MODEL_DIR_ARG=--whisper_model_dir "%WHISPER_MODEL_DIR%"
46
  )
 
49
  )
50
 
51
  :: Call the original .bat script with optional arguments
52
+ start-webui.bat %SERVER_NAME_ARG% %SERVER_PORT_ARG% %USERNAME_ARG% %PASSWORD_ARG% %SHARE_ARG% %THEME_ARG% %API_OPEN% %WHISPER_TYPE_ARG% %WHISPER_MODEL_DIR_ARG% %FASTER_WHISPER_MODEL_DIR_ARG%
53
  pause