jhj0517 commited on
Commit
b72fd8a
·
1 Parent(s): 81509c3

refactor base abstract class for whisper

Browse files
Files changed (2) hide show
  1. modules/base_interface.py +0 -23
  2. modules/whisper_base.py +333 -0
modules/base_interface.py DELETED
@@ -1,23 +0,0 @@
1
- import os
2
- import torch
3
- from typing import List
4
-
5
-
6
- class BaseInterface:
7
- def __init__(self):
8
- pass
9
-
10
- @staticmethod
11
- def release_cuda_memory():
12
- if torch.cuda.is_available():
13
- torch.cuda.empty_cache()
14
- torch.cuda.reset_max_memory_allocated()
15
-
16
- @staticmethod
17
- def remove_input_files(file_paths: List[str]):
18
- if not file_paths:
19
- return
20
-
21
- for file_path in file_paths:
22
- if file_path and os.path.exists(file_path):
23
- os.remove(file_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/whisper_base.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from typing import List
4
+ import whisper
5
+ import gradio as gr
6
+ from abc import ABC, abstractmethod
7
+ from typing import BinaryIO, Union, Tuple, List
8
+ import numpy as np
9
+ from datetime import datetime
10
+
11
+ from modules.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
12
+ from modules.youtube_manager import get_ytdata, get_ytaudio
13
+ from modules.whisper_parameter import *
14
+
15
+
16
+ class WhisperBase(ABC):
17
+ def __init__(self,
18
+ model_dir: str):
19
+ self.model = None
20
+ self.current_model_size = None
21
+ self.model_dir = model_dir
22
+ os.makedirs(self.model_dir, exist_ok=True)
23
+ self.available_models = whisper.available_models()
24
+ self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
25
+ self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
26
+ self.device = self.get_device()
27
+ self.available_compute_types = ["float16", "float32"]
28
+ self.current_compute_type = "float16" if self.device == "cuda" else "float32"
29
+
30
+ @abstractmethod
31
+ def transcribe(self,
32
+ audio: Union[str, BinaryIO, np.ndarray],
33
+ progress: gr.Progress,
34
+ *whisper_params,
35
+ ):
36
+ pass
37
+
38
+ @abstractmethod
39
+ def update_model(self,
40
+ model_size: str,
41
+ compute_type: str,
42
+ progress: gr.Progress
43
+ ):
44
+ pass
45
+
46
+ def transcribe_file(self,
47
+ files: list,
48
+ file_format: str,
49
+ add_timestamp: bool,
50
+ progress=gr.Progress(),
51
+ *whisper_params,
52
+ ) -> list:
53
+ """
54
+ Write subtitle file from Files
55
+
56
+ Parameters
57
+ ----------
58
+ files: list
59
+ List of files to transcribe from gr.Files()
60
+ file_format: str
61
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
62
+ add_timestamp: bool
63
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
64
+ progress: gr.Progress
65
+ Indicator to show progress directly in gradio.
66
+ *whisper_params: tuple
67
+ Gradio components related to Whisper. see whisper_data_class.py for details.
68
+
69
+ Returns
70
+ ----------
71
+ result_str:
72
+ Result of transcription to return to gr.Textbox()
73
+ result_file_path:
74
+ Output file path to return to gr.Files()
75
+ """
76
+ try:
77
+ files_info = {}
78
+ for file in files:
79
+ transcribed_segments, time_for_task = self.transcribe(
80
+ file.name,
81
+ progress,
82
+ *whisper_params,
83
+ )
84
+
85
+ file_name, file_ext = os.path.splitext(os.path.basename(file.name))
86
+ file_name = safe_filename(file_name)
87
+ subtitle, file_path = self.generate_and_write_file(
88
+ file_name=file_name,
89
+ transcribed_segments=transcribed_segments,
90
+ add_timestamp=add_timestamp,
91
+ file_format=file_format
92
+ )
93
+ files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path}
94
+
95
+ total_result = ''
96
+ total_time = 0
97
+ for file_name, info in files_info.items():
98
+ total_result += '------------------------------------\n'
99
+ total_result += f'{file_name}\n\n'
100
+ total_result += f'{info["subtitle"]}'
101
+ total_time += info["time_for_task"]
102
+
103
+ result_str = f"Done in {self.format_time(total_time)}! Subtitle is in the outputs folder.\n\n{total_result}"
104
+ result_file_path = [info['path'] for info in files_info.values()]
105
+
106
+ return [result_str, result_file_path]
107
+
108
+ except Exception as e:
109
+ print(f"Error transcribing file: {e}")
110
+ finally:
111
+ self.release_cuda_memory()
112
+ if not files:
113
+ self.remove_input_files([file.name for file in files])
114
+
115
+ def transcribe_mic(self,
116
+ mic_audio: str,
117
+ file_format: str,
118
+ progress=gr.Progress(),
119
+ *whisper_params,
120
+ ) -> list:
121
+ """
122
+ Write subtitle file from microphone
123
+
124
+ Parameters
125
+ ----------
126
+ mic_audio: str
127
+ Audio file path from gr.Microphone()
128
+ file_format: str
129
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
130
+ progress: gr.Progress
131
+ Indicator to show progress directly in gradio.
132
+ *whisper_params: tuple
133
+ Gradio components related to Whisper. see whisper_data_class.py for details.
134
+
135
+ Returns
136
+ ----------
137
+ result_str:
138
+ Result of transcription to return to gr.Textbox()
139
+ result_file_path:
140
+ Output file path to return to gr.Files()
141
+ """
142
+ try:
143
+ progress(0, desc="Loading Audio..")
144
+ transcribed_segments, time_for_task = self.transcribe(
145
+ mic_audio,
146
+ progress,
147
+ *whisper_params,
148
+ )
149
+ progress(1, desc="Completed!")
150
+
151
+ subtitle, result_file_path = self.generate_and_write_file(
152
+ file_name="Mic",
153
+ transcribed_segments=transcribed_segments,
154
+ add_timestamp=True,
155
+ file_format=file_format
156
+ )
157
+
158
+ result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
159
+ return [result_str, result_file_path]
160
+ except Exception as e:
161
+ print(f"Error transcribing file: {e}")
162
+ finally:
163
+ self.release_cuda_memory()
164
+ self.remove_input_files([mic_audio])
165
+
166
+ def transcribe_youtube(self,
167
+ youtube_link: str,
168
+ file_format: str,
169
+ add_timestamp: bool,
170
+ progress=gr.Progress(),
171
+ *whisper_params,
172
+ ) -> list:
173
+ """
174
+ Write subtitle file from Youtube
175
+
176
+ Parameters
177
+ ----------
178
+ youtube_link: str
179
+ URL of the Youtube video to transcribe from gr.Textbox()
180
+ file_format: str
181
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
182
+ add_timestamp: bool
183
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
184
+ progress: gr.Progress
185
+ Indicator to show progress directly in gradio.
186
+ *whisper_params: tuple
187
+ Gradio components related to Whisper. see whisper_data_class.py for details.
188
+
189
+ Returns
190
+ ----------
191
+ result_str:
192
+ Result of transcription to return to gr.Textbox()
193
+ result_file_path:
194
+ Output file path to return to gr.Files()
195
+ """
196
+ try:
197
+ progress(0, desc="Loading Audio from Youtube..")
198
+ yt = get_ytdata(youtube_link)
199
+ audio = get_ytaudio(yt)
200
+
201
+ transcribed_segments, time_for_task = self.transcribe(
202
+ audio,
203
+ progress,
204
+ *whisper_params,
205
+ )
206
+
207
+ progress(1, desc="Completed!")
208
+
209
+ file_name = safe_filename(yt.title)
210
+ subtitle, result_file_path = self.generate_and_write_file(
211
+ file_name=file_name,
212
+ transcribed_segments=transcribed_segments,
213
+ add_timestamp=add_timestamp,
214
+ file_format=file_format
215
+ )
216
+ result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
217
+
218
+ return [result_str, result_file_path]
219
+
220
+ except Exception as e:
221
+ print(f"Error transcribing file: {e}")
222
+ finally:
223
+ try:
224
+ if 'yt' not in locals():
225
+ yt = get_ytdata(youtube_link)
226
+ file_path = get_ytaudio(yt)
227
+ else:
228
+ file_path = get_ytaudio(yt)
229
+
230
+ self.release_cuda_memory()
231
+ self.remove_input_files([file_path])
232
+ except Exception as cleanup_error:
233
+ pass
234
+
235
+ @staticmethod
236
+ def generate_and_write_file(file_name: str,
237
+ transcribed_segments: list,
238
+ add_timestamp: bool,
239
+ file_format: str,
240
+ ) -> str:
241
+ """
242
+ Writes subtitle file
243
+
244
+ Parameters
245
+ ----------
246
+ file_name: str
247
+ Output file name
248
+ transcribed_segments: list
249
+ Text segments transcribed from audio
250
+ add_timestamp: bool
251
+ Determines whether to add a timestamp to the end of the filename.
252
+ file_format: str
253
+ File format to write. Supported formats: [SRT, WebVTT, txt]
254
+
255
+ Returns
256
+ ----------
257
+ content: str
258
+ Result of the transcription
259
+ output_path: str
260
+ output file path
261
+ """
262
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
263
+ if add_timestamp:
264
+ output_path = os.path.join("outputs", f"{file_name}-{timestamp}")
265
+ else:
266
+ output_path = os.path.join("outputs", f"{file_name}")
267
+
268
+ if file_format == "SRT":
269
+ content = get_srt(transcribed_segments)
270
+ output_path += '.srt'
271
+ write_file(content, output_path)
272
+
273
+ elif file_format == "WebVTT":
274
+ content = get_vtt(transcribed_segments)
275
+ output_path += '.vtt'
276
+ write_file(content, output_path)
277
+
278
+ elif file_format == "txt":
279
+ content = get_txt(transcribed_segments)
280
+ output_path += '.txt'
281
+ write_file(content, output_path)
282
+ return content, output_path
283
+
284
+ @staticmethod
285
+ def format_time(elapsed_time: float) -> str:
286
+ """
287
+ Get {hours} {minutes} {seconds} time format string
288
+
289
+ Parameters
290
+ ----------
291
+ elapsed_time: str
292
+ Elapsed time for transcription
293
+
294
+ Returns
295
+ ----------
296
+ Time format string
297
+ """
298
+ hours, rem = divmod(elapsed_time, 3600)
299
+ minutes, seconds = divmod(rem, 60)
300
+
301
+ time_str = ""
302
+ if hours:
303
+ time_str += f"{hours} hours "
304
+ if minutes:
305
+ time_str += f"{minutes} minutes "
306
+ seconds = round(seconds)
307
+ time_str += f"{seconds} seconds"
308
+
309
+ return time_str.strip()
310
+
311
+ @staticmethod
312
+ def get_device():
313
+ if torch.cuda.is_available():
314
+ return "cuda"
315
+ elif torch.backends.mps.is_available():
316
+ return "mps"
317
+ else:
318
+ return "cpu"
319
+
320
+ @staticmethod
321
+ def release_cuda_memory():
322
+ if torch.cuda.is_available():
323
+ torch.cuda.empty_cache()
324
+ torch.cuda.reset_max_memory_allocated()
325
+
326
+ @staticmethod
327
+ def remove_input_files(file_paths: List[str]):
328
+ if not file_paths:
329
+ return
330
+
331
+ for file_path in file_paths:
332
+ if file_path and os.path.exists(file_path):
333
+ os.remove(file_path)