sachinsen1295's picture
Update app.py
b6c6bcf verified
raw
history blame
3.93 kB
import gradio as gr
import whisper
import torch
import pyannote.audio
from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
from pyannote.audio import Audio
from pyannote.core import Segment
import subprocess
import wave
import contextlib
from sklearn.cluster import AgglomerativeClustering
import numpy as np
import datetime
# Load models
device = torch.device("cpu") # Explicitly set device to CPU
embedding_model = PretrainedSpeakerEmbedding(
"speechbrain/spkrec-ecapa-voxceleb",
device=device # Ensure it uses CPU
)
audio_processor = Audio()
# Function to process the audio file and extract transcript and diarization
def process_audio(audio_file, num_speakers, model_size="medium", language="English"):
# Save the uploaded file to a path
path = "/tmp/uploaded_audio.wav"
with open(path, "wb") as f:
f.write(audio_file.read())
print(f"Audio file saved to: {path}")
# Convert audio to WAV if it's not already
if path[-3:] != 'wav':
wav_path = path.replace(path.split('.')[-1], 'wav')
subprocess.call(['ffmpeg', '-i', path, wav_path, '-y'])
path = wav_path
print(f"Audio converted to: {path}")
# Load Whisper model
try:
model = whisper.load_model(model_size)
print("Whisper model loaded successfully.")
except Exception as e:
print(f"Error loading Whisper model: {e}")
return f"Error loading Whisper model: {e}"
try:
result = model.transcribe(path)
print(f"Transcription result: {result}")
except Exception as e:
print(f"Error during transcription: {e}")
return f"Error during transcription: {e}"
segments = result["segments"]
# Remaining processing code...
# Get audio duration
with contextlib.closing(wave.open(path, 'r')) as f:
frames = f.getnframes()
rate = f.getframerate()
duration = frames / float(rate)
# Function to generate segment embeddings
def segment_embedding(segment):
start = segment["start"]
end = min(duration, segment["end"])
clip = Segment(start, end)
waveform, sample_rate = audio_processor.crop(path, clip)
return embedding_model(waveform[None])
embeddings = np.zeros(shape=(len(segments), 192))
for i, segment in enumerate(segments):
embeddings[i] = segment_embedding(segment)
embeddings = np.nan_to_num(embeddings)
# Perform clustering
clustering = AgglomerativeClustering(n_clusters=num_speakers).fit(embeddings)
labels = clustering.labels_
for i in range(len(segments)):
segments[i]["speaker"] = 'SPEAKER ' + str(labels[i] + 1)
# Format the transcript
def time(secs):
return str(datetime.timedelta(seconds=round(secs)))
transcript = []
for i, segment in enumerate(segments):
if i == 0 or segments[i - 1]["speaker"] != segment["speaker"]:
transcript.append(f"\n{segment['speaker']} {time(segment['start'])}")
transcript.append(segment["text"][1:]) # Remove leading whitespace
# Return the final transcript as a string
return "\n".join(transcript)
# Gradio interface
def diarize(audio_file, num_speakers, model_size="medium"):
return process_audio(audio_file, num_speakers, model_size)
# Gradio UI
interface = gr.Interface(
fn=diarize,
inputs=[
gr.Audio(type="filepath", label="Upload Audio File"), # Use 'filepath' here
gr.Number(label="Number of Speakers", value=2, precision=0),
gr.Radio(["tiny", "base", "small", "medium", "large"], label="Model Size", value="medium")
],
outputs=gr.Textbox(label="Transcript"),
title="Speaker Diarization & Transcription",
description="Upload an audio file, specify the number of speakers, and get a diarized transcript."
)
# Run the Gradio app
if __name__ == "__main__":
interface.launch()