jhj0517 commited on
Commit
825f362
·
unverified ·
2 Parent(s): 48382b6 2db409c

Merge pull request #217 from jhj0517/feature/use-factory-pattern

Browse files
Files changed (2) hide show
  1. app.py +7 -35
  2. modules/whisper/whisper_factory.py +60 -0
app.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import argparse
3
  import gradio as gr
4
 
5
- from modules.whisper.whisper_Inference import WhisperInference
6
  from modules.whisper.faster_whisper_inference import FasterWhisperInference
7
  from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
8
  from modules.translation.nllb_inference import NLLBInference
@@ -16,7 +16,12 @@ class App:
16
  def __init__(self, args):
17
  self.args = args
18
  self.app = gr.Blocks(css=CSS, theme=self.args.theme)
19
- self.whisper_inf = self.init_whisper()
 
 
 
 
 
20
  print(f"Use \"{self.args.whisper_type}\" implementation")
21
  print(f"Device \"{self.whisper_inf.device}\" is detected")
22
  self.nllb_inf = NLLBInference(
@@ -27,39 +32,6 @@ class App:
27
  output_dir=os.path.join(self.args.output_dir, "translations")
28
  )
29
 
30
- def init_whisper(self):
31
- # Temporal fix of the issue : https://github.com/jhj0517/Whisper-WebUI/issues/144
32
- os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
33
-
34
- whisper_type = self.args.whisper_type.lower().strip()
35
-
36
- if whisper_type in ["faster_whisper", "faster-whisper", "fasterwhisper"]:
37
- whisper_inf = FasterWhisperInference(
38
- model_dir=self.args.faster_whisper_model_dir,
39
- output_dir=self.args.output_dir,
40
- args=self.args
41
- )
42
- elif whisper_type in ["whisper"]:
43
- whisper_inf = WhisperInference(
44
- model_dir=self.args.whisper_model_dir,
45
- output_dir=self.args.output_dir,
46
- args=self.args
47
- )
48
- elif whisper_type in ["insanely_fast_whisper", "insanely-fast-whisper", "insanelyfastwhisper",
49
- "insanely_faster_whisper", "insanely-faster-whisper", "insanelyfasterwhisper"]:
50
- whisper_inf = InsanelyFastWhisperInference(
51
- model_dir=self.args.insanely_fast_whisper_model_dir,
52
- output_dir=self.args.output_dir,
53
- args=self.args
54
- )
55
- else:
56
- whisper_inf = FasterWhisperInference(
57
- model_dir=self.args.faster_whisper_model_dir,
58
- output_dir=self.args.output_dir,
59
- args=self.args
60
- )
61
- return whisper_inf
62
-
63
  def create_whisper_parameters(self):
64
  with gr.Row():
65
  dd_model = gr.Dropdown(choices=self.whisper_inf.available_models, value="large-v2",
 
2
  import argparse
3
  import gradio as gr
4
 
5
+ from modules.whisper.whisper_factory import WhisperFactory
6
  from modules.whisper.faster_whisper_inference import FasterWhisperInference
7
  from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
8
  from modules.translation.nllb_inference import NLLBInference
 
16
  def __init__(self, args):
17
  self.args = args
18
  self.app = gr.Blocks(css=CSS, theme=self.args.theme)
19
+ self.whisper_inf = WhisperFactory.create_whisper_inference(
20
+ whisper_type=self.args.whisper_type,
21
+ model_dir=self.args.faster_whisper_model_dir,
22
+ output_dir=self.args.output_dir,
23
+ args=self.args
24
+ )
25
  print(f"Use \"{self.args.whisper_type}\" implementation")
26
  print(f"Device \"{self.whisper_inf.device}\" is detected")
27
  self.nllb_inf = NLLBInference(
 
32
  output_dir=os.path.join(self.args.output_dir, "translations")
33
  )
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def create_whisper_parameters(self):
36
  with gr.Row():
37
  dd_model = gr.Dropdown(choices=self.whisper_inf.available_models, value="large-v2",
modules/whisper/whisper_factory.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import Namespace
2
+ import os
3
+
4
+ from modules.whisper.faster_whisper_inference import FasterWhisperInference
5
+ from modules.whisper.whisper_Inference import WhisperInference
6
+ from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
7
+ from modules.whisper.whisper_base import WhisperBase
8
+
9
+
10
+ class WhisperFactory:
11
+ @staticmethod
12
+ def create_whisper_inference(
13
+ whisper_type: str,
14
+ model_dir: str,
15
+ output_dir: str,
16
+ args: Namespace
17
+ ) -> "WhisperBase":
18
+ """
19
+ Create a whisper inference class based on the provided whisper_type.
20
+
21
+ Parameters
22
+ ----------
23
+ whisper_type: str
24
+ The repository name of whisper inference to use. Supported values are:
25
+ - "faster-whisper" from
26
+ - "whisper"
27
+ - insanely-fast-whisper", "insanely_fast_whisper", "insanelyfastwhisper",
28
+ "insanely-faster-whisper", "insanely_faster_whisper", "insanelyfasterwhisper"
29
+ model_dir: str
30
+ The directory path where the whisper model is located.
31
+ output_dir: str
32
+ The directory path where the output files will be saved.
33
+ args: Any
34
+ Additional arguments to be passed to the whisper inference object.
35
+
36
+ Returns
37
+ -------
38
+ WhisperBase
39
+ An instance of the appropriate whisper inference class based on the whisper_type.
40
+ """
41
+ # Temporal fix of the bug : https://github.com/jhj0517/Whisper-WebUI/issues/144
42
+ os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
43
+
44
+ whisper_type = whisper_type.lower().strip()
45
+
46
+ faster_whisper_typos = ["faster_whisper", "faster-whisper", "fasterwhisper"]
47
+ whisper_typos = ["whisper"]
48
+ insanely_fast_whisper_typos = [
49
+ "insanely_fast_whisper", "insanely-fast-whisper", "insanelyfastwhisper",
50
+ "insanely_faster_whisper", "insanely-faster-whisper", "insanelyfasterwhisper"
51
+ ]
52
+
53
+ if whisper_type in faster_whisper_typos:
54
+ return FasterWhisperInference(model_dir, output_dir, args)
55
+ elif whisper_type in whisper_typos:
56
+ return WhisperInference(model_dir, output_dir, args)
57
+ elif whisper_type in insanely_fast_whisper_typos:
58
+ return InsanelyFastWhisperInference(model_dir, output_dir, args)
59
+ else:
60
+ return FasterWhisperInference(model_dir, output_dir, args)