Spaces:
Running
Running
Merge pull request #290 from jhj0517/fix/defaults
Browse files
modules/translation/nllb_inference.py
CHANGED
@@ -35,7 +35,7 @@ class NLLBInference(TranslationBase):
|
|
35 |
model_size: str,
|
36 |
src_lang: str,
|
37 |
tgt_lang: str,
|
38 |
-
progress: gr.Progress
|
39 |
):
|
40 |
if model_size != self.current_model_size or self.model is None:
|
41 |
print("\nInitializing NLLB Model..\n")
|
|
|
35 |
model_size: str,
|
36 |
src_lang: str,
|
37 |
tgt_lang: str,
|
38 |
+
progress: gr.Progress = gr.Progress()
|
39 |
):
|
40 |
if model_size != self.current_model_size or self.model is None:
|
41 |
print("\nInitializing NLLB Model..\n")
|
modules/translation/translation_base.py
CHANGED
@@ -37,7 +37,7 @@ class TranslationBase(ABC):
|
|
37 |
model_size: str,
|
38 |
src_lang: str,
|
39 |
tgt_lang: str,
|
40 |
-
progress: gr.Progress
|
41 |
):
|
42 |
pass
|
43 |
|
|
|
37 |
model_size: str,
|
38 |
src_lang: str,
|
39 |
tgt_lang: str,
|
40 |
+
progress: gr.Progress = gr.Progress()
|
41 |
):
|
42 |
pass
|
43 |
|
modules/whisper/faster_whisper_inference.py
CHANGED
@@ -40,7 +40,7 @@ class FasterWhisperInference(WhisperBase):
|
|
40 |
|
41 |
def transcribe(self,
|
42 |
audio: Union[str, BinaryIO, np.ndarray],
|
43 |
-
progress: gr.Progress,
|
44 |
*whisper_params,
|
45 |
) -> Tuple[List[dict], float]:
|
46 |
"""
|
@@ -126,7 +126,7 @@ class FasterWhisperInference(WhisperBase):
|
|
126 |
def update_model(self,
|
127 |
model_size: str,
|
128 |
compute_type: str,
|
129 |
-
progress: gr.Progress
|
130 |
):
|
131 |
"""
|
132 |
Update current model setting
|
|
|
40 |
|
41 |
def transcribe(self,
|
42 |
audio: Union[str, BinaryIO, np.ndarray],
|
43 |
+
progress: gr.Progress = gr.Progress(),
|
44 |
*whisper_params,
|
45 |
) -> Tuple[List[dict], float]:
|
46 |
"""
|
|
|
126 |
def update_model(self,
|
127 |
model_size: str,
|
128 |
compute_type: str,
|
129 |
+
progress: gr.Progress = gr.Progress()
|
130 |
):
|
131 |
"""
|
132 |
Update current model setting
|
modules/whisper/insanely_fast_whisper_inference.py
CHANGED
@@ -39,7 +39,7 @@ class InsanelyFastWhisperInference(WhisperBase):
|
|
39 |
|
40 |
def transcribe(self,
|
41 |
audio: Union[str, np.ndarray, torch.Tensor],
|
42 |
-
progress: gr.Progress,
|
43 |
*whisper_params,
|
44 |
) -> Tuple[List[dict], float]:
|
45 |
"""
|
@@ -98,7 +98,7 @@ class InsanelyFastWhisperInference(WhisperBase):
|
|
98 |
def update_model(self,
|
99 |
model_size: str,
|
100 |
compute_type: str,
|
101 |
-
progress: gr.Progress,
|
102 |
):
|
103 |
"""
|
104 |
Update current model setting
|
|
|
39 |
|
40 |
def transcribe(self,
|
41 |
audio: Union[str, np.ndarray, torch.Tensor],
|
42 |
+
progress: gr.Progress = gr.Progress(),
|
43 |
*whisper_params,
|
44 |
) -> Tuple[List[dict], float]:
|
45 |
"""
|
|
|
98 |
def update_model(self,
|
99 |
model_size: str,
|
100 |
compute_type: str,
|
101 |
+
progress: gr.Progress = gr.Progress(),
|
102 |
):
|
103 |
"""
|
104 |
Update current model setting
|
modules/whisper/whisper_Inference.py
CHANGED
@@ -28,7 +28,7 @@ class WhisperInference(WhisperBase):
|
|
28 |
|
29 |
def transcribe(self,
|
30 |
audio: Union[str, np.ndarray, torch.Tensor],
|
31 |
-
progress: gr.Progress,
|
32 |
*whisper_params,
|
33 |
) -> Tuple[List[dict], float]:
|
34 |
"""
|
@@ -79,7 +79,7 @@ class WhisperInference(WhisperBase):
|
|
79 |
def update_model(self,
|
80 |
model_size: str,
|
81 |
compute_type: str,
|
82 |
-
progress: gr.Progress,
|
83 |
):
|
84 |
"""
|
85 |
Update current model setting
|
|
|
28 |
|
29 |
def transcribe(self,
|
30 |
audio: Union[str, np.ndarray, torch.Tensor],
|
31 |
+
progress: gr.Progress = gr.Progress(),
|
32 |
*whisper_params,
|
33 |
) -> Tuple[List[dict], float]:
|
34 |
"""
|
|
|
79 |
def update_model(self,
|
80 |
model_size: str,
|
81 |
compute_type: str,
|
82 |
+
progress: gr.Progress = gr.Progress(),
|
83 |
):
|
84 |
"""
|
85 |
Update current model setting
|
modules/whisper/whisper_base.py
CHANGED
@@ -53,7 +53,7 @@ class WhisperBase(ABC):
|
|
53 |
@abstractmethod
|
54 |
def transcribe(self,
|
55 |
audio: Union[str, BinaryIO, np.ndarray],
|
56 |
-
progress: gr.Progress,
|
57 |
*whisper_params,
|
58 |
):
|
59 |
"""Inference whisper model to transcribe"""
|
@@ -63,7 +63,7 @@ class WhisperBase(ABC):
|
|
63 |
def update_model(self,
|
64 |
model_size: str,
|
65 |
compute_type: str,
|
66 |
-
progress: gr.Progress
|
67 |
):
|
68 |
"""Initialize whisper model"""
|
69 |
pass
|
@@ -171,10 +171,10 @@ class WhisperBase(ABC):
|
|
171 |
return result, elapsed_time
|
172 |
|
173 |
def transcribe_file(self,
|
174 |
-
files:
|
175 |
-
input_folder_path: str,
|
176 |
-
file_format: str,
|
177 |
-
add_timestamp: bool,
|
178 |
progress=gr.Progress(),
|
179 |
*whisper_params,
|
180 |
) -> list:
|
@@ -250,8 +250,8 @@ class WhisperBase(ABC):
|
|
250 |
|
251 |
def transcribe_mic(self,
|
252 |
mic_audio: str,
|
253 |
-
file_format: str,
|
254 |
-
add_timestamp: bool,
|
255 |
progress=gr.Progress(),
|
256 |
*whisper_params,
|
257 |
) -> list:
|
@@ -306,8 +306,8 @@ class WhisperBase(ABC):
|
|
306 |
|
307 |
def transcribe_youtube(self,
|
308 |
youtube_link: str,
|
309 |
-
file_format: str,
|
310 |
-
add_timestamp: bool,
|
311 |
progress=gr.Progress(),
|
312 |
*whisper_params,
|
313 |
) -> list:
|
@@ -411,11 +411,12 @@ class WhisperBase(ABC):
|
|
411 |
else:
|
412 |
output_path = os.path.join(output_dir, f"{file_name}")
|
413 |
|
414 |
-
|
|
|
415 |
content = get_srt(transcribed_segments)
|
416 |
output_path += '.srt'
|
417 |
|
418 |
-
elif file_format == "
|
419 |
content = get_vtt(transcribed_segments)
|
420 |
output_path += '.vtt'
|
421 |
|
|
|
53 |
@abstractmethod
|
54 |
def transcribe(self,
|
55 |
audio: Union[str, BinaryIO, np.ndarray],
|
56 |
+
progress: gr.Progress = gr.Progress(),
|
57 |
*whisper_params,
|
58 |
):
|
59 |
"""Inference whisper model to transcribe"""
|
|
|
63 |
def update_model(self,
|
64 |
model_size: str,
|
65 |
compute_type: str,
|
66 |
+
progress: gr.Progress = gr.Progress()
|
67 |
):
|
68 |
"""Initialize whisper model"""
|
69 |
pass
|
|
|
171 |
return result, elapsed_time
|
172 |
|
173 |
def transcribe_file(self,
|
174 |
+
files: Optional[List] = None,
|
175 |
+
input_folder_path: Optional[str] = None,
|
176 |
+
file_format: str = "SRT",
|
177 |
+
add_timestamp: bool = True,
|
178 |
progress=gr.Progress(),
|
179 |
*whisper_params,
|
180 |
) -> list:
|
|
|
250 |
|
251 |
def transcribe_mic(self,
|
252 |
mic_audio: str,
|
253 |
+
file_format: str = "SRT",
|
254 |
+
add_timestamp: bool = True,
|
255 |
progress=gr.Progress(),
|
256 |
*whisper_params,
|
257 |
) -> list:
|
|
|
306 |
|
307 |
def transcribe_youtube(self,
|
308 |
youtube_link: str,
|
309 |
+
file_format: str = "SRT",
|
310 |
+
add_timestamp: bool = True,
|
311 |
progress=gr.Progress(),
|
312 |
*whisper_params,
|
313 |
) -> list:
|
|
|
411 |
else:
|
412 |
output_path = os.path.join(output_dir, f"{file_name}")
|
413 |
|
414 |
+
file_format = file_format.strip().lower()
|
415 |
+
if file_format == "srt":
|
416 |
content = get_srt(transcribed_segments)
|
417 |
output_path += '.srt'
|
418 |
|
419 |
+
elif file_format == "webvtt":
|
420 |
content = get_vtt(transcribed_segments)
|
421 |
output_path += '.vtt'
|
422 |
|