Fedir Zadniprovskyi
feat: ollama-like ps endpoints
b20cbad
raw
history blame
15.9 kB
from __future__ import annotations
import asyncio
from collections import OrderedDict
from contextlib import asynccontextmanager
import gc
from io import BytesIO
import time
from typing import TYPE_CHECKING, Annotated, Literal
from fastapi import (
FastAPI,
Form,
HTTPException,
Path,
Query,
Response,
UploadFile,
WebSocket,
WebSocketDisconnect,
)
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.websockets import WebSocketState
from faster_whisper import WhisperModel
from faster_whisper.vad import VadOptions, get_speech_timestamps
import huggingface_hub
from pydantic import AfterValidator
from faster_whisper_server.asr import FasterWhisperASR
from faster_whisper_server.audio import AudioStream, audio_samples_from_file
from faster_whisper_server.config import (
SAMPLES_PER_SECOND,
Language,
ResponseFormat,
Task,
config,
)
from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt
from faster_whisper_server.logger import logger
from faster_whisper_server.server_models import (
ModelListResponse,
ModelObject,
TranscriptionJsonResponse,
TranscriptionVerboseJsonResponse,
)
from faster_whisper_server.transcriber import audio_transcriber
if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Generator, Iterable
from faster_whisper.transcribe import TranscriptionInfo
from huggingface_hub.hf_api import ModelInfo
loaded_models: OrderedDict[str, WhisperModel] = OrderedDict()
def load_model(model_name: str) -> WhisperModel:
if model_name in loaded_models:
logger.debug(f"{model_name} model already loaded")
return loaded_models[model_name]
if len(loaded_models) >= config.max_models:
oldest_model_name = next(iter(loaded_models))
logger.info(f"Max models ({config.max_models}) reached. Unloading the oldest model: {oldest_model_name}")
del loaded_models[oldest_model_name]
logger.debug(f"Loading {model_name}...")
start = time.perf_counter()
# NOTE: will raise an exception if the model name isn't valid. Should I do an explicit check?
whisper = WhisperModel(
model_name,
device=config.whisper.inference_device,
device_index=config.whisper.device_index,
compute_type=config.whisper.compute_type,
cpu_threads=config.whisper.cpu_threads,
num_workers=config.whisper.num_workers,
)
logger.info(
f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {config.whisper.inference_device}({config.whisper.compute_type}) will be used for inference." # noqa: E501
)
loaded_models[model_name] = whisper
return whisper
logger.debug(f"Config: {config}")
@asynccontextmanager
async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
for model_name in config.preload_models:
load_model(model_name)
yield
app = FastAPI(lifespan=lifespan)
if config.allow_origins is not None:
app.add_middleware(
CORSMiddleware,
allow_origins=config.allow_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/health")
def health() -> Response:
return Response(status_code=200, content="OK")
@app.get("/api/ps", tags=["experimental"], summary="Get a list of loaded models.")
def get_running_models() -> dict[str, list[str]]:
return {"models": list(loaded_models.keys())}
@app.post("/api/ps/{model_name:path}", tags=["experimental"], summary="Load a model into memory.")
def load_model_route(model_name: str) -> Response:
if model_name in loaded_models:
return Response(status_code=409, content="Model already loaded")
load_model(model_name)
return Response(status_code=201)
@app.delete("/api/ps/{model_name:path}", tags=["experimental"], summary="Unload a model from memory.")
def stop_running_model(model_name: str) -> Response:
model = loaded_models.get(model_name)
if model is not None:
del loaded_models[model_name]
gc.collect()
return Response(status_code=204)
return Response(status_code=404)
@app.get("/v1/models")
def get_models() -> ModelListResponse:
models = huggingface_hub.list_models(library="ctranslate2", tags="automatic-speech-recognition", cardData=True)
models = list(models)
models.sort(key=lambda model: model.downloads, reverse=True)
transformed_models: list[ModelObject] = []
for model in models:
assert model.created_at is not None
assert model.card_data is not None
assert model.card_data.language is None or isinstance(model.card_data.language, str | list)
if model.card_data.language is None:
language = []
elif isinstance(model.card_data.language, str):
language = [model.card_data.language]
else:
language = model.card_data.language
transformed_model = ModelObject(
id=model.id,
created=int(model.created_at.timestamp()),
object_="model",
owned_by=model.id.split("/")[0],
language=language,
)
transformed_models.append(transformed_model)
return ModelListResponse(data=transformed_models)
@app.get("/v1/models/{model_name:path}")
# NOTE: `examples` doesn't work https://github.com/tiangolo/fastapi/discussions/10537
def get_model(
model_name: Annotated[str, Path(example="Systran/faster-distil-whisper-large-v3")],
) -> ModelObject:
models = huggingface_hub.list_models(
model_name=model_name, library="ctranslate2", tags="automatic-speech-recognition", cardData=True
)
models = list(models)
models.sort(key=lambda model: model.downloads, reverse=True)
if len(models) == 0:
raise HTTPException(status_code=404, detail="Model doesn't exists")
exact_match: ModelInfo | None = None
for model in models:
if model.id == model_name:
exact_match = model
break
if exact_match is None:
raise HTTPException(
status_code=404,
detail=f"Model doesn't exists. Possible matches: {', '.join([model.id for model in models])}",
)
assert exact_match.created_at is not None
assert exact_match.card_data is not None
assert exact_match.card_data.language is None or isinstance(exact_match.card_data.language, str | list)
if exact_match.card_data.language is None:
language = []
elif isinstance(exact_match.card_data.language, str):
language = [exact_match.card_data.language]
else:
language = exact_match.card_data.language
return ModelObject(
id=exact_match.id,
created=int(exact_match.created_at.timestamp()),
object_="model",
owned_by=exact_match.id.split("/")[0],
language=language,
)
def segments_to_response(
segments: Iterable[Segment],
transcription_info: TranscriptionInfo,
response_format: ResponseFormat,
) -> Response:
segments = list(segments)
if response_format == ResponseFormat.TEXT: # noqa: RET503
return Response(segments_to_text(segments), media_type="text/plain")
elif response_format == ResponseFormat.JSON:
return Response(
TranscriptionJsonResponse.from_segments(segments).model_dump_json(),
media_type="application/json",
)
elif response_format == ResponseFormat.VERBOSE_JSON:
return Response(
TranscriptionVerboseJsonResponse.from_segments(segments, transcription_info).model_dump_json(),
media_type="application/json",
)
elif response_format == ResponseFormat.VTT:
return Response(
"".join(segments_to_vtt(segment, i) for i, segment in enumerate(segments)), media_type="text/vtt"
)
elif response_format == ResponseFormat.SRT:
return Response(
"".join(segments_to_srt(segment, i) for i, segment in enumerate(segments)), media_type="text/plain"
)
def format_as_sse(data: str) -> str:
return f"data: {data}\n\n"
def segments_to_streaming_response(
segments: Iterable[Segment],
transcription_info: TranscriptionInfo,
response_format: ResponseFormat,
) -> StreamingResponse:
def segment_responses() -> Generator[str, None, None]:
for i, segment in enumerate(segments):
if response_format == ResponseFormat.TEXT:
data = segment.text
elif response_format == ResponseFormat.JSON:
data = TranscriptionJsonResponse.from_segments([segment]).model_dump_json()
elif response_format == ResponseFormat.VERBOSE_JSON:
data = TranscriptionVerboseJsonResponse.from_segment(segment, transcription_info).model_dump_json()
elif response_format == ResponseFormat.VTT:
data = segments_to_vtt(segment, i)
elif response_format == ResponseFormat.SRT:
data = segments_to_srt(segment, i)
yield format_as_sse(data)
return StreamingResponse(segment_responses(), media_type="text/event-stream")
def handle_default_openai_model(model_name: str) -> str:
"""Exists because some callers may not be able override the default("whisper-1") model name.
For example, https://github.com/open-webui/open-webui/issues/2248#issuecomment-2162997623.
"""
if model_name == "whisper-1":
logger.info(f"{model_name} is not a valid model name. Using {config.whisper.model} instead.")
return config.whisper.model
return model_name
ModelName = Annotated[str, AfterValidator(handle_default_openai_model)]
@app.post(
"/v1/audio/translations",
response_model=str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse,
)
def translate_file(
file: Annotated[UploadFile, Form()],
model: Annotated[ModelName, Form()] = config.whisper.model,
prompt: Annotated[str | None, Form()] = None,
response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
temperature: Annotated[float, Form()] = 0.0,
stream: Annotated[bool, Form()] = False,
) -> Response | StreamingResponse:
whisper = load_model(model)
segments, transcription_info = whisper.transcribe(
file.file,
task=Task.TRANSLATE,
initial_prompt=prompt,
temperature=temperature,
vad_filter=True,
)
segments = Segment.from_faster_whisper_segments(segments)
if stream:
return segments_to_streaming_response(segments, transcription_info, response_format)
else:
return segments_to_response(segments, transcription_info, response_format)
# https://platform.openai.com/docs/api-reference/audio/createTranscription
# https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8915
@app.post(
"/v1/audio/transcriptions",
response_model=str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse,
)
def transcribe_file(
file: Annotated[UploadFile, Form()],
model: Annotated[ModelName, Form()] = config.whisper.model,
language: Annotated[Language | None, Form()] = config.default_language,
prompt: Annotated[str | None, Form()] = None,
response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
temperature: Annotated[float, Form()] = 0.0,
timestamp_granularities: Annotated[
list[Literal["segment", "word"]],
Form(alias="timestamp_granularities[]"),
] = ["segment"],
stream: Annotated[bool, Form()] = False,
hotwords: Annotated[str | None, Form()] = None,
) -> Response | StreamingResponse:
whisper = load_model(model)
segments, transcription_info = whisper.transcribe(
file.file,
task=Task.TRANSCRIBE,
language=language,
initial_prompt=prompt,
word_timestamps="word" in timestamp_granularities,
temperature=temperature,
vad_filter=True,
hotwords=hotwords,
)
segments = Segment.from_faster_whisper_segments(segments)
if stream:
return segments_to_streaming_response(segments, transcription_info, response_format)
else:
return segments_to_response(segments, transcription_info, response_format)
async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
try:
while True:
bytes_ = await asyncio.wait_for(ws.receive_bytes(), timeout=config.max_no_data_seconds)
logger.debug(f"Received {len(bytes_)} bytes of audio data")
audio_samples = audio_samples_from_file(BytesIO(bytes_))
audio_stream.extend(audio_samples)
if audio_stream.duration - config.inactivity_window_seconds >= 0:
audio = audio_stream.after(audio_stream.duration - config.inactivity_window_seconds)
vad_opts = VadOptions(min_silence_duration_ms=500, speech_pad_ms=0)
# NOTE: This is a synchronous operation that runs every time new data is received.
# This shouldn't be an issue unless data is being received in tiny chunks or the user's machine is a potato. # noqa: E501
timestamps = get_speech_timestamps(audio.data, vad_opts)
if len(timestamps) == 0:
logger.info(f"No speech detected in the last {config.inactivity_window_seconds} seconds.")
break
elif (
# last speech end time
config.inactivity_window_seconds - timestamps[-1]["end"] / SAMPLES_PER_SECOND
>= config.max_inactivity_seconds
):
logger.info(f"Not enough speech in the last {config.inactivity_window_seconds} seconds.")
break
except TimeoutError:
logger.info(f"No data received in {config.max_no_data_seconds} seconds. Closing the connection.")
except WebSocketDisconnect as e:
logger.info(f"Client disconnected: {e}")
audio_stream.close()
@app.websocket("/v1/audio/transcriptions")
async def transcribe_stream(
ws: WebSocket,
model: Annotated[ModelName, Query()] = config.whisper.model,
language: Annotated[Language | None, Query()] = config.default_language,
response_format: Annotated[ResponseFormat, Query()] = config.default_response_format,
temperature: Annotated[float, Query()] = 0.0,
) -> None:
await ws.accept()
transcribe_opts = {
"language": language,
"temperature": temperature,
"vad_filter": True,
"condition_on_previous_text": False,
}
whisper = load_model(model)
asr = FasterWhisperASR(whisper, **transcribe_opts)
audio_stream = AudioStream()
async with asyncio.TaskGroup() as tg:
tg.create_task(audio_receiver(ws, audio_stream))
async for transcription in audio_transcriber(asr, audio_stream):
logger.debug(f"Sending transcription: {transcription.text}")
if ws.client_state == WebSocketState.DISCONNECTED:
break
if response_format == ResponseFormat.TEXT:
await ws.send_text(transcription.text)
elif response_format == ResponseFormat.JSON:
await ws.send_json(TranscriptionJsonResponse.from_transcription(transcription).model_dump())
elif response_format == ResponseFormat.VERBOSE_JSON:
await ws.send_json(TranscriptionVerboseJsonResponse.from_transcription(transcription).model_dump())
if ws.client_state != WebSocketState.DISCONNECTED:
logger.info("Closing the connection.")
await ws.close()
if config.enable_ui:
import gradio as gr
from faster_whisper_server.gradio_app import create_gradio_demo
app = gr.mount_gradio_app(app, create_gradio_demo(config), path="/")