jhj0517 commited on
Commit
87272f5
·
1 Parent(s): 292ccb4

add `format_result()`

Browse files
modules/insanely_fast_whisper_inference.py CHANGED
@@ -1,20 +1,24 @@
1
- import whisper
2
- import gradio as gr
3
- import time
4
  import os
5
- from typing import BinaryIO, Union, Tuple, List
6
  import numpy as np
 
7
  import torch
 
 
 
 
 
8
 
9
- from modules.whisper_base import WhisperBase
10
  from modules.whisper_parameter import *
 
11
 
12
 
13
  class InsanelyFastWhisperInference(WhisperBase):
14
  def __init__(self):
15
  super().__init__(
16
- model_dir=os.path.join("models", "Whisper")
17
  )
 
18
 
19
  def transcribe(self,
20
  audio: Union[str, np.ndarray, torch.Tensor],
@@ -52,21 +56,14 @@ class InsanelyFastWhisperInference(WhisperBase):
52
  def progress_callback(progress_value):
53
  progress(progress_value, desc="Transcribing..")
54
 
55
- segments_result = self.model.transcribe(audio=audio,
56
- language=params.lang,
57
- verbose=False,
58
- beam_size=params.beam_size,
59
- logprob_threshold=params.log_prob_threshold,
60
- no_speech_threshold=params.no_speech_threshold,
61
- task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
62
- fp16=True if params.compute_type == "float16" else False,
63
- best_of=params.best_of,
64
- patience=params.patience,
65
- temperature=params.temperature,
66
- compression_ratio_threshold=params.compression_ratio_threshold,
67
- progress_callback=progress_callback,)["segments"]
68
  elapsed_time = time.time() - start_time
69
-
70
  return segments_result, elapsed_time
71
 
72
  def update_model(self,
@@ -90,8 +87,34 @@ class InsanelyFastWhisperInference(WhisperBase):
90
  progress(0, desc="Initializing Model..")
91
  self.current_compute_type = compute_type
92
  self.current_model_size = model_size
93
- self.model = whisper.load_model(
94
- name=model_size,
 
 
 
95
  device=self.device,
96
- download_root=self.model_dir
97
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import time
3
  import numpy as np
4
+ from typing import BinaryIO, Union, Tuple, List
5
  import torch
6
+ import transformers
7
+ from transformers import pipeline
8
+ from transformers.utils import is_flash_attn_2_available
9
+ import whisper
10
+ import gradio as gr
11
 
 
12
  from modules.whisper_parameter import *
13
+ from modules.whisper_base import WhisperBase
14
 
15
 
16
  class InsanelyFastWhisperInference(WhisperBase):
17
  def __init__(self):
18
  super().__init__(
19
+ model_dir=os.path.join("models", "Whisper", "insanely_fast_whisper")
20
  )
21
+ self.available_compute_types = ["float16"]
22
 
23
  def transcribe(self,
24
  audio: Union[str, np.ndarray, torch.Tensor],
 
56
  def progress_callback(progress_value):
57
  progress(progress_value, desc="Transcribing..")
58
 
59
+ segments_result = self.model(
60
+ inputs=audio,
61
+ chunk_length_s=30,
62
+ batch_size=24,
63
+ return_timestamps=True,
64
+ )
65
+ segments_result = self.format_result(transcribed_result=segments_result)
 
 
 
 
 
 
66
  elapsed_time = time.time() - start_time
 
67
  return segments_result, elapsed_time
68
 
69
  def update_model(self,
 
87
  progress(0, desc="Initializing Model..")
88
  self.current_compute_type = compute_type
89
  self.current_model_size = model_size
90
+
91
+ self.model = pipeline(
92
+ "automatic-speech-recognition",
93
+ model=os.path.join(self.model_dir, model_size),
94
+ torch_dtype=self.current_compute_type,
95
  device=self.device,
96
+ model_kwargs={"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {"attn_implementation": "sdpa"},
97
+ )
98
+
99
+ @staticmethod
100
+ def format_result(transcribed_result: dict) -> List[dict]:
101
+ """
102
+ Format the transcription result of insanely_fast_whisper as the same with other implementation.
103
+
104
+ Parameters
105
+ ----------
106
+ transcribed_result: dict
107
+ Transcription result of the insanely_fast_whisper
108
+
109
+ Returns
110
+ ----------
111
+ result: List[dict]
112
+ Formatted result as the same with other implementation
113
+ """
114
+ result = transcribed_result["chunks"]
115
+ for item in result:
116
+ start, end = item["timestamp"][0], item["timestamp"][1]
117
+ item["start"] = start
118
+ item["end"] = end
119
+ return result
120
+