Whisper-WebUI / modules /uvr /music_separator.py
jhj0517
Fix file name
299923f
raw
history blame
3.82 kB
# Credit to Team UVR : https://github.com/Anjok07/ultimatevocalremovergui
from typing import Optional
import torchaudio
import soundfile as sf
import os
import torch
import gc
import gradio as gr
from uvr.models import MDX, Demucs, VrNetwork, MDXC
class MusicSeparator:
def __init__(self,
model_dir: Optional[str] = None,
output_dir: Optional[str] = None):
self.model = None
self.device = self.get_device()
self.available_devices = ["cpu", "cuda"]
self.model_dir = model_dir
self.output_dir = output_dir
self.audio_info = None
self.available_models = ["UVR-MDX-NET-Inst_1", "UVR-MDX-NET-Inst_HQ_1"]
self.default_model = self.available_models[0]
self.current_model_size = self.default_model
self.model_config = {
"segment": 256,
"split": True
}
def update_model(self,
model_name: str = "UVR-MDX-NET-Inst_1",
device: Optional[str] = None,
segment_size: int = 256):
"""
Update model with the given model name
Args:
model_name (str): Model name.
device (str): Device to use for the model.
segment_size (int): Segment size for the prediction.
"""
if device is None:
device = self.device
self.device = device
self.model_config = {
"segment": segment_size,
"split": True
}
self.model = MDX(name=model_name,
other_metadata=self.model_config,
device=self.device,
logger=None,
model_dir=self.model_dir)
def separate(self,
audio_file_path: str,
model_name: str,
device: Optional[str] = None,
segment_size: int = 256,
progress: gr.Progress = gr.Progress()):
if device is None:
device = self.device
self.audio_info = torchaudio.info(audio_file_path)
sample_rate = self.audio_info.sample_rate
filename, ext = os.path.splitext(audio_file_path)
filename, ext = os.path.basename(filename), ".wav"
instrumental_output_path = os.path.join(self.output_dir, "instrumental", f"{filename}-instrumental{ext}")
vocals_output_path = os.path.join(self.output_dir, "vocals", f"{filename}-vocals{ext}")
model_config = {
"segment": segment_size,
"split": True
}
if (self.model is None or
self.current_model_size != model_name or
self.model_config != model_config or
self.audio_info.sample_rate != sample_rate):
progress(0, desc="Initializing UVR Model..")
self.update_model(
model_name=model_name,
device=device,
segment_size=segment_size
)
self.model.sample_rate = sample_rate
progress(0, desc="Separating background music from the audio..")
result = self.model(audio_file_path)
instrumental, vocals = result["instrumental"].T, result["vocals"].T
sf.write(instrumental_output_path, instrumental, sample_rate, format="WAV")
sf.write(vocals_output_path, vocals, sample_rate, format="WAV")
return instrumental_output_path, vocals_output_path
@staticmethod
def get_device():
return "cuda" if torch.cuda.is_available() else "cpu"
def offload(self):
if self.model is not None:
del self.model
self.model = None
if self.device == "cuda":
torch.cuda.empty_cache()
gc.collect()
self.audio_info = None