Spaces:
Running
Running
jhj0517
commited on
Commit
·
4b52dfd
1
Parent(s):
a526073
refactor docstring
Browse files
modules/faster_whisper_inference.py
CHANGED
@@ -32,7 +32,7 @@ class FasterWhisperInference(BaseInterface):
|
|
32 |
self.default_beam_size = 1
|
33 |
|
34 |
def transcribe_file(self,
|
35 |
-
|
36 |
file_format: str,
|
37 |
add_timestamp: bool,
|
38 |
progress=gr.Progress(),
|
@@ -43,7 +43,7 @@ class FasterWhisperInference(BaseInterface):
|
|
43 |
|
44 |
Parameters
|
45 |
----------
|
46 |
-
|
47 |
List of files to transcribe from gr.Files()
|
48 |
file_format: str
|
49 |
Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
|
@@ -56,20 +56,21 @@ class FasterWhisperInference(BaseInterface):
|
|
56 |
|
57 |
Returns
|
58 |
----------
|
59 |
-
|
60 |
-
|
61 |
-
|
|
|
62 |
"""
|
63 |
try:
|
64 |
files_info = {}
|
65 |
-
for
|
66 |
transcribed_segments, time_for_task = self.transcribe(
|
67 |
-
|
68 |
progress,
|
69 |
*whisper_params,
|
70 |
)
|
71 |
|
72 |
-
file_name, file_ext = os.path.splitext(os.path.basename(
|
73 |
file_name = safe_filename(file_name)
|
74 |
subtitle, file_path = self.generate_and_write_file(
|
75 |
file_name=file_name,
|
@@ -96,8 +97,8 @@ class FasterWhisperInference(BaseInterface):
|
|
96 |
print(f"Error transcribing file on line {e}")
|
97 |
finally:
|
98 |
self.release_cuda_memory()
|
99 |
-
if not
|
100 |
-
self.remove_input_files([
|
101 |
|
102 |
def transcribe_youtube(self,
|
103 |
youtube_link: str,
|
@@ -124,9 +125,10 @@ class FasterWhisperInference(BaseInterface):
|
|
124 |
|
125 |
Returns
|
126 |
----------
|
127 |
-
|
128 |
-
|
129 |
-
|
|
|
130 |
"""
|
131 |
try:
|
132 |
progress(0, desc="Loading Audio from Youtube..")
|
@@ -142,15 +144,15 @@ class FasterWhisperInference(BaseInterface):
|
|
142 |
progress(1, desc="Completed!")
|
143 |
|
144 |
file_name = safe_filename(yt.title)
|
145 |
-
subtitle,
|
146 |
file_name=file_name,
|
147 |
transcribed_segments=transcribed_segments,
|
148 |
add_timestamp=add_timestamp,
|
149 |
file_format=file_format
|
150 |
)
|
151 |
-
|
152 |
|
153 |
-
return [
|
154 |
|
155 |
except Exception as e:
|
156 |
print(f"Error transcribing file on line {e}")
|
@@ -189,9 +191,10 @@ class FasterWhisperInference(BaseInterface):
|
|
189 |
|
190 |
Returns
|
191 |
----------
|
192 |
-
|
193 |
-
|
194 |
-
|
|
|
195 |
"""
|
196 |
try:
|
197 |
progress(0, desc="Loading Audio..")
|
@@ -202,15 +205,15 @@ class FasterWhisperInference(BaseInterface):
|
|
202 |
)
|
203 |
progress(1, desc="Completed!")
|
204 |
|
205 |
-
subtitle,
|
206 |
file_name="Mic",
|
207 |
transcribed_segments=transcribed_segments,
|
208 |
add_timestamp=True,
|
209 |
file_format=file_format
|
210 |
)
|
211 |
|
212 |
-
|
213 |
-
return [
|
214 |
except Exception as e:
|
215 |
print(f"Error transcribing file on line {e}")
|
216 |
finally:
|
@@ -282,7 +285,17 @@ class FasterWhisperInference(BaseInterface):
|
|
282 |
progress: gr.Progress
|
283 |
):
|
284 |
"""
|
285 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
286 |
"""
|
287 |
progress(0, desc="Initializing Model..")
|
288 |
self.current_model_size = model_size
|
@@ -301,7 +314,26 @@ class FasterWhisperInference(BaseInterface):
|
|
301 |
file_format: str,
|
302 |
) -> str:
|
303 |
"""
|
304 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
305 |
"""
|
306 |
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
307 |
if add_timestamp:
|
@@ -327,6 +359,18 @@ class FasterWhisperInference(BaseInterface):
|
|
327 |
|
328 |
@staticmethod
|
329 |
def format_time(elapsed_time: float) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
330 |
hours, rem = divmod(elapsed_time, 3600)
|
331 |
minutes, seconds = divmod(rem, 60)
|
332 |
|
|
|
32 |
self.default_beam_size = 1
|
33 |
|
34 |
def transcribe_file(self,
|
35 |
+
files: list,
|
36 |
file_format: str,
|
37 |
add_timestamp: bool,
|
38 |
progress=gr.Progress(),
|
|
|
43 |
|
44 |
Parameters
|
45 |
----------
|
46 |
+
files: list
|
47 |
List of files to transcribe from gr.Files()
|
48 |
file_format: str
|
49 |
Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
|
|
|
56 |
|
57 |
Returns
|
58 |
----------
|
59 |
+
result_str:
|
60 |
+
Result of transcription to return to gr.Textbox()
|
61 |
+
result_file_path:
|
62 |
+
Output file path to return to gr.Files()
|
63 |
"""
|
64 |
try:
|
65 |
files_info = {}
|
66 |
+
for file in files:
|
67 |
transcribed_segments, time_for_task = self.transcribe(
|
68 |
+
file.name,
|
69 |
progress,
|
70 |
*whisper_params,
|
71 |
)
|
72 |
|
73 |
+
file_name, file_ext = os.path.splitext(os.path.basename(file.name))
|
74 |
file_name = safe_filename(file_name)
|
75 |
subtitle, file_path = self.generate_and_write_file(
|
76 |
file_name=file_name,
|
|
|
97 |
print(f"Error transcribing file on line {e}")
|
98 |
finally:
|
99 |
self.release_cuda_memory()
|
100 |
+
if not files:
|
101 |
+
self.remove_input_files([file.name for file in files])
|
102 |
|
103 |
def transcribe_youtube(self,
|
104 |
youtube_link: str,
|
|
|
125 |
|
126 |
Returns
|
127 |
----------
|
128 |
+
result_str:
|
129 |
+
Result of transcription to return to gr.Textbox()
|
130 |
+
result_file_path:
|
131 |
+
Output file path to return to gr.Files()
|
132 |
"""
|
133 |
try:
|
134 |
progress(0, desc="Loading Audio from Youtube..")
|
|
|
144 |
progress(1, desc="Completed!")
|
145 |
|
146 |
file_name = safe_filename(yt.title)
|
147 |
+
subtitle, result_file_path = self.generate_and_write_file(
|
148 |
file_name=file_name,
|
149 |
transcribed_segments=transcribed_segments,
|
150 |
add_timestamp=add_timestamp,
|
151 |
file_format=file_format
|
152 |
)
|
153 |
+
result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
|
154 |
|
155 |
+
return [result_str, result_file_path]
|
156 |
|
157 |
except Exception as e:
|
158 |
print(f"Error transcribing file on line {e}")
|
|
|
191 |
|
192 |
Returns
|
193 |
----------
|
194 |
+
result_str:
|
195 |
+
Result of transcription to return to gr.Textbox()
|
196 |
+
result_file_path:
|
197 |
+
Output file path to return to gr.Files()
|
198 |
"""
|
199 |
try:
|
200 |
progress(0, desc="Loading Audio..")
|
|
|
205 |
)
|
206 |
progress(1, desc="Completed!")
|
207 |
|
208 |
+
subtitle, result_file_path = self.generate_and_write_file(
|
209 |
file_name="Mic",
|
210 |
transcribed_segments=transcribed_segments,
|
211 |
add_timestamp=True,
|
212 |
file_format=file_format
|
213 |
)
|
214 |
|
215 |
+
result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
|
216 |
+
return [result_str, result_file_path]
|
217 |
except Exception as e:
|
218 |
print(f"Error transcribing file on line {e}")
|
219 |
finally:
|
|
|
285 |
progress: gr.Progress
|
286 |
):
|
287 |
"""
|
288 |
+
Update current model setting
|
289 |
+
|
290 |
+
Parameters
|
291 |
+
----------
|
292 |
+
model_size: str
|
293 |
+
Size of whisper model
|
294 |
+
compute_type: str
|
295 |
+
Compute type for transcription.
|
296 |
+
see more info : https://opennmt.net/CTranslate2/quantization.html
|
297 |
+
progress: gr.Progress
|
298 |
+
Indicator to show progress directly in gradio.
|
299 |
"""
|
300 |
progress(0, desc="Initializing Model..")
|
301 |
self.current_model_size = model_size
|
|
|
314 |
file_format: str,
|
315 |
) -> str:
|
316 |
"""
|
317 |
+
Writes subtitle file and returns str of content and output file path
|
318 |
+
|
319 |
+
Parameters
|
320 |
+
----------
|
321 |
+
file_name: str
|
322 |
+
Size of whisper model
|
323 |
+
transcribed_segments: str
|
324 |
+
Compute type for transcription.
|
325 |
+
see more info : https://opennmt.net/CTranslate2/quantization.html
|
326 |
+
add_timestamp: bool
|
327 |
+
Determines whether to add a timestamp to the end of the filename.
|
328 |
+
file_format: str
|
329 |
+
File format to write. Supported formats: [SRT, WebVTT, txt]
|
330 |
+
|
331 |
+
Returns
|
332 |
+
----------
|
333 |
+
content: str
|
334 |
+
Result of the transcription
|
335 |
+
output_path: str
|
336 |
+
output file path
|
337 |
"""
|
338 |
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
339 |
if add_timestamp:
|
|
|
359 |
|
360 |
@staticmethod
|
361 |
def format_time(elapsed_time: float) -> str:
|
362 |
+
"""
|
363 |
+
Get {hours} {minutes} {seconds} time format string
|
364 |
+
|
365 |
+
Parameters
|
366 |
+
----------
|
367 |
+
elapsed_time: str
|
368 |
+
Elapsed time for transcription
|
369 |
+
|
370 |
+
Returns
|
371 |
+
----------
|
372 |
+
Time format string
|
373 |
+
"""
|
374 |
hours, rem = divmod(elapsed_time, 3600)
|
375 |
minutes, seconds = divmod(rem, 60)
|
376 |
|