Spaces:
Running
Running
jhj0517
commited on
Commit
·
91dee77
1
Parent(s):
6e51e1b
enable fine-tuned faster-whisper model
Browse files
modules/faster_whisper_inference.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import os
|
2 |
-
|
3 |
import time
|
4 |
import numpy as np
|
5 |
from typing import BinaryIO, Union, Tuple, List
|
@@ -24,16 +23,17 @@ os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
|
24 |
class FasterWhisperInference(BaseInterface):
|
25 |
def __init__(self):
|
26 |
super().__init__()
|
|
|
27 |
self.current_model_size = None
|
28 |
self.model = None
|
29 |
-
self.
|
|
|
30 |
self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
|
31 |
self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
|
32 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
33 |
self.available_compute_types = ctranslate2.get_supported_compute_types(
|
34 |
"cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
|
35 |
self.current_compute_type = "float16" if self.device == "cuda" else "float32"
|
36 |
-
self.model_dir = os.path.join("models", "Whisper", "faster-whisper")
|
37 |
|
38 |
def transcribe_file(self,
|
39 |
files: list,
|
@@ -317,15 +317,40 @@ class FasterWhisperInference(BaseInterface):
|
|
317 |
Indicator to show progress directly in gradio.
|
318 |
"""
|
319 |
progress(0, desc="Initializing Model..")
|
320 |
-
self.current_model_size = model_size
|
321 |
self.current_compute_type = compute_type
|
322 |
self.model = faster_whisper.WhisperModel(
|
323 |
device=self.device,
|
324 |
-
model_size_or_path=
|
325 |
download_root=self.model_dir,
|
326 |
compute_type=self.current_compute_type
|
327 |
)
|
328 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
329 |
@staticmethod
|
330 |
def generate_and_write_file(file_name: str,
|
331 |
transcribed_segments: list,
|
|
|
1 |
import os
|
|
|
2 |
import time
|
3 |
import numpy as np
|
4 |
from typing import BinaryIO, Union, Tuple, List
|
|
|
23 |
class FasterWhisperInference(BaseInterface):
|
24 |
def __init__(self):
|
25 |
super().__init__()
|
26 |
+
self.model_dir = os.path.join("models", "Whisper", "faster-whisper")
|
27 |
self.current_model_size = None
|
28 |
self.model = None
|
29 |
+
self.model_paths = self.get_model_paths()
|
30 |
+
self.available_models = self.model_paths.keys()
|
31 |
self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
|
32 |
self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
|
33 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
34 |
self.available_compute_types = ctranslate2.get_supported_compute_types(
|
35 |
"cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
|
36 |
self.current_compute_type = "float16" if self.device == "cuda" else "float32"
|
|
|
37 |
|
38 |
def transcribe_file(self,
|
39 |
files: list,
|
|
|
317 |
Indicator to show progress directly in gradio.
|
318 |
"""
|
319 |
progress(0, desc="Initializing Model..")
|
320 |
+
self.current_model_size = self.model_paths[model_size]
|
321 |
self.current_compute_type = compute_type
|
322 |
self.model = faster_whisper.WhisperModel(
|
323 |
device=self.device,
|
324 |
+
model_size_or_path=self.current_model_size,
|
325 |
download_root=self.model_dir,
|
326 |
compute_type=self.current_compute_type
|
327 |
)
|
328 |
|
329 |
+
def get_model_paths(self):
|
330 |
+
"""
|
331 |
+
Get available models from models path including fine-tuned model.
|
332 |
+
|
333 |
+
Returns
|
334 |
+
----------
|
335 |
+
Name list of models
|
336 |
+
"""
|
337 |
+
model_paths = {model:model for model in whisper.available_models()}
|
338 |
+
faster_whisper_prefix = "models--Systran--faster-whisper-"
|
339 |
+
|
340 |
+
existing_models = os.listdir(self.model_dir)
|
341 |
+
wrong_dirs = [".locks"]
|
342 |
+
existing_models = list(set(existing_models) - set(wrong_dirs))
|
343 |
+
|
344 |
+
webui_dir = os.getcwd()
|
345 |
+
|
346 |
+
for model_name in existing_models:
|
347 |
+
if faster_whisper_prefix in model_name:
|
348 |
+
model_name = model_name[len(faster_whisper_prefix):]
|
349 |
+
|
350 |
+
if model_name not in whisper.available_models():
|
351 |
+
model_paths[model_name] = os.path.join(webui_dir, self.model_dir, model_name)
|
352 |
+
return model_paths
|
353 |
+
|
354 |
@staticmethod
|
355 |
def generate_and_write_file(file_name: str,
|
356 |
transcribed_segments: list,
|