Fedir Zadniprovskyi
feat: model unloading
35eafc3
raw
history blame
5.07 kB
from __future__ import annotations
from collections import OrderedDict
import gc
import logging
import threading
import time
from typing import TYPE_CHECKING
from faster_whisper import WhisperModel
if TYPE_CHECKING:
from collections.abc import Callable
from faster_whisper_server.config import (
WhisperConfig,
)
logger = logging.getLogger(__name__)
# TODO: enable concurrent model downloads
class SelfDisposingWhisperModel:
def __init__(
self,
model_id: str,
whisper_config: WhisperConfig,
*,
on_unload: Callable[[str], None] | None = None,
) -> None:
self.model_id = model_id
self.whisper_config = whisper_config
self.on_unload = on_unload
self.ref_count: int = 0
self.rlock = threading.RLock()
self.expire_timer: threading.Timer | None = None
self.whisper: WhisperModel | None = None
def unload(self) -> None:
with self.rlock:
if self.whisper is None:
raise ValueError(f"Model {self.model_id} is not loaded. {self.ref_count=}")
if self.ref_count > 0:
raise ValueError(f"Model {self.model_id} is still in use. {self.ref_count=}")
if self.expire_timer:
self.expire_timer.cancel()
self.whisper = None
# WARN: ~300 MB of memory will still be held by the model. See https://github.com/SYSTRAN/faster-whisper/issues/992
gc.collect()
logger.info(f"Model {self.model_id} unloaded")
if self.on_unload is not None:
self.on_unload(self.model_id)
def _load(self) -> None:
with self.rlock:
assert self.whisper is None
logger.debug(f"Loading model {self.model_id}")
start = time.perf_counter()
self.whisper = WhisperModel(
self.model_id,
device=self.whisper_config.inference_device,
device_index=self.whisper_config.device_index,
compute_type=self.whisper_config.compute_type,
cpu_threads=self.whisper_config.cpu_threads,
num_workers=self.whisper_config.num_workers,
)
logger.info(f"Model {self.model_id} loaded in {time.perf_counter() - start:.2f}s")
def _increment_ref(self) -> None:
with self.rlock:
self.ref_count += 1
if self.expire_timer:
logger.debug(f"Model was set to expire in {self.expire_timer.interval}s, cancelling")
self.expire_timer.cancel()
logger.debug(f"Incremented ref count for {self.model_id}, {self.ref_count=}")
def _decrement_ref(self) -> None:
with self.rlock:
self.ref_count -= 1
logger.debug(f"Decremented ref count for {self.model_id}, {self.ref_count=}")
if self.ref_count <= 0:
if self.whisper_config.ttl > 0:
logger.info(f"Model {self.model_id} is idle, scheduling offload in {self.whisper_config.ttl}s")
self.expire_timer = threading.Timer(self.whisper_config.ttl, self.unload)
self.expire_timer.start()
elif self.whisper_config.ttl == 0:
logger.info(f"Model {self.model_id} is idle, unloading immediately")
self.unload()
else:
logger.info(f"Model {self.model_id} is idle, not unloading")
def __enter__(self) -> WhisperModel:
with self.rlock:
if self.whisper is None:
self._load()
self._increment_ref()
assert self.whisper is not None
return self.whisper
def __exit__(self, *_args) -> None: # noqa: ANN002
self._decrement_ref()
class ModelManager:
def __init__(self, whisper_config: WhisperConfig) -> None:
self.whisper_config = whisper_config
self.loaded_models: OrderedDict[str, SelfDisposingWhisperModel] = OrderedDict()
self._lock = threading.Lock()
def _handle_model_unload(self, model_name: str) -> None:
with self._lock:
if model_name in self.loaded_models:
del self.loaded_models[model_name]
def unload_model(self, model_name: str) -> None:
with self._lock:
model = self.loaded_models.get(model_name)
if model is None:
raise KeyError(f"Model {model_name} not found")
self.loaded_models[model_name].unload()
def load_model(self, model_name: str) -> SelfDisposingWhisperModel:
with self._lock:
if model_name in self.loaded_models:
logger.debug(f"{model_name} model already loaded")
return self.loaded_models[model_name]
self.loaded_models[model_name] = SelfDisposingWhisperModel(
model_name,
self.whisper_config,
on_unload=self._handle_model_unload,
)
return self.loaded_models[model_name]