Spaces:
Running
Running
jhj0517
commited on
Commit
·
a85dc3c
1
Parent(s):
6a7425a
Add defaults to the function
Browse files- modules/whisper/whisper_base.py +13 -12
modules/whisper/whisper_base.py
CHANGED
@@ -53,7 +53,7 @@ class WhisperBase(ABC):
|
|
53 |
@abstractmethod
|
54 |
def transcribe(self,
|
55 |
audio: Union[str, BinaryIO, np.ndarray],
|
56 |
-
progress: gr.Progress,
|
57 |
*whisper_params,
|
58 |
):
|
59 |
"""Inference whisper model to transcribe"""
|
@@ -63,7 +63,7 @@ class WhisperBase(ABC):
|
|
63 |
def update_model(self,
|
64 |
model_size: str,
|
65 |
compute_type: str,
|
66 |
-
progress: gr.Progress
|
67 |
):
|
68 |
"""Initialize whisper model"""
|
69 |
pass
|
@@ -171,10 +171,10 @@ class WhisperBase(ABC):
|
|
171 |
return result, elapsed_time
|
172 |
|
173 |
def transcribe_file(self,
|
174 |
-
files:
|
175 |
-
input_folder_path: str,
|
176 |
-
file_format: str,
|
177 |
-
add_timestamp: bool,
|
178 |
progress=gr.Progress(),
|
179 |
*whisper_params,
|
180 |
) -> list:
|
@@ -250,8 +250,8 @@ class WhisperBase(ABC):
|
|
250 |
|
251 |
def transcribe_mic(self,
|
252 |
mic_audio: str,
|
253 |
-
file_format: str,
|
254 |
-
add_timestamp: bool,
|
255 |
progress=gr.Progress(),
|
256 |
*whisper_params,
|
257 |
) -> list:
|
@@ -306,8 +306,8 @@ class WhisperBase(ABC):
|
|
306 |
|
307 |
def transcribe_youtube(self,
|
308 |
youtube_link: str,
|
309 |
-
file_format: str,
|
310 |
-
add_timestamp: bool,
|
311 |
progress=gr.Progress(),
|
312 |
*whisper_params,
|
313 |
) -> list:
|
@@ -411,11 +411,12 @@ class WhisperBase(ABC):
|
|
411 |
else:
|
412 |
output_path = os.path.join(output_dir, f"{file_name}")
|
413 |
|
414 |
-
|
|
|
415 |
content = get_srt(transcribed_segments)
|
416 |
output_path += '.srt'
|
417 |
|
418 |
-
elif file_format == "
|
419 |
content = get_vtt(transcribed_segments)
|
420 |
output_path += '.vtt'
|
421 |
|
|
|
53 |
@abstractmethod
|
54 |
def transcribe(self,
|
55 |
audio: Union[str, BinaryIO, np.ndarray],
|
56 |
+
progress: gr.Progress = gr.Progress(),
|
57 |
*whisper_params,
|
58 |
):
|
59 |
"""Inference whisper model to transcribe"""
|
|
|
63 |
def update_model(self,
|
64 |
model_size: str,
|
65 |
compute_type: str,
|
66 |
+
progress: gr.Progress = gr.Progress()
|
67 |
):
|
68 |
"""Initialize whisper model"""
|
69 |
pass
|
|
|
171 |
return result, elapsed_time
|
172 |
|
173 |
def transcribe_file(self,
|
174 |
+
files: Optional[List] = None,
|
175 |
+
input_folder_path: Optional[str] = None,
|
176 |
+
file_format: str = "SRT",
|
177 |
+
add_timestamp: bool = True,
|
178 |
progress=gr.Progress(),
|
179 |
*whisper_params,
|
180 |
) -> list:
|
|
|
250 |
|
251 |
def transcribe_mic(self,
|
252 |
mic_audio: str,
|
253 |
+
file_format: str = "SRT",
|
254 |
+
add_timestamp: bool = True,
|
255 |
progress=gr.Progress(),
|
256 |
*whisper_params,
|
257 |
) -> list:
|
|
|
306 |
|
307 |
def transcribe_youtube(self,
|
308 |
youtube_link: str,
|
309 |
+
file_format: str = "SRT",
|
310 |
+
add_timestamp: bool = True,
|
311 |
progress=gr.Progress(),
|
312 |
*whisper_params,
|
313 |
) -> list:
|
|
|
411 |
else:
|
412 |
output_path = os.path.join(output_dir, f"{file_name}")
|
413 |
|
414 |
+
file_format = file_format.strip().lower()
|
415 |
+
if file_format == "srt":
|
416 |
content = get_srt(transcribed_segments)
|
417 |
output_path += '.srt'
|
418 |
|
419 |
+
elif file_format == "webvtt":
|
420 |
content = get_vtt(transcribed_segments)
|
421 |
output_path += '.vtt'
|
422 |
|