Whisper-WebUI / modules /whisper_base.py
jhj0517
add diarization logic
6e99075
raw
history blame
15.1 kB
import os
import torch
from typing import List
import whisperx
import whisper
import gradio as gr
from abc import ABC, abstractmethod
from typing import BinaryIO, Union, Tuple, List
import numpy as np
from datetime import datetime
from dataclasses import astuple
import time
from modules.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
from modules.youtube_manager import get_ytdata, get_ytaudio
from modules.whisper_parameter import *
class WhisperBase(ABC):
def __init__(self,
model_dir: str,
output_dir: str
):
self.model = None
self.current_model_size = None
self.model_dir = model_dir
self.diarization_model_dir = os.path.join(self.model_dir, "..", "whisperx")
self.output_dir = output_dir
os.makedirs(self.output_dir, exist_ok=True)
os.makedirs(self.model_dir, exist_ok=True)
os.makedirs(self.diarization_model_dir, exist_ok=True)
self.available_models = whisper.available_models()
self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
self.device = self.get_device()
self.available_compute_types = ["float16", "float32"]
self.current_compute_type = "float16" if self.device == "cuda" else "float32"
self.diarization_model = None
self.diarization_model_metadata = None
self.diarization_pipe = None
@abstractmethod
def transcribe(self,
audio: Union[str, BinaryIO, np.ndarray],
progress: gr.Progress,
*whisper_params,
):
pass
@abstractmethod
def update_model(self,
model_size: str,
compute_type: str,
progress: gr.Progress
):
pass
def run(self,
audio: Union[str, BinaryIO, np.ndarray],
progress: gr.Progress,
*whisper_params,
):
params = WhisperParameters.post_process(*whisper_params)
if params.lang == "Automatic Detection":
params.lang = None
else:
language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
params.lang = language_code_dict[params.lang]
result, elapsed_time = self.transcribe(
audio,
progress,
*whisper_params
)
if params.is_diarize:
if params.lang is None:
print("Diarization Failed!! You have to specify the language explicitly to use diarization")
else:
result, elapsed_time_diarization = self.diarize(
audio=audio,
language_code=params.lang,
use_auth_token=params.hf_token,
transcribed_result=result
)
elapsed_time += elapsed_time_diarization
return result, elapsed_time
def diarize(self,
audio: str,
language_code: str,
use_auth_token: str,
transcribed_result: List[dict]
):
start_time = time.time()
if (self.diarization_model is None or
self.diarization_model_metadata is None or
self.diarization_pipe is None):
self._update_diarization_model(
language_code=language_code,
use_auth_token=use_auth_token
)
audio = whisperx.load_audio(audio)
diarization_segments = self.diarization_pipe(audio)
diarized_result = whisperx.assign_word_speakers(
diarization_segments,
{"segments": transcribed_result}
)
for segment in diarized_result["segments"]:
speaker = "None"
if "speaker" in segment:
speaker = segment["speaker"]
segment["text"] = speaker + "|" + segment["text"][1:]
elapsed_time = time.time() - start_time
return diarized_result["segments"], elapsed_time
def _update_diarization_model(self,
use_auth_token: str,
language_code: str
):
print("loading diarization model...")
self.diarization_model, self.diarization_model_metadata = whisperx.load_align_model(
language_code=language_code,
device=self.device,
model_dir=self.diarization_model_dir,
)
self.diarization_pipe = whisperx.DiarizationPipeline(
use_auth_token=use_auth_token,
device=self.device
)
def transcribe_file(self,
files: list,
file_format: str,
add_timestamp: bool,
progress=gr.Progress(),
*whisper_params,
) -> list:
"""
Write subtitle file from Files
Parameters
----------
files: list
List of files to transcribe from gr.Files()
file_format: str
Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
add_timestamp: bool
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
progress: gr.Progress
Indicator to show progress directly in gradio.
*whisper_params: tuple
Gradio components related to Whisper. see whisper_data_class.py for details.
Returns
----------
result_str:
Result of transcription to return to gr.Textbox()
result_file_path:
Output file path to return to gr.Files()
"""
try:
files_info = {}
for file in files:
transcribed_segments, time_for_task = self.run(
file.name,
progress,
*whisper_params,
)
file_name, file_ext = os.path.splitext(os.path.basename(file.name))
file_name = safe_filename(file_name)
subtitle, file_path = self.generate_and_write_file(
file_name=file_name,
transcribed_segments=transcribed_segments,
add_timestamp=add_timestamp,
file_format=file_format,
output_dir=self.output_dir
)
files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path}
total_result = ''
total_time = 0
for file_name, info in files_info.items():
total_result += '------------------------------------\n'
total_result += f'{file_name}\n\n'
total_result += f'{info["subtitle"]}'
total_time += info["time_for_task"]
result_str = f"Done in {self.format_time(total_time)}! Subtitle is in the outputs folder.\n\n{total_result}"
result_file_path = [info['path'] for info in files_info.values()]
return [result_str, result_file_path]
except Exception as e:
print(f"Error transcribing file: {e}")
finally:
self.release_cuda_memory()
if not files:
self.remove_input_files([file.name for file in files])
def transcribe_mic(self,
mic_audio: str,
file_format: str,
progress=gr.Progress(),
*whisper_params,
) -> list:
"""
Write subtitle file from microphone
Parameters
----------
mic_audio: str
Audio file path from gr.Microphone()
file_format: str
Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
progress: gr.Progress
Indicator to show progress directly in gradio.
*whisper_params: tuple
Gradio components related to Whisper. see whisper_data_class.py for details.
Returns
----------
result_str:
Result of transcription to return to gr.Textbox()
result_file_path:
Output file path to return to gr.Files()
"""
try:
progress(0, desc="Loading Audio..")
transcribed_segments, time_for_task = self.run(
mic_audio,
progress,
*whisper_params,
)
progress(1, desc="Completed!")
subtitle, result_file_path = self.generate_and_write_file(
file_name="Mic",
transcribed_segments=transcribed_segments,
add_timestamp=True,
file_format=file_format,
output_dir=self.output_dir
)
result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
return [result_str, result_file_path]
except Exception as e:
print(f"Error transcribing file: {e}")
finally:
self.release_cuda_memory()
self.remove_input_files([mic_audio])
def transcribe_youtube(self,
youtube_link: str,
file_format: str,
add_timestamp: bool,
progress=gr.Progress(),
*whisper_params,
) -> list:
"""
Write subtitle file from Youtube
Parameters
----------
youtube_link: str
URL of the Youtube video to transcribe from gr.Textbox()
file_format: str
Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
add_timestamp: bool
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
progress: gr.Progress
Indicator to show progress directly in gradio.
*whisper_params: tuple
Gradio components related to Whisper. see whisper_data_class.py for details.
Returns
----------
result_str:
Result of transcription to return to gr.Textbox()
result_file_path:
Output file path to return to gr.Files()
"""
try:
progress(0, desc="Loading Audio from Youtube..")
yt = get_ytdata(youtube_link)
audio = get_ytaudio(yt)
transcribed_segments, time_for_task = self.run(
audio,
progress,
*whisper_params,
)
progress(1, desc="Completed!")
file_name = safe_filename(yt.title)
subtitle, result_file_path = self.generate_and_write_file(
file_name=file_name,
transcribed_segments=transcribed_segments,
add_timestamp=add_timestamp,
file_format=file_format,
output_dir=self.output_dir
)
result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
return [result_str, result_file_path]
except Exception as e:
print(f"Error transcribing file: {e}")
finally:
try:
if 'yt' not in locals():
yt = get_ytdata(youtube_link)
file_path = get_ytaudio(yt)
else:
file_path = get_ytaudio(yt)
self.release_cuda_memory()
self.remove_input_files([file_path])
except Exception as cleanup_error:
pass
@staticmethod
def generate_and_write_file(file_name: str,
transcribed_segments: list,
add_timestamp: bool,
file_format: str,
output_dir: str
) -> str:
"""
Writes subtitle file
Parameters
----------
file_name: str
Output file name
transcribed_segments: list
Text segments transcribed from audio
add_timestamp: bool
Determines whether to add a timestamp to the end of the filename.
file_format: str
File format to write. Supported formats: [SRT, WebVTT, txt]
output_dir: str
Directory path of the output
Returns
----------
content: str
Result of the transcription
output_path: str
output file path
"""
timestamp = datetime.now().strftime("%m%d%H%M%S")
if add_timestamp:
output_path = os.path.join(output_dir, f"{file_name}-{timestamp}")
else:
output_path = os.path.join(output_dir, f"{file_name}")
if file_format == "SRT":
content = get_srt(transcribed_segments)
output_path += '.srt'
write_file(content, output_path)
elif file_format == "WebVTT":
content = get_vtt(transcribed_segments)
output_path += '.vtt'
write_file(content, output_path)
elif file_format == "txt":
content = get_txt(transcribed_segments)
output_path += '.txt'
write_file(content, output_path)
return content, output_path
@staticmethod
def format_time(elapsed_time: float) -> str:
"""
Get {hours} {minutes} {seconds} time format string
Parameters
----------
elapsed_time: str
Elapsed time for transcription
Returns
----------
Time format string
"""
hours, rem = divmod(elapsed_time, 3600)
minutes, seconds = divmod(rem, 60)
time_str = ""
if hours:
time_str += f"{hours} hours "
if minutes:
time_str += f"{minutes} minutes "
seconds = round(seconds)
time_str += f"{seconds} seconds"
return time_str.strip()
@staticmethod
def get_device():
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
else:
return "cpu"
@staticmethod
def release_cuda_memory():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
@staticmethod
def remove_input_files(file_paths: List[str]):
if not file_paths:
return
for file_path in file_paths:
if file_path and os.path.exists(file_path):
os.remove(file_path)