Spaces:
Configuration error
Configuration error
from __future__ import annotations | |
import asyncio | |
from io import BytesIO | |
import logging | |
from typing import TYPE_CHECKING, Annotated | |
import av.error | |
from fastapi import ( | |
APIRouter, | |
Depends, | |
Form, | |
Query, | |
Request, | |
Response, | |
UploadFile, | |
WebSocket, | |
WebSocketDisconnect, | |
) | |
from fastapi.exceptions import HTTPException | |
from fastapi.responses import StreamingResponse | |
from fastapi.websockets import WebSocketState | |
from faster_whisper.audio import decode_audio | |
from faster_whisper.transcribe import BatchedInferencePipeline | |
from faster_whisper.vad import VadOptions, get_speech_timestamps | |
from numpy import float32 | |
from numpy.typing import NDArray | |
from pydantic import AfterValidator, Field | |
from faster_whisper_server.api_models import ( | |
DEFAULT_TIMESTAMP_GRANULARITIES, | |
TIMESTAMP_GRANULARITIES_COMBINATIONS, | |
CreateTranscriptionResponseJson, | |
CreateTranscriptionResponseVerboseJson, | |
TimestampGranularities, | |
TranscriptionSegment, | |
) | |
from faster_whisper_server.asr import FasterWhisperASR | |
from faster_whisper_server.audio import AudioStream, audio_samples_from_file | |
from faster_whisper_server.config import ( | |
SAMPLES_PER_SECOND, | |
Language, | |
ResponseFormat, | |
Task, | |
) | |
from faster_whisper_server.dependencies import ConfigDependency, ModelManagerDependency, get_config | |
from faster_whisper_server.text_utils import segments_to_srt, segments_to_text, segments_to_vtt | |
from faster_whisper_server.transcriber import audio_transcriber | |
if TYPE_CHECKING: | |
from collections.abc import Generator, Iterable | |
from faster_whisper.transcribe import TranscriptionInfo | |
logger = logging.getLogger(__name__) | |
router = APIRouter() | |
# TODO: test async vs sync performance | |
def audio_file_dependency( | |
file: Annotated[UploadFile, Form()], | |
) -> NDArray[float32]: | |
try: | |
audio = decode_audio(file.file) | |
except av.error.InvalidDataError as e: | |
raise HTTPException( | |
status_code=415, | |
detail="Failed to decode audio. The provided file type is not supported.", | |
) from e | |
except av.error.ValueError as e: | |
raise HTTPException( | |
status_code=400, | |
# TODO: list supported file types | |
detail="Failed to decode audio. The provided file is likely empty.", | |
) from e | |
except Exception as e: | |
logger.exception( | |
"Failed to decode audio. This is likely a bug. Please create an issue at https://github.com/fedirz/faster-whisper-server/issues/new." | |
) | |
raise HTTPException(status_code=500, detail="Failed to decode audio.") from e | |
else: | |
return audio # pyright: ignore reportReturnType | |
AudioFileDependency = Annotated[NDArray[float32], Depends(audio_file_dependency)] | |
def segments_to_response( | |
segments: Iterable[TranscriptionSegment], | |
transcription_info: TranscriptionInfo, | |
response_format: ResponseFormat, | |
) -> Response: | |
segments = list(segments) | |
match response_format: | |
case ResponseFormat.TEXT: | |
return Response(segments_to_text(segments), media_type="text/plain") | |
case ResponseFormat.JSON: | |
return Response( | |
CreateTranscriptionResponseJson.from_segments(segments).model_dump_json(), | |
media_type="application/json", | |
) | |
case ResponseFormat.VERBOSE_JSON: | |
return Response( | |
CreateTranscriptionResponseVerboseJson.from_segments(segments, transcription_info).model_dump_json(), | |
media_type="application/json", | |
) | |
case ResponseFormat.VTT: | |
return Response( | |
"".join(segments_to_vtt(segment, i) for i, segment in enumerate(segments)), media_type="text/vtt" | |
) | |
case ResponseFormat.SRT: | |
return Response( | |
"".join(segments_to_srt(segment, i) for i, segment in enumerate(segments)), media_type="text/plain" | |
) | |
def format_as_sse(data: str) -> str: | |
return f"data: {data}\n\n" | |
def segments_to_streaming_response( | |
segments: Iterable[TranscriptionSegment], | |
transcription_info: TranscriptionInfo, | |
response_format: ResponseFormat, | |
) -> StreamingResponse: | |
def segment_responses() -> Generator[str, None, None]: | |
for i, segment in enumerate(segments): | |
if response_format == ResponseFormat.TEXT: | |
data = segment.text | |
elif response_format == ResponseFormat.JSON: | |
data = CreateTranscriptionResponseJson.from_segments([segment]).model_dump_json() | |
elif response_format == ResponseFormat.VERBOSE_JSON: | |
data = CreateTranscriptionResponseVerboseJson.from_segment( | |
segment, transcription_info | |
).model_dump_json() | |
elif response_format == ResponseFormat.VTT: | |
data = segments_to_vtt(segment, i) | |
elif response_format == ResponseFormat.SRT: | |
data = segments_to_srt(segment, i) | |
yield format_as_sse(data) | |
return StreamingResponse(segment_responses(), media_type="text/event-stream") | |
def handle_default_openai_model(model_name: str) -> str: | |
"""Exists because some callers may not be able override the default("whisper-1") model name. | |
For example, https://github.com/open-webui/open-webui/issues/2248#issuecomment-2162997623. | |
""" | |
config = get_config() # HACK | |
if model_name == "whisper-1": | |
logger.info(f"{model_name} is not a valid model name. Using {config.whisper.model} instead.") | |
return config.whisper.model | |
return model_name | |
ModelName = Annotated[ | |
str, | |
AfterValidator(handle_default_openai_model), | |
Field( | |
description="The ID of the model. You can get a list of available models by calling `/v1/models`.", | |
examples=[ | |
"Systran/faster-distil-whisper-large-v3", | |
"bofenghuang/whisper-large-v2-cv11-french-ct2", | |
], | |
), | |
] | |
def translate_file( | |
config: ConfigDependency, | |
model_manager: ModelManagerDependency, | |
audio: AudioFileDependency, | |
model: Annotated[ModelName | None, Form()] = None, | |
prompt: Annotated[str | None, Form()] = None, | |
response_format: Annotated[ResponseFormat | None, Form()] = None, | |
temperature: Annotated[float, Form()] = 0.0, | |
stream: Annotated[bool, Form()] = False, | |
vad_filter: Annotated[bool, Form()] = False, | |
) -> Response | StreamingResponse: | |
if model is None: | |
model = config.whisper.model | |
if response_format is None: | |
response_format = config.default_response_format | |
with model_manager.load_model(model) as whisper: | |
whisper_model = BatchedInferencePipeline(model=whisper) if config.whisper.use_batched_mode else whisper | |
segments, transcription_info = whisper_model.transcribe( | |
audio, | |
task=Task.TRANSLATE, | |
initial_prompt=prompt, | |
temperature=temperature, | |
vad_filter=vad_filter, | |
) | |
segments = TranscriptionSegment.from_faster_whisper_segments(segments) | |
if stream: | |
return segments_to_streaming_response(segments, transcription_info, response_format) | |
else: | |
return segments_to_response(segments, transcription_info, response_format) | |
# HACK: Since Form() doesn't support `alias`, we need to use a workaround. | |
async def get_timestamp_granularities(request: Request) -> TimestampGranularities: | |
form = await request.form() | |
if form.get("timestamp_granularities[]") is None: | |
return DEFAULT_TIMESTAMP_GRANULARITIES | |
timestamp_granularities = form.getlist("timestamp_granularities[]") | |
assert ( | |
timestamp_granularities in TIMESTAMP_GRANULARITIES_COMBINATIONS | |
), f"{timestamp_granularities} is not a valid value for `timestamp_granularities[]`." | |
return timestamp_granularities | |
# https://platform.openai.com/docs/api-reference/audio/createTranscription | |
# https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8915 | |
def transcribe_file( | |
config: ConfigDependency, | |
model_manager: ModelManagerDependency, | |
request: Request, | |
audio: AudioFileDependency, | |
model: Annotated[ModelName | None, Form()] = None, | |
language: Annotated[Language | None, Form()] = None, | |
prompt: Annotated[str | None, Form()] = None, | |
response_format: Annotated[ResponseFormat | None, Form()] = None, | |
temperature: Annotated[float, Form()] = 0.0, | |
timestamp_granularities: Annotated[ | |
TimestampGranularities, | |
# WARN: `alias` doesn't actually work. | |
Form(alias="timestamp_granularities[]"), | |
] = ["segment"], | |
stream: Annotated[bool, Form()] = False, | |
hotwords: Annotated[str | None, Form()] = None, | |
vad_filter: Annotated[bool, Form()] = False, | |
) -> Response | StreamingResponse: | |
if model is None: | |
model = config.whisper.model | |
if language is None: | |
language = config.default_language | |
if response_format is None: | |
response_format = config.default_response_format | |
timestamp_granularities = asyncio.run(get_timestamp_granularities(request)) | |
if timestamp_granularities != DEFAULT_TIMESTAMP_GRANULARITIES and response_format != ResponseFormat.VERBOSE_JSON: | |
logger.warning( | |
"It only makes sense to provide `timestamp_granularities[]` when `response_format` is set to `verbose_json`. See https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities." # noqa: E501 | |
) | |
with model_manager.load_model(model) as whisper: | |
whisper_model = BatchedInferencePipeline(model=whisper) if config.whisper.use_batched_mode else whisper | |
segments, transcription_info = whisper_model.transcribe( | |
audio, | |
task=Task.TRANSCRIBE, | |
language=language, | |
initial_prompt=prompt, | |
word_timestamps="word" in timestamp_granularities, | |
temperature=temperature, | |
vad_filter=vad_filter, | |
hotwords=hotwords, | |
) | |
segments = TranscriptionSegment.from_faster_whisper_segments(segments) | |
if stream: | |
return segments_to_streaming_response(segments, transcription_info, response_format) | |
else: | |
return segments_to_response(segments, transcription_info, response_format) | |
async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None: | |
config = get_config() # HACK | |
try: | |
while True: | |
bytes_ = await asyncio.wait_for(ws.receive_bytes(), timeout=config.max_no_data_seconds) | |
logger.debug(f"Received {len(bytes_)} bytes of audio data") | |
audio_samples = audio_samples_from_file(BytesIO(bytes_)) | |
audio_stream.extend(audio_samples) | |
if audio_stream.duration - config.inactivity_window_seconds >= 0: | |
audio = audio_stream.after(audio_stream.duration - config.inactivity_window_seconds) | |
vad_opts = VadOptions(min_silence_duration_ms=500, speech_pad_ms=0) | |
# NOTE: This is a synchronous operation that runs every time new data is received. | |
# This shouldn't be an issue unless data is being received in tiny chunks or the user's machine is a potato. # noqa: E501 | |
timestamps = get_speech_timestamps(audio.data, vad_opts) | |
if len(timestamps) == 0: | |
logger.info(f"No speech detected in the last {config.inactivity_window_seconds} seconds.") | |
break | |
elif ( | |
# last speech end time | |
config.inactivity_window_seconds - timestamps[-1]["end"] / SAMPLES_PER_SECOND | |
>= config.max_inactivity_seconds | |
): | |
logger.info(f"Not enough speech in the last {config.inactivity_window_seconds} seconds.") | |
break | |
except TimeoutError: | |
logger.info(f"No data received in {config.max_no_data_seconds} seconds. Closing the connection.") | |
except WebSocketDisconnect as e: | |
logger.info(f"Client disconnected: {e}") | |
audio_stream.close() | |
async def transcribe_stream( | |
config: ConfigDependency, | |
model_manager: ModelManagerDependency, | |
ws: WebSocket, | |
model: Annotated[ModelName | None, Query()] = None, | |
language: Annotated[Language | None, Query()] = None, | |
response_format: Annotated[ResponseFormat | None, Query()] = None, | |
temperature: Annotated[float, Query()] = 0.0, | |
vad_filter: Annotated[bool, Query()] = False, | |
) -> None: | |
if model is None: | |
model = config.whisper.model | |
if language is None: | |
language = config.default_language | |
if response_format is None: | |
response_format = config.default_response_format | |
await ws.accept() | |
transcribe_opts = { | |
"language": language, | |
"temperature": temperature, | |
"vad_filter": vad_filter, | |
"condition_on_previous_text": False, | |
} | |
with model_manager.load_model(model) as whisper: | |
asr = FasterWhisperASR(whisper, **transcribe_opts) | |
audio_stream = AudioStream() | |
async with asyncio.TaskGroup() as tg: | |
tg.create_task(audio_receiver(ws, audio_stream)) | |
async for transcription in audio_transcriber(asr, audio_stream, min_duration=config.min_duration): | |
logger.debug(f"Sending transcription: {transcription.text}") | |
if ws.client_state == WebSocketState.DISCONNECTED: | |
break | |
if response_format == ResponseFormat.TEXT: | |
await ws.send_text(transcription.text) | |
elif response_format == ResponseFormat.JSON: | |
await ws.send_json(CreateTranscriptionResponseJson.from_transcription(transcription).model_dump()) | |
elif response_format == ResponseFormat.VERBOSE_JSON: | |
await ws.send_json( | |
CreateTranscriptionResponseVerboseJson.from_transcription(transcription).model_dump() | |
) | |
if ws.client_state != WebSocketState.DISCONNECTED: | |
logger.info("Closing the connection.") | |
await ws.close() | |