jhj0517 commited on
Commit
e76c01c
·
1 Parent(s): b72fd8a

refactor base abstract class for whisper

Browse files
modules/faster_whisper_inference.py CHANGED
@@ -2,233 +2,30 @@ import os
2
  import time
3
  import numpy as np
4
  from typing import BinaryIO, Union, Tuple, List
5
- from datetime import datetime
6
 
7
  import faster_whisper
8
  from faster_whisper.vad import VadOptions
9
  import ctranslate2
10
  import whisper
11
- import torch
12
  import gradio as gr
13
 
14
- from .base_interface import BaseInterface
15
- from modules.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
16
- from modules.youtube_manager import get_ytdata, get_ytaudio
17
  from modules.whisper_parameter import *
 
18
 
19
  # Temporal fix of the issue : https://github.com/jhj0517/Whisper-WebUI/issues/144
20
  os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
21
 
22
 
23
- class FasterWhisperInference(BaseInterface):
24
  def __init__(self):
25
- super().__init__()
 
 
26
  self.model_dir = os.path.join("models", "Whisper", "faster-whisper")
27
- os.makedirs(self.model_dir, exist_ok=True)
28
- self.current_model_size = None
29
- self.model = None
30
  self.model_paths = self.get_model_paths()
31
  self.available_models = self.model_paths.keys()
32
- self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
33
- self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
34
- if torch.cuda.is_available():
35
- self.device = "cuda"
36
- elif torch.backends.mps.is_available():
37
- self.device = "mps"
38
- else:
39
- self.device = "cpu"
40
  self.available_compute_types = ctranslate2.get_supported_compute_types(
41
  "cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
42
- self.current_compute_type = "float16" if self.device == "cuda" else "float32"
43
-
44
- def transcribe_file(self,
45
- files: list,
46
- file_format: str,
47
- add_timestamp: bool,
48
- progress=gr.Progress(),
49
- *whisper_params,
50
- ) -> list:
51
- """
52
- Write subtitle file from Files
53
-
54
- Parameters
55
- ----------
56
- files: list
57
- List of files to transcribe from gr.Files()
58
- file_format: str
59
- Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
60
- add_timestamp: bool
61
- Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
62
- progress: gr.Progress
63
- Indicator to show progress directly in gradio.
64
- *whisper_params: tuple
65
- Gradio components related to Whisper. see whisper_data_class.py for details.
66
-
67
- Returns
68
- ----------
69
- result_str:
70
- Result of transcription to return to gr.Textbox()
71
- result_file_path:
72
- Output file path to return to gr.Files()
73
- """
74
- try:
75
- files_info = {}
76
- for file in files:
77
- transcribed_segments, time_for_task = self.transcribe(
78
- file.name,
79
- progress,
80
- *whisper_params,
81
- )
82
-
83
- file_name, file_ext = os.path.splitext(os.path.basename(file.name))
84
- file_name = safe_filename(file_name)
85
- subtitle, file_path = self.generate_and_write_file(
86
- file_name=file_name,
87
- transcribed_segments=transcribed_segments,
88
- add_timestamp=add_timestamp,
89
- file_format=file_format
90
- )
91
- files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path}
92
-
93
- total_result = ''
94
- total_time = 0
95
- for file_name, info in files_info.items():
96
- total_result += '------------------------------------\n'
97
- total_result += f'{file_name}\n\n'
98
- total_result += f'{info["subtitle"]}'
99
- total_time += info["time_for_task"]
100
-
101
- result_str = f"Done in {self.format_time(total_time)}! Subtitle is in the outputs folder.\n\n{total_result}"
102
- result_file_path = [info['path'] for info in files_info.values()]
103
-
104
- return [result_str, result_file_path]
105
-
106
- except Exception as e:
107
- print(f"Error transcribing file: {e}")
108
- finally:
109
- self.release_cuda_memory()
110
- if not files:
111
- self.remove_input_files([file.name for file in files])
112
-
113
- def transcribe_youtube(self,
114
- youtube_link: str,
115
- file_format: str,
116
- add_timestamp: bool,
117
- progress=gr.Progress(),
118
- *whisper_params,
119
- ) -> list:
120
- """
121
- Write subtitle file from Youtube
122
-
123
- Parameters
124
- ----------
125
- youtube_link: str
126
- URL of the Youtube video to transcribe from gr.Textbox()
127
- file_format: str
128
- Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
129
- add_timestamp: bool
130
- Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
131
- progress: gr.Progress
132
- Indicator to show progress directly in gradio.
133
- *whisper_params: tuple
134
- Gradio components related to Whisper. see whisper_data_class.py for details.
135
-
136
- Returns
137
- ----------
138
- result_str:
139
- Result of transcription to return to gr.Textbox()
140
- result_file_path:
141
- Output file path to return to gr.Files()
142
- """
143
- try:
144
- progress(0, desc="Loading Audio from Youtube..")
145
- yt = get_ytdata(youtube_link)
146
- audio = get_ytaudio(yt)
147
-
148
- transcribed_segments, time_for_task = self.transcribe(
149
- audio,
150
- progress,
151
- *whisper_params,
152
- )
153
-
154
- progress(1, desc="Completed!")
155
-
156
- file_name = safe_filename(yt.title)
157
- subtitle, result_file_path = self.generate_and_write_file(
158
- file_name=file_name,
159
- transcribed_segments=transcribed_segments,
160
- add_timestamp=add_timestamp,
161
- file_format=file_format
162
- )
163
- result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
164
-
165
- return [result_str, result_file_path]
166
-
167
- except Exception as e:
168
- print(f"Error transcribing file: {e}")
169
- finally:
170
- try:
171
- if 'yt' not in locals():
172
- yt = get_ytdata(youtube_link)
173
- file_path = get_ytaudio(yt)
174
- else:
175
- file_path = get_ytaudio(yt)
176
-
177
- self.release_cuda_memory()
178
- self.remove_input_files([file_path])
179
- except Exception as cleanup_error:
180
- pass
181
-
182
- def transcribe_mic(self,
183
- mic_audio: str,
184
- file_format: str,
185
- progress=gr.Progress(),
186
- *whisper_params,
187
- ) -> list:
188
- """
189
- Write subtitle file from microphone
190
-
191
- Parameters
192
- ----------
193
- mic_audio: str
194
- Audio file path from gr.Microphone()
195
- file_format: str
196
- Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
197
- progress: gr.Progress
198
- Indicator to show progress directly in gradio.
199
- *whisper_params: tuple
200
- Gradio components related to Whisper. see whisper_data_class.py for details.
201
-
202
- Returns
203
- ----------
204
- result_str:
205
- Result of transcription to return to gr.Textbox()
206
- result_file_path:
207
- Output file path to return to gr.Files()
208
- """
209
- try:
210
- progress(0, desc="Loading Audio..")
211
- transcribed_segments, time_for_task = self.transcribe(
212
- mic_audio,
213
- progress,
214
- *whisper_params,
215
- )
216
- progress(1, desc="Completed!")
217
-
218
- subtitle, result_file_path = self.generate_and_write_file(
219
- file_name="Mic",
220
- transcribed_segments=transcribed_segments,
221
- add_timestamp=True,
222
- file_format=file_format
223
- )
224
-
225
- result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
226
- return [result_str, result_file_path]
227
- except Exception as e:
228
- print(f"Error transcribing file: {e}")
229
- finally:
230
- self.release_cuda_memory()
231
- self.remove_input_files([mic_audio])
232
 
233
  def transcribe(self,
234
  audio: Union[str, BinaryIO, np.ndarray],
@@ -356,79 +153,3 @@ class FasterWhisperInference(BaseInterface):
356
  if model_name not in whisper.available_models():
357
  model_paths[model_name] = os.path.join(webui_dir, self.model_dir, model_name)
358
  return model_paths
359
-
360
- @staticmethod
361
- def generate_and_write_file(file_name: str,
362
- transcribed_segments: list,
363
- add_timestamp: bool,
364
- file_format: str,
365
- ) -> str:
366
- """
367
- Writes subtitle file
368
-
369
- Parameters
370
- ----------
371
- file_name: str
372
- Output file name
373
- transcribed_segments: list
374
- Text segments transcribed from audio
375
- add_timestamp: bool
376
- Determines whether to add a timestamp to the end of the filename.
377
- file_format: str
378
- File format to write. Supported formats: [SRT, WebVTT, txt]
379
-
380
- Returns
381
- ----------
382
- content: str
383
- Result of the transcription
384
- output_path: str
385
- output file path
386
- """
387
- timestamp = datetime.now().strftime("%m%d%H%M%S")
388
- if add_timestamp:
389
- output_path = os.path.join("outputs", f"{file_name}-{timestamp}")
390
- else:
391
- output_path = os.path.join("outputs", f"{file_name}")
392
-
393
- if file_format == "SRT":
394
- content = get_srt(transcribed_segments)
395
- output_path += '.srt'
396
- write_file(content, output_path)
397
-
398
- elif file_format == "WebVTT":
399
- content = get_vtt(transcribed_segments)
400
- output_path += '.vtt'
401
- write_file(content, output_path)
402
-
403
- elif file_format == "txt":
404
- content = get_txt(transcribed_segments)
405
- output_path += '.txt'
406
- write_file(content, output_path)
407
- return content, output_path
408
-
409
- @staticmethod
410
- def format_time(elapsed_time: float) -> str:
411
- """
412
- Get {hours} {minutes} {seconds} time format string
413
-
414
- Parameters
415
- ----------
416
- elapsed_time: str
417
- Elapsed time for transcription
418
-
419
- Returns
420
- ----------
421
- Time format string
422
- """
423
- hours, rem = divmod(elapsed_time, 3600)
424
- minutes, seconds = divmod(rem, 60)
425
-
426
- time_str = ""
427
- if hours:
428
- time_str += f"{hours} hours "
429
- if minutes:
430
- time_str += f"{minutes} minutes "
431
- seconds = round(seconds)
432
- time_str += f"{seconds} seconds"
433
-
434
- return time_str.strip()
 
2
  import time
3
  import numpy as np
4
  from typing import BinaryIO, Union, Tuple, List
 
5
 
6
  import faster_whisper
7
  from faster_whisper.vad import VadOptions
8
  import ctranslate2
9
  import whisper
 
10
  import gradio as gr
11
 
 
 
 
12
  from modules.whisper_parameter import *
13
+ from modules.whisper_base import WhisperBase
14
 
15
  # Temporal fix of the issue : https://github.com/jhj0517/Whisper-WebUI/issues/144
16
  os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
17
 
18
 
19
+ class FasterWhisperInference(WhisperBase):
20
  def __init__(self):
21
+ super().__init__(
22
+ model_dir=os.path.join("models", "Whisper", "faster-whisper")
23
+ )
24
  self.model_dir = os.path.join("models", "Whisper", "faster-whisper")
 
 
 
25
  self.model_paths = self.get_model_paths()
26
  self.available_models = self.model_paths.keys()
 
 
 
 
 
 
 
 
27
  self.available_compute_types = ctranslate2.get_supported_compute_types(
28
  "cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  def transcribe(self,
31
  audio: Union[str, BinaryIO, np.ndarray],
 
153
  if model_name not in whisper.available_models():
154
  model_paths[model_name] = os.path.join(webui_dir, self.model_dir, model_name)
155
  return model_paths
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/whisper_Inference.py CHANGED
@@ -4,217 +4,17 @@ import time
4
  import os
5
  from typing import BinaryIO, Union, Tuple, List
6
  import numpy as np
7
- from datetime import datetime
8
  import torch
9
 
10
- from .base_interface import BaseInterface
11
- from modules.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
12
- from modules.youtube_manager import get_ytdata, get_ytaudio
13
  from modules.whisper_parameter import *
14
 
15
- DEFAULT_MODEL_SIZE = "large-v3"
16
 
17
-
18
- class WhisperInference(BaseInterface):
19
  def __init__(self):
20
- super().__init__()
21
- self.current_model_size = None
22
- self.model = None
23
- self.available_models = whisper.available_models()
24
- self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
25
- self.translatable_model = ["large", "large-v1", "large-v2", "large-v3"]
26
- if torch.cuda.is_available():
27
- self.device = "cuda"
28
- elif torch.backends.mps.is_available():
29
- self.device = "mps"
30
- else:
31
- self.device = "cpu"
32
- self.available_compute_types = ["float16", "float32"]
33
- self.current_compute_type = "float16" if self.device == "cuda" else "float32"
34
- self.model_dir = os.path.join("models", "Whisper")
35
-
36
- def transcribe_file(self,
37
- files: list,
38
- file_format: str,
39
- add_timestamp: bool,
40
- progress=gr.Progress(),
41
- *whisper_params
42
- ) -> list:
43
- """
44
- Write subtitle file from Files
45
-
46
- Parameters
47
- ----------
48
- files: list
49
- List of files to transcribe from gr.Files()
50
- file_format: str
51
- Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
52
- add_timestamp: bool
53
- Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
54
- progress: gr.Progress
55
- Indicator to show progress directly in gradio.
56
- *whisper_params: tuple
57
- Gradio components related to Whisper. see whisper_data_class.py for details.
58
-
59
- Returns
60
- ----------
61
- result_str:
62
- Result of transcription to return to gr.Textbox()
63
- result_file_path:
64
- Output file path to return to gr.Files()
65
- """
66
- try:
67
- files_info = {}
68
- for file in files:
69
- progress(0, desc="Loading Audio..")
70
- audio = whisper.load_audio(file.name)
71
-
72
- result, elapsed_time = self.transcribe(audio,
73
- progress,
74
- *whisper_params)
75
- progress(1, desc="Completed!")
76
-
77
- file_name, file_ext = os.path.splitext(os.path.basename(file.name))
78
- file_name = safe_filename(file_name)
79
- subtitle, file_path = self.generate_and_write_file(
80
- file_name=file_name,
81
- transcribed_segments=result,
82
- add_timestamp=add_timestamp,
83
- file_format=file_format
84
- )
85
- files_info[file_name] = {"subtitle": subtitle, "elapsed_time": elapsed_time, "path": file_path}
86
-
87
- total_result = ''
88
- total_time = 0
89
- for file_name, info in files_info.items():
90
- total_result += '------------------------------------\n'
91
- total_result += f'{file_name}\n\n'
92
- total_result += f"{info['subtitle']}"
93
- total_time += info["elapsed_time"]
94
-
95
- result_str = f"Done in {self.format_time(total_time)}! Subtitle is in the outputs folder.\n\n{total_result}"
96
- result_file_path = [info['path'] for info in files_info.values()]
97
-
98
- return [result_str, result_file_path]
99
- except Exception as e:
100
- print(f"Error transcribing file: {str(e)}")
101
- finally:
102
- self.release_cuda_memory()
103
- self.remove_input_files([file.name for file in files])
104
-
105
- def transcribe_youtube(self,
106
- youtube_link: str,
107
- file_format: str,
108
- add_timestamp: bool,
109
- progress=gr.Progress(),
110
- *whisper_params) -> list:
111
- """
112
- Write subtitle file from Youtube
113
-
114
- Parameters
115
- ----------
116
- youtube_link: str
117
- URL of the Youtube video to transcribe from gr.Textbox()
118
- file_format: str
119
- Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
120
- add_timestamp: bool
121
- Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
122
- progress: gr.Progress
123
- Indicator to show progress directly in gradio.
124
- *whisper_params: tuple
125
- Gradio components related to Whisper. see whisper_data_class.py for details.
126
-
127
- Returns
128
- ----------
129
- result_str:
130
- Result of transcription to return to gr.Textbox()
131
- result_file_path:
132
- Output file path to return to gr.Files()
133
- """
134
- try:
135
- progress(0, desc="Loading Audio from Youtube..")
136
- yt = get_ytdata(youtube_link)
137
- audio = whisper.load_audio(get_ytaudio(yt))
138
-
139
- result, elapsed_time = self.transcribe(audio,
140
- progress,
141
- *whisper_params)
142
- progress(1, desc="Completed!")
143
-
144
- file_name = safe_filename(yt.title)
145
- subtitle, result_file_path = self.generate_and_write_file(
146
- file_name=file_name,
147
- transcribed_segments=result,
148
- add_timestamp=add_timestamp,
149
- file_format=file_format
150
- )
151
-
152
- result_str = f"Done in {self.format_time(elapsed_time)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
153
- return [result_str, result_file_path]
154
- except Exception as e:
155
- print(f"Error transcribing youtube video: {str(e)}")
156
- finally:
157
- try:
158
- if 'yt' not in locals():
159
- yt = get_ytdata(youtube_link)
160
- file_path = get_ytaudio(yt)
161
- else:
162
- file_path = get_ytaudio(yt)
163
-
164
- self.release_cuda_memory()
165
- self.remove_input_files([file_path])
166
- except Exception as cleanup_error:
167
- pass
168
-
169
- def transcribe_mic(self,
170
- mic_audio: str,
171
- file_format: str,
172
- progress=gr.Progress(),
173
- *whisper_params) -> list:
174
- """
175
- Write subtitle file from microphone
176
-
177
- Parameters
178
- ----------
179
- mic_audio: str
180
- Audio file path from gr.Microphone()
181
- file_format: str
182
- Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
183
- progress: gr.Progress
184
- Indicator to show progress directly in gradio.
185
- *whisper_params: tuple
186
- Gradio components related to Whisper. see whisper_data_class.py for details.
187
-
188
- Returns
189
- ----------
190
- result_str:
191
- Result of transcription to return to gr.Textbox()
192
- result_file_path:
193
- Output file path to return to gr.Files()
194
- """
195
- try:
196
- progress(0, desc="Loading Audio..")
197
- result, elapsed_time = self.transcribe(
198
- mic_audio,
199
- progress,
200
- *whisper_params,
201
- )
202
- progress(1, desc="Completed!")
203
-
204
- subtitle, result_file_path = self.generate_and_write_file(
205
- file_name="Mic",
206
- transcribed_segments=result,
207
- add_timestamp=True,
208
- file_format=file_format
209
- )
210
-
211
- result_str = f"Done in {self.format_time(elapsed_time)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
212
- return [result_str, result_file_path]
213
- except Exception as e:
214
- print(f"Error transcribing mic: {str(e)}")
215
- finally:
216
- self.release_cuda_memory()
217
- self.remove_input_files([mic_audio])
218
 
219
  def transcribe(self,
220
  audio: Union[str, np.ndarray, torch.Tensor],
@@ -258,7 +58,7 @@ class WhisperInference(BaseInterface):
258
  beam_size=params.beam_size,
259
  logprob_threshold=params.log_prob_threshold,
260
  no_speech_threshold=params.no_speech_threshold,
261
- task="translate" if params.is_translate and self.current_model_size in self.translatable_model else "transcribe",
262
  fp16=True if params.compute_type == "float16" else False,
263
  best_of=params.best_of,
264
  patience=params.patience,
@@ -294,80 +94,4 @@ class WhisperInference(BaseInterface):
294
  name=model_size,
295
  device=self.device,
296
  download_root=self.model_dir
297
- )
298
-
299
- @staticmethod
300
- def generate_and_write_file(file_name: str,
301
- transcribed_segments: list,
302
- add_timestamp: bool,
303
- file_format: str,
304
- ) -> str:
305
- """
306
- Writes subtitle file
307
-
308
- Parameters
309
- ----------
310
- file_name: str
311
- Output file name
312
- transcribed_segments: list
313
- Text segments transcribed from audio
314
- add_timestamp: bool
315
- Determines whether to add a timestamp to the end of the filename.
316
- file_format: str
317
- File format to write. Supported formats: [SRT, WebVTT, txt]
318
-
319
- Returns
320
- ----------
321
- content: str
322
- Result of the transcription
323
- output_path: str
324
- output file path
325
- """
326
- timestamp = datetime.now().strftime("%m%d%H%M%S")
327
- if add_timestamp:
328
- output_path = os.path.join("outputs", f"{file_name}-{timestamp}")
329
- else:
330
- output_path = os.path.join("outputs", f"{file_name}")
331
-
332
- if file_format == "SRT":
333
- content = get_srt(transcribed_segments)
334
- output_path += '.srt'
335
- write_file(content, output_path)
336
-
337
- elif file_format == "WebVTT":
338
- content = get_vtt(transcribed_segments)
339
- output_path += '.vtt'
340
- write_file(content, output_path)
341
-
342
- elif file_format == "txt":
343
- content = get_txt(transcribed_segments)
344
- output_path += '.txt'
345
- write_file(content, output_path)
346
- return content, output_path
347
-
348
- @staticmethod
349
- def format_time(elapsed_time: float) -> str:
350
- """
351
- Get {hours} {minutes} {seconds} time format string
352
-
353
- Parameters
354
- ----------
355
- elapsed_time: str
356
- Elapsed time for transcription
357
-
358
- Returns
359
- ----------
360
- Time format string
361
- """
362
- hours, rem = divmod(elapsed_time, 3600)
363
- minutes, seconds = divmod(rem, 60)
364
-
365
- time_str = ""
366
- if hours:
367
- time_str += f"{hours} hours "
368
- if minutes:
369
- time_str += f"{minutes} minutes "
370
- seconds = round(seconds)
371
- time_str += f"{seconds} seconds"
372
-
373
- return time_str.strip()
 
4
  import os
5
  from typing import BinaryIO, Union, Tuple, List
6
  import numpy as np
 
7
  import torch
8
 
9
+ from modules.whisper_base import WhisperBase
 
 
10
  from modules.whisper_parameter import *
11
 
 
12
 
13
+ class WhisperInference(WhisperBase):
 
14
  def __init__(self):
15
+ super().__init__(
16
+ model_dir=os.path.join("models", "Whisper")
17
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def transcribe(self,
20
  audio: Union[str, np.ndarray, torch.Tensor],
 
58
  beam_size=params.beam_size,
59
  logprob_threshold=params.log_prob_threshold,
60
  no_speech_threshold=params.no_speech_threshold,
61
+ task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
62
  fp16=True if params.compute_type == "float16" else False,
63
  best_of=params.best_of,
64
  patience=params.patience,
 
94
  name=model_size,
95
  device=self.device,
96
  download_root=self.model_dir
97
+ )