Spaces:
Configuration error
Configuration error
from __future__ import annotations | |
import asyncio | |
import logging | |
import time | |
from contextlib import asynccontextmanager | |
from io import BytesIO | |
from typing import Annotated, Literal | |
from fastapi import (FastAPI, Form, Query, Response, UploadFile, WebSocket, | |
WebSocketDisconnect) | |
from fastapi.responses import StreamingResponse | |
from fastapi.websockets import WebSocketState | |
from faster_whisper import WhisperModel | |
from faster_whisper.vad import VadOptions, get_speech_timestamps | |
from speaches import utils | |
from speaches.asr import FasterWhisperASR | |
from speaches.audio import AudioStream, audio_samples_from_file | |
from speaches.config import SAMPLES_PER_SECOND, Language, Model, config | |
from speaches.core import Transcription | |
from speaches.logger import logger | |
from speaches.server_models import (ResponseFormat, TranscriptionJsonResponse, | |
TranscriptionVerboseJsonResponse) | |
from speaches.transcriber import audio_transcriber | |
whisper: WhisperModel = None # type: ignore | |
async def lifespan(_: FastAPI): | |
global whisper | |
logging.debug(f"Loading {config.whisper.model}") | |
start = time.perf_counter() | |
whisper = WhisperModel( | |
config.whisper.model, | |
device=config.whisper.inference_device, | |
compute_type=config.whisper.compute_type, | |
) | |
logger.debug( | |
f"Loaded {config.whisper.model} loaded in {time.perf_counter() - start:.2f} seconds" | |
) | |
yield | |
app = FastAPI(lifespan=lifespan) | |
def health() -> Response: | |
return Response(status_code=200, content="Everything is peachy!") | |
async def translate_file( | |
file: Annotated[UploadFile, Form()], | |
model: Annotated[Model, Form()] = config.whisper.model, | |
prompt: Annotated[str | None, Form()] = None, | |
response_format: Annotated[ResponseFormat, Form()] = config.default_response_format, | |
temperature: Annotated[float, Form()] = 0.0, | |
stream: Annotated[bool, Form()] = False, | |
): | |
if model != config.whisper.model: | |
logger.warning( | |
f"Specifying a model that is different from the default is not supported yet. Using {config.whisper.model}." | |
) | |
start = time.perf_counter() | |
segments, transcription_info = whisper.transcribe( | |
file.file, | |
task="translate", | |
initial_prompt=prompt, | |
temperature=temperature, | |
vad_filter=True, | |
) | |
def segment_responses(): | |
for segment in segments: | |
if response_format == ResponseFormat.TEXT: | |
yield segment.text | |
elif response_format == ResponseFormat.JSON: | |
yield TranscriptionJsonResponse.from_segments( | |
[segment] | |
).model_dump_json() | |
elif response_format == ResponseFormat.VERBOSE_JSON: | |
yield TranscriptionVerboseJsonResponse.from_segment( | |
segment, transcription_info | |
).model_dump_json() | |
if not stream: | |
segments = list(segments) | |
logger.info( | |
f"Translated {transcription_info.duration}({transcription_info.duration_after_vad}) seconds of audio in {time.perf_counter() - start:.2f} seconds" | |
) | |
if response_format == ResponseFormat.TEXT: | |
return utils.segments_text(segments) | |
elif response_format == ResponseFormat.JSON: | |
return TranscriptionJsonResponse.from_segments(segments) | |
elif response_format == ResponseFormat.VERBOSE_JSON: | |
return TranscriptionVerboseJsonResponse.from_segments( | |
segments, transcription_info | |
) | |
else: | |
return StreamingResponse(segment_responses(), media_type="text/event-stream") | |
# https://platform.openai.com/docs/api-reference/audio/createTranscription | |
# https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8915 | |
async def transcribe_file( | |
file: Annotated[UploadFile, Form()], | |
model: Annotated[Model, Form()] = config.whisper.model, | |
language: Annotated[Language | None, Form()] = config.default_language, | |
prompt: Annotated[str | None, Form()] = None, | |
response_format: Annotated[ResponseFormat, Form()] = config.default_response_format, | |
temperature: Annotated[float, Form()] = 0.0, | |
timestamp_granularities: Annotated[ | |
list[Literal["segments"] | Literal["words"]], | |
Form(alias="timestamp_granularities[]"), | |
] = ["segments"], | |
stream: Annotated[bool, Form()] = False, | |
): | |
if model != config.whisper.model: | |
logger.warning( | |
f"Specifying a model that is different from the default is not supported yet. Using {config.whisper.model}." | |
) | |
start = time.perf_counter() | |
segments, transcription_info = whisper.transcribe( | |
file.file, | |
task="transcribe", | |
language=language, | |
initial_prompt=prompt, | |
word_timestamps="words" in timestamp_granularities, | |
temperature=temperature, | |
vad_filter=True, | |
) | |
def segment_responses(): | |
for segment in segments: | |
logger.info( | |
f"Transcribed {segment.end - segment.start} seconds of audio in {time.perf_counter() - start:.2f} seconds" | |
) | |
if response_format == ResponseFormat.TEXT: | |
yield segment.text | |
elif response_format == ResponseFormat.JSON: | |
yield TranscriptionJsonResponse.from_segments( | |
[segment] | |
).model_dump_json() | |
elif response_format == ResponseFormat.VERBOSE_JSON: | |
yield TranscriptionVerboseJsonResponse.from_segment( | |
segment, transcription_info | |
).model_dump_json() | |
if not stream: | |
segments = list(segments) | |
logger.info( | |
f"Transcribed {transcription_info.duration}({transcription_info.duration_after_vad}) seconds of audio in {time.perf_counter() - start:.2f} seconds" | |
) | |
if response_format == ResponseFormat.TEXT: | |
return utils.segments_text(segments) | |
elif response_format == ResponseFormat.JSON: | |
return TranscriptionJsonResponse.from_segments(segments) | |
elif response_format == ResponseFormat.VERBOSE_JSON: | |
return TranscriptionVerboseJsonResponse.from_segments( | |
segments, transcription_info | |
) | |
else: | |
return StreamingResponse(segment_responses(), media_type="text/event-stream") | |
async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None: | |
try: | |
while True: | |
bytes_ = await asyncio.wait_for( | |
ws.receive_bytes(), timeout=config.max_no_data_seconds | |
) | |
logger.debug(f"Received {len(bytes_)} bytes of audio data") | |
audio_samples = audio_samples_from_file(BytesIO(bytes_)) | |
audio_stream.extend(audio_samples) | |
if audio_stream.duration - config.inactivity_window_seconds >= 0: | |
audio = audio_stream.after( | |
audio_stream.duration - config.inactivity_window_seconds | |
) | |
vad_opts = VadOptions(min_silence_duration_ms=500, speech_pad_ms=0) | |
# NOTE: This is a synchronous operation that runs every time new data is received. | |
# This shouldn't be an issue unless data is being received in tiny chunks or the user's machine is a potato. | |
timestamps = get_speech_timestamps(audio.data, vad_opts) | |
if len(timestamps) == 0: | |
logger.info( | |
f"No speech detected in the last {config.inactivity_window_seconds} seconds." | |
) | |
break | |
elif ( | |
# last speech end time | |
config.inactivity_window_seconds | |
- timestamps[-1]["end"] / SAMPLES_PER_SECOND | |
>= config.max_inactivity_seconds | |
): | |
logger.info( | |
f"Not enough speech in the last {config.inactivity_window_seconds} seconds." | |
) | |
break | |
except asyncio.TimeoutError: | |
logger.info( | |
f"No data received in {config.max_no_data_seconds} seconds. Closing the connection." | |
) | |
except WebSocketDisconnect as e: | |
logger.info(f"Client disconnected: {e}") | |
audio_stream.close() | |
def format_transcription( | |
transcription: Transcription, response_format: ResponseFormat | |
) -> str: | |
if response_format == ResponseFormat.TEXT: | |
return transcription.text | |
elif response_format == ResponseFormat.JSON: | |
return TranscriptionJsonResponse.from_transcription( | |
transcription | |
).model_dump_json() | |
elif response_format == ResponseFormat.VERBOSE_JSON: | |
return TranscriptionVerboseJsonResponse.from_transcription( | |
transcription | |
).model_dump_json() | |
async def transcribe_stream( | |
ws: WebSocket, | |
model: Annotated[Model, Query()] = config.whisper.model, | |
language: Annotated[Language | None, Query()] = config.default_language, | |
prompt: Annotated[str | None, Query()] = None, | |
response_format: Annotated[ | |
ResponseFormat, Query() | |
] = config.default_response_format, | |
temperature: Annotated[float, Query()] = 0.0, | |
timestamp_granularities: Annotated[ | |
list[Literal["segments"] | Literal["words"]], | |
Query( | |
alias="timestamp_granularities[]", | |
description="No-op. Ignored. Only for compatibility.", | |
), | |
] = ["segments", "words"], | |
) -> None: | |
if model != config.whisper.model: | |
logger.warning( | |
f"Specifying a model that is different from the default is not supported yet. Using {config.whisper.model}." | |
) | |
await ws.accept() | |
transcribe_opts = { | |
"language": language, | |
"initial_prompt": prompt, | |
"temperature": temperature, | |
"vad_filter": True, | |
"condition_on_previous_text": False, | |
} | |
asr = FasterWhisperASR(whisper, **transcribe_opts) | |
audio_stream = AudioStream() | |
async with asyncio.TaskGroup() as tg: | |
tg.create_task(audio_receiver(ws, audio_stream)) | |
async for transcription in audio_transcriber(asr, audio_stream): | |
logger.debug(f"Sending transcription: {transcription.text}") | |
if ws.client_state == WebSocketState.DISCONNECTED: | |
break | |
await ws.send_text(format_transcription(transcription, response_format)) | |
if not ws.client_state == WebSocketState.DISCONNECTED: | |
logger.info("Closing the connection.") | |
await ws.close() | |