jhj0517 commited on
Commit
6ff3ca6
·
1 Parent(s): 3ec9a9b

Add music separation pre-process to whisper base

Browse files
Files changed (1) hide show
  1. modules/whisper/whisper_base.py +17 -2
modules/whisper/whisper_base.py CHANGED
@@ -9,7 +9,9 @@ from datetime import datetime
9
  from faster_whisper.vad import VadOptions
10
  from dataclasses import astuple
11
 
12
- from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH)
 
 
13
  from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
14
  from modules.utils.youtube_manager import get_ytdata, get_ytaudio
15
  from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml
@@ -22,6 +24,7 @@ class WhisperBase(ABC):
22
  def __init__(self,
23
  model_dir: str = WHISPER_MODELS_DIR,
24
  diarization_model_dir: str = DIARIZATION_MODELS_DIR,
 
25
  output_dir: str = OUTPUT_DIR,
26
  ):
27
  self.model_dir = model_dir
@@ -32,6 +35,10 @@ class WhisperBase(ABC):
32
  model_dir=diarization_model_dir
33
  )
34
  self.vad = SileroVAD()
 
 
 
 
35
 
36
  self.model = None
37
  self.current_model_size = None
@@ -102,7 +109,15 @@ class WhisperBase(ABC):
102
  language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
103
  params.lang = language_code_dict[params.lang]
104
 
105
- speech_chunks = None
 
 
 
 
 
 
 
 
106
  if params.vad_filter:
107
  # Explicit value set for float('inf') from gr.Number()
108
  if params.max_speech_duration_s >= 9999:
 
9
  from faster_whisper.vad import VadOptions
10
  from dataclasses import astuple
11
 
12
+ from modules.uvr.music_separator import MusicSeparator
13
+ from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
14
+ UVR_MODELS_DIR)
15
  from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
16
  from modules.utils.youtube_manager import get_ytdata, get_ytaudio
17
  from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml
 
24
  def __init__(self,
25
  model_dir: str = WHISPER_MODELS_DIR,
26
  diarization_model_dir: str = DIARIZATION_MODELS_DIR,
27
+ uvr_model_dir: str = UVR_MODELS_DIR,
28
  output_dir: str = OUTPUT_DIR,
29
  ):
30
  self.model_dir = model_dir
 
35
  model_dir=diarization_model_dir
36
  )
37
  self.vad = SileroVAD()
38
+ self.music_separator = MusicSeparator(
39
+ model_dir=uvr_model_dir,
40
+ output_dir=os.path.join(output_dir, "UVR")
41
+ )
42
 
43
  self.model = None
44
  self.current_model_size = None
 
109
  language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
110
  params.lang = language_code_dict[params.lang]
111
 
112
+ if params.is_bgm_separate:
113
+ music, audio = self.music_separator.separate(
114
+ audio_file_path=audio,
115
+ model_name=params.uvr_model_size,
116
+ device=params.uvr_device,
117
+ segment_size=params.uvr_segment_size,
118
+ )
119
+ self.music_separator.offload()
120
+
121
  if params.vad_filter:
122
  # Explicit value set for float('inf') from gr.Number()
123
  if params.max_speech_duration_s >= 9999: