jhj0517 commited on
Commit
91dee77
·
1 Parent(s): 6e51e1b

enable fine-tuned faster-whisper model

Browse files
Files changed (1) hide show
  1. modules/faster_whisper_inference.py +30 -5
modules/faster_whisper_inference.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
-
3
  import time
4
  import numpy as np
5
  from typing import BinaryIO, Union, Tuple, List
@@ -24,16 +23,17 @@ os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
24
  class FasterWhisperInference(BaseInterface):
25
  def __init__(self):
26
  super().__init__()
 
27
  self.current_model_size = None
28
  self.model = None
29
- self.available_models = whisper.available_models()
 
30
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
31
  self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
32
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
33
  self.available_compute_types = ctranslate2.get_supported_compute_types(
34
  "cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
35
  self.current_compute_type = "float16" if self.device == "cuda" else "float32"
36
- self.model_dir = os.path.join("models", "Whisper", "faster-whisper")
37
 
38
  def transcribe_file(self,
39
  files: list,
@@ -317,15 +317,40 @@ class FasterWhisperInference(BaseInterface):
317
  Indicator to show progress directly in gradio.
318
  """
319
  progress(0, desc="Initializing Model..")
320
- self.current_model_size = model_size
321
  self.current_compute_type = compute_type
322
  self.model = faster_whisper.WhisperModel(
323
  device=self.device,
324
- model_size_or_path=model_size,
325
  download_root=self.model_dir,
326
  compute_type=self.current_compute_type
327
  )
328
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  @staticmethod
330
  def generate_and_write_file(file_name: str,
331
  transcribed_segments: list,
 
1
  import os
 
2
  import time
3
  import numpy as np
4
  from typing import BinaryIO, Union, Tuple, List
 
23
  class FasterWhisperInference(BaseInterface):
24
  def __init__(self):
25
  super().__init__()
26
+ self.model_dir = os.path.join("models", "Whisper", "faster-whisper")
27
  self.current_model_size = None
28
  self.model = None
29
+ self.model_paths = self.get_model_paths()
30
+ self.available_models = self.model_paths.keys()
31
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
32
  self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
33
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
34
  self.available_compute_types = ctranslate2.get_supported_compute_types(
35
  "cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
36
  self.current_compute_type = "float16" if self.device == "cuda" else "float32"
 
37
 
38
  def transcribe_file(self,
39
  files: list,
 
317
  Indicator to show progress directly in gradio.
318
  """
319
  progress(0, desc="Initializing Model..")
320
+ self.current_model_size = self.model_paths[model_size]
321
  self.current_compute_type = compute_type
322
  self.model = faster_whisper.WhisperModel(
323
  device=self.device,
324
+ model_size_or_path=self.current_model_size,
325
  download_root=self.model_dir,
326
  compute_type=self.current_compute_type
327
  )
328
 
329
+ def get_model_paths(self):
330
+ """
331
+ Get available models from models path including fine-tuned model.
332
+
333
+ Returns
334
+ ----------
335
+ Name list of models
336
+ """
337
+ model_paths = {model:model for model in whisper.available_models()}
338
+ faster_whisper_prefix = "models--Systran--faster-whisper-"
339
+
340
+ existing_models = os.listdir(self.model_dir)
341
+ wrong_dirs = [".locks"]
342
+ existing_models = list(set(existing_models) - set(wrong_dirs))
343
+
344
+ webui_dir = os.getcwd()
345
+
346
+ for model_name in existing_models:
347
+ if faster_whisper_prefix in model_name:
348
+ model_name = model_name[len(faster_whisper_prefix):]
349
+
350
+ if model_name not in whisper.available_models():
351
+ model_paths[model_name] = os.path.join(webui_dir, self.model_dir, model_name)
352
+ return model_paths
353
+
354
  @staticmethod
355
  def generate_and_write_file(file_name: str,
356
  transcribed_segments: list,