File size: 4,141 Bytes
d317550
80e4171
 
b14146c
d317550
 
 
da960ac
c5d0765
80e4171
d317550
 
 
 
 
 
 
 
 
 
e2bed65
d317550
 
b14146c
e2bed65
 
 
 
 
 
 
d317550
 
 
e2bed65
d317550
da960ac
 
 
 
 
 
 
 
e2bed65
 
 
 
 
 
 
 
 
 
d317550
 
e2bed65
d317550
 
80e4171
e2bed65
 
c5d0765
80e4171
c5d0765
b14146c
80e4171
 
 
 
 
 
 
 
 
770e792
e2bed65
 
 
 
 
da960ac
 
 
80e4171
 
c5d0765
e2bed65
 
 
 
 
5306f0a
d317550
c5d0765
80e4171
b7fa700
 
80e4171
 
 
 
 
d317550
80e4171
d317550
 
 
da960ac
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Credit to Team UVR : https://github.com/Anjok07/ultimatevocalremovergui
from typing import Optional, Union
import numpy as np
import torchaudio
import soundfile as sf
import os
import torch
import gc
import gradio as gr
from datetime import datetime

from uvr.models import MDX, Demucs, VrNetwork, MDXC


class MusicSeparator:
    def __init__(self,
                 model_dir: Optional[str] = None,
                 output_dir: Optional[str] = None):
        self.model = None
        self.device = self.get_device()
        self.available_devices = ["cpu", "cuda"]
        self.model_dir = model_dir
        self.output_dir = output_dir
        self.audio_info = None
        self.available_models = ["UVR-MDX-NET-Inst_1", "UVR-MDX-NET-Inst_HQ_1"]
        self.default_model = self.available_models[0]
        self.current_model_size = self.default_model
        self.model_config = {
            "segment": 256,
            "split": True
        }

    def update_model(self,
                     model_name: str = "UVR-MDX-NET-Inst_1",
                     device: Optional[str] = None,
                     segment_size: int = 256):
        """
        Update model with the given model name

        Args:
            model_name (str): Model name.
            device (str): Device to use for the model.
            segment_size (int): Segment size for the prediction.
        """
        if device is None:
            device = self.device

        self.device = device
        self.model_config = {
            "segment": segment_size,
            "split": True
        }
        self.model = MDX(name=model_name,
                         other_metadata=self.model_config,
                         device=self.device,
                         logger=None,
                         model_dir=self.model_dir)

    def separate(self,
                 audio: Union[str, np.ndarray],
                 model_name: str,
                 device: Optional[str] = None,
                 segment_size: int = 256,
                 save_file: bool = False,
                 progress: gr.Progress = gr.Progress()):

        if isinstance(audio, str):
            self.audio_info = torchaudio.info(audio)
            sample_rate = self.audio_info.sample_rate
            output_filename, ext = os.path.splitext(audio)
            output_filename, ext = os.path.basename(audio), ".wav"
        else:
            sample_rate = 16000
            timestamp = datetime.now().strftime("%m%d%H%M%S")
            output_filename, ext = f"UVR-{timestamp}", ".wav"

        model_config = {
            "segment": segment_size,
            "split": True
        }

        if (self.model is None or
                self.current_model_size != model_name or
                self.model_config != model_config or
                self.audio_info.sample_rate != sample_rate or
                self.device != device):
            progress(0, desc="Initializing UVR Model..")
            self.update_model(
                model_name=model_name,
                device=device,
                segment_size=segment_size
            )
            self.model.sample_rate = sample_rate

        progress(0, desc="Separating background music from the audio..")
        result = self.model(audio)
        instrumental, vocals = result["instrumental"].T, result["vocals"].T

        if save_file:
            instrumental_output_path = os.path.join(self.output_dir, "instrumental", f"{output_filename}-instrumental{ext}")
            vocals_output_path = os.path.join(self.output_dir, "vocals", f"{output_filename}-vocals{ext}")
            sf.write(instrumental_output_path, instrumental, sample_rate, format="WAV")
            sf.write(vocals_output_path, vocals, sample_rate, format="WAV")

        return instrumental, vocals

    @staticmethod
    def get_device():
        return "cuda" if torch.cuda.is_available() else "cpu"

    def offload(self):
        if self.model is not None:
            del self.model
            self.model = None
        if self.device == "cuda":
            torch.cuda.empty_cache()
        gc.collect()
        self.audio_info = None