jhj0517 commited on
Commit
80e4171
·
1 Parent(s): 09fb62c

Update input type

Browse files
Files changed (1) hide show
  1. modules/uvr/music_separator.py +23 -16
modules/uvr/music_separator.py CHANGED
@@ -1,11 +1,13 @@
1
  # Credit to Team UVR : https://github.com/Anjok07/ultimatevocalremovergui
2
- from typing import Optional
 
3
  import torchaudio
4
  import soundfile as sf
5
  import os
6
  import torch
7
  import gc
8
  import gradio as gr
 
9
 
10
  from uvr.models import MDX, Demucs, VrNetwork, MDXC
11
 
@@ -55,21 +57,22 @@ class MusicSeparator:
55
  model_dir=self.model_dir)
56
 
57
  def separate(self,
58
- audio_file_path: str,
59
  model_name: str,
60
  device: Optional[str] = None,
61
  segment_size: int = 256,
 
62
  progress: gr.Progress = gr.Progress()):
63
- if device is None:
64
- device = self.device
65
-
66
- self.audio_info = torchaudio.info(audio_file_path)
67
- sample_rate = self.audio_info.sample_rate
68
 
69
- filename, ext = os.path.splitext(audio_file_path)
70
- filename, ext = os.path.basename(filename), ".wav"
71
- instrumental_output_path = os.path.join(self.output_dir, "instrumental", f"{filename}-instrumental{ext}")
72
- vocals_output_path = os.path.join(self.output_dir, "vocals", f"{filename}-vocals{ext}")
 
 
 
 
 
73
 
74
  model_config = {
75
  "segment": segment_size,
@@ -79,7 +82,8 @@ class MusicSeparator:
79
  if (self.model is None or
80
  self.current_model_size != model_name or
81
  self.model_config != model_config or
82
- self.audio_info.sample_rate != sample_rate):
 
83
  progress(0, desc="Initializing UVR Model..")
84
  self.update_model(
85
  model_name=model_name,
@@ -89,13 +93,16 @@ class MusicSeparator:
89
  self.model.sample_rate = sample_rate
90
 
91
  progress(0, desc="Separating background music from the audio..")
92
- result = self.model(audio_file_path)
93
  instrumental, vocals = result["instrumental"].T, result["vocals"].T
94
 
95
- sf.write(instrumental_output_path, instrumental, sample_rate, format="WAV")
96
- sf.write(vocals_output_path, vocals, sample_rate, format="WAV")
 
 
 
97
 
98
- return instrumental_output_path, vocals_output_path
99
 
100
  @staticmethod
101
  def get_device():
 
1
  # Credit to Team UVR : https://github.com/Anjok07/ultimatevocalremovergui
2
+ from typing import Optional, Union
3
+ import numpy as np
4
  import torchaudio
5
  import soundfile as sf
6
  import os
7
  import torch
8
  import gc
9
  import gradio as gr
10
+ from datetime import datetime
11
 
12
  from uvr.models import MDX, Demucs, VrNetwork, MDXC
13
 
 
57
  model_dir=self.model_dir)
58
 
59
  def separate(self,
60
+ audio: Union[str, np.ndarray],
61
  model_name: str,
62
  device: Optional[str] = None,
63
  segment_size: int = 256,
64
+ save_file: bool = False,
65
  progress: gr.Progress = gr.Progress()):
 
 
 
 
 
66
 
67
+ if isinstance(audio, str):
68
+ self.audio_info = torchaudio.info(audio)
69
+ sample_rate = self.audio_info.sample_rate
70
+ output_filename, ext = os.path.splitext(audio)
71
+ output_filename, ext = os.path.basename(audio), ".wav"
72
+ else:
73
+ sample_rate = 16000
74
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
75
+ output_filename, ext = f"UVR-{timestamp}", ".wav"
76
 
77
  model_config = {
78
  "segment": segment_size,
 
82
  if (self.model is None or
83
  self.current_model_size != model_name or
84
  self.model_config != model_config or
85
+ self.audio_info.sample_rate != sample_rate or
86
+ self.device != device):
87
  progress(0, desc="Initializing UVR Model..")
88
  self.update_model(
89
  model_name=model_name,
 
93
  self.model.sample_rate = sample_rate
94
 
95
  progress(0, desc="Separating background music from the audio..")
96
+ result = self.model(audio)
97
  instrumental, vocals = result["instrumental"].T, result["vocals"].T
98
 
99
+ if save_file:
100
+ instrumental_output_path = os.path.join(self.output_dir, "instrumental", f"{output_filename}-instrumental{ext}")
101
+ vocals_output_path = os.path.join(self.output_dir, "vocals", f"{output_filename}-vocals{ext}")
102
+ sf.write(instrumental_output_path, instrumental, sample_rate, format="WAV")
103
+ sf.write(vocals_output_path, vocals, sample_rate, format="WAV")
104
 
105
+ return instrumental, vocals
106
 
107
  @staticmethod
108
  def get_device():