jhj0517 commited on
Commit
29cce95
·
unverified ·
2 Parent(s): 6a7425a aa11c47

Merge pull request #290 from jhj0517/fix/defaults

Browse files
modules/translation/nllb_inference.py CHANGED
@@ -35,7 +35,7 @@ class NLLBInference(TranslationBase):
35
  model_size: str,
36
  src_lang: str,
37
  tgt_lang: str,
38
- progress: gr.Progress
39
  ):
40
  if model_size != self.current_model_size or self.model is None:
41
  print("\nInitializing NLLB Model..\n")
 
35
  model_size: str,
36
  src_lang: str,
37
  tgt_lang: str,
38
+ progress: gr.Progress = gr.Progress()
39
  ):
40
  if model_size != self.current_model_size or self.model is None:
41
  print("\nInitializing NLLB Model..\n")
modules/translation/translation_base.py CHANGED
@@ -37,7 +37,7 @@ class TranslationBase(ABC):
37
  model_size: str,
38
  src_lang: str,
39
  tgt_lang: str,
40
- progress: gr.Progress
41
  ):
42
  pass
43
 
 
37
  model_size: str,
38
  src_lang: str,
39
  tgt_lang: str,
40
+ progress: gr.Progress = gr.Progress()
41
  ):
42
  pass
43
 
modules/whisper/faster_whisper_inference.py CHANGED
@@ -40,7 +40,7 @@ class FasterWhisperInference(WhisperBase):
40
 
41
  def transcribe(self,
42
  audio: Union[str, BinaryIO, np.ndarray],
43
- progress: gr.Progress,
44
  *whisper_params,
45
  ) -> Tuple[List[dict], float]:
46
  """
@@ -126,7 +126,7 @@ class FasterWhisperInference(WhisperBase):
126
  def update_model(self,
127
  model_size: str,
128
  compute_type: str,
129
- progress: gr.Progress
130
  ):
131
  """
132
  Update current model setting
 
40
 
41
  def transcribe(self,
42
  audio: Union[str, BinaryIO, np.ndarray],
43
+ progress: gr.Progress = gr.Progress(),
44
  *whisper_params,
45
  ) -> Tuple[List[dict], float]:
46
  """
 
126
  def update_model(self,
127
  model_size: str,
128
  compute_type: str,
129
+ progress: gr.Progress = gr.Progress()
130
  ):
131
  """
132
  Update current model setting
modules/whisper/insanely_fast_whisper_inference.py CHANGED
@@ -39,7 +39,7 @@ class InsanelyFastWhisperInference(WhisperBase):
39
 
40
  def transcribe(self,
41
  audio: Union[str, np.ndarray, torch.Tensor],
42
- progress: gr.Progress,
43
  *whisper_params,
44
  ) -> Tuple[List[dict], float]:
45
  """
@@ -98,7 +98,7 @@ class InsanelyFastWhisperInference(WhisperBase):
98
  def update_model(self,
99
  model_size: str,
100
  compute_type: str,
101
- progress: gr.Progress,
102
  ):
103
  """
104
  Update current model setting
 
39
 
40
  def transcribe(self,
41
  audio: Union[str, np.ndarray, torch.Tensor],
42
+ progress: gr.Progress = gr.Progress(),
43
  *whisper_params,
44
  ) -> Tuple[List[dict], float]:
45
  """
 
98
  def update_model(self,
99
  model_size: str,
100
  compute_type: str,
101
+ progress: gr.Progress = gr.Progress(),
102
  ):
103
  """
104
  Update current model setting
modules/whisper/whisper_Inference.py CHANGED
@@ -28,7 +28,7 @@ class WhisperInference(WhisperBase):
28
 
29
  def transcribe(self,
30
  audio: Union[str, np.ndarray, torch.Tensor],
31
- progress: gr.Progress,
32
  *whisper_params,
33
  ) -> Tuple[List[dict], float]:
34
  """
@@ -79,7 +79,7 @@ class WhisperInference(WhisperBase):
79
  def update_model(self,
80
  model_size: str,
81
  compute_type: str,
82
- progress: gr.Progress,
83
  ):
84
  """
85
  Update current model setting
 
28
 
29
  def transcribe(self,
30
  audio: Union[str, np.ndarray, torch.Tensor],
31
+ progress: gr.Progress = gr.Progress(),
32
  *whisper_params,
33
  ) -> Tuple[List[dict], float]:
34
  """
 
79
  def update_model(self,
80
  model_size: str,
81
  compute_type: str,
82
+ progress: gr.Progress = gr.Progress(),
83
  ):
84
  """
85
  Update current model setting
modules/whisper/whisper_base.py CHANGED
@@ -53,7 +53,7 @@ class WhisperBase(ABC):
53
  @abstractmethod
54
  def transcribe(self,
55
  audio: Union[str, BinaryIO, np.ndarray],
56
- progress: gr.Progress,
57
  *whisper_params,
58
  ):
59
  """Inference whisper model to transcribe"""
@@ -63,7 +63,7 @@ class WhisperBase(ABC):
63
  def update_model(self,
64
  model_size: str,
65
  compute_type: str,
66
- progress: gr.Progress
67
  ):
68
  """Initialize whisper model"""
69
  pass
@@ -171,10 +171,10 @@ class WhisperBase(ABC):
171
  return result, elapsed_time
172
 
173
  def transcribe_file(self,
174
- files: list,
175
- input_folder_path: str,
176
- file_format: str,
177
- add_timestamp: bool,
178
  progress=gr.Progress(),
179
  *whisper_params,
180
  ) -> list:
@@ -250,8 +250,8 @@ class WhisperBase(ABC):
250
 
251
  def transcribe_mic(self,
252
  mic_audio: str,
253
- file_format: str,
254
- add_timestamp: bool,
255
  progress=gr.Progress(),
256
  *whisper_params,
257
  ) -> list:
@@ -306,8 +306,8 @@ class WhisperBase(ABC):
306
 
307
  def transcribe_youtube(self,
308
  youtube_link: str,
309
- file_format: str,
310
- add_timestamp: bool,
311
  progress=gr.Progress(),
312
  *whisper_params,
313
  ) -> list:
@@ -411,11 +411,12 @@ class WhisperBase(ABC):
411
  else:
412
  output_path = os.path.join(output_dir, f"{file_name}")
413
 
414
- if file_format == "SRT":
 
415
  content = get_srt(transcribed_segments)
416
  output_path += '.srt'
417
 
418
- elif file_format == "WebVTT":
419
  content = get_vtt(transcribed_segments)
420
  output_path += '.vtt'
421
 
 
53
  @abstractmethod
54
  def transcribe(self,
55
  audio: Union[str, BinaryIO, np.ndarray],
56
+ progress: gr.Progress = gr.Progress(),
57
  *whisper_params,
58
  ):
59
  """Inference whisper model to transcribe"""
 
63
  def update_model(self,
64
  model_size: str,
65
  compute_type: str,
66
+ progress: gr.Progress = gr.Progress()
67
  ):
68
  """Initialize whisper model"""
69
  pass
 
171
  return result, elapsed_time
172
 
173
  def transcribe_file(self,
174
+ files: Optional[List] = None,
175
+ input_folder_path: Optional[str] = None,
176
+ file_format: str = "SRT",
177
+ add_timestamp: bool = True,
178
  progress=gr.Progress(),
179
  *whisper_params,
180
  ) -> list:
 
250
 
251
  def transcribe_mic(self,
252
  mic_audio: str,
253
+ file_format: str = "SRT",
254
+ add_timestamp: bool = True,
255
  progress=gr.Progress(),
256
  *whisper_params,
257
  ) -> list:
 
306
 
307
  def transcribe_youtube(self,
308
  youtube_link: str,
309
+ file_format: str = "SRT",
310
+ add_timestamp: bool = True,
311
  progress=gr.Progress(),
312
  *whisper_params,
313
  ) -> list:
 
411
  else:
412
  output_path = os.path.join(output_dir, f"{file_name}")
413
 
414
+ file_format = file_format.strip().lower()
415
+ if file_format == "srt":
416
  content = get_srt(transcribed_segments)
417
  output_path += '.srt'
418
 
419
+ elif file_format == "webvtt":
420
  content = get_vtt(transcribed_segments)
421
  output_path += '.vtt'
422