jhj0517 commited on
Commit
4b52dfd
·
1 Parent(s): a526073

refactor docstring

Browse files
Files changed (1) hide show
  1. modules/faster_whisper_inference.py +68 -24
modules/faster_whisper_inference.py CHANGED
@@ -32,7 +32,7 @@ class FasterWhisperInference(BaseInterface):
32
  self.default_beam_size = 1
33
 
34
  def transcribe_file(self,
35
- fileobjs: list,
36
  file_format: str,
37
  add_timestamp: bool,
38
  progress=gr.Progress(),
@@ -43,7 +43,7 @@ class FasterWhisperInference(BaseInterface):
43
 
44
  Parameters
45
  ----------
46
- fileobjs: list
47
  List of files to transcribe from gr.Files()
48
  file_format: str
49
  Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
@@ -56,20 +56,21 @@ class FasterWhisperInference(BaseInterface):
56
 
57
  Returns
58
  ----------
59
- A List of
60
- String to return to gr.Textbox()
61
- Files to return to gr.Files()
 
62
  """
63
  try:
64
  files_info = {}
65
- for fileobj in fileobjs:
66
  transcribed_segments, time_for_task = self.transcribe(
67
- fileobj.name,
68
  progress,
69
  *whisper_params,
70
  )
71
 
72
- file_name, file_ext = os.path.splitext(os.path.basename(fileobj.name))
73
  file_name = safe_filename(file_name)
74
  subtitle, file_path = self.generate_and_write_file(
75
  file_name=file_name,
@@ -96,8 +97,8 @@ class FasterWhisperInference(BaseInterface):
96
  print(f"Error transcribing file on line {e}")
97
  finally:
98
  self.release_cuda_memory()
99
- if not fileobjs:
100
- self.remove_input_files([fileobj.name for fileobj in fileobjs])
101
 
102
  def transcribe_youtube(self,
103
  youtube_link: str,
@@ -124,9 +125,10 @@ class FasterWhisperInference(BaseInterface):
124
 
125
  Returns
126
  ----------
127
- A List of
128
- String to return to gr.Textbox()
129
- Files to return to gr.Files()
 
130
  """
131
  try:
132
  progress(0, desc="Loading Audio from Youtube..")
@@ -142,15 +144,15 @@ class FasterWhisperInference(BaseInterface):
142
  progress(1, desc="Completed!")
143
 
144
  file_name = safe_filename(yt.title)
145
- subtitle, file_path = self.generate_and_write_file(
146
  file_name=file_name,
147
  transcribed_segments=transcribed_segments,
148
  add_timestamp=add_timestamp,
149
  file_format=file_format
150
  )
151
- gr_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
152
 
153
- return [gr_str, file_path]
154
 
155
  except Exception as e:
156
  print(f"Error transcribing file on line {e}")
@@ -189,9 +191,10 @@ class FasterWhisperInference(BaseInterface):
189
 
190
  Returns
191
  ----------
192
- A List of
193
- String to return to gr.Textbox()
194
- Files to return to gr.Files()
 
195
  """
196
  try:
197
  progress(0, desc="Loading Audio..")
@@ -202,15 +205,15 @@ class FasterWhisperInference(BaseInterface):
202
  )
203
  progress(1, desc="Completed!")
204
 
205
- subtitle, file_path = self.generate_and_write_file(
206
  file_name="Mic",
207
  transcribed_segments=transcribed_segments,
208
  add_timestamp=True,
209
  file_format=file_format
210
  )
211
 
212
- gr_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
213
- return [gr_str, file_path]
214
  except Exception as e:
215
  print(f"Error transcribing file on line {e}")
216
  finally:
@@ -282,7 +285,17 @@ class FasterWhisperInference(BaseInterface):
282
  progress: gr.Progress
283
  ):
284
  """
285
- update current model setting
 
 
 
 
 
 
 
 
 
 
286
  """
287
  progress(0, desc="Initializing Model..")
288
  self.current_model_size = model_size
@@ -301,7 +314,26 @@ class FasterWhisperInference(BaseInterface):
301
  file_format: str,
302
  ) -> str:
303
  """
304
- This method writes subtitle file and returns str to gr.Textbox
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
  """
306
  timestamp = datetime.now().strftime("%m%d%H%M%S")
307
  if add_timestamp:
@@ -327,6 +359,18 @@ class FasterWhisperInference(BaseInterface):
327
 
328
  @staticmethod
329
  def format_time(elapsed_time: float) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
330
  hours, rem = divmod(elapsed_time, 3600)
331
  minutes, seconds = divmod(rem, 60)
332
 
 
32
  self.default_beam_size = 1
33
 
34
  def transcribe_file(self,
35
+ files: list,
36
  file_format: str,
37
  add_timestamp: bool,
38
  progress=gr.Progress(),
 
43
 
44
  Parameters
45
  ----------
46
+ files: list
47
  List of files to transcribe from gr.Files()
48
  file_format: str
49
  Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
 
56
 
57
  Returns
58
  ----------
59
+ result_str:
60
+ Result of transcription to return to gr.Textbox()
61
+ result_file_path:
62
+ Output file path to return to gr.Files()
63
  """
64
  try:
65
  files_info = {}
66
+ for file in files:
67
  transcribed_segments, time_for_task = self.transcribe(
68
+ file.name,
69
  progress,
70
  *whisper_params,
71
  )
72
 
73
+ file_name, file_ext = os.path.splitext(os.path.basename(file.name))
74
  file_name = safe_filename(file_name)
75
  subtitle, file_path = self.generate_and_write_file(
76
  file_name=file_name,
 
97
  print(f"Error transcribing file on line {e}")
98
  finally:
99
  self.release_cuda_memory()
100
+ if not files:
101
+ self.remove_input_files([file.name for file in files])
102
 
103
  def transcribe_youtube(self,
104
  youtube_link: str,
 
125
 
126
  Returns
127
  ----------
128
+ result_str:
129
+ Result of transcription to return to gr.Textbox()
130
+ result_file_path:
131
+ Output file path to return to gr.Files()
132
  """
133
  try:
134
  progress(0, desc="Loading Audio from Youtube..")
 
144
  progress(1, desc="Completed!")
145
 
146
  file_name = safe_filename(yt.title)
147
+ subtitle, result_file_path = self.generate_and_write_file(
148
  file_name=file_name,
149
  transcribed_segments=transcribed_segments,
150
  add_timestamp=add_timestamp,
151
  file_format=file_format
152
  )
153
+ result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
154
 
155
+ return [result_str, result_file_path]
156
 
157
  except Exception as e:
158
  print(f"Error transcribing file on line {e}")
 
191
 
192
  Returns
193
  ----------
194
+ result_str:
195
+ Result of transcription to return to gr.Textbox()
196
+ result_file_path:
197
+ Output file path to return to gr.Files()
198
  """
199
  try:
200
  progress(0, desc="Loading Audio..")
 
205
  )
206
  progress(1, desc="Completed!")
207
 
208
+ subtitle, result_file_path = self.generate_and_write_file(
209
  file_name="Mic",
210
  transcribed_segments=transcribed_segments,
211
  add_timestamp=True,
212
  file_format=file_format
213
  )
214
 
215
+ result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
216
+ return [result_str, result_file_path]
217
  except Exception as e:
218
  print(f"Error transcribing file on line {e}")
219
  finally:
 
285
  progress: gr.Progress
286
  ):
287
  """
288
+ Update current model setting
289
+
290
+ Parameters
291
+ ----------
292
+ model_size: str
293
+ Size of whisper model
294
+ compute_type: str
295
+ Compute type for transcription.
296
+ see more info : https://opennmt.net/CTranslate2/quantization.html
297
+ progress: gr.Progress
298
+ Indicator to show progress directly in gradio.
299
  """
300
  progress(0, desc="Initializing Model..")
301
  self.current_model_size = model_size
 
314
  file_format: str,
315
  ) -> str:
316
  """
317
+ Writes subtitle file and returns str of content and output file path
318
+
319
+ Parameters
320
+ ----------
321
+ file_name: str
322
+ Size of whisper model
323
+ transcribed_segments: str
324
+ Compute type for transcription.
325
+ see more info : https://opennmt.net/CTranslate2/quantization.html
326
+ add_timestamp: bool
327
+ Determines whether to add a timestamp to the end of the filename.
328
+ file_format: str
329
+ File format to write. Supported formats: [SRT, WebVTT, txt]
330
+
331
+ Returns
332
+ ----------
333
+ content: str
334
+ Result of the transcription
335
+ output_path: str
336
+ output file path
337
  """
338
  timestamp = datetime.now().strftime("%m%d%H%M%S")
339
  if add_timestamp:
 
359
 
360
  @staticmethod
361
  def format_time(elapsed_time: float) -> str:
362
+ """
363
+ Get {hours} {minutes} {seconds} time format string
364
+
365
+ Parameters
366
+ ----------
367
+ elapsed_time: str
368
+ Elapsed time for transcription
369
+
370
+ Returns
371
+ ----------
372
+ Time format string
373
+ """
374
  hours, rem = divmod(elapsed_time, 3600)
375
  minutes, seconds = divmod(rem, 60)
376