jhj0517 commited on
Commit
131f180
·
1 Parent(s): 80e4171

Add resamples

Browse files
Files changed (1) hide show
  1. modules/whisper/whisper_base.py +24 -1
modules/whisper/whisper_base.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import torch
3
  import whisper
4
  import gradio as gr
 
5
  from abc import ABC, abstractmethod
6
  from typing import BinaryIO, Union, Tuple, List
7
  import numpy as np
@@ -111,12 +112,19 @@ class WhisperBase(ABC):
111
 
112
  if params.is_bgm_separate:
113
  music, audio = self.music_separator.separate(
114
- audio_file_path=audio,
115
  model_name=params.uvr_model_size,
116
  device=params.uvr_device,
117
  segment_size=params.uvr_segment_size,
 
118
  progress=progress
119
  )
 
 
 
 
 
 
120
  self.music_separator.offload()
121
 
122
  if params.vad_filter:
@@ -473,3 +481,18 @@ class WhisperBase(ABC):
473
  cached_yaml["whisper"]["add_timestamp"] = add_timestamp
474
 
475
  save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
  import whisper
4
  import gradio as gr
5
+ import torchaudio
6
  from abc import ABC, abstractmethod
7
  from typing import BinaryIO, Union, Tuple, List
8
  import numpy as np
 
112
 
113
  if params.is_bgm_separate:
114
  music, audio = self.music_separator.separate(
115
+ audio=audio,
116
  model_name=params.uvr_model_size,
117
  device=params.uvr_device,
118
  segment_size=params.uvr_segment_size,
119
+ save_file=params.uvr_save_file,
120
  progress=progress
121
  )
122
+
123
+ if audio.ndim >= 2:
124
+ audio = audio.mean(axis=1)
125
+ origin_sample_rate = self.music_separator.audio_info.sample_rate
126
+ audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate)
127
+
128
  self.music_separator.offload()
129
 
130
  if params.vad_filter:
 
481
  cached_yaml["whisper"]["add_timestamp"] = add_timestamp
482
 
483
  save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
484
+
485
+ @staticmethod
486
+ def resample_audio(audio: Union[str, np.ndarray],
487
+ new_sample_rate: int = 16000,
488
+ original_sample_rate: Optional[int] = None,) -> np.ndarray:
489
+ """Resamples audio to 16k sample rate, standard on Whisper model"""
490
+ if isinstance(audio, str):
491
+ audio, original_sample_rate = torchaudio.load(audio)
492
+ else:
493
+ if original_sample_rate is None:
494
+ raise ValueError("original_sample_rate must be provided when audio is numpy array.")
495
+ audio = torch.from_numpy(audio)
496
+ resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=new_sample_rate)
497
+ resampled_audio = resampler(audio).numpy()
498
+ return resampled_audio