File size: 4,770 Bytes
80e4171
 
b14146c
d317550
 
 
da960ac
c5d0765
80e4171
d317550
 
 
 
 
 
 
 
 
 
e2bed65
d317550
 
b14146c
7baa848
e2bed65
 
 
 
 
 
d317550
 
 
e2bed65
d317550
da960ac
 
 
 
 
 
 
 
e2bed65
 
 
 
 
 
 
 
 
 
d317550
 
e2bed65
d317550
 
80e4171
e2bed65
 
c5d0765
80e4171
8fd7f62
 
 
b14146c
8fd7f62
 
 
 
 
 
 
 
 
 
 
80e4171
 
 
 
 
 
 
 
 
770e792
e2bed65
 
 
 
 
da960ac
 
 
80e4171
 
c5d0765
e2bed65
 
 
 
 
5306f0a
d317550
c5d0765
80e4171
b7fa700
 
80e4171
 
 
 
 
d317550
80e4171
d317550
 
 
8fd7f62
da960ac
 
 
8fd7f62
da960ac
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from typing import Optional, Union
import numpy as np
import torchaudio
import soundfile as sf
import os
import torch
import gc
import gradio as gr
from datetime import datetime

from uvr.models import MDX, Demucs, VrNetwork, MDXC


class MusicSeparator:
    def __init__(self,
                 model_dir: Optional[str] = None,
                 output_dir: Optional[str] = None):
        self.model = None
        self.device = self.get_device()
        self.available_devices = ["cpu", "cuda"]
        self.model_dir = model_dir
        self.output_dir = output_dir
        self.audio_info = None
        self.available_models = ["UVR-MDX-NET-Inst_HQ_4", "UVR-MDX-NET-Inst_3"]
        self.default_model = self.available_models[0]
        self.current_model_size = self.default_model
        self.model_config = {
            "segment": 256,
            "split": True
        }

    def update_model(self,
                     model_name: str = "UVR-MDX-NET-Inst_1",
                     device: Optional[str] = None,
                     segment_size: int = 256):
        """
        Update model with the given model name

        Args:
            model_name (str): Model name.
            device (str): Device to use for the model.
            segment_size (int): Segment size for the prediction.
        """
        if device is None:
            device = self.device

        self.device = device
        self.model_config = {
            "segment": segment_size,
            "split": True
        }
        self.model = MDX(name=model_name,
                         other_metadata=self.model_config,
                         device=self.device,
                         logger=None,
                         model_dir=self.model_dir)

    def separate(self,
                 audio: Union[str, np.ndarray],
                 model_name: str,
                 device: Optional[str] = None,
                 segment_size: int = 256,
                 save_file: bool = False,
                 progress: gr.Progress = gr.Progress()) -> tuple[np.ndarray, np.ndarray]:
        """
        Separate the background music from the audio.

        Args:
            audio (Union[str, np.ndarray]): Audio path or numpy array.
            model_name (str): Model name.
            device (str): Device to use for the model.
            segment_size (int): Segment size for the prediction.
            save_file (bool): Whether to save the separated audio to output path or not.
            progress (gr.Progress): Gradio progress indicator.

        Returns:
            tuple[np.ndarray, np.ndarray]: Instrumental and vocals numpy arrays.
        """
        if isinstance(audio, str):
            self.audio_info = torchaudio.info(audio)
            sample_rate = self.audio_info.sample_rate
            output_filename, ext = os.path.splitext(audio)
            output_filename, ext = os.path.basename(audio), ".wav"
        else:
            sample_rate = 16000
            timestamp = datetime.now().strftime("%m%d%H%M%S")
            output_filename, ext = f"UVR-{timestamp}", ".wav"

        model_config = {
            "segment": segment_size,
            "split": True
        }

        if (self.model is None or
                self.current_model_size != model_name or
                self.model_config != model_config or
                self.audio_info.sample_rate != sample_rate or
                self.device != device):
            progress(0, desc="Initializing UVR Model..")
            self.update_model(
                model_name=model_name,
                device=device,
                segment_size=segment_size
            )
            self.model.sample_rate = sample_rate

        progress(0, desc="Separating background music from the audio..")
        result = self.model(audio)
        instrumental, vocals = result["instrumental"].T, result["vocals"].T

        if save_file:
            instrumental_output_path = os.path.join(self.output_dir, "instrumental", f"{output_filename}-instrumental{ext}")
            vocals_output_path = os.path.join(self.output_dir, "vocals", f"{output_filename}-vocals{ext}")
            sf.write(instrumental_output_path, instrumental, sample_rate, format="WAV")
            sf.write(vocals_output_path, vocals, sample_rate, format="WAV")

        return instrumental, vocals

    @staticmethod
    def get_device():
        """Get device for the model"""
        return "cuda" if torch.cuda.is_available() else "cpu"

    def offload(self):
        """Offload the model and free up the memory"""
        if self.model is not None:
            del self.model
            self.model = None
        if self.device == "cuda":
            torch.cuda.empty_cache()
        gc.collect()
        self.audio_info = None