jhj0517 commited on
Commit
a85dc3c
·
1 Parent(s): 6a7425a

Add defaults to the function

Browse files
Files changed (1) hide show
  1. modules/whisper/whisper_base.py +13 -12
modules/whisper/whisper_base.py CHANGED
@@ -53,7 +53,7 @@ class WhisperBase(ABC):
53
  @abstractmethod
54
  def transcribe(self,
55
  audio: Union[str, BinaryIO, np.ndarray],
56
- progress: gr.Progress,
57
  *whisper_params,
58
  ):
59
  """Inference whisper model to transcribe"""
@@ -63,7 +63,7 @@ class WhisperBase(ABC):
63
  def update_model(self,
64
  model_size: str,
65
  compute_type: str,
66
- progress: gr.Progress
67
  ):
68
  """Initialize whisper model"""
69
  pass
@@ -171,10 +171,10 @@ class WhisperBase(ABC):
171
  return result, elapsed_time
172
 
173
  def transcribe_file(self,
174
- files: list,
175
- input_folder_path: str,
176
- file_format: str,
177
- add_timestamp: bool,
178
  progress=gr.Progress(),
179
  *whisper_params,
180
  ) -> list:
@@ -250,8 +250,8 @@ class WhisperBase(ABC):
250
 
251
  def transcribe_mic(self,
252
  mic_audio: str,
253
- file_format: str,
254
- add_timestamp: bool,
255
  progress=gr.Progress(),
256
  *whisper_params,
257
  ) -> list:
@@ -306,8 +306,8 @@ class WhisperBase(ABC):
306
 
307
  def transcribe_youtube(self,
308
  youtube_link: str,
309
- file_format: str,
310
- add_timestamp: bool,
311
  progress=gr.Progress(),
312
  *whisper_params,
313
  ) -> list:
@@ -411,11 +411,12 @@ class WhisperBase(ABC):
411
  else:
412
  output_path = os.path.join(output_dir, f"{file_name}")
413
 
414
- if file_format == "SRT":
 
415
  content = get_srt(transcribed_segments)
416
  output_path += '.srt'
417
 
418
- elif file_format == "WebVTT":
419
  content = get_vtt(transcribed_segments)
420
  output_path += '.vtt'
421
 
 
53
  @abstractmethod
54
  def transcribe(self,
55
  audio: Union[str, BinaryIO, np.ndarray],
56
+ progress: gr.Progress = gr.Progress(),
57
  *whisper_params,
58
  ):
59
  """Inference whisper model to transcribe"""
 
63
  def update_model(self,
64
  model_size: str,
65
  compute_type: str,
66
+ progress: gr.Progress = gr.Progress()
67
  ):
68
  """Initialize whisper model"""
69
  pass
 
171
  return result, elapsed_time
172
 
173
  def transcribe_file(self,
174
+ files: Optional[List] = None,
175
+ input_folder_path: Optional[str] = None,
176
+ file_format: str = "SRT",
177
+ add_timestamp: bool = True,
178
  progress=gr.Progress(),
179
  *whisper_params,
180
  ) -> list:
 
250
 
251
  def transcribe_mic(self,
252
  mic_audio: str,
253
+ file_format: str = "SRT",
254
+ add_timestamp: bool = True,
255
  progress=gr.Progress(),
256
  *whisper_params,
257
  ) -> list:
 
306
 
307
  def transcribe_youtube(self,
308
  youtube_link: str,
309
+ file_format: str = "SRT",
310
+ add_timestamp: bool = True,
311
  progress=gr.Progress(),
312
  *whisper_params,
313
  ) -> list:
 
411
  else:
412
  output_path = os.path.join(output_dir, f"{file_name}")
413
 
414
+ file_format = file_format.strip().lower()
415
+ if file_format == "srt":
416
  content = get_srt(transcribed_segments)
417
  output_path += '.srt'
418
 
419
+ elif file_format == "webvtt":
420
  content = get_vtt(transcribed_segments)
421
  output_path += '.vtt'
422