Fedir Zadniprovskyi commited on
Commit
f5b5ebf
·
1 Parent(s): c8f37a4

feat: add /v1/models and /v1/model routes #14

Browse files

Now users will also have to specify a full model name.
Before: tiny.en
Now: Systran/faster-whisper-tiny.en

faster_whisper_server/main.py CHANGED
@@ -6,9 +6,11 @@ from contextlib import asynccontextmanager
6
  from io import BytesIO
7
  from typing import Annotated, Literal, OrderedDict
8
 
 
9
  from fastapi import (
10
  FastAPI,
11
  Form,
 
12
  Query,
13
  Response,
14
  UploadFile,
@@ -19,6 +21,7 @@ from fastapi.responses import StreamingResponse
19
  from fastapi.websockets import WebSocketState
20
  from faster_whisper import WhisperModel
21
  from faster_whisper.vad import VadOptions, get_speech_timestamps
 
22
 
23
  from faster_whisper_server import utils
24
  from faster_whisper_server.asr import FasterWhisperASR
@@ -31,24 +34,25 @@ from faster_whisper_server.config import (
31
  )
32
  from faster_whisper_server.logger import logger
33
  from faster_whisper_server.server_models import (
 
34
  TranscriptionJsonResponse,
35
  TranscriptionVerboseJsonResponse,
36
  )
37
  from faster_whisper_server.transcriber import audio_transcriber
38
 
39
- models: OrderedDict[str, WhisperModel] = OrderedDict()
40
 
41
 
42
  def load_model(model_name: str) -> WhisperModel:
43
- if model_name in models:
44
  logger.debug(f"{model_name} model already loaded")
45
- return models[model_name]
46
- if len(models) >= config.max_models:
47
- oldest_model_name = next(iter(models))
48
  logger.info(
49
  f"Max models ({config.max_models}) reached. Unloading the oldest model: {oldest_model_name}"
50
  )
51
- del models[oldest_model_name]
52
  logger.debug(f"Loading {model_name}...")
53
  start = time.perf_counter()
54
  # NOTE: will raise an exception if the model name isn't valid
@@ -60,7 +64,7 @@ def load_model(model_name: str) -> WhisperModel:
60
  logger.info(
61
  f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {config.whisper.inference_device}({config.whisper.compute_type}) will be used for inference."
62
  )
63
- models[model_name] = whisper
64
  return whisper
65
 
66
 
@@ -68,9 +72,9 @@ def load_model(model_name: str) -> WhisperModel:
68
  async def lifespan(_: FastAPI):
69
  load_model(config.whisper.model)
70
  yield
71
- for model in models.keys():
72
  logger.info(f"Unloading {model}")
73
- del models[model]
74
 
75
 
76
  app = FastAPI(lifespan=lifespan)
@@ -81,6 +85,48 @@ def health() -> Response:
81
  return Response(status_code=200, content="OK")
82
 
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  @app.post("/v1/audio/translations")
85
  def translate_file(
86
  file: Annotated[UploadFile, Form()],
 
6
  from io import BytesIO
7
  from typing import Annotated, Literal, OrderedDict
8
 
9
+ import huggingface_hub
10
  from fastapi import (
11
  FastAPI,
12
  Form,
13
+ HTTPException,
14
  Query,
15
  Response,
16
  UploadFile,
 
21
  from fastapi.websockets import WebSocketState
22
  from faster_whisper import WhisperModel
23
  from faster_whisper.vad import VadOptions, get_speech_timestamps
24
+ from huggingface_hub.hf_api import ModelInfo
25
 
26
  from faster_whisper_server import utils
27
  from faster_whisper_server.asr import FasterWhisperASR
 
34
  )
35
  from faster_whisper_server.logger import logger
36
  from faster_whisper_server.server_models import (
37
+ ModelObject,
38
  TranscriptionJsonResponse,
39
  TranscriptionVerboseJsonResponse,
40
  )
41
  from faster_whisper_server.transcriber import audio_transcriber
42
 
43
+ loaded_models: OrderedDict[str, WhisperModel] = OrderedDict()
44
 
45
 
46
  def load_model(model_name: str) -> WhisperModel:
47
+ if model_name in loaded_models:
48
  logger.debug(f"{model_name} model already loaded")
49
+ return loaded_models[model_name]
50
+ if len(loaded_models) >= config.max_models:
51
+ oldest_model_name = next(iter(loaded_models))
52
  logger.info(
53
  f"Max models ({config.max_models}) reached. Unloading the oldest model: {oldest_model_name}"
54
  )
55
+ del loaded_models[oldest_model_name]
56
  logger.debug(f"Loading {model_name}...")
57
  start = time.perf_counter()
58
  # NOTE: will raise an exception if the model name isn't valid
 
64
  logger.info(
65
  f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {config.whisper.inference_device}({config.whisper.compute_type}) will be used for inference."
66
  )
67
+ loaded_models[model_name] = whisper
68
  return whisper
69
 
70
 
 
72
  async def lifespan(_: FastAPI):
73
  load_model(config.whisper.model)
74
  yield
75
+ for model in loaded_models.keys():
76
  logger.info(f"Unloading {model}")
77
+ del loaded_models[model]
78
 
79
 
80
  app = FastAPI(lifespan=lifespan)
 
85
  return Response(status_code=200, content="OK")
86
 
87
 
88
+ @app.get("/v1/models", response_model=list[ModelObject])
89
+ def get_models() -> list[ModelObject]:
90
+ models = huggingface_hub.list_models(library="ctranslate2")
91
+ models = [
92
+ ModelObject(
93
+ id=model.id,
94
+ created=int(model.created_at.timestamp()),
95
+ object_="model",
96
+ owned_by=model.id.split("/")[0],
97
+ )
98
+ for model in models
99
+ if model.created_at is not None
100
+ ]
101
+ return models
102
+
103
+
104
+ @app.get("/v1/models/{model_name:path}", response_model=ModelObject)
105
+ def get_model(model_name: str) -> ModelObject:
106
+ models = list(
107
+ huggingface_hub.list_models(model_name=model_name, library="ctranslate2")
108
+ )
109
+ if len(models) == 0:
110
+ raise HTTPException(status_code=404, detail="Model doesn't exists")
111
+ exact_match: ModelInfo | None = None
112
+ for model in models:
113
+ if model.id == model_name:
114
+ exact_match = model
115
+ break
116
+ if exact_match is None:
117
+ raise HTTPException(
118
+ status_code=404,
119
+ detail=f"Model doesn't exists. Possible matches: {", ".join([model.id for model in models])}",
120
+ )
121
+ assert exact_match.created_at is not None
122
+ return ModelObject(
123
+ id=exact_match.id,
124
+ created=int(exact_match.created_at.timestamp()),
125
+ object_="model",
126
+ owned_by=exact_match.id.split("/")[0],
127
+ )
128
+
129
+
130
  @app.post("/v1/audio/translations")
131
  def translate_file(
132
  file: Annotated[UploadFile, Form()],
faster_whisper_server/server_models.py CHANGED
@@ -1,7 +1,9 @@
1
  from __future__ import annotations
2
 
 
 
3
  from faster_whisper.transcribe import Segment, TranscriptionInfo, Word
4
- from pydantic import BaseModel
5
 
6
  from faster_whisper_server import utils
7
  from faster_whisper_server.core import Transcription
@@ -125,3 +127,16 @@ class TranscriptionVerboseJsonResponse(BaseModel):
125
  ],
126
  segments=[], # FIX: hardcoded
127
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
+ from typing import Literal
4
+
5
  from faster_whisper.transcribe import Segment, TranscriptionInfo, Word
6
+ from pydantic import BaseModel, ConfigDict, Field
7
 
8
  from faster_whisper_server import utils
9
  from faster_whisper_server.core import Transcription
 
127
  ],
128
  segments=[], # FIX: hardcoded
129
  )
130
+
131
+
132
+ class ModelObject(BaseModel):
133
+ model_config = ConfigDict(populate_by_name=True)
134
+
135
+ id: str
136
+ """The model identifier, which can be referenced in the API endpoints."""
137
+ created: int
138
+ """The Unix timestamp (in seconds) when the model was created."""
139
+ object_: Literal["model"] = Field(serialization_alias="object")
140
+ """The object type, which is always "model"."""
141
+ owned_by: str
142
+ """The organization that owns the model."""
tests/__init__.py ADDED
File without changes
tests/api_model_test.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Generator
2
+
3
+ import pytest
4
+ from fastapi.testclient import TestClient
5
+
6
+ from faster_whisper_server.main import app
7
+ from faster_whisper_server.server_models import ModelObject
8
+
9
+ MODEL_THAT_EXISTS = "Systran/faster-whisper-tiny.en"
10
+ MODEL_THAT_DOES_NOT_EXIST = "i-do-not-exist"
11
+ MIN_EXPECTED_NUMBER_OF_MODELS = (
12
+ 200 # At the time of the test creation there are 228 models
13
+ )
14
+
15
+
16
+ @pytest.fixture()
17
+ def client() -> Generator[TestClient, None, None]:
18
+ with TestClient(app) as client:
19
+ yield client
20
+
21
+
22
+ # HACK: because ModelObject(**data) doesn't work
23
+ def model_dict_to_object(model_dict: dict) -> ModelObject:
24
+ return ModelObject(
25
+ id=model_dict["id"],
26
+ created=model_dict["created"],
27
+ object_=model_dict["object"],
28
+ owned_by=model_dict["owned_by"],
29
+ )
30
+
31
+
32
+ def test_list_models(client: TestClient):
33
+ response = client.get("/v1/models")
34
+ data = response.json()
35
+ models = [model_dict_to_object(model_dict) for model_dict in data]
36
+ assert len(models) > MIN_EXPECTED_NUMBER_OF_MODELS
37
+
38
+
39
+ def test_model_exists(client: TestClient):
40
+ response = client.get(f"/v1/model/{MODEL_THAT_EXISTS}")
41
+ data = response.json()
42
+ model = model_dict_to_object(data)
43
+ assert model.id == MODEL_THAT_EXISTS
44
+
45
+
46
+ def test_model_does_not_exist(client: TestClient):
47
+ response = client.get(f"/v1/model/{MODEL_THAT_DOES_NOT_EXIST}")
48
+ assert response.status_code == 404