jhj0517 commited on
Commit
17e1e38
·
1 Parent(s): b4876a0

Refactor default parameters for initialization

Browse files
modules/whisper/faster_whisper_inference.py CHANGED
@@ -17,17 +17,15 @@ from modules.whisper.whisper_base import WhisperBase
17
 
18
  class FasterWhisperInference(WhisperBase):
19
  def __init__(self,
20
- model_dir: Optional[str] = None,
21
- diarization_model_dir: Optional[str] = None,
22
- output_dir: Optional[str] = None,
23
  ):
24
  super().__init__(
25
  model_dir=model_dir,
26
  diarization_model_dir=diarization_model_dir,
27
  output_dir=output_dir
28
  )
29
- if model_dir is None:
30
- model_dir = os.path.join("models", "Whisper", "faster-whisper")
31
  self.model_dir = model_dir
32
  os.makedirs(self.model_dir, exist_ok=True)
33
 
 
17
 
18
  class FasterWhisperInference(WhisperBase):
19
  def __init__(self,
20
+ model_dir: str = os.path.join("models", "Whisper", "faster-whisper"),
21
+ diarization_model_dir: str = os.path.join("models", "Diarization"),
22
+ output_dir: str = os.path.join("outputs"),
23
  ):
24
  super().__init__(
25
  model_dir=model_dir,
26
  diarization_model_dir=diarization_model_dir,
27
  output_dir=output_dir
28
  )
 
 
29
  self.model_dir = model_dir
30
  os.makedirs(self.model_dir, exist_ok=True)
31
 
modules/whisper/insanely_fast_whisper_inference.py CHANGED
@@ -17,17 +17,15 @@ from modules.whisper.whisper_base import WhisperBase
17
 
18
  class InsanelyFastWhisperInference(WhisperBase):
19
  def __init__(self,
20
- model_dir: Optional[str] = None,
21
- diarization_model_dir: Optional[str] = None,
22
- output_dir: Optional[str] = None,
23
  ):
24
  super().__init__(
25
  model_dir=model_dir,
26
  output_dir=output_dir,
27
  diarization_model_dir=diarization_model_dir
28
  )
29
- if model_dir is None:
30
- model_dir = os.path.join("models", "Whisper", "insanely-fast-whisper")
31
  self.model_dir = model_dir
32
  os.makedirs(self.model_dir, exist_ok=True)
33
 
 
17
 
18
  class InsanelyFastWhisperInference(WhisperBase):
19
  def __init__(self,
20
+ model_dir: str = os.path.join("models", "Whisper", "insanely-fast-whisper"),
21
+ diarization_model_dir: str = os.path.join("models", "Diarization"),
22
+ output_dir: str = os.path.join("outputs"),
23
  ):
24
  super().__init__(
25
  model_dir=model_dir,
26
  output_dir=output_dir,
27
  diarization_model_dir=diarization_model_dir
28
  )
 
 
29
  self.model_dir = model_dir
30
  os.makedirs(self.model_dir, exist_ok=True)
31
 
modules/whisper/whisper_Inference.py CHANGED
@@ -4,6 +4,7 @@ import time
4
  from typing import BinaryIO, Union, Tuple, List
5
  import numpy as np
6
  import torch
 
7
  from argparse import Namespace
8
 
9
  from modules.whisper.whisper_base import WhisperBase
@@ -12,9 +13,9 @@ from modules.whisper.whisper_parameter import *
12
 
13
  class WhisperInference(WhisperBase):
14
  def __init__(self,
15
- model_dir: Optional[str] = None,
16
- diarization_model_dir: Optional[str] = None,
17
- output_dir: Optional[str] = None,
18
  ):
19
  super().__init__(
20
  model_dir=model_dir,
 
4
  from typing import BinaryIO, Union, Tuple, List
5
  import numpy as np
6
  import torch
7
+ import os
8
  from argparse import Namespace
9
 
10
  from modules.whisper.whisper_base import WhisperBase
 
13
 
14
  class WhisperInference(WhisperBase):
15
  def __init__(self,
16
+ model_dir: str = os.path.join("models", "Whisper"),
17
+ diarization_model_dir: str = os.path.join("models", "Diarization"),
18
+ output_dir: str = os.path.join("outputs"),
19
  ):
20
  super().__init__(
21
  model_dir=model_dir,
modules/whisper/whisper_base.py CHANGED
@@ -19,17 +19,10 @@ from modules.vad.silero_vad import SileroVAD
19
 
20
  class WhisperBase(ABC):
21
  def __init__(self,
22
- model_dir: Optional[str] = None,
23
- diarization_model_dir: Optional[str] = None,
24
- output_dir: Optional[str] = None,
25
  ):
26
- if model_dir is None:
27
- model_dir = os.path.join("models", "Whisper")
28
- if diarization_model_dir is None:
29
- diarization_model_dir = os.path.join("models", "Diarization")
30
- if output_dir is None:
31
- output_dir = os.path.join("outputs")
32
-
33
  self.model_dir = model_dir
34
  self.output_dir = output_dir
35
  os.makedirs(self.output_dir, exist_ok=True)
 
19
 
20
  class WhisperBase(ABC):
21
  def __init__(self,
22
+ model_dir: str = os.path.join("models", "Whisper"),
23
+ diarization_model_dir: str = os.path.join("models", "Diarization"),
24
+ output_dir: str = os.path.join("outputs"),
25
  ):
 
 
 
 
 
 
 
26
  self.model_dir = model_dir
27
  self.output_dir = output_dir
28
  os.makedirs(self.output_dir, exist_ok=True)
modules/whisper/whisper_factory.py CHANGED
@@ -11,11 +11,11 @@ class WhisperFactory:
11
  @staticmethod
12
  def create_whisper_inference(
13
  whisper_type: str,
14
- whisper_model_dir: Optional[str] = None,
15
- faster_whisper_model_dir: Optional[str] = None,
16
- insanely_fast_whisper_model_dir: Optional[str] = None,
17
- diarization_model_dir: Optional[str] = None,
18
- output_dir: Optional[str] = None,
19
  ) -> "WhisperBase":
20
  """
21
  Create a whisper inference class based on the provided whisper_type.
 
11
  @staticmethod
12
  def create_whisper_inference(
13
  whisper_type: str,
14
+ whisper_model_dir: str = os.path.join("models", "Whisper"),
15
+ faster_whisper_model_dir: str = os.path.join("models", "Whisper", "faster-whisper"),
16
+ insanely_fast_whisper_model_dir: str = os.path.join("models", "Whisper", "insanely-fast-whisper"),
17
+ diarization_model_dir: str = os.path.join("models", "Diarization"),
18
+ output_dir: str = os.path.join("outputs"),
19
  ) -> "WhisperBase":
20
  """
21
  Create a whisper inference class based on the provided whisper_type.