Fedir Zadniprovskyi commited on
Commit
aada575
·
1 Parent(s): 6127072

feat: support loading multiple models

Browse files
Files changed (2) hide show
  1. speaches/config.py +20 -18
  2. speaches/main.py +46 -48
speaches/config.py CHANGED
@@ -163,39 +163,41 @@ class Language(enum.StrEnum):
163
 
164
 
165
  class WhisperConfig(BaseModel):
166
- model: Model = Field(default=Model.DISTIL_MEDIUM_EN) # ENV: WHISPER_MODEL
167
- inference_device: Device = Field(
168
- default=Device.AUTO
169
- ) # ENV: WHISPER_INFERENCE_DEVICE
170
- compute_type: Quantization = Field(
171
- default=Quantization.DEFAULT
172
- ) # ENV: WHISPER_COMPUTE_TYPE
173
 
174
 
175
  class Config(BaseSettings):
 
 
 
 
 
 
 
176
  model_config = SettingsConfigDict(env_nested_delimiter="_")
177
 
178
- log_level: str = "info" # ENV: LOG_LEVEL
179
- default_language: Language | None = None # ENV: DEFAULT_LANGUAGE
180
- default_response_format: ResponseFormat = (
181
- ResponseFormat.JSON
182
- ) # ENV: DEFAULT_RESPONSE_FORMAT
183
- whisper: WhisperConfig = WhisperConfig() # ENV: WHISPER_*
184
  """
185
  Max duration to for the next audio chunk before transcription is finilized and connection is closed.
186
  """
187
- max_no_data_seconds: float = 1.0 # ENV: MAX_NO_DATA_SECONDS
188
- min_duration: float = 1.0 # ENV: MIN_DURATION
189
- word_timestamp_error_margin: float = 0.2 # ENV: WORD_TIMESTAMP_ERROR_MARGIN
190
  """
191
  Max allowed audio duration without any speech being detected before transcription is finilized and connection is closed.
192
  """
193
- max_inactivity_seconds: float = 2.0 # ENV: MAX_INACTIVITY_SECONDS
194
  """
195
  Controls how many latest seconds of audio are being passed through VAD.
196
  Should be greater than `max_inactivity_seconds`
197
  """
198
- inactivity_window_seconds: float = 3.0 # ENV: INACTIVITY_WINDOW_SECONDS
199
 
200
 
201
  config = Config()
 
163
 
164
 
165
  class WhisperConfig(BaseModel):
166
+ model: Model = Field(default=Model.DISTIL_MEDIUM_EN)
167
+ inference_device: Device = Field(default=Device.AUTO)
168
+ compute_type: Quantization = Field(default=Quantization.DEFAULT)
 
 
 
 
169
 
170
 
171
  class Config(BaseSettings):
172
+ """
173
+ Configuration for the application. Values can be set via environment variables.
174
+ Pydantic will automatically handle mapping uppercased environment variables to the corresponding fields.
175
+ To populate nested, the environment should be prefixed with the nested field name and an underscore. For example,
176
+ the environment variable `LOG_LEVEL` will be mapped to `log_level`, `WHISPER_MODEL` to `whisper.model`, etc.
177
+ """
178
+
179
  model_config = SettingsConfigDict(env_nested_delimiter="_")
180
 
181
+ log_level: str = "info"
182
+ default_language: Language | None = None
183
+ default_response_format: ResponseFormat = ResponseFormat.JSON
184
+ whisper: WhisperConfig = WhisperConfig()
185
+ max_models: int = 1
 
186
  """
187
  Max duration to for the next audio chunk before transcription is finilized and connection is closed.
188
  """
189
+ max_no_data_seconds: float = 1.0
190
+ min_duration: float = 1.0
191
+ word_timestamp_error_margin: float = 0.2
192
  """
193
  Max allowed audio duration without any speech being detected before transcription is finilized and connection is closed.
194
  """
195
+ max_inactivity_seconds: float = 2.0
196
  """
197
  Controls how many latest seconds of audio are being passed through VAD.
198
  Should be greater than `max_inactivity_seconds`
199
  """
200
+ inactivity_window_seconds: float = 3.0
201
 
202
 
203
  config = Config()
speaches/main.py CHANGED
@@ -1,11 +1,10 @@
1
  from __future__ import annotations
2
 
3
  import asyncio
4
- import logging
5
  import time
6
  from contextlib import asynccontextmanager
7
  from io import BytesIO
8
- from typing import Annotated, Literal
9
 
10
  from fastapi import (FastAPI, Form, Query, Response, UploadFile, WebSocket,
11
  WebSocketDisconnect)
@@ -19,29 +18,45 @@ from speaches.asr import FasterWhisperASR
19
  from speaches.audio import AudioStream, audio_samples_from_file
20
  from speaches.config import (SAMPLES_PER_SECOND, Language, Model,
21
  ResponseFormat, config)
22
- from speaches.core import Transcription
23
  from speaches.logger import logger
24
  from speaches.server_models import (TranscriptionJsonResponse,
25
  TranscriptionVerboseJsonResponse)
26
  from speaches.transcriber import audio_transcriber
27
 
28
- whisper: WhisperModel = None # type: ignore
29
 
30
 
31
- @asynccontextmanager
32
- async def lifespan(_: FastAPI):
33
- global whisper
34
- logging.debug(f"Loading {config.whisper.model}")
 
 
 
 
 
 
 
35
  start = time.perf_counter()
36
  whisper = WhisperModel(
37
- config.whisper.model,
38
  device=config.whisper.inference_device,
39
  compute_type=config.whisper.compute_type,
40
  )
41
- logger.debug(
42
- f"Loaded {config.whisper.model} loaded in {time.perf_counter() - start:.2f} seconds"
43
  )
 
 
 
 
 
 
 
44
  yield
 
 
 
45
 
46
 
47
  app = FastAPI(lifespan=lifespan)
@@ -53,7 +68,7 @@ def health() -> Response:
53
 
54
 
55
  @app.post("/v1/audio/translations")
56
- async def translate_file(
57
  file: Annotated[UploadFile, Form()],
58
  model: Annotated[Model, Form()] = config.whisper.model,
59
  prompt: Annotated[str | None, Form()] = None,
@@ -61,11 +76,8 @@ async def translate_file(
61
  temperature: Annotated[float, Form()] = 0.0,
62
  stream: Annotated[bool, Form()] = False,
63
  ):
64
- if model != config.whisper.model:
65
- logger.warning(
66
- f"Specifying a model that is different from the default is not supported yet. Using {config.whisper.model}."
67
- )
68
  start = time.perf_counter()
 
69
  segments, transcription_info = whisper.transcribe(
70
  file.file,
71
  task="translate",
@@ -107,7 +119,7 @@ async def translate_file(
107
  # https://platform.openai.com/docs/api-reference/audio/createTranscription
108
  # https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8915
109
  @app.post("/v1/audio/transcriptions")
110
- async def transcribe_file(
111
  file: Annotated[UploadFile, Form()],
112
  model: Annotated[Model, Form()] = config.whisper.model,
113
  language: Annotated[Language | None, Form()] = config.default_language,
@@ -120,11 +132,8 @@ async def transcribe_file(
120
  ] = ["segments"],
121
  stream: Annotated[bool, Form()] = False,
122
  ):
123
- if model != config.whisper.model:
124
- logger.warning(
125
- f"Specifying a model that is different from the default is not supported yet. Using {config.whisper.model}."
126
- )
127
  start = time.perf_counter()
 
128
  segments, transcription_info = whisper.transcribe(
129
  file.file,
130
  task="transcribe",
@@ -209,21 +218,6 @@ async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
209
  audio_stream.close()
210
 
211
 
212
- def format_transcription(
213
- transcription: Transcription, response_format: ResponseFormat
214
- ) -> str:
215
- if response_format == ResponseFormat.TEXT:
216
- return transcription.text
217
- elif response_format == ResponseFormat.JSON:
218
- return TranscriptionJsonResponse.from_transcription(
219
- transcription
220
- ).model_dump_json()
221
- elif response_format == ResponseFormat.VERBOSE_JSON:
222
- return TranscriptionVerboseJsonResponse.from_transcription(
223
- transcription
224
- ).model_dump_json()
225
-
226
-
227
  @app.websocket("/v1/audio/transcriptions")
228
  async def transcribe_stream(
229
  ws: WebSocket,
@@ -234,18 +228,7 @@ async def transcribe_stream(
234
  ResponseFormat, Query()
235
  ] = config.default_response_format,
236
  temperature: Annotated[float, Query()] = 0.0,
237
- timestamp_granularities: Annotated[
238
- list[Literal["segments"] | Literal["words"]],
239
- Query(
240
- alias="timestamp_granularities[]",
241
- description="No-op. Ignored. Only for compatibility.",
242
- ),
243
- ] = ["segments", "words"],
244
  ) -> None:
245
- if model != config.whisper.model:
246
- logger.warning(
247
- f"Specifying a model that is different from the default is not supported yet. Using {config.whisper.model}."
248
- )
249
  await ws.accept()
250
  transcribe_opts = {
251
  "language": language,
@@ -254,6 +237,7 @@ async def transcribe_stream(
254
  "vad_filter": True,
255
  "condition_on_previous_text": False,
256
  }
 
257
  asr = FasterWhisperASR(whisper, **transcribe_opts)
258
  audio_stream = AudioStream()
259
  async with asyncio.TaskGroup() as tg:
@@ -262,7 +246,21 @@ async def transcribe_stream(
262
  logger.debug(f"Sending transcription: {transcription.text}")
263
  if ws.client_state == WebSocketState.DISCONNECTED:
264
  break
265
- await ws.send_text(format_transcription(transcription, response_format))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
  if not ws.client_state == WebSocketState.DISCONNECTED:
268
  logger.info("Closing the connection.")
 
1
  from __future__ import annotations
2
 
3
  import asyncio
 
4
  import time
5
  from contextlib import asynccontextmanager
6
  from io import BytesIO
7
+ from typing import Annotated, Literal, OrderedDict
8
 
9
  from fastapi import (FastAPI, Form, Query, Response, UploadFile, WebSocket,
10
  WebSocketDisconnect)
 
18
  from speaches.audio import AudioStream, audio_samples_from_file
19
  from speaches.config import (SAMPLES_PER_SECOND, Language, Model,
20
  ResponseFormat, config)
 
21
  from speaches.logger import logger
22
  from speaches.server_models import (TranscriptionJsonResponse,
23
  TranscriptionVerboseJsonResponse)
24
  from speaches.transcriber import audio_transcriber
25
 
26
+ models: OrderedDict[Model, WhisperModel] = OrderedDict()
27
 
28
 
29
+ def load_model(model_name: Model) -> WhisperModel:
30
+ if model_name in models:
31
+ logger.debug(f"{model_name} model already loaded")
32
+ return models[model_name]
33
+ if len(models) >= config.max_models:
34
+ oldest_model_name = next(iter(models))
35
+ logger.info(
36
+ f"Max models ({config.max_models}) reached. Unloading the oldest model: {oldest_model_name}"
37
+ )
38
+ del models[oldest_model_name]
39
+ logger.debug(f"Loading {model_name}")
40
  start = time.perf_counter()
41
  whisper = WhisperModel(
42
+ model_name,
43
  device=config.whisper.inference_device,
44
  compute_type=config.whisper.compute_type,
45
  )
46
+ logger.info(
47
+ f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds"
48
  )
49
+ models[model_name] = whisper
50
+ return whisper
51
+
52
+
53
+ @asynccontextmanager
54
+ async def lifespan(_: FastAPI):
55
+ load_model(config.whisper.model)
56
  yield
57
+ for model in models.keys():
58
+ logger.info(f"Unloading {model}")
59
+ del models[model]
60
 
61
 
62
  app = FastAPI(lifespan=lifespan)
 
68
 
69
 
70
  @app.post("/v1/audio/translations")
71
+ def translate_file(
72
  file: Annotated[UploadFile, Form()],
73
  model: Annotated[Model, Form()] = config.whisper.model,
74
  prompt: Annotated[str | None, Form()] = None,
 
76
  temperature: Annotated[float, Form()] = 0.0,
77
  stream: Annotated[bool, Form()] = False,
78
  ):
 
 
 
 
79
  start = time.perf_counter()
80
+ whisper = load_model(model)
81
  segments, transcription_info = whisper.transcribe(
82
  file.file,
83
  task="translate",
 
119
  # https://platform.openai.com/docs/api-reference/audio/createTranscription
120
  # https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8915
121
  @app.post("/v1/audio/transcriptions")
122
+ def transcribe_file(
123
  file: Annotated[UploadFile, Form()],
124
  model: Annotated[Model, Form()] = config.whisper.model,
125
  language: Annotated[Language | None, Form()] = config.default_language,
 
132
  ] = ["segments"],
133
  stream: Annotated[bool, Form()] = False,
134
  ):
 
 
 
 
135
  start = time.perf_counter()
136
+ whisper = load_model(model)
137
  segments, transcription_info = whisper.transcribe(
138
  file.file,
139
  task="transcribe",
 
218
  audio_stream.close()
219
 
220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  @app.websocket("/v1/audio/transcriptions")
222
  async def transcribe_stream(
223
  ws: WebSocket,
 
228
  ResponseFormat, Query()
229
  ] = config.default_response_format,
230
  temperature: Annotated[float, Query()] = 0.0,
 
 
 
 
 
 
 
231
  ) -> None:
 
 
 
 
232
  await ws.accept()
233
  transcribe_opts = {
234
  "language": language,
 
237
  "vad_filter": True,
238
  "condition_on_previous_text": False,
239
  }
240
+ whisper = load_model(model)
241
  asr = FasterWhisperASR(whisper, **transcribe_opts)
242
  audio_stream = AudioStream()
243
  async with asyncio.TaskGroup() as tg:
 
246
  logger.debug(f"Sending transcription: {transcription.text}")
247
  if ws.client_state == WebSocketState.DISCONNECTED:
248
  break
249
+
250
+ if response_format == ResponseFormat.TEXT:
251
+ await ws.send_text(transcription.text)
252
+ elif response_format == ResponseFormat.JSON:
253
+ await ws.send_json(
254
+ TranscriptionJsonResponse.from_transcription(
255
+ transcription
256
+ ).model_dump()
257
+ )
258
+ elif response_format == ResponseFormat.VERBOSE_JSON:
259
+ await ws.send_json(
260
+ TranscriptionVerboseJsonResponse.from_transcription(
261
+ transcription
262
+ ).model_dump()
263
+ )
264
 
265
  if not ws.client_state == WebSocketState.DISCONNECTED:
266
  logger.info("Closing the connection.")