Spaces:
Running
Running
import os | |
import time | |
import numpy as np | |
import torch | |
from typing import BinaryIO, Union, Tuple, List | |
import faster_whisper | |
from faster_whisper.vad import VadOptions | |
import ast | |
import ctranslate2 | |
import whisper | |
import gradio as gr | |
from argparse import Namespace | |
from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR) | |
from modules.whisper.whisper_parameter import * | |
from modules.whisper.whisper_base import WhisperBase | |
class FasterWhisperInference(WhisperBase): | |
def __init__(self, | |
model_dir: str = FASTER_WHISPER_MODELS_DIR, | |
diarization_model_dir: str = DIARIZATION_MODELS_DIR, | |
uvr_model_dir: str = UVR_MODELS_DIR, | |
output_dir: str = OUTPUT_DIR, | |
): | |
super().__init__( | |
model_dir=model_dir, | |
diarization_model_dir=diarization_model_dir, | |
uvr_model_dir=uvr_model_dir, | |
output_dir=output_dir | |
) | |
self.model_dir = model_dir | |
os.makedirs(self.model_dir, exist_ok=True) | |
self.model_paths = self.get_model_paths() | |
self.device = self.get_device() | |
self.available_models = self.model_paths.keys() | |
self.available_compute_types = ctranslate2.get_supported_compute_types( | |
"cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu") | |
def transcribe(self, | |
audio: Union[str, BinaryIO, np.ndarray], | |
progress: gr.Progress = gr.Progress(), | |
*whisper_params, | |
) -> Tuple[List[dict], float]: | |
""" | |
transcribe method for faster-whisper. | |
Parameters | |
---------- | |
audio: Union[str, BinaryIO, np.ndarray] | |
Audio path or file binary or Audio numpy array | |
progress: gr.Progress | |
Indicator to show progress directly in gradio. | |
*whisper_params: tuple | |
Parameters related with whisper. This will be dealt with "WhisperParameters" data class | |
Returns | |
---------- | |
segments_result: List[dict] | |
list of dicts that includes start, end timestamps and transcribed text | |
elapsed_time: float | |
elapsed time for transcription | |
""" | |
start_time = time.time() | |
params = WhisperParameters.as_value(*whisper_params) | |
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type: | |
self.update_model(params.model_size, params.compute_type, progress) | |
# None parameters with Textboxes: https://github.com/gradio-app/gradio/issues/8723 | |
if not params.initial_prompt: | |
params.initial_prompt = None | |
if not params.prefix: | |
params.prefix = None | |
if not params.hotwords: | |
params.hotwords = None | |
params.suppress_tokens = self.format_suppress_tokens_str(params.suppress_tokens) | |
segments, info = self.model.transcribe( | |
audio=audio, | |
language=params.lang, | |
task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe", | |
beam_size=params.beam_size, | |
log_prob_threshold=params.log_prob_threshold, | |
no_speech_threshold=params.no_speech_threshold, | |
best_of=params.best_of, | |
patience=params.patience, | |
temperature=params.temperature, | |
initial_prompt=params.initial_prompt, | |
compression_ratio_threshold=params.compression_ratio_threshold, | |
length_penalty=params.length_penalty, | |
repetition_penalty=params.repetition_penalty, | |
no_repeat_ngram_size=params.no_repeat_ngram_size, | |
prefix=params.prefix, | |
suppress_blank=params.suppress_blank, | |
suppress_tokens=params.suppress_tokens, | |
max_initial_timestamp=params.max_initial_timestamp, | |
word_timestamps=params.word_timestamps, | |
prepend_punctuations=params.prepend_punctuations, | |
append_punctuations=params.append_punctuations, | |
max_new_tokens=params.max_new_tokens, | |
chunk_length=params.chunk_length, | |
hallucination_silence_threshold=params.hallucination_silence_threshold, | |
hotwords=params.hotwords, | |
language_detection_threshold=params.language_detection_threshold, | |
language_detection_segments=params.language_detection_segments, | |
prompt_reset_on_temperature=params.prompt_reset_on_temperature, | |
) | |
progress(0, desc="Loading audio..") | |
segments_result = [] | |
for segment in segments: | |
progress(segment.start / info.duration, desc="Transcribing..") | |
segments_result.append({ | |
"start": segment.start, | |
"end": segment.end, | |
"text": segment.text | |
}) | |
elapsed_time = time.time() - start_time | |
return segments_result, elapsed_time | |
def update_model(self, | |
model_size: str, | |
compute_type: str, | |
progress: gr.Progress = gr.Progress() | |
): | |
""" | |
Update current model setting | |
Parameters | |
---------- | |
model_size: str | |
Size of whisper model | |
compute_type: str | |
Compute type for transcription. | |
see more info : https://opennmt.net/CTranslate2/quantization.html | |
progress: gr.Progress | |
Indicator to show progress directly in gradio. | |
""" | |
progress(0, desc="Initializing Model..") | |
self.current_model_size = self.model_paths[model_size] | |
self.current_compute_type = compute_type | |
self.model = faster_whisper.WhisperModel( | |
device=self.device, | |
model_size_or_path=self.current_model_size, | |
download_root=self.model_dir, | |
compute_type=self.current_compute_type | |
) | |
def get_model_paths(self): | |
""" | |
Get available models from models path including fine-tuned model. | |
Returns | |
---------- | |
Name list of models | |
""" | |
model_paths = {model:model for model in whisper.available_models()} | |
faster_whisper_prefix = "models--Systran--faster-whisper-" | |
existing_models = os.listdir(self.model_dir) | |
wrong_dirs = [".locks"] | |
existing_models = list(set(existing_models) - set(wrong_dirs)) | |
for model_name in existing_models: | |
if faster_whisper_prefix in model_name: | |
model_name = model_name[len(faster_whisper_prefix):] | |
if model_name not in whisper.available_models(): | |
model_paths[model_name] = os.path.join(self.model_dir, model_name) | |
return model_paths | |
def get_device(): | |
if torch.cuda.is_available(): | |
return "cuda" | |
else: | |
return "auto" | |
def format_suppress_tokens_str(suppress_tokens_str: str) -> List[int]: | |
try: | |
suppress_tokens = ast.literal_eval(suppress_tokens_str) | |
if not isinstance(suppress_tokens, list) or not all(isinstance(item, int) for item in suppress_tokens): | |
raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]") | |
return suppress_tokens | |
except Exception as e: | |
raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]") | |