Spaces:
Configuration error
Configuration error
Fedir Zadniprovskyi
commited on
Commit
·
f5b5ebf
1
Parent(s):
c8f37a4
feat: add /v1/models and /v1/model routes #14
Browse filesNow users will also have to specify a full model name.
Before: tiny.en
Now: Systran/faster-whisper-tiny.en
- faster_whisper_server/main.py +55 -9
- faster_whisper_server/server_models.py +16 -1
- tests/__init__.py +0 -0
- tests/api_model_test.py +48 -0
faster_whisper_server/main.py
CHANGED
@@ -6,9 +6,11 @@ from contextlib import asynccontextmanager
|
|
6 |
from io import BytesIO
|
7 |
from typing import Annotated, Literal, OrderedDict
|
8 |
|
|
|
9 |
from fastapi import (
|
10 |
FastAPI,
|
11 |
Form,
|
|
|
12 |
Query,
|
13 |
Response,
|
14 |
UploadFile,
|
@@ -19,6 +21,7 @@ from fastapi.responses import StreamingResponse
|
|
19 |
from fastapi.websockets import WebSocketState
|
20 |
from faster_whisper import WhisperModel
|
21 |
from faster_whisper.vad import VadOptions, get_speech_timestamps
|
|
|
22 |
|
23 |
from faster_whisper_server import utils
|
24 |
from faster_whisper_server.asr import FasterWhisperASR
|
@@ -31,24 +34,25 @@ from faster_whisper_server.config import (
|
|
31 |
)
|
32 |
from faster_whisper_server.logger import logger
|
33 |
from faster_whisper_server.server_models import (
|
|
|
34 |
TranscriptionJsonResponse,
|
35 |
TranscriptionVerboseJsonResponse,
|
36 |
)
|
37 |
from faster_whisper_server.transcriber import audio_transcriber
|
38 |
|
39 |
-
|
40 |
|
41 |
|
42 |
def load_model(model_name: str) -> WhisperModel:
|
43 |
-
if model_name in
|
44 |
logger.debug(f"{model_name} model already loaded")
|
45 |
-
return
|
46 |
-
if len(
|
47 |
-
oldest_model_name = next(iter(
|
48 |
logger.info(
|
49 |
f"Max models ({config.max_models}) reached. Unloading the oldest model: {oldest_model_name}"
|
50 |
)
|
51 |
-
del
|
52 |
logger.debug(f"Loading {model_name}...")
|
53 |
start = time.perf_counter()
|
54 |
# NOTE: will raise an exception if the model name isn't valid
|
@@ -60,7 +64,7 @@ def load_model(model_name: str) -> WhisperModel:
|
|
60 |
logger.info(
|
61 |
f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {config.whisper.inference_device}({config.whisper.compute_type}) will be used for inference."
|
62 |
)
|
63 |
-
|
64 |
return whisper
|
65 |
|
66 |
|
@@ -68,9 +72,9 @@ def load_model(model_name: str) -> WhisperModel:
|
|
68 |
async def lifespan(_: FastAPI):
|
69 |
load_model(config.whisper.model)
|
70 |
yield
|
71 |
-
for model in
|
72 |
logger.info(f"Unloading {model}")
|
73 |
-
del
|
74 |
|
75 |
|
76 |
app = FastAPI(lifespan=lifespan)
|
@@ -81,6 +85,48 @@ def health() -> Response:
|
|
81 |
return Response(status_code=200, content="OK")
|
82 |
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
@app.post("/v1/audio/translations")
|
85 |
def translate_file(
|
86 |
file: Annotated[UploadFile, Form()],
|
|
|
6 |
from io import BytesIO
|
7 |
from typing import Annotated, Literal, OrderedDict
|
8 |
|
9 |
+
import huggingface_hub
|
10 |
from fastapi import (
|
11 |
FastAPI,
|
12 |
Form,
|
13 |
+
HTTPException,
|
14 |
Query,
|
15 |
Response,
|
16 |
UploadFile,
|
|
|
21 |
from fastapi.websockets import WebSocketState
|
22 |
from faster_whisper import WhisperModel
|
23 |
from faster_whisper.vad import VadOptions, get_speech_timestamps
|
24 |
+
from huggingface_hub.hf_api import ModelInfo
|
25 |
|
26 |
from faster_whisper_server import utils
|
27 |
from faster_whisper_server.asr import FasterWhisperASR
|
|
|
34 |
)
|
35 |
from faster_whisper_server.logger import logger
|
36 |
from faster_whisper_server.server_models import (
|
37 |
+
ModelObject,
|
38 |
TranscriptionJsonResponse,
|
39 |
TranscriptionVerboseJsonResponse,
|
40 |
)
|
41 |
from faster_whisper_server.transcriber import audio_transcriber
|
42 |
|
43 |
+
loaded_models: OrderedDict[str, WhisperModel] = OrderedDict()
|
44 |
|
45 |
|
46 |
def load_model(model_name: str) -> WhisperModel:
|
47 |
+
if model_name in loaded_models:
|
48 |
logger.debug(f"{model_name} model already loaded")
|
49 |
+
return loaded_models[model_name]
|
50 |
+
if len(loaded_models) >= config.max_models:
|
51 |
+
oldest_model_name = next(iter(loaded_models))
|
52 |
logger.info(
|
53 |
f"Max models ({config.max_models}) reached. Unloading the oldest model: {oldest_model_name}"
|
54 |
)
|
55 |
+
del loaded_models[oldest_model_name]
|
56 |
logger.debug(f"Loading {model_name}...")
|
57 |
start = time.perf_counter()
|
58 |
# NOTE: will raise an exception if the model name isn't valid
|
|
|
64 |
logger.info(
|
65 |
f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {config.whisper.inference_device}({config.whisper.compute_type}) will be used for inference."
|
66 |
)
|
67 |
+
loaded_models[model_name] = whisper
|
68 |
return whisper
|
69 |
|
70 |
|
|
|
72 |
async def lifespan(_: FastAPI):
|
73 |
load_model(config.whisper.model)
|
74 |
yield
|
75 |
+
for model in loaded_models.keys():
|
76 |
logger.info(f"Unloading {model}")
|
77 |
+
del loaded_models[model]
|
78 |
|
79 |
|
80 |
app = FastAPI(lifespan=lifespan)
|
|
|
85 |
return Response(status_code=200, content="OK")
|
86 |
|
87 |
|
88 |
+
@app.get("/v1/models", response_model=list[ModelObject])
|
89 |
+
def get_models() -> list[ModelObject]:
|
90 |
+
models = huggingface_hub.list_models(library="ctranslate2")
|
91 |
+
models = [
|
92 |
+
ModelObject(
|
93 |
+
id=model.id,
|
94 |
+
created=int(model.created_at.timestamp()),
|
95 |
+
object_="model",
|
96 |
+
owned_by=model.id.split("/")[0],
|
97 |
+
)
|
98 |
+
for model in models
|
99 |
+
if model.created_at is not None
|
100 |
+
]
|
101 |
+
return models
|
102 |
+
|
103 |
+
|
104 |
+
@app.get("/v1/models/{model_name:path}", response_model=ModelObject)
|
105 |
+
def get_model(model_name: str) -> ModelObject:
|
106 |
+
models = list(
|
107 |
+
huggingface_hub.list_models(model_name=model_name, library="ctranslate2")
|
108 |
+
)
|
109 |
+
if len(models) == 0:
|
110 |
+
raise HTTPException(status_code=404, detail="Model doesn't exists")
|
111 |
+
exact_match: ModelInfo | None = None
|
112 |
+
for model in models:
|
113 |
+
if model.id == model_name:
|
114 |
+
exact_match = model
|
115 |
+
break
|
116 |
+
if exact_match is None:
|
117 |
+
raise HTTPException(
|
118 |
+
status_code=404,
|
119 |
+
detail=f"Model doesn't exists. Possible matches: {", ".join([model.id for model in models])}",
|
120 |
+
)
|
121 |
+
assert exact_match.created_at is not None
|
122 |
+
return ModelObject(
|
123 |
+
id=exact_match.id,
|
124 |
+
created=int(exact_match.created_at.timestamp()),
|
125 |
+
object_="model",
|
126 |
+
owned_by=exact_match.id.split("/")[0],
|
127 |
+
)
|
128 |
+
|
129 |
+
|
130 |
@app.post("/v1/audio/translations")
|
131 |
def translate_file(
|
132 |
file: Annotated[UploadFile, Form()],
|
faster_whisper_server/server_models.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
from __future__ import annotations
|
2 |
|
|
|
|
|
3 |
from faster_whisper.transcribe import Segment, TranscriptionInfo, Word
|
4 |
-
from pydantic import BaseModel
|
5 |
|
6 |
from faster_whisper_server import utils
|
7 |
from faster_whisper_server.core import Transcription
|
@@ -125,3 +127,16 @@ class TranscriptionVerboseJsonResponse(BaseModel):
|
|
125 |
],
|
126 |
segments=[], # FIX: hardcoded
|
127 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
+
from typing import Literal
|
4 |
+
|
5 |
from faster_whisper.transcribe import Segment, TranscriptionInfo, Word
|
6 |
+
from pydantic import BaseModel, ConfigDict, Field
|
7 |
|
8 |
from faster_whisper_server import utils
|
9 |
from faster_whisper_server.core import Transcription
|
|
|
127 |
],
|
128 |
segments=[], # FIX: hardcoded
|
129 |
)
|
130 |
+
|
131 |
+
|
132 |
+
class ModelObject(BaseModel):
|
133 |
+
model_config = ConfigDict(populate_by_name=True)
|
134 |
+
|
135 |
+
id: str
|
136 |
+
"""The model identifier, which can be referenced in the API endpoints."""
|
137 |
+
created: int
|
138 |
+
"""The Unix timestamp (in seconds) when the model was created."""
|
139 |
+
object_: Literal["model"] = Field(serialization_alias="object")
|
140 |
+
"""The object type, which is always "model"."""
|
141 |
+
owned_by: str
|
142 |
+
"""The organization that owns the model."""
|
tests/__init__.py
ADDED
File without changes
|
tests/api_model_test.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Generator
|
2 |
+
|
3 |
+
import pytest
|
4 |
+
from fastapi.testclient import TestClient
|
5 |
+
|
6 |
+
from faster_whisper_server.main import app
|
7 |
+
from faster_whisper_server.server_models import ModelObject
|
8 |
+
|
9 |
+
MODEL_THAT_EXISTS = "Systran/faster-whisper-tiny.en"
|
10 |
+
MODEL_THAT_DOES_NOT_EXIST = "i-do-not-exist"
|
11 |
+
MIN_EXPECTED_NUMBER_OF_MODELS = (
|
12 |
+
200 # At the time of the test creation there are 228 models
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
@pytest.fixture()
|
17 |
+
def client() -> Generator[TestClient, None, None]:
|
18 |
+
with TestClient(app) as client:
|
19 |
+
yield client
|
20 |
+
|
21 |
+
|
22 |
+
# HACK: because ModelObject(**data) doesn't work
|
23 |
+
def model_dict_to_object(model_dict: dict) -> ModelObject:
|
24 |
+
return ModelObject(
|
25 |
+
id=model_dict["id"],
|
26 |
+
created=model_dict["created"],
|
27 |
+
object_=model_dict["object"],
|
28 |
+
owned_by=model_dict["owned_by"],
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
def test_list_models(client: TestClient):
|
33 |
+
response = client.get("/v1/models")
|
34 |
+
data = response.json()
|
35 |
+
models = [model_dict_to_object(model_dict) for model_dict in data]
|
36 |
+
assert len(models) > MIN_EXPECTED_NUMBER_OF_MODELS
|
37 |
+
|
38 |
+
|
39 |
+
def test_model_exists(client: TestClient):
|
40 |
+
response = client.get(f"/v1/model/{MODEL_THAT_EXISTS}")
|
41 |
+
data = response.json()
|
42 |
+
model = model_dict_to_object(data)
|
43 |
+
assert model.id == MODEL_THAT_EXISTS
|
44 |
+
|
45 |
+
|
46 |
+
def test_model_does_not_exist(client: TestClient):
|
47 |
+
response = client.get(f"/v1/model/{MODEL_THAT_DOES_NOT_EXIST}")
|
48 |
+
assert response.status_code == 404
|