|
import gradio as gr |
|
import whisper |
|
import torch |
|
import pyannote.audio |
|
from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding |
|
from pyannote.audio import Audio |
|
from pyannote.core import Segment |
|
import subprocess |
|
import wave |
|
import contextlib |
|
from sklearn.cluster import AgglomerativeClustering |
|
import numpy as np |
|
import datetime |
|
|
|
|
|
device = torch.device("cpu") |
|
embedding_model = PretrainedSpeakerEmbedding( |
|
"speechbrain/spkrec-ecapa-voxceleb", |
|
device=device |
|
) |
|
audio_processor = Audio() |
|
|
|
|
|
def process_audio(audio_file, num_speakers, model_size="medium", language="English"): |
|
|
|
path = "/tmp/uploaded_audio.wav" |
|
with open(path, "wb") as f: |
|
f.write(audio_file.read()) |
|
|
|
print(f"Audio file saved to: {path}") |
|
|
|
|
|
if path[-3:] != 'wav': |
|
wav_path = path.replace(path.split('.')[-1], 'wav') |
|
subprocess.call(['ffmpeg', '-i', path, wav_path, '-y']) |
|
path = wav_path |
|
|
|
print(f"Audio converted to: {path}") |
|
|
|
|
|
try: |
|
model = whisper.load_model(model_size) |
|
print("Whisper model loaded successfully.") |
|
except Exception as e: |
|
print(f"Error loading Whisper model: {e}") |
|
return f"Error loading Whisper model: {e}" |
|
|
|
try: |
|
result = model.transcribe(path) |
|
print(f"Transcription result: {result}") |
|
except Exception as e: |
|
print(f"Error during transcription: {e}") |
|
return f"Error during transcription: {e}" |
|
|
|
segments = result["segments"] |
|
|
|
|
|
|
|
|
|
|
|
with contextlib.closing(wave.open(path, 'r')) as f: |
|
frames = f.getnframes() |
|
rate = f.getframerate() |
|
duration = frames / float(rate) |
|
|
|
|
|
def segment_embedding(segment): |
|
start = segment["start"] |
|
end = min(duration, segment["end"]) |
|
clip = Segment(start, end) |
|
waveform, sample_rate = audio_processor.crop(path, clip) |
|
return embedding_model(waveform[None]) |
|
|
|
embeddings = np.zeros(shape=(len(segments), 192)) |
|
for i, segment in enumerate(segments): |
|
embeddings[i] = segment_embedding(segment) |
|
|
|
embeddings = np.nan_to_num(embeddings) |
|
|
|
|
|
clustering = AgglomerativeClustering(n_clusters=num_speakers).fit(embeddings) |
|
labels = clustering.labels_ |
|
for i in range(len(segments)): |
|
segments[i]["speaker"] = 'SPEAKER ' + str(labels[i] + 1) |
|
|
|
|
|
def time(secs): |
|
return str(datetime.timedelta(seconds=round(secs))) |
|
|
|
transcript = [] |
|
for i, segment in enumerate(segments): |
|
if i == 0 or segments[i - 1]["speaker"] != segment["speaker"]: |
|
transcript.append(f"\n{segment['speaker']} {time(segment['start'])}") |
|
transcript.append(segment["text"][1:]) |
|
|
|
|
|
return "\n".join(transcript) |
|
|
|
|
|
def diarize(audio_file, num_speakers, model_size="medium"): |
|
return process_audio(audio_file, num_speakers, model_size) |
|
|
|
|
|
interface = gr.Interface( |
|
fn=diarize, |
|
inputs=[ |
|
gr.Audio(type="filepath", label="Upload Audio File"), |
|
gr.Number(label="Number of Speakers", value=2, precision=0), |
|
gr.Radio(["tiny", "base", "small", "medium", "large"], label="Model Size", value="medium") |
|
], |
|
outputs=gr.Textbox(label="Transcript"), |
|
title="Speaker Diarization & Transcription", |
|
description="Upload an audio file, specify the number of speakers, and get a diarized transcript." |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
interface.launch() |
|
|