Fedir Zadniprovskyi commited on
Commit
3a14175
·
1 Parent(s): e9aef91

feat: support model preloading (#66)

Browse files
faster_whisper_server/config.py CHANGED
@@ -1,6 +1,7 @@
1
  import enum
 
2
 
3
- from pydantic import BaseModel, Field
4
  from pydantic_settings import BaseSettings, SettingsConfigDict
5
 
6
  SAMPLES_PER_SECOND = 16000
@@ -151,7 +152,9 @@ class WhisperConfig(BaseModel):
151
 
152
  model: str = Field(default="Systran/faster-whisper-medium.en")
153
  """
154
- Huggingface model to use for transcription. Note, the model must support being ran using CTranslate2.
 
 
155
  Models created by authors of `faster-whisper` can be found at https://huggingface.co/Systran
156
  You can find other supported models at https://huggingface.co/models?p=2&sort=trending&search=ctranslate2 and https://huggingface.co/models?sort=trending&search=ct2
157
  """
@@ -199,6 +202,16 @@ class Config(BaseSettings):
199
  """
200
  Maximum number of models that can be loaded at a time.
201
  """
 
 
 
 
 
 
 
 
 
 
202
  max_no_data_seconds: float = 1.0
203
  """
204
  Max duration to wait for the next audio chunk before transcription is finilized and connection is closed.
@@ -218,5 +231,13 @@ class Config(BaseSettings):
218
  Should be greater than `max_inactivity_seconds`
219
  """
220
 
 
 
 
 
 
 
 
 
221
 
222
  config = Config()
 
1
  import enum
2
+ from typing import Self
3
 
4
+ from pydantic import BaseModel, Field, model_validator
5
  from pydantic_settings import BaseSettings, SettingsConfigDict
6
 
7
  SAMPLES_PER_SECOND = 16000
 
152
 
153
  model: str = Field(default="Systran/faster-whisper-medium.en")
154
  """
155
+ Default Huggingface model to use for transcription. Note, the model must support being ran using CTranslate2.
156
+ This model will be used if no model is specified in the request.
157
+
158
  Models created by authors of `faster-whisper` can be found at https://huggingface.co/Systran
159
  You can find other supported models at https://huggingface.co/models?p=2&sort=trending&search=ctranslate2 and https://huggingface.co/models?sort=trending&search=ct2
160
  """
 
202
  """
203
  Maximum number of models that can be loaded at a time.
204
  """
205
+ preload_models: list[str] = Field(
206
+ default_factory=list,
207
+ examples=[
208
+ ["Systran/faster-whisper-medium.en"],
209
+ ["Systran/faster-whisper-medium.en", "Systran/faster-whisper-small.en"],
210
+ ],
211
+ )
212
+ """
213
+ List of models to preload on startup. Shouldn't be greater than `max_models`. By default, the model is first loaded on first request.
214
+ """ # noqa: E501
215
  max_no_data_seconds: float = 1.0
216
  """
217
  Max duration to wait for the next audio chunk before transcription is finilized and connection is closed.
 
231
  Should be greater than `max_inactivity_seconds`
232
  """
233
 
234
+ @model_validator(mode="after")
235
+ def ensure_preloaded_models_is_lte_max_models(self) -> Self:
236
+ if len(self.preload_models) > self.max_models:
237
+ raise ValueError(
238
+ f"Number of preloaded models ({len(self.preload_models)}) is greater than max_models ({self.max_models})" # noqa: E501
239
+ )
240
+ return self
241
+
242
 
243
  config = Config()
faster_whisper_server/main.py CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
2
 
3
  import asyncio
4
  from collections import OrderedDict
 
5
  from io import BytesIO
6
  import time
7
  from typing import TYPE_CHECKING, Annotated, Literal
@@ -45,7 +46,7 @@ from faster_whisper_server.server_models import (
45
  from faster_whisper_server.transcriber import audio_transcriber
46
 
47
  if TYPE_CHECKING:
48
- from collections.abc import Generator, Iterable
49
 
50
  from faster_whisper.transcribe import TranscriptionInfo
51
  from huggingface_hub.hf_api import ModelInfo
@@ -63,7 +64,7 @@ def load_model(model_name: str) -> WhisperModel:
63
  del loaded_models[oldest_model_name]
64
  logger.debug(f"Loading {model_name}...")
65
  start = time.perf_counter()
66
- # NOTE: will raise an exception if the model name isn't valid
67
  whisper = WhisperModel(
68
  model_name,
69
  device=config.whisper.inference_device,
@@ -81,7 +82,15 @@ def load_model(model_name: str) -> WhisperModel:
81
 
82
  logger.debug(f"Config: {config}")
83
 
84
- app = FastAPI()
 
 
 
 
 
 
 
 
85
 
86
  if config.allow_origins is not None:
87
  app.add_middleware(
 
2
 
3
  import asyncio
4
  from collections import OrderedDict
5
+ from contextlib import asynccontextmanager
6
  from io import BytesIO
7
  import time
8
  from typing import TYPE_CHECKING, Annotated, Literal
 
46
  from faster_whisper_server.transcriber import audio_transcriber
47
 
48
  if TYPE_CHECKING:
49
+ from collections.abc import AsyncGenerator, Generator, Iterable
50
 
51
  from faster_whisper.transcribe import TranscriptionInfo
52
  from huggingface_hub.hf_api import ModelInfo
 
64
  del loaded_models[oldest_model_name]
65
  logger.debug(f"Loading {model_name}...")
66
  start = time.perf_counter()
67
+ # NOTE: will raise an exception if the model name isn't valid. Should I do an explicit check?
68
  whisper = WhisperModel(
69
  model_name,
70
  device=config.whisper.inference_device,
 
82
 
83
  logger.debug(f"Config: {config}")
84
 
85
+
86
+ @asynccontextmanager
87
+ async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
88
+ for model_name in config.preload_models:
89
+ load_model(model_name)
90
+ yield
91
+
92
+
93
+ app = FastAPI(lifespan=lifespan)
94
 
95
  if config.allow_origins is not None:
96
  app.add_middleware(