Spaces:
Running
Running
import os | |
import time | |
import numpy as np | |
from typing import BinaryIO, Union, Tuple, List | |
import torch | |
from transformers import pipeline | |
from transformers.utils import is_flash_attn_2_available | |
import gradio as gr | |
import wget | |
from modules.whisper_parameter import * | |
from modules.whisper_base import WhisperBase | |
class InsanelyFastWhisperInference(WhisperBase): | |
def __init__(self): | |
super().__init__( | |
model_dir=os.path.join("models", "Whisper", "insanely_fast_whisper") | |
) | |
self.available_compute_types = ["float16"] | |
def transcribe(self, | |
audio: Union[str, np.ndarray, torch.Tensor], | |
progress: gr.Progress, | |
*whisper_params, | |
) -> Tuple[List[dict], float]: | |
""" | |
transcribe method for faster-whisper. | |
Parameters | |
---------- | |
audio: Union[str, BinaryIO, np.ndarray] | |
Audio path or file binary or Audio numpy array | |
progress: gr.Progress | |
Indicator to show progress directly in gradio. | |
*whisper_params: tuple | |
Gradio components related to Whisper. see whisper_data_class.py for details. | |
Returns | |
---------- | |
segments_result: List[dict] | |
list of dicts that includes start, end timestamps and transcribed text | |
elapsed_time: float | |
elapsed time for transcription | |
""" | |
start_time = time.time() | |
params = WhisperValues(*whisper_params) | |
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type: | |
self.update_model(params.model_size, params.compute_type, progress) | |
if params.lang == "Automatic Detection": | |
params.lang = None | |
progress(0, desc="Transcribing...") | |
segments = self.model( | |
inputs=audio, | |
chunk_length_s=30, | |
batch_size=24, | |
return_timestamps=True, | |
) | |
segments_result = self.format_result( | |
transcribed_result=segments, | |
) | |
elapsed_time = time.time() - start_time | |
return segments_result, elapsed_time | |
def update_model(self, | |
model_size: str, | |
compute_type: str, | |
progress: gr.Progress, | |
): | |
""" | |
Update current model setting | |
Parameters | |
---------- | |
model_size: str | |
Size of whisper model | |
compute_type: str | |
Compute type for transcription. | |
see more info : https://opennmt.net/CTranslate2/quantization.html | |
progress: gr.Progress | |
Indicator to show progress directly in gradio. | |
""" | |
progress(0, desc="Initializing Model..") | |
model_path = os.path.join(self.model_dir, model_size) | |
if not os.path.isdir(model_path) or not os.listdir(model_path): | |
self.download_model( | |
model_size=model_size, | |
download_root=model_path, | |
progress=progress | |
) | |
self.current_compute_type = compute_type | |
self.current_model_size = model_size | |
self.model = pipeline( | |
"automatic-speech-recognition", | |
model=os.path.join(self.model_dir, model_size), | |
torch_dtype=self.current_compute_type, | |
device=self.device, | |
model_kwargs={"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {"attn_implementation": "sdpa"}, | |
) | |
def format_result( | |
transcribed_result: dict | |
) -> List[dict]: | |
""" | |
Format the transcription result of insanely_fast_whisper as the same with other implementation. | |
Parameters | |
---------- | |
transcribed_result: dict | |
Transcription result of the insanely_fast_whisper | |
progress: gr.Progress | |
Indicator to show progress directly in gradio. | |
Returns | |
---------- | |
result: List[dict] | |
Formatted result as the same with other implementation | |
""" | |
result = transcribed_result["chunks"] | |
for item in result: | |
start, end = item["timestamp"][0], item["timestamp"][1] | |
item["start"] = start | |
item["end"] = end | |
return result | |
def download_model( | |
model_size: str, | |
download_root: str, | |
progress: gr.Progress | |
): | |
progress(0, 'Initializing model..') | |
print(f'Downloading {model_size} to "{download_root}"....') | |
os.makedirs(download_root, exist_ok=True) | |
download_list = [ | |
"model.safetensors", | |
"config.json", | |
"generation_config.json", | |
"preprocessor_config.json", | |
"tokenizer.json", | |
"tokenizer_config.json", | |
"added_tokens.json", | |
"special_tokens_map.json", | |
"vocab.json", | |
] | |
download_host = f"https://huggingface.co/openai/whisper-{model_size}/resolve/main" | |
for item in download_list: | |
wget.download( | |
download_host+"/"+item, | |
download_root | |
) | |