Spaces:
Running
Running
jhj0517
commited on
Commit
·
131f180
1
Parent(s):
80e4171
Add resamples
Browse files
modules/whisper/whisper_base.py
CHANGED
@@ -2,6 +2,7 @@ import os
|
|
2 |
import torch
|
3 |
import whisper
|
4 |
import gradio as gr
|
|
|
5 |
from abc import ABC, abstractmethod
|
6 |
from typing import BinaryIO, Union, Tuple, List
|
7 |
import numpy as np
|
@@ -111,12 +112,19 @@ class WhisperBase(ABC):
|
|
111 |
|
112 |
if params.is_bgm_separate:
|
113 |
music, audio = self.music_separator.separate(
|
114 |
-
|
115 |
model_name=params.uvr_model_size,
|
116 |
device=params.uvr_device,
|
117 |
segment_size=params.uvr_segment_size,
|
|
|
118 |
progress=progress
|
119 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
self.music_separator.offload()
|
121 |
|
122 |
if params.vad_filter:
|
@@ -473,3 +481,18 @@ class WhisperBase(ABC):
|
|
473 |
cached_yaml["whisper"]["add_timestamp"] = add_timestamp
|
474 |
|
475 |
save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import torch
|
3 |
import whisper
|
4 |
import gradio as gr
|
5 |
+
import torchaudio
|
6 |
from abc import ABC, abstractmethod
|
7 |
from typing import BinaryIO, Union, Tuple, List
|
8 |
import numpy as np
|
|
|
112 |
|
113 |
if params.is_bgm_separate:
|
114 |
music, audio = self.music_separator.separate(
|
115 |
+
audio=audio,
|
116 |
model_name=params.uvr_model_size,
|
117 |
device=params.uvr_device,
|
118 |
segment_size=params.uvr_segment_size,
|
119 |
+
save_file=params.uvr_save_file,
|
120 |
progress=progress
|
121 |
)
|
122 |
+
|
123 |
+
if audio.ndim >= 2:
|
124 |
+
audio = audio.mean(axis=1)
|
125 |
+
origin_sample_rate = self.music_separator.audio_info.sample_rate
|
126 |
+
audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate)
|
127 |
+
|
128 |
self.music_separator.offload()
|
129 |
|
130 |
if params.vad_filter:
|
|
|
481 |
cached_yaml["whisper"]["add_timestamp"] = add_timestamp
|
482 |
|
483 |
save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
|
484 |
+
|
485 |
+
@staticmethod
|
486 |
+
def resample_audio(audio: Union[str, np.ndarray],
|
487 |
+
new_sample_rate: int = 16000,
|
488 |
+
original_sample_rate: Optional[int] = None,) -> np.ndarray:
|
489 |
+
"""Resamples audio to 16k sample rate, standard on Whisper model"""
|
490 |
+
if isinstance(audio, str):
|
491 |
+
audio, original_sample_rate = torchaudio.load(audio)
|
492 |
+
else:
|
493 |
+
if original_sample_rate is None:
|
494 |
+
raise ValueError("original_sample_rate must be provided when audio is numpy array.")
|
495 |
+
audio = torch.from_numpy(audio)
|
496 |
+
resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=new_sample_rate)
|
497 |
+
resampled_audio = resampler(audio).numpy()
|
498 |
+
return resampled_audio
|