Whisper-WebUI / modules /uvr /music_separator.py
jhj0517
Update input type
80e4171
raw
history blame
4.14 kB
# Credit to Team UVR : https://github.com/Anjok07/ultimatevocalremovergui
from typing import Optional, Union
import numpy as np
import torchaudio
import soundfile as sf
import os
import torch
import gc
import gradio as gr
from datetime import datetime
from uvr.models import MDX, Demucs, VrNetwork, MDXC
class MusicSeparator:
def __init__(self,
model_dir: Optional[str] = None,
output_dir: Optional[str] = None):
self.model = None
self.device = self.get_device()
self.available_devices = ["cpu", "cuda"]
self.model_dir = model_dir
self.output_dir = output_dir
self.audio_info = None
self.available_models = ["UVR-MDX-NET-Inst_1", "UVR-MDX-NET-Inst_HQ_1"]
self.default_model = self.available_models[0]
self.current_model_size = self.default_model
self.model_config = {
"segment": 256,
"split": True
}
def update_model(self,
model_name: str = "UVR-MDX-NET-Inst_1",
device: Optional[str] = None,
segment_size: int = 256):
"""
Update model with the given model name
Args:
model_name (str): Model name.
device (str): Device to use for the model.
segment_size (int): Segment size for the prediction.
"""
if device is None:
device = self.device
self.device = device
self.model_config = {
"segment": segment_size,
"split": True
}
self.model = MDX(name=model_name,
other_metadata=self.model_config,
device=self.device,
logger=None,
model_dir=self.model_dir)
def separate(self,
audio: Union[str, np.ndarray],
model_name: str,
device: Optional[str] = None,
segment_size: int = 256,
save_file: bool = False,
progress: gr.Progress = gr.Progress()):
if isinstance(audio, str):
self.audio_info = torchaudio.info(audio)
sample_rate = self.audio_info.sample_rate
output_filename, ext = os.path.splitext(audio)
output_filename, ext = os.path.basename(audio), ".wav"
else:
sample_rate = 16000
timestamp = datetime.now().strftime("%m%d%H%M%S")
output_filename, ext = f"UVR-{timestamp}", ".wav"
model_config = {
"segment": segment_size,
"split": True
}
if (self.model is None or
self.current_model_size != model_name or
self.model_config != model_config or
self.audio_info.sample_rate != sample_rate or
self.device != device):
progress(0, desc="Initializing UVR Model..")
self.update_model(
model_name=model_name,
device=device,
segment_size=segment_size
)
self.model.sample_rate = sample_rate
progress(0, desc="Separating background music from the audio..")
result = self.model(audio)
instrumental, vocals = result["instrumental"].T, result["vocals"].T
if save_file:
instrumental_output_path = os.path.join(self.output_dir, "instrumental", f"{output_filename}-instrumental{ext}")
vocals_output_path = os.path.join(self.output_dir, "vocals", f"{output_filename}-vocals{ext}")
sf.write(instrumental_output_path, instrumental, sample_rate, format="WAV")
sf.write(vocals_output_path, vocals, sample_rate, format="WAV")
return instrumental, vocals
@staticmethod
def get_device():
return "cuda" if torch.cuda.is_available() else "cpu"
def offload(self):
if self.model is not None:
del self.model
self.model = None
if self.device == "cuda":
torch.cuda.empty_cache()
gc.collect()
self.audio_info = None