jhj0517 commited on
Commit
ec41bf5
·
1 Parent(s): 75962fd

add base inference

Browse files
modules/insanely_fast_whisper_inference.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import whisper
2
+ import gradio as gr
3
+ import time
4
+ import os
5
+ from typing import BinaryIO, Union, Tuple, List
6
+ import numpy as np
7
+ import torch
8
+
9
+ from modules.whisper_base import WhisperBase
10
+ from modules.whisper_parameter import *
11
+
12
+
13
+ class InsanelyFastWhisperInference(WhisperBase):
14
+ def __init__(self):
15
+ super().__init__(
16
+ model_dir=os.path.join("models", "Whisper")
17
+ )
18
+
19
+ def transcribe(self,
20
+ audio: Union[str, np.ndarray, torch.Tensor],
21
+ progress: gr.Progress,
22
+ *whisper_params,
23
+ ) -> Tuple[List[dict], float]:
24
+ """
25
+ transcribe method for faster-whisper.
26
+
27
+ Parameters
28
+ ----------
29
+ audio: Union[str, BinaryIO, np.ndarray]
30
+ Audio path or file binary or Audio numpy array
31
+ progress: gr.Progress
32
+ Indicator to show progress directly in gradio.
33
+ *whisper_params: tuple
34
+ Gradio components related to Whisper. see whisper_data_class.py for details.
35
+
36
+ Returns
37
+ ----------
38
+ segments_result: List[dict]
39
+ list of dicts that includes start, end timestamps and transcribed text
40
+ elapsed_time: float
41
+ elapsed time for transcription
42
+ """
43
+ start_time = time.time()
44
+ params = WhisperValues(*whisper_params)
45
+
46
+ if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
47
+ self.update_model(params.model_size, params.compute_type, progress)
48
+
49
+ if params.lang == "Automatic Detection":
50
+ params.lang = None
51
+
52
+ def progress_callback(progress_value):
53
+ progress(progress_value, desc="Transcribing..")
54
+
55
+ segments_result = self.model.transcribe(audio=audio,
56
+ language=params.lang,
57
+ verbose=False,
58
+ beam_size=params.beam_size,
59
+ logprob_threshold=params.log_prob_threshold,
60
+ no_speech_threshold=params.no_speech_threshold,
61
+ task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
62
+ fp16=True if params.compute_type == "float16" else False,
63
+ best_of=params.best_of,
64
+ patience=params.patience,
65
+ temperature=params.temperature,
66
+ compression_ratio_threshold=params.compression_ratio_threshold,
67
+ progress_callback=progress_callback,)["segments"]
68
+ elapsed_time = time.time() - start_time
69
+
70
+ return segments_result, elapsed_time
71
+
72
+ def update_model(self,
73
+ model_size: str,
74
+ compute_type: str,
75
+ progress: gr.Progress,
76
+ ):
77
+ """
78
+ Update current model setting
79
+
80
+ Parameters
81
+ ----------
82
+ model_size: str
83
+ Size of whisper model
84
+ compute_type: str
85
+ Compute type for transcription.
86
+ see more info : https://opennmt.net/CTranslate2/quantization.html
87
+ progress: gr.Progress
88
+ Indicator to show progress directly in gradio.
89
+ """
90
+ progress(0, desc="Initializing Model..")
91
+ self.current_compute_type = compute_type
92
+ self.current_model_size = model_size
93
+ self.model = whisper.load_model(
94
+ name=model_size,
95
+ device=self.device,
96
+ download_root=self.model_dir
97
+ )