Fedir Zadniprovskyi
chore: don't raise when specified model isn't the default one
d16cb74
raw
history blame
10.8 kB
from __future__ import annotations
import asyncio
import logging
import time
from contextlib import asynccontextmanager
from io import BytesIO
from typing import Annotated, Literal
from fastapi import (FastAPI, Form, Query, Response, UploadFile, WebSocket,
WebSocketDisconnect)
from fastapi.responses import StreamingResponse
from fastapi.websockets import WebSocketState
from faster_whisper import WhisperModel
from faster_whisper.vad import VadOptions, get_speech_timestamps
from speaches import utils
from speaches.asr import FasterWhisperASR
from speaches.audio import AudioStream, audio_samples_from_file
from speaches.config import SAMPLES_PER_SECOND, Language, Model, config
from speaches.core import Transcription
from speaches.logger import logger
from speaches.server_models import (ResponseFormat, TranscriptionJsonResponse,
TranscriptionVerboseJsonResponse)
from speaches.transcriber import audio_transcriber
whisper: WhisperModel = None # type: ignore
@asynccontextmanager
async def lifespan(_: FastAPI):
global whisper
logging.debug(f"Loading {config.whisper.model}")
start = time.perf_counter()
whisper = WhisperModel(
config.whisper.model,
device=config.whisper.inference_device,
compute_type=config.whisper.compute_type,
)
logger.debug(
f"Loaded {config.whisper.model} loaded in {time.perf_counter() - start:.2f} seconds"
)
yield
app = FastAPI(lifespan=lifespan)
@app.get("/health")
def health() -> Response:
return Response(status_code=200, content="Everything is peachy!")
@app.post("/v1/audio/translations")
async def translate_file(
file: Annotated[UploadFile, Form()],
model: Annotated[Model, Form()] = config.whisper.model,
prompt: Annotated[str | None, Form()] = None,
response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
temperature: Annotated[float, Form()] = 0.0,
stream: Annotated[bool, Form()] = False,
):
if model != config.whisper.model:
logger.warning(
f"Specifying a model that is different from the default is not supported yet. Using {config.whisper.model}."
)
start = time.perf_counter()
segments, transcription_info = whisper.transcribe(
file.file,
task="translate",
initial_prompt=prompt,
temperature=temperature,
vad_filter=True,
)
def segment_responses():
for segment in segments:
if response_format == ResponseFormat.TEXT:
yield segment.text
elif response_format == ResponseFormat.JSON:
yield TranscriptionJsonResponse.from_segments(
[segment]
).model_dump_json()
elif response_format == ResponseFormat.VERBOSE_JSON:
yield TranscriptionVerboseJsonResponse.from_segment(
segment, transcription_info
).model_dump_json()
if not stream:
segments = list(segments)
logger.info(
f"Translated {transcription_info.duration}({transcription_info.duration_after_vad}) seconds of audio in {time.perf_counter() - start:.2f} seconds"
)
if response_format == ResponseFormat.TEXT:
return utils.segments_text(segments)
elif response_format == ResponseFormat.JSON:
return TranscriptionJsonResponse.from_segments(segments)
elif response_format == ResponseFormat.VERBOSE_JSON:
return TranscriptionVerboseJsonResponse.from_segments(
segments, transcription_info
)
else:
return StreamingResponse(segment_responses(), media_type="text/event-stream")
# https://platform.openai.com/docs/api-reference/audio/createTranscription
# https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8915
@app.post("/v1/audio/transcriptions")
async def transcribe_file(
file: Annotated[UploadFile, Form()],
model: Annotated[Model, Form()] = config.whisper.model,
language: Annotated[Language | None, Form()] = config.default_language,
prompt: Annotated[str | None, Form()] = None,
response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
temperature: Annotated[float, Form()] = 0.0,
timestamp_granularities: Annotated[
list[Literal["segments"] | Literal["words"]],
Form(alias="timestamp_granularities[]"),
] = ["segments"],
stream: Annotated[bool, Form()] = False,
):
if model != config.whisper.model:
logger.warning(
f"Specifying a model that is different from the default is not supported yet. Using {config.whisper.model}."
)
start = time.perf_counter()
segments, transcription_info = whisper.transcribe(
file.file,
task="transcribe",
language=language,
initial_prompt=prompt,
word_timestamps="words" in timestamp_granularities,
temperature=temperature,
vad_filter=True,
)
def segment_responses():
for segment in segments:
logger.info(
f"Transcribed {segment.end - segment.start} seconds of audio in {time.perf_counter() - start:.2f} seconds"
)
if response_format == ResponseFormat.TEXT:
yield segment.text
elif response_format == ResponseFormat.JSON:
yield TranscriptionJsonResponse.from_segments(
[segment]
).model_dump_json()
elif response_format == ResponseFormat.VERBOSE_JSON:
yield TranscriptionVerboseJsonResponse.from_segment(
segment, transcription_info
).model_dump_json()
if not stream:
segments = list(segments)
logger.info(
f"Transcribed {transcription_info.duration}({transcription_info.duration_after_vad}) seconds of audio in {time.perf_counter() - start:.2f} seconds"
)
if response_format == ResponseFormat.TEXT:
return utils.segments_text(segments)
elif response_format == ResponseFormat.JSON:
return TranscriptionJsonResponse.from_segments(segments)
elif response_format == ResponseFormat.VERBOSE_JSON:
return TranscriptionVerboseJsonResponse.from_segments(
segments, transcription_info
)
else:
return StreamingResponse(segment_responses(), media_type="text/event-stream")
async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
try:
while True:
bytes_ = await asyncio.wait_for(
ws.receive_bytes(), timeout=config.max_no_data_seconds
)
logger.debug(f"Received {len(bytes_)} bytes of audio data")
audio_samples = audio_samples_from_file(BytesIO(bytes_))
audio_stream.extend(audio_samples)
if audio_stream.duration - config.inactivity_window_seconds >= 0:
audio = audio_stream.after(
audio_stream.duration - config.inactivity_window_seconds
)
vad_opts = VadOptions(min_silence_duration_ms=500, speech_pad_ms=0)
# NOTE: This is a synchronous operation that runs every time new data is received.
# This shouldn't be an issue unless data is being received in tiny chunks or the user's machine is a potato.
timestamps = get_speech_timestamps(audio.data, vad_opts)
if len(timestamps) == 0:
logger.info(
f"No speech detected in the last {config.inactivity_window_seconds} seconds."
)
break
elif (
# last speech end time
config.inactivity_window_seconds
- timestamps[-1]["end"] / SAMPLES_PER_SECOND
>= config.max_inactivity_seconds
):
logger.info(
f"Not enough speech in the last {config.inactivity_window_seconds} seconds."
)
break
except asyncio.TimeoutError:
logger.info(
f"No data received in {config.max_no_data_seconds} seconds. Closing the connection."
)
except WebSocketDisconnect as e:
logger.info(f"Client disconnected: {e}")
audio_stream.close()
def format_transcription(
transcription: Transcription, response_format: ResponseFormat
) -> str:
if response_format == ResponseFormat.TEXT:
return transcription.text
elif response_format == ResponseFormat.JSON:
return TranscriptionJsonResponse.from_transcription(
transcription
).model_dump_json()
elif response_format == ResponseFormat.VERBOSE_JSON:
return TranscriptionVerboseJsonResponse.from_transcription(
transcription
).model_dump_json()
@app.websocket("/v1/audio/transcriptions")
async def transcribe_stream(
ws: WebSocket,
model: Annotated[Model, Query()] = config.whisper.model,
language: Annotated[Language | None, Query()] = config.default_language,
prompt: Annotated[str | None, Query()] = None,
response_format: Annotated[
ResponseFormat, Query()
] = config.default_response_format,
temperature: Annotated[float, Query()] = 0.0,
timestamp_granularities: Annotated[
list[Literal["segments"] | Literal["words"]],
Query(
alias="timestamp_granularities[]",
description="No-op. Ignored. Only for compatibility.",
),
] = ["segments", "words"],
) -> None:
if model != config.whisper.model:
logger.warning(
f"Specifying a model that is different from the default is not supported yet. Using {config.whisper.model}."
)
await ws.accept()
transcribe_opts = {
"language": language,
"initial_prompt": prompt,
"temperature": temperature,
"vad_filter": True,
"condition_on_previous_text": False,
}
asr = FasterWhisperASR(whisper, **transcribe_opts)
audio_stream = AudioStream()
async with asyncio.TaskGroup() as tg:
tg.create_task(audio_receiver(ws, audio_stream))
async for transcription in audio_transcriber(asr, audio_stream):
logger.debug(f"Sending transcription: {transcription.text}")
if ws.client_state == WebSocketState.DISCONNECTED:
break
await ws.send_text(format_transcription(transcription, response_format))
if not ws.client_state == WebSocketState.DISCONNECTED:
logger.info("Closing the connection.")
await ws.close()