TranscriptTool / transcription_tool.py
maguid28's picture
Implemented smolagent tool
13d3de7
raw
history blame
3.71 kB
from smolagents import Tool
import os
import tempfile
import shutil
import torch
import subprocess
from transcription import run_whisper_transcription
from logging_config import logger, log_buffer
from ffmpeg_setup import ensure_ffmpeg_in_path
class TranscriptTool(Tool):
name = "TranscriptTool"
description = """
A smolagent tool for transcribing audio and video files into text. This tool utilises Whisper for transcription
and ffmpeg for media conversion, enabling agents to process multimedia inputs into text. The tool supports robust
file handling, including format conversion to WAV and dynamic device selection for optimal performance.
"""
inputs = {
"file_path": {
"type": "string",
"description": "Path to the audio or video file for transcription."
}
}
output_type = "string"
def __init__(self, audio_directory=None):
super().__init__()
ensure_ffmpeg_in_path()
self.audio_directory = audio_directory or os.getcwd()
def locate_audio_file(self, file_name):
for root, _, files in os.walk(self.audio_directory):
if file_name in files:
return os.path.join(root, file_name)
return None
def convert_audio_to_wav(self, input_file: str, output_file: str, ffmpeg_path: str) -> str:
logger.info(f"Converting {input_file} to WAV format: {output_file}")
cmd = [
ffmpeg_path,
"-y", # Overwrite output files without asking
"-i", input_file,
"-ar", "16000", # Set audio sampling rate to 16kHz
"-ac", "1", # Set number of audio channels to mono
output_file
]
try:
subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
logger.info("Audio conversion to WAV completed successfully.")
return output_file
except subprocess.CalledProcessError as e:
ffmpeg_error = e.stderr.decode()
logger.error(f"ffmpeg error: {ffmpeg_error}")
raise RuntimeError("Failed to convert audio to WAV.") from e
def forward(self, file_path: str) -> str:
log_buffer.seek(0)
log_buffer.truncate()
try:
# Locate the file if it does not exist
if not os.path.isfile(file_path):
file_name = os.path.basename(file_path)
file_path = self.locate_audio_file(file_name)
if not file_path:
return f"Error: File '{file_name}' not found in '{self.audio_directory}'."
with tempfile.TemporaryDirectory() as tmpdir:
# Copy file to temp dir
filename = os.path.basename(file_path)
input_file_path = os.path.join(tmpdir, filename)
shutil.copy(file_path, input_file_path)
# Convert to wav
wav_file_path = os.path.join(tmpdir, "converted_audio.wav")
ffmpeg_path = shutil.which("ffmpeg")
if not ffmpeg_path:
raise RuntimeError("ffmpeg is not accessible in PATH.")
self.convert_audio_to_wav(input_file_path, wav_file_path, ffmpeg_path)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Transcribe audio
transcription_generator = run_whisper_transcription(wav_file_path, device)
for transcription, _ in transcription_generator:
return transcription
except Exception as e:
logger.error(f"Error in transcription: {str(e)}")
return f"An error occurred: {str(e)}"