Spaces:
Configuration error
Configuration error
File size: 5,071 Bytes
624f97e bf48682 35eafc3 624f97e bf48682 624f97e bf48682 35eafc3 bf48682 35eafc3 bf48682 624f97e 35eafc3 624f97e 35eafc3 624f97e 35eafc3 bf48682 35eafc3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
from __future__ import annotations
from collections import OrderedDict
import gc
import logging
import threading
import time
from typing import TYPE_CHECKING
from faster_whisper import WhisperModel
if TYPE_CHECKING:
from collections.abc import Callable
from faster_whisper_server.config import (
WhisperConfig,
)
logger = logging.getLogger(__name__)
# TODO: enable concurrent model downloads
class SelfDisposingWhisperModel:
def __init__(
self,
model_id: str,
whisper_config: WhisperConfig,
*,
on_unload: Callable[[str], None] | None = None,
) -> None:
self.model_id = model_id
self.whisper_config = whisper_config
self.on_unload = on_unload
self.ref_count: int = 0
self.rlock = threading.RLock()
self.expire_timer: threading.Timer | None = None
self.whisper: WhisperModel | None = None
def unload(self) -> None:
with self.rlock:
if self.whisper is None:
raise ValueError(f"Model {self.model_id} is not loaded. {self.ref_count=}")
if self.ref_count > 0:
raise ValueError(f"Model {self.model_id} is still in use. {self.ref_count=}")
if self.expire_timer:
self.expire_timer.cancel()
self.whisper = None
# WARN: ~300 MB of memory will still be held by the model. See https://github.com/SYSTRAN/faster-whisper/issues/992
gc.collect()
logger.info(f"Model {self.model_id} unloaded")
if self.on_unload is not None:
self.on_unload(self.model_id)
def _load(self) -> None:
with self.rlock:
assert self.whisper is None
logger.debug(f"Loading model {self.model_id}")
start = time.perf_counter()
self.whisper = WhisperModel(
self.model_id,
device=self.whisper_config.inference_device,
device_index=self.whisper_config.device_index,
compute_type=self.whisper_config.compute_type,
cpu_threads=self.whisper_config.cpu_threads,
num_workers=self.whisper_config.num_workers,
)
logger.info(f"Model {self.model_id} loaded in {time.perf_counter() - start:.2f}s")
def _increment_ref(self) -> None:
with self.rlock:
self.ref_count += 1
if self.expire_timer:
logger.debug(f"Model was set to expire in {self.expire_timer.interval}s, cancelling")
self.expire_timer.cancel()
logger.debug(f"Incremented ref count for {self.model_id}, {self.ref_count=}")
def _decrement_ref(self) -> None:
with self.rlock:
self.ref_count -= 1
logger.debug(f"Decremented ref count for {self.model_id}, {self.ref_count=}")
if self.ref_count <= 0:
if self.whisper_config.ttl > 0:
logger.info(f"Model {self.model_id} is idle, scheduling offload in {self.whisper_config.ttl}s")
self.expire_timer = threading.Timer(self.whisper_config.ttl, self.unload)
self.expire_timer.start()
elif self.whisper_config.ttl == 0:
logger.info(f"Model {self.model_id} is idle, unloading immediately")
self.unload()
else:
logger.info(f"Model {self.model_id} is idle, not unloading")
def __enter__(self) -> WhisperModel:
with self.rlock:
if self.whisper is None:
self._load()
self._increment_ref()
assert self.whisper is not None
return self.whisper
def __exit__(self, *_args) -> None: # noqa: ANN002
self._decrement_ref()
class ModelManager:
def __init__(self, whisper_config: WhisperConfig) -> None:
self.whisper_config = whisper_config
self.loaded_models: OrderedDict[str, SelfDisposingWhisperModel] = OrderedDict()
self._lock = threading.Lock()
def _handle_model_unload(self, model_name: str) -> None:
with self._lock:
if model_name in self.loaded_models:
del self.loaded_models[model_name]
def unload_model(self, model_name: str) -> None:
with self._lock:
model = self.loaded_models.get(model_name)
if model is None:
raise KeyError(f"Model {model_name} not found")
self.loaded_models[model_name].unload()
def load_model(self, model_name: str) -> SelfDisposingWhisperModel:
with self._lock:
if model_name in self.loaded_models:
logger.debug(f"{model_name} model already loaded")
return self.loaded_models[model_name]
self.loaded_models[model_name] = SelfDisposingWhisperModel(
model_name,
self.whisper_config,
on_unload=self._handle_model_unload,
)
return self.loaded_models[model_name]
|