jhj0517 commited on
Commit
5633565
·
1 Parent(s): b2bb752

add output_dir arg

Browse files
app.py CHANGED
@@ -27,25 +27,25 @@ class App:
27
 
28
  if whisper_type in ["faster_whisper", "faster-whisper", "fasterwhisper"]:
29
  whisper_inf = FasterWhisperInference(
30
- model_dir=self.args.faster_whisper_model_dir
 
31
  )
32
- whisper_inf.model_dir = self.args.faster_whisper_model_dir
33
  elif whisper_type in ["whisper"]:
34
  whisper_inf = WhisperInference(
35
- model_dir=self.args.whisper_model_dir
 
36
  )
37
- whisper_inf.model_dir = self.args.whisper_model_dir
38
  elif whisper_type in ["insanely_fast_whisper", "insanely-fast-whisper", "insanelyfastwhisper",
39
  "insanely_faster_whisper", "insanely-faster-whisper", "insanelyfasterwhisper"]:
40
  whisper_inf = InsanelyFastWhisperInference(
41
- model_dir=self.args.insanely_fast_whisper_model_dir
 
42
  )
43
- whisper_inf.model_dir = self.args.insanely_fast_whisper_model_dir
44
  else:
45
  whisper_inf = FasterWhisperInference(
46
- model_dir=self.args.faster_whisper_model_dir
 
47
  )
48
- whisper_inf.model_dir = self.args.faster_whisper_model_dir
49
  return whisper_inf
50
 
51
  @staticmethod
@@ -387,6 +387,7 @@ parser.add_argument('--api_open', type=bool, default=False, nargs='?', const=Tru
387
  parser.add_argument('--whisper_model_dir', type=str, default=os.path.join("models", "Whisper"), help='Directory path of the whisper model')
388
  parser.add_argument('--faster_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "faster-whisper"), help='Directory path of the faster-whisper model')
389
  parser.add_argument('--insanely_fast_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "insanely-fast-whisper"), help='Directory path of the insanely-fast-whisper model')
 
390
  _args = parser.parse_args()
391
 
392
  if __name__ == "__main__":
 
27
 
28
  if whisper_type in ["faster_whisper", "faster-whisper", "fasterwhisper"]:
29
  whisper_inf = FasterWhisperInference(
30
+ model_dir=self.args.faster_whisper_model_dir,
31
+ output_dir=self.args.output_dir
32
  )
 
33
  elif whisper_type in ["whisper"]:
34
  whisper_inf = WhisperInference(
35
+ model_dir=self.args.whisper_model_dir,
36
+ output_dir=self.args.output_dir
37
  )
 
38
  elif whisper_type in ["insanely_fast_whisper", "insanely-fast-whisper", "insanelyfastwhisper",
39
  "insanely_faster_whisper", "insanely-faster-whisper", "insanelyfasterwhisper"]:
40
  whisper_inf = InsanelyFastWhisperInference(
41
+ model_dir=self.args.insanely_fast_whisper_model_dir,
42
+ output_dir=self.args.output_dir
43
  )
 
44
  else:
45
  whisper_inf = FasterWhisperInference(
46
+ model_dir=self.args.faster_whisper_model_dir,
47
+ output_dir=self.args.output_dir
48
  )
 
49
  return whisper_inf
50
 
51
  @staticmethod
 
387
  parser.add_argument('--whisper_model_dir', type=str, default=os.path.join("models", "Whisper"), help='Directory path of the whisper model')
388
  parser.add_argument('--faster_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "faster-whisper"), help='Directory path of the faster-whisper model')
389
  parser.add_argument('--insanely_fast_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "insanely-fast-whisper"), help='Directory path of the insanely-fast-whisper model')
390
+ parser.add_argument('--output_dir', type=str, default=os.path.join("outputs"), help='Directory path of the outputs')
391
  _args = parser.parse_args()
392
 
393
  if __name__ == "__main__":
modules/faster_whisper_inference.py CHANGED
@@ -18,10 +18,12 @@ os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
18
 
19
  class FasterWhisperInference(WhisperBase):
20
  def __init__(self,
21
- model_dir: str
 
22
  ):
23
  super().__init__(
24
- model_dir=model_dir
 
25
  )
26
  self.model_paths = self.get_model_paths()
27
  self.available_models = self.model_paths.keys()
 
18
 
19
  class FasterWhisperInference(WhisperBase):
20
  def __init__(self,
21
+ model_dir: str,
22
+ output_dir: str
23
  ):
24
  super().__init__(
25
+ model_dir=model_dir,
26
+ output_dir=output_dir
27
  )
28
  self.model_paths = self.get_model_paths()
29
  self.available_models = self.model_paths.keys()
modules/insanely_fast_whisper_inference.py CHANGED
@@ -16,10 +16,12 @@ from modules.whisper_base import WhisperBase
16
 
17
  class InsanelyFastWhisperInference(WhisperBase):
18
  def __init__(self,
19
- model_dir: str
 
20
  ):
21
  super().__init__(
22
- model_dir=model_dir
 
23
  )
24
  openai_models = whisper.available_models()
25
  distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
 
16
 
17
  class InsanelyFastWhisperInference(WhisperBase):
18
  def __init__(self,
19
+ model_dir: str,
20
+ output_dir: str
21
  ):
22
  super().__init__(
23
+ model_dir=model_dir,
24
+ output_dir=output_dir
25
  )
26
  openai_models = whisper.available_models()
27
  distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
modules/whisper_Inference.py CHANGED
@@ -12,10 +12,12 @@ from modules.whisper_parameter import *
12
 
13
  class WhisperInference(WhisperBase):
14
  def __init__(self,
15
- model_dir: str
 
16
  ):
17
  super().__init__(
18
- model_dir=model_dir
 
19
  )
20
 
21
  def transcribe(self,
 
12
 
13
  class WhisperInference(WhisperBase):
14
  def __init__(self,
15
+ model_dir: str,
16
+ output_dir: str
17
  ):
18
  super().__init__(
19
+ model_dir=model_dir,
20
+ output_dir=output_dir
21
  )
22
 
23
  def transcribe(self,
modules/whisper_base.py CHANGED
@@ -15,10 +15,14 @@ from modules.whisper_parameter import *
15
 
16
  class WhisperBase(ABC):
17
  def __init__(self,
18
- model_dir: str):
 
 
19
  self.model = None
20
  self.current_model_size = None
21
  self.model_dir = model_dir
 
 
22
  os.makedirs(self.model_dir, exist_ok=True)
23
  self.available_models = whisper.available_models()
24
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
@@ -88,7 +92,8 @@ class WhisperBase(ABC):
88
  file_name=file_name,
89
  transcribed_segments=transcribed_segments,
90
  add_timestamp=add_timestamp,
91
- file_format=file_format
 
92
  )
93
  files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path}
94
 
@@ -152,7 +157,8 @@ class WhisperBase(ABC):
152
  file_name="Mic",
153
  transcribed_segments=transcribed_segments,
154
  add_timestamp=True,
155
- file_format=file_format
 
156
  )
157
 
158
  result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
@@ -211,7 +217,8 @@ class WhisperBase(ABC):
211
  file_name=file_name,
212
  transcribed_segments=transcribed_segments,
213
  add_timestamp=add_timestamp,
214
- file_format=file_format
 
215
  )
216
  result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
217
 
@@ -237,6 +244,7 @@ class WhisperBase(ABC):
237
  transcribed_segments: list,
238
  add_timestamp: bool,
239
  file_format: str,
 
240
  ) -> str:
241
  """
242
  Writes subtitle file
@@ -251,6 +259,8 @@ class WhisperBase(ABC):
251
  Determines whether to add a timestamp to the end of the filename.
252
  file_format: str
253
  File format to write. Supported formats: [SRT, WebVTT, txt]
 
 
254
 
255
  Returns
256
  ----------
@@ -261,9 +271,9 @@ class WhisperBase(ABC):
261
  """
262
  timestamp = datetime.now().strftime("%m%d%H%M%S")
263
  if add_timestamp:
264
- output_path = os.path.join("outputs", f"{file_name}-{timestamp}")
265
  else:
266
- output_path = os.path.join("outputs", f"{file_name}")
267
 
268
  if file_format == "SRT":
269
  content = get_srt(transcribed_segments)
 
15
 
16
  class WhisperBase(ABC):
17
  def __init__(self,
18
+ model_dir: str,
19
+ output_dir: str
20
+ ):
21
  self.model = None
22
  self.current_model_size = None
23
  self.model_dir = model_dir
24
+ self.output_dir = output_dir
25
+ os.makedirs(self.output_dir, exist_ok=True)
26
  os.makedirs(self.model_dir, exist_ok=True)
27
  self.available_models = whisper.available_models()
28
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
 
92
  file_name=file_name,
93
  transcribed_segments=transcribed_segments,
94
  add_timestamp=add_timestamp,
95
+ file_format=file_format,
96
+ output_dir=self.output_dir
97
  )
98
  files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path}
99
 
 
157
  file_name="Mic",
158
  transcribed_segments=transcribed_segments,
159
  add_timestamp=True,
160
+ file_format=file_format,
161
+ output_dir=self.output_dir
162
  )
163
 
164
  result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
 
217
  file_name=file_name,
218
  transcribed_segments=transcribed_segments,
219
  add_timestamp=add_timestamp,
220
+ file_format=file_format,
221
+ output_dir=self.output_dir
222
  )
223
  result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
224
 
 
244
  transcribed_segments: list,
245
  add_timestamp: bool,
246
  file_format: str,
247
+ output_dir: str
248
  ) -> str:
249
  """
250
  Writes subtitle file
 
259
  Determines whether to add a timestamp to the end of the filename.
260
  file_format: str
261
  File format to write. Supported formats: [SRT, WebVTT, txt]
262
+ output_dir: str
263
+ Directory path of the output
264
 
265
  Returns
266
  ----------
 
271
  """
272
  timestamp = datetime.now().strftime("%m%d%H%M%S")
273
  if add_timestamp:
274
+ output_path = os.path.join(output_dir, f"{file_name}-{timestamp}")
275
  else:
276
+ output_path = os.path.join(output_dir, f"{file_name}")
277
 
278
  if file_format == "SRT":
279
  content = get_srt(transcribed_segments)