Spaces:
Configuration error
Configuration error
Fedir Zadniprovskyi
commited on
Commit
·
624f97e
1
Parent(s):
1500d25
refactor: add `ModelManager`
Browse files
src/faster_whisper_server/main.py
CHANGED
@@ -1,11 +1,9 @@
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
import asyncio
|
4 |
-
from collections import OrderedDict
|
5 |
from contextlib import asynccontextmanager
|
6 |
import gc
|
7 |
from io import BytesIO
|
8 |
-
import time
|
9 |
from typing import TYPE_CHECKING, Annotated, Literal
|
10 |
|
11 |
from fastapi import (
|
@@ -22,7 +20,6 @@ from fastapi import (
|
|
22 |
from fastapi.middleware.cors import CORSMiddleware
|
23 |
from fastapi.responses import StreamingResponse
|
24 |
from fastapi.websockets import WebSocketState
|
25 |
-
from faster_whisper import WhisperModel
|
26 |
from faster_whisper.vad import VadOptions, get_speech_timestamps
|
27 |
import huggingface_hub
|
28 |
from huggingface_hub.hf_api import RepositoryNotFoundError
|
@@ -40,6 +37,7 @@ from faster_whisper_server.config import (
|
|
40 |
)
|
41 |
from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt
|
42 |
from faster_whisper_server.logger import logger
|
|
|
43 |
from faster_whisper_server.server_models import (
|
44 |
ModelListResponse,
|
45 |
ModelObject,
|
@@ -54,42 +52,16 @@ if TYPE_CHECKING:
|
|
54 |
from faster_whisper.transcribe import TranscriptionInfo
|
55 |
from huggingface_hub.hf_api import ModelInfo
|
56 |
|
57 |
-
loaded_models: OrderedDict[str, WhisperModel] = OrderedDict()
|
58 |
-
|
59 |
-
|
60 |
-
def load_model(model_name: str) -> WhisperModel:
|
61 |
-
if model_name in loaded_models:
|
62 |
-
logger.debug(f"{model_name} model already loaded")
|
63 |
-
return loaded_models[model_name]
|
64 |
-
if len(loaded_models) >= config.max_models:
|
65 |
-
oldest_model_name = next(iter(loaded_models))
|
66 |
-
logger.info(f"Max models ({config.max_models}) reached. Unloading the oldest model: {oldest_model_name}")
|
67 |
-
del loaded_models[oldest_model_name]
|
68 |
-
logger.debug(f"Loading {model_name}...")
|
69 |
-
start = time.perf_counter()
|
70 |
-
# NOTE: will raise an exception if the model name isn't valid. Should I do an explicit check?
|
71 |
-
whisper = WhisperModel(
|
72 |
-
model_name,
|
73 |
-
device=config.whisper.inference_device,
|
74 |
-
device_index=config.whisper.device_index,
|
75 |
-
compute_type=config.whisper.compute_type,
|
76 |
-
cpu_threads=config.whisper.cpu_threads,
|
77 |
-
num_workers=config.whisper.num_workers,
|
78 |
-
)
|
79 |
-
logger.info(
|
80 |
-
f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {config.whisper.inference_device}({config.whisper.compute_type}) will be used for inference." # noqa: E501
|
81 |
-
)
|
82 |
-
loaded_models[model_name] = whisper
|
83 |
-
return whisper
|
84 |
-
|
85 |
|
86 |
logger.debug(f"Config: {config}")
|
87 |
|
|
|
|
|
88 |
|
89 |
@asynccontextmanager
|
90 |
async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
|
91 |
for model_name in config.preload_models:
|
92 |
-
load_model(model_name)
|
93 |
yield
|
94 |
|
95 |
|
@@ -123,22 +95,22 @@ def pull_model(model_name: str) -> Response:
|
|
123 |
|
124 |
@app.get("/api/ps", tags=["experimental"], summary="Get a list of loaded models.")
|
125 |
def get_running_models() -> dict[str, list[str]]:
|
126 |
-
return {"models": list(loaded_models.keys())}
|
127 |
|
128 |
|
129 |
@app.post("/api/ps/{model_name:path}", tags=["experimental"], summary="Load a model into memory.")
|
130 |
def load_model_route(model_name: str) -> Response:
|
131 |
-
if model_name in loaded_models:
|
132 |
return Response(status_code=409, content="Model already loaded")
|
133 |
-
load_model(model_name)
|
134 |
return Response(status_code=201)
|
135 |
|
136 |
|
137 |
@app.delete("/api/ps/{model_name:path}", tags=["experimental"], summary="Unload a model from memory.")
|
138 |
def stop_running_model(model_name: str) -> Response:
|
139 |
-
model = loaded_models.get(model_name)
|
140 |
if model is not None:
|
141 |
-
del loaded_models[model_name]
|
142 |
gc.collect()
|
143 |
return Response(status_code=204)
|
144 |
return Response(status_code=404)
|
@@ -291,7 +263,7 @@ def translate_file(
|
|
291 |
temperature: Annotated[float, Form()] = 0.0,
|
292 |
stream: Annotated[bool, Form()] = False,
|
293 |
) -> Response | StreamingResponse:
|
294 |
-
whisper = load_model(model)
|
295 |
segments, transcription_info = whisper.transcribe(
|
296 |
file.file,
|
297 |
task=Task.TRANSLATE,
|
@@ -327,7 +299,7 @@ def transcribe_file(
|
|
327 |
stream: Annotated[bool, Form()] = False,
|
328 |
hotwords: Annotated[str | None, Form()] = None,
|
329 |
) -> Response | StreamingResponse:
|
330 |
-
whisper = load_model(model)
|
331 |
segments, transcription_info = whisper.transcribe(
|
332 |
file.file,
|
333 |
task=Task.TRANSCRIBE,
|
@@ -391,7 +363,7 @@ async def transcribe_stream(
|
|
391 |
"vad_filter": True,
|
392 |
"condition_on_previous_text": False,
|
393 |
}
|
394 |
-
whisper = load_model(model)
|
395 |
asr = FasterWhisperASR(whisper, **transcribe_opts)
|
396 |
audio_stream = AudioStream()
|
397 |
async with asyncio.TaskGroup() as tg:
|
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
import asyncio
|
|
|
4 |
from contextlib import asynccontextmanager
|
5 |
import gc
|
6 |
from io import BytesIO
|
|
|
7 |
from typing import TYPE_CHECKING, Annotated, Literal
|
8 |
|
9 |
from fastapi import (
|
|
|
20 |
from fastapi.middleware.cors import CORSMiddleware
|
21 |
from fastapi.responses import StreamingResponse
|
22 |
from fastapi.websockets import WebSocketState
|
|
|
23 |
from faster_whisper.vad import VadOptions, get_speech_timestamps
|
24 |
import huggingface_hub
|
25 |
from huggingface_hub.hf_api import RepositoryNotFoundError
|
|
|
37 |
)
|
38 |
from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt
|
39 |
from faster_whisper_server.logger import logger
|
40 |
+
from faster_whisper_server.model_manager import ModelManager
|
41 |
from faster_whisper_server.server_models import (
|
42 |
ModelListResponse,
|
43 |
ModelObject,
|
|
|
52 |
from faster_whisper.transcribe import TranscriptionInfo
|
53 |
from huggingface_hub.hf_api import ModelInfo
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
logger.debug(f"Config: {config}")
|
57 |
|
58 |
+
model_manager = ModelManager()
|
59 |
+
|
60 |
|
61 |
@asynccontextmanager
|
62 |
async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
|
63 |
for model_name in config.preload_models:
|
64 |
+
model_manager.load_model(model_name)
|
65 |
yield
|
66 |
|
67 |
|
|
|
95 |
|
96 |
@app.get("/api/ps", tags=["experimental"], summary="Get a list of loaded models.")
|
97 |
def get_running_models() -> dict[str, list[str]]:
|
98 |
+
return {"models": list(model_manager.loaded_models.keys())}
|
99 |
|
100 |
|
101 |
@app.post("/api/ps/{model_name:path}", tags=["experimental"], summary="Load a model into memory.")
|
102 |
def load_model_route(model_name: str) -> Response:
|
103 |
+
if model_name in model_manager.loaded_models:
|
104 |
return Response(status_code=409, content="Model already loaded")
|
105 |
+
model_manager.load_model(model_name)
|
106 |
return Response(status_code=201)
|
107 |
|
108 |
|
109 |
@app.delete("/api/ps/{model_name:path}", tags=["experimental"], summary="Unload a model from memory.")
|
110 |
def stop_running_model(model_name: str) -> Response:
|
111 |
+
model = model_manager.loaded_models.get(model_name)
|
112 |
if model is not None:
|
113 |
+
del model_manager.loaded_models[model_name]
|
114 |
gc.collect()
|
115 |
return Response(status_code=204)
|
116 |
return Response(status_code=404)
|
|
|
263 |
temperature: Annotated[float, Form()] = 0.0,
|
264 |
stream: Annotated[bool, Form()] = False,
|
265 |
) -> Response | StreamingResponse:
|
266 |
+
whisper = model_manager.load_model(model)
|
267 |
segments, transcription_info = whisper.transcribe(
|
268 |
file.file,
|
269 |
task=Task.TRANSLATE,
|
|
|
299 |
stream: Annotated[bool, Form()] = False,
|
300 |
hotwords: Annotated[str | None, Form()] = None,
|
301 |
) -> Response | StreamingResponse:
|
302 |
+
whisper = model_manager.load_model(model)
|
303 |
segments, transcription_info = whisper.transcribe(
|
304 |
file.file,
|
305 |
task=Task.TRANSCRIBE,
|
|
|
363 |
"vad_filter": True,
|
364 |
"condition_on_previous_text": False,
|
365 |
}
|
366 |
+
whisper = model_manager.load_model(model)
|
367 |
asr = FasterWhisperASR(whisper, **transcribe_opts)
|
368 |
audio_stream = AudioStream()
|
369 |
async with asyncio.TaskGroup() as tg:
|
src/faster_whisper_server/model_manager.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from collections import OrderedDict
|
4 |
+
import gc
|
5 |
+
import time
|
6 |
+
|
7 |
+
from faster_whisper import WhisperModel
|
8 |
+
|
9 |
+
from faster_whisper_server.config import (
|
10 |
+
config,
|
11 |
+
)
|
12 |
+
from faster_whisper_server.logger import logger
|
13 |
+
|
14 |
+
|
15 |
+
class ModelManager:
|
16 |
+
def __init__(self) -> None:
|
17 |
+
self.loaded_models: OrderedDict[str, WhisperModel] = OrderedDict()
|
18 |
+
|
19 |
+
def load_model(self, model_name: str) -> WhisperModel:
|
20 |
+
if model_name in self.loaded_models:
|
21 |
+
logger.debug(f"{model_name} model already loaded")
|
22 |
+
return self.loaded_models[model_name]
|
23 |
+
if len(self.loaded_models) >= config.max_models:
|
24 |
+
oldest_model_name = next(iter(self.loaded_models))
|
25 |
+
logger.info(f"Max models ({config.max_models}) reached. Unloading the oldest model: {oldest_model_name}")
|
26 |
+
del self.loaded_models[oldest_model_name]
|
27 |
+
gc.collect()
|
28 |
+
logger.debug(f"Loading {model_name}...")
|
29 |
+
start = time.perf_counter()
|
30 |
+
# NOTE: will raise an exception if the model name isn't valid. Should I do an explicit check?
|
31 |
+
whisper = WhisperModel(
|
32 |
+
model_name,
|
33 |
+
device=config.whisper.inference_device,
|
34 |
+
device_index=config.whisper.device_index,
|
35 |
+
compute_type=config.whisper.compute_type,
|
36 |
+
cpu_threads=config.whisper.cpu_threads,
|
37 |
+
num_workers=config.whisper.num_workers,
|
38 |
+
)
|
39 |
+
logger.info(
|
40 |
+
f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {config.whisper.inference_device}({config.whisper.compute_type}) will be used for inference." # noqa: E501
|
41 |
+
)
|
42 |
+
self.loaded_models[model_name] = whisper
|
43 |
+
return whisper
|