jhj0517 commited on
Commit
6e99075
·
1 Parent(s): f1d9939

add diarization logic

Browse files
Files changed (1) hide show
  1. modules/whisper_base.py +91 -3
modules/whisper_base.py CHANGED
@@ -1,12 +1,15 @@
1
  import os
2
  import torch
3
  from typing import List
 
4
  import whisper
5
  import gradio as gr
6
  from abc import ABC, abstractmethod
7
  from typing import BinaryIO, Union, Tuple, List
8
  import numpy as np
9
  from datetime import datetime
 
 
10
 
11
  from modules.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
12
  from modules.youtube_manager import get_ytdata, get_ytaudio
@@ -21,15 +24,20 @@ class WhisperBase(ABC):
21
  self.model = None
22
  self.current_model_size = None
23
  self.model_dir = model_dir
 
24
  self.output_dir = output_dir
25
  os.makedirs(self.output_dir, exist_ok=True)
26
  os.makedirs(self.model_dir, exist_ok=True)
 
27
  self.available_models = whisper.available_models()
28
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
29
  self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
30
  self.device = self.get_device()
31
  self.available_compute_types = ["float16", "float32"]
32
  self.current_compute_type = "float16" if self.device == "cuda" else "float32"
 
 
 
33
 
34
  @abstractmethod
35
  def transcribe(self,
@@ -47,6 +55,86 @@ class WhisperBase(ABC):
47
  ):
48
  pass
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def transcribe_file(self,
51
  files: list,
52
  file_format: str,
@@ -80,7 +168,7 @@ class WhisperBase(ABC):
80
  try:
81
  files_info = {}
82
  for file in files:
83
- transcribed_segments, time_for_task = self.transcribe(
84
  file.name,
85
  progress,
86
  *whisper_params,
@@ -146,7 +234,7 @@ class WhisperBase(ABC):
146
  """
147
  try:
148
  progress(0, desc="Loading Audio..")
149
- transcribed_segments, time_for_task = self.transcribe(
150
  mic_audio,
151
  progress,
152
  *whisper_params,
@@ -204,7 +292,7 @@ class WhisperBase(ABC):
204
  yt = get_ytdata(youtube_link)
205
  audio = get_ytaudio(yt)
206
 
207
- transcribed_segments, time_for_task = self.transcribe(
208
  audio,
209
  progress,
210
  *whisper_params,
 
1
  import os
2
  import torch
3
  from typing import List
4
+ import whisperx
5
  import whisper
6
  import gradio as gr
7
  from abc import ABC, abstractmethod
8
  from typing import BinaryIO, Union, Tuple, List
9
  import numpy as np
10
  from datetime import datetime
11
+ from dataclasses import astuple
12
+ import time
13
 
14
  from modules.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
15
  from modules.youtube_manager import get_ytdata, get_ytaudio
 
24
  self.model = None
25
  self.current_model_size = None
26
  self.model_dir = model_dir
27
+ self.diarization_model_dir = os.path.join(self.model_dir, "..", "whisperx")
28
  self.output_dir = output_dir
29
  os.makedirs(self.output_dir, exist_ok=True)
30
  os.makedirs(self.model_dir, exist_ok=True)
31
+ os.makedirs(self.diarization_model_dir, exist_ok=True)
32
  self.available_models = whisper.available_models()
33
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
34
  self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
35
  self.device = self.get_device()
36
  self.available_compute_types = ["float16", "float32"]
37
  self.current_compute_type = "float16" if self.device == "cuda" else "float32"
38
+ self.diarization_model = None
39
+ self.diarization_model_metadata = None
40
+ self.diarization_pipe = None
41
 
42
  @abstractmethod
43
  def transcribe(self,
 
55
  ):
56
  pass
57
 
58
+ def run(self,
59
+ audio: Union[str, BinaryIO, np.ndarray],
60
+ progress: gr.Progress,
61
+ *whisper_params,
62
+ ):
63
+ params = WhisperParameters.post_process(*whisper_params)
64
+
65
+ if params.lang == "Automatic Detection":
66
+ params.lang = None
67
+ else:
68
+ language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
69
+ params.lang = language_code_dict[params.lang]
70
+
71
+ result, elapsed_time = self.transcribe(
72
+ audio,
73
+ progress,
74
+ *whisper_params
75
+ )
76
+
77
+ if params.is_diarize:
78
+ if params.lang is None:
79
+ print("Diarization Failed!! You have to specify the language explicitly to use diarization")
80
+ else:
81
+ result, elapsed_time_diarization = self.diarize(
82
+ audio=audio,
83
+ language_code=params.lang,
84
+ use_auth_token=params.hf_token,
85
+ transcribed_result=result
86
+ )
87
+ elapsed_time += elapsed_time_diarization
88
+ return result, elapsed_time
89
+
90
+ def diarize(self,
91
+ audio: str,
92
+ language_code: str,
93
+ use_auth_token: str,
94
+ transcribed_result: List[dict]
95
+ ):
96
+ start_time = time.time()
97
+
98
+ if (self.diarization_model is None or
99
+ self.diarization_model_metadata is None or
100
+ self.diarization_pipe is None):
101
+ self._update_diarization_model(
102
+ language_code=language_code,
103
+ use_auth_token=use_auth_token
104
+ )
105
+
106
+ audio = whisperx.load_audio(audio)
107
+ diarization_segments = self.diarization_pipe(audio)
108
+ diarized_result = whisperx.assign_word_speakers(
109
+ diarization_segments,
110
+ {"segments": transcribed_result}
111
+ )
112
+
113
+ for segment in diarized_result["segments"]:
114
+ speaker = "None"
115
+ if "speaker" in segment:
116
+ speaker = segment["speaker"]
117
+
118
+ segment["text"] = speaker + "|" + segment["text"][1:]
119
+
120
+ elapsed_time = time.time() - start_time
121
+ return diarized_result["segments"], elapsed_time
122
+
123
+ def _update_diarization_model(self,
124
+ use_auth_token: str,
125
+ language_code: str
126
+ ):
127
+ print("loading diarization model...")
128
+ self.diarization_model, self.diarization_model_metadata = whisperx.load_align_model(
129
+ language_code=language_code,
130
+ device=self.device,
131
+ model_dir=self.diarization_model_dir,
132
+ )
133
+ self.diarization_pipe = whisperx.DiarizationPipeline(
134
+ use_auth_token=use_auth_token,
135
+ device=self.device
136
+ )
137
+
138
  def transcribe_file(self,
139
  files: list,
140
  file_format: str,
 
168
  try:
169
  files_info = {}
170
  for file in files:
171
+ transcribed_segments, time_for_task = self.run(
172
  file.name,
173
  progress,
174
  *whisper_params,
 
234
  """
235
  try:
236
  progress(0, desc="Loading Audio..")
237
+ transcribed_segments, time_for_task = self.run(
238
  mic_audio,
239
  progress,
240
  *whisper_params,
 
292
  yt = get_ytdata(youtube_link)
293
  audio = get_ytaudio(yt)
294
 
295
+ transcribed_segments, time_for_task = self.run(
296
  audio,
297
  progress,
298
  *whisper_params,