jhj0517 commited on
Commit
a230be5
·
unverified ·
2 Parent(s): 0450240 94904d8

Merge pull request #178 from jhj0517/feature/add-output-dir

Browse files
app.py CHANGED
@@ -19,25 +19,38 @@ class App:
19
  self.whisper_inf = self.init_whisper()
20
  print(f"Use \"{self.args.whisper_type}\" implementation")
21
  print(f"Device \"{self.whisper_inf.device}\" is detected")
22
- self.nllb_inf = NLLBInference()
23
- self.deepl_api = DeepLAPI()
 
 
 
 
 
24
 
25
  def init_whisper(self):
26
  whisper_type = self.args.whisper_type.lower().strip()
27
 
28
  if whisper_type in ["faster_whisper", "faster-whisper", "fasterwhisper"]:
29
- whisper_inf = FasterWhisperInference()
30
- whisper_inf.model_dir = self.args.faster_whisper_model_dir
 
 
31
  elif whisper_type in ["whisper"]:
32
- whisper_inf = WhisperInference()
33
- whisper_inf.model_dir = self.args.whisper_model_dir
 
 
34
  elif whisper_type in ["insanely_fast_whisper", "insanely-fast-whisper", "insanelyfastwhisper",
35
  "insanely_faster_whisper", "insanely-faster-whisper", "insanelyfasterwhisper"]:
36
- whisper_inf = InsanelyFastWhisperInference()
37
- whisper_inf.model_dir = self.args.insanely_fast_whisper_model_dir
 
 
38
  else:
39
- whisper_inf = FasterWhisperInference()
40
- whisper_inf.model_dir = self.args.faster_whisper_model_dir
 
 
41
  return whisper_inf
42
 
43
  @staticmethod
@@ -366,7 +379,7 @@ class App:
366
 
367
  # Create the parser for command-line arguments
368
  parser = argparse.ArgumentParser()
369
- parser.add_argument('--whisper_type', type=str, default="faster-whisper", help='A type of the whisper implementation between: ["whisper", "faster-whisper"]')
370
  parser.add_argument('--share', type=bool, default=False, nargs='?', const=True, help='Gradio share value')
371
  parser.add_argument('--server_name', type=str, default=None, help='Gradio server host')
372
  parser.add_argument('--server_port', type=int, default=None, help='Gradio server port')
@@ -379,6 +392,8 @@ parser.add_argument('--api_open', type=bool, default=False, nargs='?', const=Tru
379
  parser.add_argument('--whisper_model_dir', type=str, default=os.path.join("models", "Whisper"), help='Directory path of the whisper model')
380
  parser.add_argument('--faster_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "faster-whisper"), help='Directory path of the faster-whisper model')
381
  parser.add_argument('--insanely_fast_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "insanely-fast-whisper"), help='Directory path of the insanely-fast-whisper model')
 
 
382
  _args = parser.parse_args()
383
 
384
  if __name__ == "__main__":
 
19
  self.whisper_inf = self.init_whisper()
20
  print(f"Use \"{self.args.whisper_type}\" implementation")
21
  print(f"Device \"{self.whisper_inf.device}\" is detected")
22
+ self.nllb_inf = NLLBInference(
23
+ model_dir=self.args.nllb_model_dir,
24
+ output_dir=self.args.output_dir
25
+ )
26
+ self.deepl_api = DeepLAPI(
27
+ output_dir=self.args.output_dir
28
+ )
29
 
30
  def init_whisper(self):
31
  whisper_type = self.args.whisper_type.lower().strip()
32
 
33
  if whisper_type in ["faster_whisper", "faster-whisper", "fasterwhisper"]:
34
+ whisper_inf = FasterWhisperInference(
35
+ model_dir=self.args.faster_whisper_model_dir,
36
+ output_dir=self.args.output_dir
37
+ )
38
  elif whisper_type in ["whisper"]:
39
+ whisper_inf = WhisperInference(
40
+ model_dir=self.args.whisper_model_dir,
41
+ output_dir=self.args.output_dir
42
+ )
43
  elif whisper_type in ["insanely_fast_whisper", "insanely-fast-whisper", "insanelyfastwhisper",
44
  "insanely_faster_whisper", "insanely-faster-whisper", "insanelyfasterwhisper"]:
45
+ whisper_inf = InsanelyFastWhisperInference(
46
+ model_dir=self.args.insanely_fast_whisper_model_dir,
47
+ output_dir=self.args.output_dir
48
+ )
49
  else:
50
+ whisper_inf = FasterWhisperInference(
51
+ model_dir=self.args.faster_whisper_model_dir,
52
+ output_dir=self.args.output_dir
53
+ )
54
  return whisper_inf
55
 
56
  @staticmethod
 
379
 
380
  # Create the parser for command-line arguments
381
  parser = argparse.ArgumentParser()
382
+ parser.add_argument('--whisper_type', type=str, default="faster-whisper", help='A type of the whisper implementation between: ["whisper", "faster-whisper", "insanely-fast-whisper"]')
383
  parser.add_argument('--share', type=bool, default=False, nargs='?', const=True, help='Gradio share value')
384
  parser.add_argument('--server_name', type=str, default=None, help='Gradio server host')
385
  parser.add_argument('--server_port', type=int, default=None, help='Gradio server port')
 
392
  parser.add_argument('--whisper_model_dir', type=str, default=os.path.join("models", "Whisper"), help='Directory path of the whisper model')
393
  parser.add_argument('--faster_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "faster-whisper"), help='Directory path of the faster-whisper model')
394
  parser.add_argument('--insanely_fast_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "insanely-fast-whisper"), help='Directory path of the insanely-fast-whisper model')
395
+ parser.add_argument('--nllb_model_dir', type=str, default=os.path.join("models", "NLLB"), help='Directory path of the Facebook NLLB model')
396
+ parser.add_argument('--output_dir', type=str, default=os.path.join("outputs"), help='Directory path of the outputs')
397
  _args = parser.parse_args()
398
 
399
  if __name__ == "__main__":
modules/deepl_api.py CHANGED
@@ -82,11 +82,14 @@ DEEPL_AVAILABLE_SOURCE_LANGS = {
82
 
83
 
84
  class DeepLAPI:
85
- def __init__(self):
 
 
86
  self.api_interval = 1
87
  self.max_text_batch_size = 50
88
  self.available_target_langs = DEEPL_AVAILABLE_TARGET_LANGS
89
  self.available_source_langs = DEEPL_AVAILABLE_SOURCE_LANGS
 
90
 
91
  def translate_deepl(self,
92
  auth_key: str,
@@ -111,6 +114,7 @@ class DeepLAPI:
111
  Boolean value that is about pro user or not from gr.Checkbox().
112
  progress: gr.Progress
113
  Indicator to show progress directly in gradio.
 
114
  Returns
115
  ----------
116
  A List of
@@ -140,7 +144,7 @@ class DeepLAPI:
140
  timestamp = datetime.now().strftime("%m%d%H%M%S")
141
 
142
  file_name = file_name[:-9]
143
- output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}.srt")
144
  write_file(subtitle, output_path)
145
 
146
  elif file_ext == ".vtt":
@@ -160,7 +164,7 @@ class DeepLAPI:
160
  timestamp = datetime.now().strftime("%m%d%H%M%S")
161
 
162
  file_name = file_name[:-9]
163
- output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}.vtt")
164
 
165
  write_file(subtitle, output_path)
166
 
 
82
 
83
 
84
  class DeepLAPI:
85
+ def __init__(self,
86
+ output_dir: str
87
+ ):
88
  self.api_interval = 1
89
  self.max_text_batch_size = 50
90
  self.available_target_langs = DEEPL_AVAILABLE_TARGET_LANGS
91
  self.available_source_langs = DEEPL_AVAILABLE_SOURCE_LANGS
92
+ self.output_dir = output_dir
93
 
94
  def translate_deepl(self,
95
  auth_key: str,
 
114
  Boolean value that is about pro user or not from gr.Checkbox().
115
  progress: gr.Progress
116
  Indicator to show progress directly in gradio.
117
+
118
  Returns
119
  ----------
120
  A List of
 
144
  timestamp = datetime.now().strftime("%m%d%H%M%S")
145
 
146
  file_name = file_name[:-9]
147
+ output_path = os.path.join(self.output_dir, "translations", f"{file_name}-{timestamp}.srt")
148
  write_file(subtitle, output_path)
149
 
150
  elif file_ext == ".vtt":
 
164
  timestamp = datetime.now().strftime("%m%d%H%M%S")
165
 
166
  file_name = file_name[:-9]
167
+ output_path = os.path.join(self.output_dir, "translations", f"{file_name}-{timestamp}.vtt")
168
 
169
  write_file(subtitle, output_path)
170
 
modules/faster_whisper_inference.py CHANGED
@@ -17,9 +17,13 @@ os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
17
 
18
 
19
  class FasterWhisperInference(WhisperBase):
20
- def __init__(self):
 
 
 
21
  super().__init__(
22
- model_dir=os.path.join("models", "Whisper", "faster-whisper")
 
23
  )
24
  self.model_paths = self.get_model_paths()
25
  self.available_models = self.model_paths.keys()
 
17
 
18
 
19
  class FasterWhisperInference(WhisperBase):
20
+ def __init__(self,
21
+ model_dir: str,
22
+ output_dir: str
23
+ ):
24
  super().__init__(
25
+ model_dir=model_dir,
26
+ output_dir=output_dir
27
  )
28
  self.model_paths = self.get_model_paths()
29
  self.available_models = self.model_paths.keys()
modules/insanely_fast_whisper_inference.py CHANGED
@@ -15,9 +15,13 @@ from modules.whisper_base import WhisperBase
15
 
16
 
17
  class InsanelyFastWhisperInference(WhisperBase):
18
- def __init__(self):
 
 
 
19
  super().__init__(
20
- model_dir=os.path.join("models", "Whisper", "insanely_fast_whisper")
 
21
  )
22
  openai_models = whisper.available_models()
23
  distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
 
15
 
16
 
17
  class InsanelyFastWhisperInference(WhisperBase):
18
+ def __init__(self,
19
+ model_dir: str,
20
+ output_dir: str
21
+ ):
22
  super().__init__(
23
+ model_dir=model_dir,
24
+ output_dir=output_dir
25
  )
26
  openai_models = whisper.available_models()
27
  distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
modules/nllb_inference.py CHANGED
@@ -6,9 +6,13 @@ from modules.translation_base import TranslationBase
6
 
7
 
8
  class NLLBInference(TranslationBase):
9
- def __init__(self):
 
 
 
10
  super().__init__(
11
- model_dir=os.path.join("models", "NLLB")
 
12
  )
13
  self.tokenizer = None
14
  self.available_models = ["facebook/nllb-200-3.3B", "facebook/nllb-200-1.3B", "facebook/nllb-200-distilled-600M"]
 
6
 
7
 
8
  class NLLBInference(TranslationBase):
9
+ def __init__(self,
10
+ model_dir: str,
11
+ output_dir: str
12
+ ):
13
  super().__init__(
14
+ model_dir=model_dir,
15
+ output_dir=output_dir
16
  )
17
  self.tokenizer = None
18
  self.available_models = ["facebook/nllb-200-3.3B", "facebook/nllb-200-1.3B", "facebook/nllb-200-distilled-600M"]
modules/translation_base.py CHANGED
@@ -11,11 +11,14 @@ from modules.subtitle_manager import *
11
 
12
  class TranslationBase(ABC):
13
  def __init__(self,
14
- model_dir: str):
 
15
  super().__init__()
16
  self.model = None
17
  self.model_dir = model_dir
 
18
  os.makedirs(self.model_dir, exist_ok=True)
 
19
  self.current_model_size = None
20
  self.device = self.get_device()
21
 
@@ -87,7 +90,7 @@ class TranslationBase(ABC):
87
 
88
  timestamp = datetime.now().strftime("%m%d%H%M%S")
89
  if add_timestamp:
90
- output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}")
91
  else:
92
  output_path = os.path.join("outputs", "translations", f"{file_name}.srt")
93
 
@@ -102,9 +105,9 @@ class TranslationBase(ABC):
102
 
103
  timestamp = datetime.now().strftime("%m%d%H%M%S")
104
  if add_timestamp:
105
- output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}")
106
  else:
107
- output_path = os.path.join("outputs", "translations", f"{file_name}.vtt")
108
 
109
  write_file(subtitle, output_path)
110
  files_info[file_name] = subtitle
 
11
 
12
  class TranslationBase(ABC):
13
  def __init__(self,
14
+ model_dir: str,
15
+ output_dir: str):
16
  super().__init__()
17
  self.model = None
18
  self.model_dir = model_dir
19
+ self.output_dir = output_dir
20
  os.makedirs(self.model_dir, exist_ok=True)
21
+ os.makedirs(self.output_dir, exist_ok=True)
22
  self.current_model_size = None
23
  self.device = self.get_device()
24
 
 
90
 
91
  timestamp = datetime.now().strftime("%m%d%H%M%S")
92
  if add_timestamp:
93
+ output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}.srt")
94
  else:
95
  output_path = os.path.join("outputs", "translations", f"{file_name}.srt")
96
 
 
105
 
106
  timestamp = datetime.now().strftime("%m%d%H%M%S")
107
  if add_timestamp:
108
+ output_path = os.path.join(self.output_dir, "translations", f"{file_name}-{timestamp}.vtt")
109
  else:
110
+ output_path = os.path.join(self.output_dir, "translations", f"{file_name}.vtt")
111
 
112
  write_file(subtitle, output_path)
113
  files_info[file_name] = subtitle
modules/whisper_Inference.py CHANGED
@@ -11,9 +11,13 @@ from modules.whisper_parameter import *
11
 
12
 
13
  class WhisperInference(WhisperBase):
14
- def __init__(self):
 
 
 
15
  super().__init__(
16
- model_dir=os.path.join("models", "Whisper")
 
17
  )
18
 
19
  def transcribe(self,
 
11
 
12
 
13
  class WhisperInference(WhisperBase):
14
+ def __init__(self,
15
+ model_dir: str,
16
+ output_dir: str
17
+ ):
18
  super().__init__(
19
+ model_dir=model_dir,
20
+ output_dir=output_dir
21
  )
22
 
23
  def transcribe(self,
modules/whisper_base.py CHANGED
@@ -15,10 +15,14 @@ from modules.whisper_parameter import *
15
 
16
  class WhisperBase(ABC):
17
  def __init__(self,
18
- model_dir: str):
 
 
19
  self.model = None
20
  self.current_model_size = None
21
  self.model_dir = model_dir
 
 
22
  os.makedirs(self.model_dir, exist_ok=True)
23
  self.available_models = whisper.available_models()
24
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
@@ -88,7 +92,8 @@ class WhisperBase(ABC):
88
  file_name=file_name,
89
  transcribed_segments=transcribed_segments,
90
  add_timestamp=add_timestamp,
91
- file_format=file_format
 
92
  )
93
  files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path}
94
 
@@ -152,7 +157,8 @@ class WhisperBase(ABC):
152
  file_name="Mic",
153
  transcribed_segments=transcribed_segments,
154
  add_timestamp=True,
155
- file_format=file_format
 
156
  )
157
 
158
  result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
@@ -211,7 +217,8 @@ class WhisperBase(ABC):
211
  file_name=file_name,
212
  transcribed_segments=transcribed_segments,
213
  add_timestamp=add_timestamp,
214
- file_format=file_format
 
215
  )
216
  result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
217
 
@@ -237,6 +244,7 @@ class WhisperBase(ABC):
237
  transcribed_segments: list,
238
  add_timestamp: bool,
239
  file_format: str,
 
240
  ) -> str:
241
  """
242
  Writes subtitle file
@@ -251,6 +259,8 @@ class WhisperBase(ABC):
251
  Determines whether to add a timestamp to the end of the filename.
252
  file_format: str
253
  File format to write. Supported formats: [SRT, WebVTT, txt]
 
 
254
 
255
  Returns
256
  ----------
@@ -261,9 +271,9 @@ class WhisperBase(ABC):
261
  """
262
  timestamp = datetime.now().strftime("%m%d%H%M%S")
263
  if add_timestamp:
264
- output_path = os.path.join("outputs", f"{file_name}-{timestamp}")
265
  else:
266
- output_path = os.path.join("outputs", f"{file_name}")
267
 
268
  if file_format == "SRT":
269
  content = get_srt(transcribed_segments)
 
15
 
16
  class WhisperBase(ABC):
17
  def __init__(self,
18
+ model_dir: str,
19
+ output_dir: str
20
+ ):
21
  self.model = None
22
  self.current_model_size = None
23
  self.model_dir = model_dir
24
+ self.output_dir = output_dir
25
+ os.makedirs(self.output_dir, exist_ok=True)
26
  os.makedirs(self.model_dir, exist_ok=True)
27
  self.available_models = whisper.available_models()
28
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
 
92
  file_name=file_name,
93
  transcribed_segments=transcribed_segments,
94
  add_timestamp=add_timestamp,
95
+ file_format=file_format,
96
+ output_dir=self.output_dir
97
  )
98
  files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path}
99
 
 
157
  file_name="Mic",
158
  transcribed_segments=transcribed_segments,
159
  add_timestamp=True,
160
+ file_format=file_format,
161
+ output_dir=self.output_dir
162
  )
163
 
164
  result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
 
217
  file_name=file_name,
218
  transcribed_segments=transcribed_segments,
219
  add_timestamp=add_timestamp,
220
+ file_format=file_format,
221
+ output_dir=self.output_dir
222
  )
223
  result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
224
 
 
244
  transcribed_segments: list,
245
  add_timestamp: bool,
246
  file_format: str,
247
+ output_dir: str
248
  ) -> str:
249
  """
250
  Writes subtitle file
 
259
  Determines whether to add a timestamp to the end of the filename.
260
  file_format: str
261
  File format to write. Supported formats: [SRT, WebVTT, txt]
262
+ output_dir: str
263
+ Directory path of the output
264
 
265
  Returns
266
  ----------
 
271
  """
272
  timestamp = datetime.now().strftime("%m%d%H%M%S")
273
  if add_timestamp:
274
+ output_path = os.path.join(output_dir, f"{file_name}-{timestamp}")
275
  else:
276
+ output_path = os.path.join(output_dir, f"{file_name}")
277
 
278
  if file_format == "SRT":
279
  content = get_srt(transcribed_segments)