Fedir Zadniprovskyi
chore: rename to 'faster-whisper-server'
39ee116
raw
history blame
10.7 kB
from __future__ import annotations
import asyncio
import time
from contextlib import asynccontextmanager
from io import BytesIO
from typing import Annotated, Literal, OrderedDict
from fastapi import (
FastAPI,
Form,
Query,
Response,
UploadFile,
WebSocket,
WebSocketDisconnect,
)
from fastapi.responses import StreamingResponse
from fastapi.websockets import WebSocketState
from faster_whisper import WhisperModel
from faster_whisper.vad import VadOptions, get_speech_timestamps
from faster_whisper_server import utils
from faster_whisper_server.asr import FasterWhisperASR
from faster_whisper_server.audio import AudioStream, audio_samples_from_file
from faster_whisper_server.config import (
SAMPLES_PER_SECOND,
Language,
Model,
ResponseFormat,
config,
)
from faster_whisper_server.logger import logger
from faster_whisper_server.server_models import (
TranscriptionJsonResponse,
TranscriptionVerboseJsonResponse,
)
from faster_whisper_server.transcriber import audio_transcriber
models: OrderedDict[Model, WhisperModel] = OrderedDict()
def load_model(model_name: Model) -> WhisperModel:
if model_name in models:
logger.debug(f"{model_name} model already loaded")
return models[model_name]
if len(models) >= config.max_models:
oldest_model_name = next(iter(models))
logger.info(
f"Max models ({config.max_models}) reached. Unloading the oldest model: {oldest_model_name}"
)
del models[oldest_model_name]
logger.debug(f"Loading {model_name}")
start = time.perf_counter()
whisper = WhisperModel(
model_name,
device=config.whisper.inference_device,
compute_type=config.whisper.compute_type,
)
logger.info(
f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds"
)
models[model_name] = whisper
return whisper
@asynccontextmanager
async def lifespan(_: FastAPI):
load_model(config.whisper.model)
yield
for model in models.keys():
logger.info(f"Unloading {model}")
del models[model]
app = FastAPI(lifespan=lifespan)
@app.get("/health")
def health() -> Response:
return Response(status_code=200, content="OK")
@app.post("/v1/audio/translations")
def translate_file(
file: Annotated[UploadFile, Form()],
model: Annotated[Model, Form()] = config.whisper.model,
prompt: Annotated[str | None, Form()] = None,
response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
temperature: Annotated[float, Form()] = 0.0,
stream: Annotated[bool, Form()] = False,
):
start = time.perf_counter()
whisper = load_model(model)
segments, transcription_info = whisper.transcribe(
file.file,
task="translate",
initial_prompt=prompt,
temperature=temperature,
vad_filter=True,
)
def segment_responses():
for segment in segments:
if response_format == ResponseFormat.TEXT:
yield segment.text
elif response_format == ResponseFormat.JSON:
yield TranscriptionJsonResponse.from_segments(
[segment]
).model_dump_json()
elif response_format == ResponseFormat.VERBOSE_JSON:
yield TranscriptionVerboseJsonResponse.from_segment(
segment, transcription_info
).model_dump_json()
if not stream:
segments = list(segments)
logger.info(
f"Translated {transcription_info.duration}({transcription_info.duration_after_vad}) seconds of audio in {time.perf_counter() - start:.2f} seconds"
)
if response_format == ResponseFormat.TEXT:
return utils.segments_text(segments)
elif response_format == ResponseFormat.JSON:
return TranscriptionJsonResponse.from_segments(segments)
elif response_format == ResponseFormat.VERBOSE_JSON:
return TranscriptionVerboseJsonResponse.from_segments(
segments, transcription_info
)
else:
return StreamingResponse(segment_responses(), media_type="text/event-stream")
# https://platform.openai.com/docs/api-reference/audio/createTranscription
# https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8915
@app.post("/v1/audio/transcriptions")
def transcribe_file(
file: Annotated[UploadFile, Form()],
model: Annotated[Model, Form()] = config.whisper.model,
language: Annotated[Language | None, Form()] = config.default_language,
prompt: Annotated[str | None, Form()] = None,
response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
temperature: Annotated[float, Form()] = 0.0,
timestamp_granularities: Annotated[
list[Literal["segments"] | Literal["words"]],
Form(alias="timestamp_granularities[]"),
] = ["segments"],
stream: Annotated[bool, Form()] = False,
):
start = time.perf_counter()
whisper = load_model(model)
segments, transcription_info = whisper.transcribe(
file.file,
task="transcribe",
language=language,
initial_prompt=prompt,
word_timestamps="words" in timestamp_granularities,
temperature=temperature,
vad_filter=True,
)
def segment_responses():
for segment in segments:
logger.info(
f"Transcribed {segment.end - segment.start} seconds of audio in {time.perf_counter() - start:.2f} seconds"
)
if response_format == ResponseFormat.TEXT:
yield segment.text
elif response_format == ResponseFormat.JSON:
yield TranscriptionJsonResponse.from_segments(
[segment]
).model_dump_json()
elif response_format == ResponseFormat.VERBOSE_JSON:
yield TranscriptionVerboseJsonResponse.from_segment(
segment, transcription_info
).model_dump_json()
if not stream:
segments = list(segments)
logger.info(
f"Transcribed {transcription_info.duration}({transcription_info.duration_after_vad}) seconds of audio in {time.perf_counter() - start:.2f} seconds"
)
if response_format == ResponseFormat.TEXT:
return utils.segments_text(segments)
elif response_format == ResponseFormat.JSON:
return TranscriptionJsonResponse.from_segments(segments)
elif response_format == ResponseFormat.VERBOSE_JSON:
return TranscriptionVerboseJsonResponse.from_segments(
segments, transcription_info
)
else:
return StreamingResponse(segment_responses(), media_type="text/event-stream")
async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
try:
while True:
bytes_ = await asyncio.wait_for(
ws.receive_bytes(), timeout=config.max_no_data_seconds
)
logger.debug(f"Received {len(bytes_)} bytes of audio data")
audio_samples = audio_samples_from_file(BytesIO(bytes_))
audio_stream.extend(audio_samples)
if audio_stream.duration - config.inactivity_window_seconds >= 0:
audio = audio_stream.after(
audio_stream.duration - config.inactivity_window_seconds
)
vad_opts = VadOptions(min_silence_duration_ms=500, speech_pad_ms=0)
# NOTE: This is a synchronous operation that runs every time new data is received.
# This shouldn't be an issue unless data is being received in tiny chunks or the user's machine is a potato.
timestamps = get_speech_timestamps(audio.data, vad_opts)
if len(timestamps) == 0:
logger.info(
f"No speech detected in the last {config.inactivity_window_seconds} seconds."
)
break
elif (
# last speech end time
config.inactivity_window_seconds
- timestamps[-1]["end"] / SAMPLES_PER_SECOND
>= config.max_inactivity_seconds
):
logger.info(
f"Not enough speech in the last {config.inactivity_window_seconds} seconds."
)
break
except asyncio.TimeoutError:
logger.info(
f"No data received in {config.max_no_data_seconds} seconds. Closing the connection."
)
except WebSocketDisconnect as e:
logger.info(f"Client disconnected: {e}")
audio_stream.close()
@app.websocket("/v1/audio/transcriptions")
async def transcribe_stream(
ws: WebSocket,
model: Annotated[Model, Query()] = config.whisper.model,
language: Annotated[Language | None, Query()] = config.default_language,
prompt: Annotated[str | None, Query()] = None,
response_format: Annotated[
ResponseFormat, Query()
] = config.default_response_format,
temperature: Annotated[float, Query()] = 0.0,
) -> None:
await ws.accept()
transcribe_opts = {
"language": language,
"initial_prompt": prompt,
"temperature": temperature,
"vad_filter": True,
"condition_on_previous_text": False,
}
whisper = load_model(model)
asr = FasterWhisperASR(whisper, **transcribe_opts)
audio_stream = AudioStream()
async with asyncio.TaskGroup() as tg:
tg.create_task(audio_receiver(ws, audio_stream))
async for transcription in audio_transcriber(asr, audio_stream):
logger.debug(f"Sending transcription: {transcription.text}")
if ws.client_state == WebSocketState.DISCONNECTED:
break
if response_format == ResponseFormat.TEXT:
await ws.send_text(transcription.text)
elif response_format == ResponseFormat.JSON:
await ws.send_json(
TranscriptionJsonResponse.from_transcription(
transcription
).model_dump()
)
elif response_format == ResponseFormat.VERBOSE_JSON:
await ws.send_json(
TranscriptionVerboseJsonResponse.from_transcription(
transcription
).model_dump()
)
if not ws.client_state == WebSocketState.DISCONNECTED:
logger.info("Closing the connection.")
await ws.close()