jhj0517 commited on
Commit
9f69aa4
·
1 Parent(s): 7644f39

refactoring to use data class

Browse files
Files changed (1) hide show
  1. modules/whisper_Inference.py +142 -202
modules/whisper_Inference.py CHANGED
@@ -10,6 +10,7 @@ import torch
10
  from .base_interface import BaseInterface
11
  from modules.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
12
  from modules.youtube_manager import get_ytdata, get_ytaudio
 
13
 
14
  DEFAULT_MODEL_SIZE = "large-v3"
15
 
@@ -21,82 +22,54 @@ class WhisperInference(BaseInterface):
21
  self.model = None
22
  self.available_models = whisper.available_models()
23
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
 
24
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
25
  self.available_compute_types = ["float16", "float32"]
26
  self.current_compute_type = "float16" if self.device == "cuda" else "float32"
27
  self.default_beam_size = 1
28
 
29
  def transcribe_file(self,
30
- fileobjs: list,
31
- model_size: str,
32
- lang: str,
33
  file_format: str,
34
- istranslate: bool,
35
  add_timestamp: bool,
36
- beam_size: int,
37
- log_prob_threshold: float,
38
- no_speech_threshold: float,
39
- compute_type: str,
40
- progress=gr.Progress()) -> list:
41
  """
42
  Write subtitle file from Files
43
 
44
  Parameters
45
  ----------
46
- fileobjs: list
47
  List of files to transcribe from gr.Files()
48
- model_size: str
49
- Whisper model size from gr.Dropdown()
50
- lang: str
51
- Source language of the file to transcribe from gr.Dropdown()
52
  file_format: str
53
- File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
54
- istranslate: bool
55
- Boolean value from gr.Checkbox() that determines whether to translate to English.
56
- It's Whisper's feature to translate speech from another language directly into English end-to-end.
57
  add_timestamp: bool
58
- Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
59
- beam_size: int
60
- Int value from gr.Number() that is used for decoding option.
61
- log_prob_threshold: float
62
- float value from gr.Number(). If the average log probability over sampled tokens is
63
- below this value, treat as failed.
64
- no_speech_threshold: float
65
- float value from gr.Number(). If the no_speech probability is higher than this value AND
66
- the average log probability over sampled tokens is below `log_prob_threshold`,
67
- consider the segment as silent.
68
- compute_type: str
69
- compute type from gr.Dropdown().
70
  progress: gr.Progress
71
  Indicator to show progress directly in gradio.
72
- I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
 
73
 
74
  Returns
75
  ----------
76
- A List of
77
- String to return to gr.Textbox()
78
- Files to return to gr.Files()
 
79
  """
80
  try:
81
- self.update_model_if_needed(model_size=model_size, compute_type=compute_type, progress=progress)
82
-
83
  files_info = {}
84
- for fileobj in fileobjs:
85
  progress(0, desc="Loading Audio..")
86
- audio = whisper.load_audio(fileobj.name)
87
-
88
- result, elapsed_time = self.transcribe(audio=audio,
89
- lang=lang,
90
- istranslate=istranslate,
91
- beam_size=beam_size,
92
- log_prob_threshold=log_prob_threshold,
93
- no_speech_threshold=no_speech_threshold,
94
- compute_type=compute_type,
95
- progress=progress
96
- )
97
  progress(1, desc="Completed!")
98
 
99
- file_name, file_ext = os.path.splitext(os.path.basename(fileobj.name))
100
  file_name = safe_filename(file_name)
101
  subtitle, file_path = self.generate_and_write_file(
102
  file_name=file_name,
@@ -104,7 +77,7 @@ class WhisperInference(BaseInterface):
104
  add_timestamp=add_timestamp,
105
  file_format=file_format
106
  )
107
- files_info[file_name] = {"subtitle": subtitle, "elapsed_time": elapsed_time, "path": file_path}
108
 
109
  total_result = ''
110
  total_time = 0
@@ -114,100 +87,71 @@ class WhisperInference(BaseInterface):
114
  total_result += f"{info['subtitle']}"
115
  total_time += info["elapsed_time"]
116
 
117
- gr_str = f"Done in {self.format_time(total_time)}! Subtitle is in the outputs folder.\n\n{total_result}"
118
- gr_file_path = [info['path'] for info in files_info.values()]
119
 
120
- return [gr_str, gr_file_path]
121
  except Exception as e:
122
  print(f"Error transcribing file: {str(e)}")
123
  finally:
124
  self.release_cuda_memory()
125
- self.remove_input_files([fileobj.name for fileobj in fileobjs])
126
 
127
  def transcribe_youtube(self,
128
- youtubelink: str,
129
- model_size: str,
130
- lang: str,
131
  file_format: str,
132
- istranslate: bool,
133
  add_timestamp: bool,
134
- beam_size: int,
135
- log_prob_threshold: float,
136
- no_speech_threshold: float,
137
- compute_type: str,
138
- progress=gr.Progress()) -> list:
139
  """
140
  Write subtitle file from Youtube
141
 
142
  Parameters
143
  ----------
144
- youtubelink: str
145
- Link of Youtube to transcribe from gr.Textbox()
146
- model_size: str
147
- Whisper model size from gr.Dropdown()
148
- lang: str
149
- Source language of the file to transcribe from gr.Dropdown()
150
  file_format: str
151
- File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
152
- istranslate: bool
153
- Boolean value from gr.Checkbox() that determines whether to translate to English.
154
- It's Whisper's feature to translate speech from another language directly into English end-to-end.
155
  add_timestamp: bool
156
  Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
157
- beam_size: int
158
- Int value from gr.Number() that is used for decoding option.
159
- log_prob_threshold: float
160
- float value from gr.Number(). If the average log probability over sampled tokens is
161
- below this value, treat as failed.
162
- no_speech_threshold: float
163
- float value from gr.Number(). If the no_speech probability is higher than this value AND
164
- the average log probability over sampled tokens is below `log_prob_threshold`,
165
- consider the segment as silent.
166
- compute_type: str
167
- compute type from gr.Dropdown().
168
  progress: gr.Progress
169
  Indicator to show progress directly in gradio.
170
- I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
 
171
 
172
  Returns
173
  ----------
174
- A List of
175
- String to return to gr.Textbox()
176
- Files to return to gr.Files()
 
177
  """
178
  try:
179
- self.update_model_if_needed(model_size=model_size, compute_type=compute_type, progress=progress)
180
-
181
  progress(0, desc="Loading Audio from Youtube..")
182
- yt = get_ytdata(youtubelink)
183
  audio = whisper.load_audio(get_ytaudio(yt))
184
 
185
- result, elapsed_time = self.transcribe(audio=audio,
186
- lang=lang,
187
- istranslate=istranslate,
188
- beam_size=beam_size,
189
- log_prob_threshold=log_prob_threshold,
190
- no_speech_threshold=no_speech_threshold,
191
- compute_type=compute_type,
192
- progress=progress)
193
  progress(1, desc="Completed!")
194
 
195
  file_name = safe_filename(yt.title)
196
- subtitle, file_path = self.generate_and_write_file(
197
  file_name=file_name,
198
  transcribed_segments=result,
199
  add_timestamp=add_timestamp,
200
  file_format=file_format
201
  )
202
 
203
- gr_str = f"Done in {self.format_time(elapsed_time)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
204
- return [gr_str, file_path]
205
  except Exception as e:
206
  print(f"Error transcribing youtube video: {str(e)}")
207
  finally:
208
  try:
209
  if 'yt' not in locals():
210
- yt = get_ytdata(youtubelink)
211
  file_path = get_ytaudio(yt)
212
  else:
213
  file_path = get_ytaudio(yt)
@@ -218,116 +162,71 @@ class WhisperInference(BaseInterface):
218
  pass
219
 
220
  def transcribe_mic(self,
221
- micaudio: str,
222
- model_size: str,
223
- lang: str,
224
  file_format: str,
225
- istranslate: bool,
226
- beam_size: int,
227
- log_prob_threshold: float,
228
- no_speech_threshold: float,
229
- compute_type: str,
230
- progress=gr.Progress()) -> list:
231
  """
232
  Write subtitle file from microphone
233
 
234
  Parameters
235
  ----------
236
- micaudio: str
237
  Audio file path from gr.Microphone()
238
- model_size: str
239
- Whisper model size from gr.Dropdown()
240
- lang: str
241
- Source language of the file to transcribe from gr.Dropdown()
242
  file_format: str
243
- Subtitle format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
244
- istranslate: bool
245
- Boolean value from gr.Checkbox() that determines whether to translate to English.
246
- It's Whisper's feature to translate speech from another language directly into English end-to-end.
247
- beam_size: int
248
- Int value from gr.Number() that is used for decoding option.
249
- log_prob_threshold: float
250
- float value from gr.Number(). If the average log probability over sampled tokens is
251
- below this value, treat as failed.
252
- no_speech_threshold: float
253
- float value from gr.Number(). If the no_speech probability is higher than this value AND
254
- the average log probability over sampled tokens is below `log_prob_threshold`,
255
- consider the segment as silent.
256
- compute_type: str
257
- compute type from gr.Dropdown().
258
  progress: gr.Progress
259
  Indicator to show progress directly in gradio.
260
- I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
 
261
 
262
  Returns
263
  ----------
264
- A List of
265
- String to return to gr.Textbox()
266
- Files to return to gr.Files()
 
267
  """
268
  try:
269
- self.update_model_if_needed(model_size=model_size, compute_type=compute_type, progress=progress)
270
-
271
- result, elapsed_time = self.transcribe(audio=micaudio,
272
- lang=lang,
273
- istranslate=istranslate,
274
- beam_size=beam_size,
275
- log_prob_threshold=log_prob_threshold,
276
- no_speech_threshold=no_speech_threshold,
277
- compute_type=compute_type,
278
- progress=progress)
279
  progress(1, desc="Completed!")
280
 
281
- subtitle, file_path = self.generate_and_write_file(
282
  file_name="Mic",
283
  transcribed_segments=result,
284
  add_timestamp=True,
285
  file_format=file_format
286
  )
287
 
288
- gr_str = f"Done in {self.format_time(elapsed_time)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
289
- return [gr_str, file_path]
290
  except Exception as e:
291
  print(f"Error transcribing mic: {str(e)}")
292
  finally:
293
  self.release_cuda_memory()
294
- self.remove_input_files([micaudio])
295
 
296
  def transcribe(self,
297
  audio: Union[str, np.ndarray, torch.Tensor],
298
- lang: str,
299
- istranslate: bool,
300
- beam_size: int,
301
- log_prob_threshold: float,
302
- no_speech_threshold: float,
303
- compute_type: str,
304
- progress: gr.Progress
305
  ) -> Tuple[List[dict], float]:
306
  """
307
- transcribe method for OpenAI's Whisper implementation.
308
 
309
  Parameters
310
  ----------
311
- audio: Union[str, BinaryIO, torch.Tensor]
312
  Audio path or file binary or Audio numpy array
313
- lang: str
314
- Source language of the file to transcribe from gr.Dropdown()
315
- istranslate: bool
316
- Boolean value from gr.Checkbox() that determines whether to translate to English.
317
- It's Whisper's feature to translate speech from another language directly into English end-to-end.
318
- beam_size: int
319
- Int value from gr.Number() that is used for decoding option.
320
- log_prob_threshold: float
321
- float value from gr.Number(). If the average log probability over sampled tokens is
322
- below this value, treat as failed.
323
- no_speech_threshold: float
324
- float value from gr.Number(). If the no_speech probability is higher than this value AND
325
- the average log probability over sampled tokens is below `log_prob_threshold`,
326
- consider the segment as silent.
327
- compute_type: str
328
- compute type from gr.Dropdown().
329
  progress: gr.Progress
330
  Indicator to show progress directly in gradio.
 
 
331
 
332
  Returns
333
  ----------
@@ -337,45 +236,56 @@ class WhisperInference(BaseInterface):
337
  elapsed time for transcription
338
  """
339
  start_time = time.time()
 
 
 
 
 
 
 
340
 
341
  def progress_callback(progress_value):
342
  progress(progress_value, desc="Transcribing..")
343
 
344
- if lang == "Automatic Detection":
345
- lang = None
346
-
347
- translatable_model = ["large", "large-v1", "large-v2", "large-v3"]
348
  segments_result = self.model.transcribe(audio=audio,
349
- language=lang,
350
  verbose=False,
351
- beam_size=beam_size,
352
- logprob_threshold=log_prob_threshold,
353
- no_speech_threshold=no_speech_threshold,
354
- task="translate" if istranslate and self.current_model_size in translatable_model else "transcribe",
355
- fp16=True if compute_type == "float16" else False,
356
  progress_callback=progress_callback)["segments"]
357
  elapsed_time = time.time() - start_time
358
 
359
  return segments_result, elapsed_time
360
 
361
- def update_model_if_needed(self,
362
- model_size: str,
363
- compute_type: str,
364
- progress: gr.Progress,
365
- ):
366
  """
367
- Initialize model if it doesn't match with current model setting
 
 
 
 
 
 
 
 
 
 
368
  """
369
- if compute_type != self.current_compute_type:
370
- self.current_compute_type = compute_type
371
- if model_size != self.current_model_size or self.model is None:
372
- progress(0, desc="Initializing Model..")
373
- self.current_model_size = model_size
374
- self.model = whisper.load_model(
375
- name=model_size,
376
- device=self.device,
377
- download_root=os.path.join("models", "Whisper")
378
- )
379
 
380
  @staticmethod
381
  def generate_and_write_file(file_name: str,
@@ -384,7 +294,25 @@ class WhisperInference(BaseInterface):
384
  file_format: str,
385
  ) -> str:
386
  """
387
- This method writes subtitle file and returns str to gr.Textbox
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
  """
389
  timestamp = datetime.now().strftime("%m%d%H%M%S")
390
  if add_timestamp:
@@ -410,6 +338,18 @@ class WhisperInference(BaseInterface):
410
 
411
  @staticmethod
412
  def format_time(elapsed_time: float) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
413
  hours, rem = divmod(elapsed_time, 3600)
414
  minutes, seconds = divmod(rem, 60)
415
 
 
10
  from .base_interface import BaseInterface
11
  from modules.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
12
  from modules.youtube_manager import get_ytdata, get_ytaudio
13
+ from modules.whisper_data_class import *
14
 
15
  DEFAULT_MODEL_SIZE = "large-v3"
16
 
 
22
  self.model = None
23
  self.available_models = whisper.available_models()
24
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
25
+ self.translatable_model = ["large", "large-v1", "large-v2", "large-v3"]
26
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
27
  self.available_compute_types = ["float16", "float32"]
28
  self.current_compute_type = "float16" if self.device == "cuda" else "float32"
29
  self.default_beam_size = 1
30
 
31
  def transcribe_file(self,
32
+ files: list,
 
 
33
  file_format: str,
 
34
  add_timestamp: bool,
35
+ progress=gr.Progress(),
36
+ *whisper_params
37
+ ) -> list:
 
 
38
  """
39
  Write subtitle file from Files
40
 
41
  Parameters
42
  ----------
43
+ files: list
44
  List of files to transcribe from gr.Files()
 
 
 
 
45
  file_format: str
46
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
 
 
 
47
  add_timestamp: bool
48
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
 
 
 
 
 
 
 
 
 
 
 
49
  progress: gr.Progress
50
  Indicator to show progress directly in gradio.
51
+ *whisper_params: tuple
52
+ Gradio components related to Whisper. see whisper_data_class.py for details.
53
 
54
  Returns
55
  ----------
56
+ result_str:
57
+ Result of transcription to return to gr.Textbox()
58
+ result_file_path:
59
+ Output file path to return to gr.Files()
60
  """
61
  try:
 
 
62
  files_info = {}
63
+ for file in files:
64
  progress(0, desc="Loading Audio..")
65
+ audio = whisper.load_audio(file.name)
66
+
67
+ result, elapsed_time = self.transcribe(audio,
68
+ progress,
69
+ *whisper_params)
 
 
 
 
 
 
70
  progress(1, desc="Completed!")
71
 
72
+ file_name, file_ext = os.path.splitext(os.path.basename(file.name))
73
  file_name = safe_filename(file_name)
74
  subtitle, file_path = self.generate_and_write_file(
75
  file_name=file_name,
 
77
  add_timestamp=add_timestamp,
78
  file_format=file_format
79
  )
80
+ files_info[file_name] = {"subtitle": subtitle, "elapsed_time": elapsed_time, "path": file_path}
81
 
82
  total_result = ''
83
  total_time = 0
 
87
  total_result += f"{info['subtitle']}"
88
  total_time += info["elapsed_time"]
89
 
90
+ result_str = f"Done in {self.format_time(total_time)}! Subtitle is in the outputs folder.\n\n{total_result}"
91
+ result_file_path = [info['path'] for info in files_info.values()]
92
 
93
+ return [result_str, result_file_path]
94
  except Exception as e:
95
  print(f"Error transcribing file: {str(e)}")
96
  finally:
97
  self.release_cuda_memory()
98
+ self.remove_input_files([file.name for file in files])
99
 
100
  def transcribe_youtube(self,
101
+ youtube_link: str,
 
 
102
  file_format: str,
 
103
  add_timestamp: bool,
104
+ progress=gr.Progress(),
105
+ *whisper_params) -> list:
 
 
 
106
  """
107
  Write subtitle file from Youtube
108
 
109
  Parameters
110
  ----------
111
+ youtube_link: str
112
+ URL of the Youtube video to transcribe from gr.Textbox()
 
 
 
 
113
  file_format: str
114
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
 
 
 
115
  add_timestamp: bool
116
  Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
 
 
 
 
 
 
 
 
 
 
 
117
  progress: gr.Progress
118
  Indicator to show progress directly in gradio.
119
+ *whisper_params: tuple
120
+ Gradio components related to Whisper. see whisper_data_class.py for details.
121
 
122
  Returns
123
  ----------
124
+ result_str:
125
+ Result of transcription to return to gr.Textbox()
126
+ result_file_path:
127
+ Output file path to return to gr.Files()
128
  """
129
  try:
 
 
130
  progress(0, desc="Loading Audio from Youtube..")
131
+ yt = get_ytdata(youtube_link)
132
  audio = whisper.load_audio(get_ytaudio(yt))
133
 
134
+ result, elapsed_time = self.transcribe(audio,
135
+ progress,
136
+ *whisper_params)
 
 
 
 
 
137
  progress(1, desc="Completed!")
138
 
139
  file_name = safe_filename(yt.title)
140
+ subtitle, result_file_path = self.generate_and_write_file(
141
  file_name=file_name,
142
  transcribed_segments=result,
143
  add_timestamp=add_timestamp,
144
  file_format=file_format
145
  )
146
 
147
+ result_str = f"Done in {self.format_time(elapsed_time)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
148
+ return [result_str, result_file_path]
149
  except Exception as e:
150
  print(f"Error transcribing youtube video: {str(e)}")
151
  finally:
152
  try:
153
  if 'yt' not in locals():
154
+ yt = get_ytdata(youtube_link)
155
  file_path = get_ytaudio(yt)
156
  else:
157
  file_path = get_ytaudio(yt)
 
162
  pass
163
 
164
  def transcribe_mic(self,
165
+ mic_audio: str,
 
 
166
  file_format: str,
167
+ progress=gr.Progress(),
168
+ *whisper_params) -> list:
 
 
 
 
169
  """
170
  Write subtitle file from microphone
171
 
172
  Parameters
173
  ----------
174
+ mic_audio: str
175
  Audio file path from gr.Microphone()
 
 
 
 
176
  file_format: str
177
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  progress: gr.Progress
179
  Indicator to show progress directly in gradio.
180
+ *whisper_params: tuple
181
+ Gradio components related to Whisper. see whisper_data_class.py for details.
182
 
183
  Returns
184
  ----------
185
+ result_str:
186
+ Result of transcription to return to gr.Textbox()
187
+ result_file_path:
188
+ Output file path to return to gr.Files()
189
  """
190
  try:
191
+ progress(0, desc="Loading Audio..")
192
+ result, elapsed_time = self.transcribe(
193
+ mic_audio,
194
+ progress,
195
+ *whisper_params,
196
+ )
 
 
 
 
197
  progress(1, desc="Completed!")
198
 
199
+ subtitle, result_file_path = self.generate_and_write_file(
200
  file_name="Mic",
201
  transcribed_segments=result,
202
  add_timestamp=True,
203
  file_format=file_format
204
  )
205
 
206
+ result_str = f"Done in {self.format_time(elapsed_time)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
207
+ return [result_str, result_file_path]
208
  except Exception as e:
209
  print(f"Error transcribing mic: {str(e)}")
210
  finally:
211
  self.release_cuda_memory()
212
+ self.remove_input_files([mic_audio])
213
 
214
  def transcribe(self,
215
  audio: Union[str, np.ndarray, torch.Tensor],
216
+ progress: gr.Progress,
217
+ *whisper_params,
 
 
 
 
 
218
  ) -> Tuple[List[dict], float]:
219
  """
220
+ transcribe method for faster-whisper.
221
 
222
  Parameters
223
  ----------
224
+ audio: Union[str, BinaryIO, np.ndarray]
225
  Audio path or file binary or Audio numpy array
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  progress: gr.Progress
227
  Indicator to show progress directly in gradio.
228
+ *whisper_params: tuple
229
+ Gradio components related to Whisper. see whisper_data_class.py for details.
230
 
231
  Returns
232
  ----------
 
236
  elapsed time for transcription
237
  """
238
  start_time = time.time()
239
+ params = WhisperGradioComponents.to_values(*whisper_params)
240
+
241
+ if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
242
+ self.update_model(params.model_size, params.compute_type, progress)
243
+
244
+ if params.lang == "Automatic Detection":
245
+ params.lang = None
246
 
247
  def progress_callback(progress_value):
248
  progress(progress_value, desc="Transcribing..")
249
 
 
 
 
 
250
  segments_result = self.model.transcribe(audio=audio,
251
+ language=params.lang,
252
  verbose=False,
253
+ beam_size=params.beam_size,
254
+ logprob_threshold=params.log_prob_threshold,
255
+ no_speech_threshold=params.no_speech_threshold,
256
+ task="translate" if params.is_translate and self.current_model_size in self.translatable_model else "transcribe",
257
+ fp16=True if params.compute_type == "float16" else False,
258
  progress_callback=progress_callback)["segments"]
259
  elapsed_time = time.time() - start_time
260
 
261
  return segments_result, elapsed_time
262
 
263
+ def update_model(self,
264
+ model_size: str,
265
+ compute_type: str,
266
+ progress: gr.Progress,
267
+ ):
268
  """
269
+ Update current model setting
270
+
271
+ Parameters
272
+ ----------
273
+ model_size: str
274
+ Size of whisper model
275
+ compute_type: str
276
+ Compute type for transcription.
277
+ see more info : https://opennmt.net/CTranslate2/quantization.html
278
+ progress: gr.Progress
279
+ Indicator to show progress directly in gradio.
280
  """
281
+ progress(0, desc="Initializing Model..")
282
+ self.current_compute_type = compute_type
283
+ self.current_model_size = model_size
284
+ self.model = whisper.load_model(
285
+ name=model_size,
286
+ device=self.device,
287
+ download_root=os.path.join("models", "Whisper")
288
+ )
 
 
289
 
290
  @staticmethod
291
  def generate_and_write_file(file_name: str,
 
294
  file_format: str,
295
  ) -> str:
296
  """
297
+ Writes subtitle file
298
+
299
+ Parameters
300
+ ----------
301
+ file_name: str
302
+ Output file name
303
+ transcribed_segments: list
304
+ Text segments transcribed from audio
305
+ add_timestamp: bool
306
+ Determines whether to add a timestamp to the end of the filename.
307
+ file_format: str
308
+ File format to write. Supported formats: [SRT, WebVTT, txt]
309
+
310
+ Returns
311
+ ----------
312
+ content: str
313
+ Result of the transcription
314
+ output_path: str
315
+ output file path
316
  """
317
  timestamp = datetime.now().strftime("%m%d%H%M%S")
318
  if add_timestamp:
 
338
 
339
  @staticmethod
340
  def format_time(elapsed_time: float) -> str:
341
+ """
342
+ Get {hours} {minutes} {seconds} time format string
343
+
344
+ Parameters
345
+ ----------
346
+ elapsed_time: str
347
+ Elapsed time for transcription
348
+
349
+ Returns
350
+ ----------
351
+ Time format string
352
+ """
353
  hours, rem = divmod(elapsed_time, 3600)
354
  minutes, seconds = divmod(rem, 60)
355