from __future__ import annotations import asyncio from io import BytesIO import logging from typing import TYPE_CHECKING, Annotated import av.error from fastapi import ( APIRouter, Depends, Form, Query, Request, Response, UploadFile, WebSocket, WebSocketDisconnect, ) from fastapi.exceptions import HTTPException from fastapi.responses import StreamingResponse from fastapi.websockets import WebSocketState from faster_whisper.audio import decode_audio from faster_whisper.transcribe import BatchedInferencePipeline from faster_whisper.vad import VadOptions, get_speech_timestamps from numpy import float32 from numpy.typing import NDArray from pydantic import AfterValidator, Field from faster_whisper_server.api_models import ( DEFAULT_TIMESTAMP_GRANULARITIES, TIMESTAMP_GRANULARITIES_COMBINATIONS, CreateTranscriptionResponseJson, CreateTranscriptionResponseVerboseJson, TimestampGranularities, TranscriptionSegment, ) from faster_whisper_server.asr import FasterWhisperASR from faster_whisper_server.audio import AudioStream, audio_samples_from_file from faster_whisper_server.config import ( SAMPLES_PER_SECOND, Language, ResponseFormat, Task, ) from faster_whisper_server.dependencies import ConfigDependency, ModelManagerDependency, get_config from faster_whisper_server.text_utils import segments_to_srt, segments_to_text, segments_to_vtt from faster_whisper_server.transcriber import audio_transcriber if TYPE_CHECKING: from collections.abc import Generator, Iterable from faster_whisper.transcribe import TranscriptionInfo logger = logging.getLogger(__name__) router = APIRouter() # TODO: test async vs sync performance def audio_file_dependency( file: Annotated[UploadFile, Form()], ) -> NDArray[float32]: try: audio = decode_audio(file.file) except av.error.InvalidDataError as e: raise HTTPException( status_code=415, detail="Failed to decode audio. The provided file type is not supported.", ) from e except av.error.ValueError as e: raise HTTPException( status_code=400, # TODO: list supported file types detail="Failed to decode audio. The provided file is likely empty.", ) from e except Exception as e: logger.exception( "Failed to decode audio. This is likely a bug. Please create an issue at https://github.com/fedirz/faster-whisper-server/issues/new." ) raise HTTPException(status_code=500, detail="Failed to decode audio.") from e else: return audio # pyright: ignore reportReturnType AudioFileDependency = Annotated[NDArray[float32], Depends(audio_file_dependency)] def segments_to_response( segments: Iterable[TranscriptionSegment], transcription_info: TranscriptionInfo, response_format: ResponseFormat, ) -> Response: segments = list(segments) match response_format: case ResponseFormat.TEXT: return Response(segments_to_text(segments), media_type="text/plain") case ResponseFormat.JSON: return Response( CreateTranscriptionResponseJson.from_segments(segments).model_dump_json(), media_type="application/json", ) case ResponseFormat.VERBOSE_JSON: return Response( CreateTranscriptionResponseVerboseJson.from_segments(segments, transcription_info).model_dump_json(), media_type="application/json", ) case ResponseFormat.VTT: return Response( "".join(segments_to_vtt(segment, i) for i, segment in enumerate(segments)), media_type="text/vtt" ) case ResponseFormat.SRT: return Response( "".join(segments_to_srt(segment, i) for i, segment in enumerate(segments)), media_type="text/plain" ) def format_as_sse(data: str) -> str: return f"data: {data}\n\n" def segments_to_streaming_response( segments: Iterable[TranscriptionSegment], transcription_info: TranscriptionInfo, response_format: ResponseFormat, ) -> StreamingResponse: def segment_responses() -> Generator[str, None, None]: for i, segment in enumerate(segments): if response_format == ResponseFormat.TEXT: data = segment.text elif response_format == ResponseFormat.JSON: data = CreateTranscriptionResponseJson.from_segments([segment]).model_dump_json() elif response_format == ResponseFormat.VERBOSE_JSON: data = CreateTranscriptionResponseVerboseJson.from_segment( segment, transcription_info ).model_dump_json() elif response_format == ResponseFormat.VTT: data = segments_to_vtt(segment, i) elif response_format == ResponseFormat.SRT: data = segments_to_srt(segment, i) yield format_as_sse(data) return StreamingResponse(segment_responses(), media_type="text/event-stream") def handle_default_openai_model(model_name: str) -> str: """Exists because some callers may not be able override the default("whisper-1") model name. For example, https://github.com/open-webui/open-webui/issues/2248#issuecomment-2162997623. """ config = get_config() # HACK if model_name == "whisper-1": logger.info(f"{model_name} is not a valid model name. Using {config.whisper.model} instead.") return config.whisper.model return model_name ModelName = Annotated[ str, AfterValidator(handle_default_openai_model), Field( description="The ID of the model. You can get a list of available models by calling `/v1/models`.", examples=[ "Systran/faster-distil-whisper-large-v3", "bofenghuang/whisper-large-v2-cv11-french-ct2", ], ), ] @router.post( "/v1/audio/translations", response_model=str | CreateTranscriptionResponseJson | CreateTranscriptionResponseVerboseJson, ) def translate_file( config: ConfigDependency, model_manager: ModelManagerDependency, audio: AudioFileDependency, model: Annotated[ModelName | None, Form()] = None, prompt: Annotated[str | None, Form()] = None, response_format: Annotated[ResponseFormat | None, Form()] = None, temperature: Annotated[float, Form()] = 0.0, stream: Annotated[bool, Form()] = False, vad_filter: Annotated[bool, Form()] = False, ) -> Response | StreamingResponse: if model is None: model = config.whisper.model if response_format is None: response_format = config.default_response_format with model_manager.load_model(model) as whisper: whisper_model = BatchedInferencePipeline(model=whisper) if config.whisper.use_batched_mode else whisper segments, transcription_info = whisper_model.transcribe( audio, task=Task.TRANSLATE, initial_prompt=prompt, temperature=temperature, vad_filter=vad_filter, ) segments = TranscriptionSegment.from_faster_whisper_segments(segments) if stream: return segments_to_streaming_response(segments, transcription_info, response_format) else: return segments_to_response(segments, transcription_info, response_format) # HACK: Since Form() doesn't support `alias`, we need to use a workaround. async def get_timestamp_granularities(request: Request) -> TimestampGranularities: form = await request.form() if form.get("timestamp_granularities[]") is None: return DEFAULT_TIMESTAMP_GRANULARITIES timestamp_granularities = form.getlist("timestamp_granularities[]") assert ( timestamp_granularities in TIMESTAMP_GRANULARITIES_COMBINATIONS ), f"{timestamp_granularities} is not a valid value for `timestamp_granularities[]`." return timestamp_granularities # https://platform.openai.com/docs/api-reference/audio/createTranscription # https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8915 @router.post( "/v1/audio/transcriptions", response_model=str | CreateTranscriptionResponseJson | CreateTranscriptionResponseVerboseJson, ) def transcribe_file( config: ConfigDependency, model_manager: ModelManagerDependency, request: Request, audio: AudioFileDependency, model: Annotated[ModelName | None, Form()] = None, language: Annotated[Language | None, Form()] = None, prompt: Annotated[str | None, Form()] = None, response_format: Annotated[ResponseFormat | None, Form()] = None, temperature: Annotated[float, Form()] = 0.0, timestamp_granularities: Annotated[ TimestampGranularities, # WARN: `alias` doesn't actually work. Form(alias="timestamp_granularities[]"), ] = ["segment"], stream: Annotated[bool, Form()] = False, hotwords: Annotated[str | None, Form()] = None, vad_filter: Annotated[bool, Form()] = False, ) -> Response | StreamingResponse: if model is None: model = config.whisper.model if language is None: language = config.default_language if response_format is None: response_format = config.default_response_format timestamp_granularities = asyncio.run(get_timestamp_granularities(request)) if timestamp_granularities != DEFAULT_TIMESTAMP_GRANULARITIES and response_format != ResponseFormat.VERBOSE_JSON: logger.warning( "It only makes sense to provide `timestamp_granularities[]` when `response_format` is set to `verbose_json`. See https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities." # noqa: E501 ) with model_manager.load_model(model) as whisper: whisper_model = BatchedInferencePipeline(model=whisper) if config.whisper.use_batched_mode else whisper segments, transcription_info = whisper_model.transcribe( audio, task=Task.TRANSCRIBE, language=language, initial_prompt=prompt, word_timestamps="word" in timestamp_granularities, temperature=temperature, vad_filter=vad_filter, hotwords=hotwords, ) segments = TranscriptionSegment.from_faster_whisper_segments(segments) if stream: return segments_to_streaming_response(segments, transcription_info, response_format) else: return segments_to_response(segments, transcription_info, response_format) async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None: config = get_config() # HACK try: while True: bytes_ = await asyncio.wait_for(ws.receive_bytes(), timeout=config.max_no_data_seconds) logger.debug(f"Received {len(bytes_)} bytes of audio data") audio_samples = audio_samples_from_file(BytesIO(bytes_)) audio_stream.extend(audio_samples) if audio_stream.duration - config.inactivity_window_seconds >= 0: audio = audio_stream.after(audio_stream.duration - config.inactivity_window_seconds) vad_opts = VadOptions(min_silence_duration_ms=500, speech_pad_ms=0) # NOTE: This is a synchronous operation that runs every time new data is received. # This shouldn't be an issue unless data is being received in tiny chunks or the user's machine is a potato. # noqa: E501 timestamps = get_speech_timestamps(audio.data, vad_opts) if len(timestamps) == 0: logger.info(f"No speech detected in the last {config.inactivity_window_seconds} seconds.") break elif ( # last speech end time config.inactivity_window_seconds - timestamps[-1]["end"] / SAMPLES_PER_SECOND >= config.max_inactivity_seconds ): logger.info(f"Not enough speech in the last {config.inactivity_window_seconds} seconds.") break except TimeoutError: logger.info(f"No data received in {config.max_no_data_seconds} seconds. Closing the connection.") except WebSocketDisconnect as e: logger.info(f"Client disconnected: {e}") audio_stream.close() @router.websocket("/v1/audio/transcriptions") async def transcribe_stream( config: ConfigDependency, model_manager: ModelManagerDependency, ws: WebSocket, model: Annotated[ModelName | None, Query()] = None, language: Annotated[Language | None, Query()] = None, response_format: Annotated[ResponseFormat | None, Query()] = None, temperature: Annotated[float, Query()] = 0.0, vad_filter: Annotated[bool, Query()] = False, ) -> None: if model is None: model = config.whisper.model if language is None: language = config.default_language if response_format is None: response_format = config.default_response_format await ws.accept() transcribe_opts = { "language": language, "temperature": temperature, "vad_filter": vad_filter, "condition_on_previous_text": False, } with model_manager.load_model(model) as whisper: asr = FasterWhisperASR(whisper, **transcribe_opts) audio_stream = AudioStream() async with asyncio.TaskGroup() as tg: tg.create_task(audio_receiver(ws, audio_stream)) async for transcription in audio_transcriber(asr, audio_stream, min_duration=config.min_duration): logger.debug(f"Sending transcription: {transcription.text}") if ws.client_state == WebSocketState.DISCONNECTED: break if response_format == ResponseFormat.TEXT: await ws.send_text(transcription.text) elif response_format == ResponseFormat.JSON: await ws.send_json(CreateTranscriptionResponseJson.from_transcription(transcription).model_dump()) elif response_format == ResponseFormat.VERBOSE_JSON: await ws.send_json( CreateTranscriptionResponseVerboseJson.from_transcription(transcription).model_dump() ) if ws.client_state != WebSocketState.DISCONNECTED: logger.info("Closing the connection.") await ws.close()