jhj0517 commited on
Commit
48382b6
·
unverified ·
2 Parent(s): 7386da0 815f5df

Merge pull request #216 from jhj0517/feature/modularize-vad

Browse files
app.py CHANGED
@@ -137,7 +137,7 @@ class App:
137
  nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
138
  nb_batch_size = gr.Number(label="Batch Size", value=24, precision=0)
139
 
140
- with gr.Accordion("VAD", open=False, visible=isinstance(self.whisper_inf, FasterWhisperInference)):
141
  cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
142
  sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5,
143
  info="Lower it to be more sensitive to small sounds.")
 
137
  nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
138
  nb_batch_size = gr.Number(label="Batch Size", value=24, precision=0)
139
 
140
+ with gr.Accordion("VAD", open=False):
141
  cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
142
  sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5,
143
  info="Lower it to be more sensitive to small sounds.")
modules/vad/silero_vad.py CHANGED
@@ -2,9 +2,10 @@
2
 
3
  from faster_whisper.vad import VadOptions, get_vad_model
4
  import numpy as np
5
- from typing import BinaryIO, Union, List, Optional
6
  import warnings
7
  import faster_whisper
 
8
  import gradio as gr
9
 
10
 
@@ -17,7 +18,8 @@ class SileroVAD:
17
  def run(self,
18
  audio: Union[str, BinaryIO, np.ndarray],
19
  vad_parameters: VadOptions,
20
- progress: gr.Progress = gr.Progress()):
 
21
  """
22
  Run VAD
23
 
@@ -32,8 +34,10 @@ class SileroVAD:
32
 
33
  Returns
34
  ----------
35
- audio: np.ndarray
36
  Pre-processed audio with VAD
 
 
37
  """
38
 
39
  sampling_rate = self.sampling_rate
@@ -56,7 +60,7 @@ class SileroVAD:
56
  audio = self.collect_chunks(audio, speech_chunks)
57
  duration_after_vad = audio.shape[0] / sampling_rate
58
 
59
- return audio
60
 
61
  def get_speech_timestamps(
62
  self,
@@ -241,3 +245,20 @@ class SileroVAD:
241
  f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
242
  )
243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  from faster_whisper.vad import VadOptions, get_vad_model
4
  import numpy as np
5
+ from typing import BinaryIO, Union, List, Optional, Tuple
6
  import warnings
7
  import faster_whisper
8
+ from faster_whisper.transcribe import SpeechTimestampsMap, Segment
9
  import gradio as gr
10
 
11
 
 
18
  def run(self,
19
  audio: Union[str, BinaryIO, np.ndarray],
20
  vad_parameters: VadOptions,
21
+ progress: gr.Progress = gr.Progress()
22
+ ) -> Tuple[np.ndarray, List[dict]]:
23
  """
24
  Run VAD
25
 
 
34
 
35
  Returns
36
  ----------
37
+ np.ndarray
38
  Pre-processed audio with VAD
39
+ List[dict]
40
+ Chunks of speeches to be used to restore the timestamps later
41
  """
42
 
43
  sampling_rate = self.sampling_rate
 
60
  audio = self.collect_chunks(audio, speech_chunks)
61
  duration_after_vad = audio.shape[0] / sampling_rate
62
 
63
+ return audio, speech_chunks
64
 
65
  def get_speech_timestamps(
66
  self,
 
245
  f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
246
  )
247
 
248
+ def restore_speech_timestamps(
249
+ self,
250
+ segments: List[dict],
251
+ speech_chunks: List[dict],
252
+ sampling_rate: Optional[int] = None,
253
+ ) -> List[dict]:
254
+ if sampling_rate is None:
255
+ sampling_rate = self.sampling_rate
256
+
257
+ ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)
258
+
259
+ for segment in segments:
260
+ segment["start"] = ts_map.get_original_time(segment["start"])
261
+ segment["end"] = ts_map.get_original_time(segment["end"])
262
+
263
+ return segments
264
+
modules/whisper/faster_whisper_inference.py CHANGED
@@ -71,20 +71,6 @@ class FasterWhisperInference(WhisperBase):
71
  if not params.hotwords:
72
  params.hotwords = None
73
 
74
- vad_options = None
75
- if params.vad_filter:
76
- # Explicit value set for float('inf') from gr.Number()
77
- if params.max_speech_duration_s >= 9999:
78
- params.max_speech_duration_s = float('inf')
79
-
80
- vad_options = VadOptions(
81
- threshold=params.threshold,
82
- min_speech_duration_ms=params.min_speech_duration_ms,
83
- max_speech_duration_s=params.max_speech_duration_s,
84
- min_silence_duration_ms=params.min_silence_duration_ms,
85
- speech_pad_ms=params.speech_pad_ms
86
- )
87
-
88
  params.suppress_tokens = self.format_suppress_tokens_str(params.suppress_tokens)
89
 
90
  segments, info = self.model.transcribe(
@@ -115,8 +101,6 @@ class FasterWhisperInference(WhisperBase):
115
  language_detection_threshold=params.language_detection_threshold,
116
  language_detection_segments=params.language_detection_segments,
117
  prompt_reset_on_temperature=params.prompt_reset_on_temperature,
118
- vad_filter=params.vad_filter,
119
- vad_parameters=vad_options
120
  )
121
  progress(0, desc="Loading audio..")
122
 
 
71
  if not params.hotwords:
72
  params.hotwords = None
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  params.suppress_tokens = self.format_suppress_tokens_str(params.suppress_tokens)
75
 
76
  segments, info = self.model.transcribe(
 
101
  language_detection_threshold=params.language_detection_threshold,
102
  language_detection_segments=params.language_detection_segments,
103
  prompt_reset_on_temperature=params.prompt_reset_on_temperature,
 
 
104
  )
105
  progress(0, desc="Loading audio..")
106
 
modules/whisper/whisper_base.py CHANGED
@@ -91,12 +91,38 @@ class WhisperBase(ABC):
91
  language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
92
  params.lang = language_code_dict[params.lang]
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  result, elapsed_time = self.transcribe(
95
  audio,
96
  progress,
97
  *astuple(params)
98
  )
99
 
 
 
 
 
 
 
100
  if params.is_diarize:
101
  result, elapsed_time_diarization = self.diarizer.run(
102
  audio=audio,
 
91
  language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
92
  params.lang = language_code_dict[params.lang]
93
 
94
+ speech_chunks = None
95
+ if params.vad_filter:
96
+ # Explicit value set for float('inf') from gr.Number()
97
+ if params.max_speech_duration_s >= 9999:
98
+ params.max_speech_duration_s = float('inf')
99
+
100
+ vad_options = VadOptions(
101
+ threshold=params.threshold,
102
+ min_speech_duration_ms=params.min_speech_duration_ms,
103
+ max_speech_duration_s=params.max_speech_duration_s,
104
+ min_silence_duration_ms=params.min_silence_duration_ms,
105
+ speech_pad_ms=params.speech_pad_ms
106
+ )
107
+
108
+ audio, speech_chunks = self.vad.run(
109
+ audio=audio,
110
+ vad_parameters=vad_options,
111
+ progress=progress
112
+ )
113
+
114
  result, elapsed_time = self.transcribe(
115
  audio,
116
  progress,
117
  *astuple(params)
118
  )
119
 
120
+ if params.vad_filter:
121
+ result = self.vad.restore_speech_timestamps(
122
+ segments=result,
123
+ speech_chunks=speech_chunks,
124
+ )
125
+
126
  if params.is_diarize:
127
  result, elapsed_time_diarization = self.diarizer.run(
128
  audio=audio,