Fedir Zadniprovskyi commited on
Commit
624f97e
·
1 Parent(s): 1500d25

refactor: add `ModelManager`

Browse files
src/faster_whisper_server/main.py CHANGED
@@ -1,11 +1,9 @@
1
  from __future__ import annotations
2
 
3
  import asyncio
4
- from collections import OrderedDict
5
  from contextlib import asynccontextmanager
6
  import gc
7
  from io import BytesIO
8
- import time
9
  from typing import TYPE_CHECKING, Annotated, Literal
10
 
11
  from fastapi import (
@@ -22,7 +20,6 @@ from fastapi import (
22
  from fastapi.middleware.cors import CORSMiddleware
23
  from fastapi.responses import StreamingResponse
24
  from fastapi.websockets import WebSocketState
25
- from faster_whisper import WhisperModel
26
  from faster_whisper.vad import VadOptions, get_speech_timestamps
27
  import huggingface_hub
28
  from huggingface_hub.hf_api import RepositoryNotFoundError
@@ -40,6 +37,7 @@ from faster_whisper_server.config import (
40
  )
41
  from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt
42
  from faster_whisper_server.logger import logger
 
43
  from faster_whisper_server.server_models import (
44
  ModelListResponse,
45
  ModelObject,
@@ -54,42 +52,16 @@ if TYPE_CHECKING:
54
  from faster_whisper.transcribe import TranscriptionInfo
55
  from huggingface_hub.hf_api import ModelInfo
56
 
57
- loaded_models: OrderedDict[str, WhisperModel] = OrderedDict()
58
-
59
-
60
- def load_model(model_name: str) -> WhisperModel:
61
- if model_name in loaded_models:
62
- logger.debug(f"{model_name} model already loaded")
63
- return loaded_models[model_name]
64
- if len(loaded_models) >= config.max_models:
65
- oldest_model_name = next(iter(loaded_models))
66
- logger.info(f"Max models ({config.max_models}) reached. Unloading the oldest model: {oldest_model_name}")
67
- del loaded_models[oldest_model_name]
68
- logger.debug(f"Loading {model_name}...")
69
- start = time.perf_counter()
70
- # NOTE: will raise an exception if the model name isn't valid. Should I do an explicit check?
71
- whisper = WhisperModel(
72
- model_name,
73
- device=config.whisper.inference_device,
74
- device_index=config.whisper.device_index,
75
- compute_type=config.whisper.compute_type,
76
- cpu_threads=config.whisper.cpu_threads,
77
- num_workers=config.whisper.num_workers,
78
- )
79
- logger.info(
80
- f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {config.whisper.inference_device}({config.whisper.compute_type}) will be used for inference." # noqa: E501
81
- )
82
- loaded_models[model_name] = whisper
83
- return whisper
84
-
85
 
86
  logger.debug(f"Config: {config}")
87
 
 
 
88
 
89
  @asynccontextmanager
90
  async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
91
  for model_name in config.preload_models:
92
- load_model(model_name)
93
  yield
94
 
95
 
@@ -123,22 +95,22 @@ def pull_model(model_name: str) -> Response:
123
 
124
  @app.get("/api/ps", tags=["experimental"], summary="Get a list of loaded models.")
125
  def get_running_models() -> dict[str, list[str]]:
126
- return {"models": list(loaded_models.keys())}
127
 
128
 
129
  @app.post("/api/ps/{model_name:path}", tags=["experimental"], summary="Load a model into memory.")
130
  def load_model_route(model_name: str) -> Response:
131
- if model_name in loaded_models:
132
  return Response(status_code=409, content="Model already loaded")
133
- load_model(model_name)
134
  return Response(status_code=201)
135
 
136
 
137
  @app.delete("/api/ps/{model_name:path}", tags=["experimental"], summary="Unload a model from memory.")
138
  def stop_running_model(model_name: str) -> Response:
139
- model = loaded_models.get(model_name)
140
  if model is not None:
141
- del loaded_models[model_name]
142
  gc.collect()
143
  return Response(status_code=204)
144
  return Response(status_code=404)
@@ -291,7 +263,7 @@ def translate_file(
291
  temperature: Annotated[float, Form()] = 0.0,
292
  stream: Annotated[bool, Form()] = False,
293
  ) -> Response | StreamingResponse:
294
- whisper = load_model(model)
295
  segments, transcription_info = whisper.transcribe(
296
  file.file,
297
  task=Task.TRANSLATE,
@@ -327,7 +299,7 @@ def transcribe_file(
327
  stream: Annotated[bool, Form()] = False,
328
  hotwords: Annotated[str | None, Form()] = None,
329
  ) -> Response | StreamingResponse:
330
- whisper = load_model(model)
331
  segments, transcription_info = whisper.transcribe(
332
  file.file,
333
  task=Task.TRANSCRIBE,
@@ -391,7 +363,7 @@ async def transcribe_stream(
391
  "vad_filter": True,
392
  "condition_on_previous_text": False,
393
  }
394
- whisper = load_model(model)
395
  asr = FasterWhisperASR(whisper, **transcribe_opts)
396
  audio_stream = AudioStream()
397
  async with asyncio.TaskGroup() as tg:
 
1
  from __future__ import annotations
2
 
3
  import asyncio
 
4
  from contextlib import asynccontextmanager
5
  import gc
6
  from io import BytesIO
 
7
  from typing import TYPE_CHECKING, Annotated, Literal
8
 
9
  from fastapi import (
 
20
  from fastapi.middleware.cors import CORSMiddleware
21
  from fastapi.responses import StreamingResponse
22
  from fastapi.websockets import WebSocketState
 
23
  from faster_whisper.vad import VadOptions, get_speech_timestamps
24
  import huggingface_hub
25
  from huggingface_hub.hf_api import RepositoryNotFoundError
 
37
  )
38
  from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt
39
  from faster_whisper_server.logger import logger
40
+ from faster_whisper_server.model_manager import ModelManager
41
  from faster_whisper_server.server_models import (
42
  ModelListResponse,
43
  ModelObject,
 
52
  from faster_whisper.transcribe import TranscriptionInfo
53
  from huggingface_hub.hf_api import ModelInfo
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  logger.debug(f"Config: {config}")
57
 
58
+ model_manager = ModelManager()
59
+
60
 
61
  @asynccontextmanager
62
  async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
63
  for model_name in config.preload_models:
64
+ model_manager.load_model(model_name)
65
  yield
66
 
67
 
 
95
 
96
  @app.get("/api/ps", tags=["experimental"], summary="Get a list of loaded models.")
97
  def get_running_models() -> dict[str, list[str]]:
98
+ return {"models": list(model_manager.loaded_models.keys())}
99
 
100
 
101
  @app.post("/api/ps/{model_name:path}", tags=["experimental"], summary="Load a model into memory.")
102
  def load_model_route(model_name: str) -> Response:
103
+ if model_name in model_manager.loaded_models:
104
  return Response(status_code=409, content="Model already loaded")
105
+ model_manager.load_model(model_name)
106
  return Response(status_code=201)
107
 
108
 
109
  @app.delete("/api/ps/{model_name:path}", tags=["experimental"], summary="Unload a model from memory.")
110
  def stop_running_model(model_name: str) -> Response:
111
+ model = model_manager.loaded_models.get(model_name)
112
  if model is not None:
113
+ del model_manager.loaded_models[model_name]
114
  gc.collect()
115
  return Response(status_code=204)
116
  return Response(status_code=404)
 
263
  temperature: Annotated[float, Form()] = 0.0,
264
  stream: Annotated[bool, Form()] = False,
265
  ) -> Response | StreamingResponse:
266
+ whisper = model_manager.load_model(model)
267
  segments, transcription_info = whisper.transcribe(
268
  file.file,
269
  task=Task.TRANSLATE,
 
299
  stream: Annotated[bool, Form()] = False,
300
  hotwords: Annotated[str | None, Form()] = None,
301
  ) -> Response | StreamingResponse:
302
+ whisper = model_manager.load_model(model)
303
  segments, transcription_info = whisper.transcribe(
304
  file.file,
305
  task=Task.TRANSCRIBE,
 
363
  "vad_filter": True,
364
  "condition_on_previous_text": False,
365
  }
366
+ whisper = model_manager.load_model(model)
367
  asr = FasterWhisperASR(whisper, **transcribe_opts)
368
  audio_stream = AudioStream()
369
  async with asyncio.TaskGroup() as tg:
src/faster_whisper_server/model_manager.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from collections import OrderedDict
4
+ import gc
5
+ import time
6
+
7
+ from faster_whisper import WhisperModel
8
+
9
+ from faster_whisper_server.config import (
10
+ config,
11
+ )
12
+ from faster_whisper_server.logger import logger
13
+
14
+
15
+ class ModelManager:
16
+ def __init__(self) -> None:
17
+ self.loaded_models: OrderedDict[str, WhisperModel] = OrderedDict()
18
+
19
+ def load_model(self, model_name: str) -> WhisperModel:
20
+ if model_name in self.loaded_models:
21
+ logger.debug(f"{model_name} model already loaded")
22
+ return self.loaded_models[model_name]
23
+ if len(self.loaded_models) >= config.max_models:
24
+ oldest_model_name = next(iter(self.loaded_models))
25
+ logger.info(f"Max models ({config.max_models}) reached. Unloading the oldest model: {oldest_model_name}")
26
+ del self.loaded_models[oldest_model_name]
27
+ gc.collect()
28
+ logger.debug(f"Loading {model_name}...")
29
+ start = time.perf_counter()
30
+ # NOTE: will raise an exception if the model name isn't valid. Should I do an explicit check?
31
+ whisper = WhisperModel(
32
+ model_name,
33
+ device=config.whisper.inference_device,
34
+ device_index=config.whisper.device_index,
35
+ compute_type=config.whisper.compute_type,
36
+ cpu_threads=config.whisper.cpu_threads,
37
+ num_workers=config.whisper.num_workers,
38
+ )
39
+ logger.info(
40
+ f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {config.whisper.inference_device}({config.whisper.compute_type}) will be used for inference." # noqa: E501
41
+ )
42
+ self.loaded_models[model_name] = whisper
43
+ return whisper