Spaces:
Running
Running
from smolagents import Tool | |
import os | |
import tempfile | |
import shutil | |
import torch | |
import subprocess | |
from transcription import run_whisper_transcription | |
from logging_config import logger, log_buffer | |
from ffmpeg_setup import ensure_ffmpeg_in_path | |
class TranscriptTool(Tool): | |
name = "TranscriptTool" | |
description = """ | |
A smolagent tool for transcribing audio and video files into text. This tool utilises Whisper for transcription | |
and ffmpeg for media conversion, enabling agents to process multimedia inputs into text. The tool supports robust | |
file handling, including format conversion to WAV and dynamic device selection for optimal performance. | |
""" | |
inputs = { | |
"file_path": { | |
"type": "string", | |
"description": "Path to the audio or video file for transcription." | |
} | |
} | |
output_type = "string" | |
def __init__(self, audio_directory=None): | |
super().__init__() | |
ensure_ffmpeg_in_path() | |
self.audio_directory = audio_directory or os.getcwd() | |
def locate_audio_file(self, file_name): | |
for root, _, files in os.walk(self.audio_directory): | |
if file_name in files: | |
return os.path.join(root, file_name) | |
return None | |
def convert_audio_to_wav(self, input_file: str, output_file: str, ffmpeg_path: str) -> str: | |
logger.info(f"Converting {input_file} to WAV format: {output_file}") | |
cmd = [ | |
ffmpeg_path, | |
"-y", # Overwrite output files without asking | |
"-i", input_file, | |
"-ar", "16000", # Set audio sampling rate to 16kHz | |
"-ac", "1", # Set number of audio channels to mono | |
output_file | |
] | |
try: | |
subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
logger.info("Audio conversion to WAV completed successfully.") | |
return output_file | |
except subprocess.CalledProcessError as e: | |
ffmpeg_error = e.stderr.decode() | |
logger.error(f"ffmpeg error: {ffmpeg_error}") | |
raise RuntimeError("Failed to convert audio to WAV.") from e | |
def forward(self, file_path: str) -> str: | |
log_buffer.seek(0) | |
log_buffer.truncate() | |
try: | |
# Locate the file if it does not exist | |
if not os.path.isfile(file_path): | |
file_name = os.path.basename(file_path) | |
file_path = self.locate_audio_file(file_name) | |
if not file_path: | |
return f"Error: File '{file_name}' not found in '{self.audio_directory}'." | |
with tempfile.TemporaryDirectory() as tmpdir: | |
# Copy file to temp dir | |
filename = os.path.basename(file_path) | |
input_file_path = os.path.join(tmpdir, filename) | |
shutil.copy(file_path, input_file_path) | |
# Convert to wav | |
wav_file_path = os.path.join(tmpdir, "converted_audio.wav") | |
ffmpeg_path = shutil.which("ffmpeg") | |
if not ffmpeg_path: | |
raise RuntimeError("ffmpeg is not accessible in PATH.") | |
self.convert_audio_to_wav(input_file_path, wav_file_path, ffmpeg_path) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Transcribe audio | |
transcription_generator = run_whisper_transcription(wav_file_path, device) | |
for transcription, _ in transcription_generator: | |
return transcription | |
except Exception as e: | |
logger.error(f"Error in transcription: {str(e)}") | |
return f"An error occurred: {str(e)}" | |